diff --git a/.gitignore b/.gitignore index 68380e97ab92a0632675a709836d19be669de89d..9db2912c07bc2d6abb01c322a25519ac0ff158fa 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ .DS_Store build/ +build_fpga/ .idea/ @@ -71,6 +72,9 @@ build cmake-build-debug cmake-build-release +# vscode +.vscode + # ios tools/libomp.a diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a7235a9d653b0da544a006dda6f9a9c957364f4..f83d99d862ca22508996ebdbf52f4b2fe1a57cc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: entry: bash ./tools/codestyle/cpplint_pre_commit.hook language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ - exclude: ^(mobile/|metal/|web/) + exclude: ^(mobile/) | ^(metal/) | ^(web/) #- repo: local #hooks: #- id: pylint-doc-string diff --git a/CMakeLists.txt b/CMakeLists.txt index 03275b1a8d9943f66246463ea80081dc6bc6b0db..1ec5352fa4009144b9f572ecbe061aba11e884d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,33 +47,19 @@ include(simd) ################################ Exposed Configurations ####################################### lite_option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) lite_option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ON IF ${AVX_FOUND}) -lite_option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) lite_option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) lite_option(WITH_MKL "Compile PaddlePaddle with MKL support." ON IF ${AVX_FOUND}) lite_option(WITH_ARM_DOTPROD "Compile PaddlePaddle with ARM dot production" ON) lite_option(WITH_SYSTEM_BLAS "Use system blas library" OFF) -# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter. -if(ANDROID OR IOS OR ARMLINUX) - set(WITH_GPU OFF CACHE STRING - "Disable GPU when cross-compiling for Android and iOS" FORCE) - set(WITH_DSO OFF CACHE STRING - "Disable DSO when cross-compiling for Android and iOS" FORCE) - set(WITH_AVX OFF CACHE STRING - "Disable AVX when cross-compiling for Android and iOS" FORCE) - set(WITH_PYTHON OFF CACHE STRING - "Disable PYTHON when cross-compiling for Android and iOS" FORCE) - set(WITH_RDMA OFF CACHE STRING - "Disable RDMA when cross-compiling for Android and iOS" FORCE) - set(WITH_MKL OFF CACHE STRING - "Disable MKL when cross-compiling for Android and iOS" FORCE) -endif() # for lite, both server and mobile framework. lite_option(LITE_WITH_JAVA "Enable Java JNI lib in lite mode" OFF) +lite_option(LITE_WITH_PYTHON "Enable Python api lib in lite mode" OFF) lite_option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF) lite_option(LITE_WITH_X86 "Enable X86 in lite mode" ON) lite_option(LITE_WITH_ARM "Enable ARM in lite mode" OFF) lite_option(LITE_WITH_NPU "Enable NPU in lite mode" OFF) +lite_option(LITE_WITH_XPU "Enable XPU in lite mode" OFF) lite_option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON) lite_option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF) lite_option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF) @@ -82,8 +68,29 @@ lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF IF LITE_WITH_PROFILE) lite_option(LITE_SHUTDOWN_LOG "Shutdown log system or not." OFF) lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) +lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF) # publish options lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kernels and operators" OFF) +lite_option(LITE_BUILD_TAILOR "Enable tailoring library according to model" OFF) + +# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter. +if(ANDROID OR IOS OR ARMLINUX) + set(WITH_GPU OFF CACHE STRING + "Disable GPU when cross-compiling for Android and iOS" FORCE) + set(WITH_DSO OFF CACHE STRING + "Disable DSO when cross-compiling for Android and iOS" FORCE) + set(WITH_AVX OFF CACHE STRING + "Disable AVX when cross-compiling for Android and iOS" FORCE) + set(WITH_RDMA OFF CACHE STRING + "Disable RDMA when cross-compiling for Android and iOS" FORCE) + set(WITH_MKL OFF CACHE STRING + "Disable MKL when cross-compiling for Android and iOS" FORCE) +endif() + +if(ANDROID OR IOS) + set(LITE_WITH_PYTHON OFF CACHE STRING + "Disable PYTHON when cross-compiling for Android and iOS" FORCE) +endif() set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING "A path setting third party libraries download & build directories.") @@ -94,6 +101,7 @@ if(NOT CMAKE_BUILD_TYPE) "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE) endif() +message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") # check options if (LITE_ON_TINY_PUBLISH) @@ -104,6 +112,15 @@ if (LITE_ON_TINY_PUBLISH) endif() include_directories("${PADDLE_SOURCE_DIR}") +# the generated header files. +set(LITE_GENERATED_INCLUDE_DIR "${CMAKE_BINARY_DIR}") +include_directories("${LITE_GENERATED_INCLUDE_DIR}") + +if (LITE_WITH_PYTHON) + include(external/python) # download, build, install python + include(external/pybind11) # download, build, install pybind11 +endif() + # for mobile if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -168,10 +185,15 @@ if(LITE_WITH_CUDA) include(cuda) endif() +if(LITE_WITH_XPU) + include(xpu) +endif() + include(generic) # simplify cmake module include(ccache) # set ccache for compilation include(util) # set unittest and link libs include(version) # set PADDLE_VERSION +include(flags) set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") diff --git a/README.md b/README.md index e32840a21dba66cc698b47ff7ee6436ab2b0124b..23974beee9a8af5ee7e2c454575efff2e3d96ee2 100644 --- a/README.md +++ b/README.md @@ -3,14 +3,14 @@ # Paddle Lite -[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](https://github.com/PaddlePaddle/Paddle-Lite/wiki) +[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](https://paddlepaddle.github.io/Paddle-Lite/) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) Paddle Lite is an updated version of Paddle-Mobile, an open-open source deep learning framework designed to make it easy to perform inference on mobile, embeded, and IoT devices. It is compatible with PaddlePaddle and pre-trained models from other sources. -For tutorials, please see [PaddleLite Wiki](https://github.com/PaddlePaddle/Paddle-Lite/wiki). +For tutorials, please see [PaddleLite Document](https://paddlepaddle.github.io/Paddle-Lite/). ## Key Features @@ -30,7 +30,7 @@ It also supports INT8 quantizations with [PaddleSlim model compression tools](ht On Huawei NPU and FPGA, the performance is also boosted. -The latest benchmark is located at [benchmark](https://github.com/PaddlePaddle/Paddle-Lite/wiki/benchmark) +The latest benchmark is located at [benchmark](https://paddlepaddle.github.io/Paddle-Lite/develop/benchmark/) ### High Compatibility diff --git a/README_cn.md b/README_cn.md index d2111786b13b6d2b1ee25a5678809a9097e39466..99d38c47ffbbaa3b8593801701e3528167899f97 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,13 +1,13 @@ # Paddle Lite -[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](https://github.com/PaddlePaddle/Paddle-Lite/wiki) +[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](https://paddlepaddle.github.io/Paddle-Lite/) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) Paddle Lite为Paddle-Mobile的升级版,定位支持包括手机移动端在内更多场景的轻量化高效预测,支持更广泛的硬件和平台,是一个高性能、轻量级的深度学习预测引擎。在保持和PaddlePaddle无缝对接外,也兼容支持其他训练框架产出的模型。 -完整使用文档位于 [PaddleLite Wiki](https://github.com/PaddlePaddle/Paddle-Lite/wiki) 。 +完整使用文档位于 [PaddleLite 文档](https://paddlepaddle.github.io/Paddle-Lite/) 。 ## 特性 @@ -21,7 +21,7 @@ Paddle Lite为Paddle-Mobile的升级版,定位支持包括手机移动端在 支持INT8量化计算,结合 [PaddleSlim 模型压缩工具](https://github.com/PaddlePaddle/models/tree/v1.5/PaddleSlim) 中 INT8量化训练功能,可以提供高精度高性能的预测能力。 在Huawei NPU, FPGA上也具有有很好的性能表现。 -最新 Benchmark 位于 [benchmark](https://github.com/PaddlePaddle/Paddle-Lite/wiki/benchmark)。 +最新 Benchmark 位于 [benchmark](https://paddlepaddle.github.io/Paddle-Lite/develop/benchmark/)。 ### 通用性 硬件方面,Paddle Lite 的架构设计为多硬件兼容支持做了良好设计。除了支持ARM CPU、Mali GPU、Adreno GPU,还特别支持了华为 NPU,以及 FPGA 等边缘设备广泛使用的硬件。即将支持支持包括寒武纪、比特大陆等AI芯片,未来会增加对更多硬件的支持。 diff --git a/cmake/configure.cmake b/cmake/configure.cmake index b919c147c7064f39e964b0d30e522303168c291b..5dbb7f3fca4a2ecdab943cd49f34ee97f9bac9b0 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -34,33 +34,6 @@ elseif(SSE3_FOUND) set(SIMD_FLAG ${SSE3_FLAG}) endif() -if(WIN32) - # windows header option for all targets. - add_definitions(-D_XKEYCHECK_H) - # Use symbols instead of absolute path, reduce the cmake link command length. - SET(CMAKE_C_USE_RESPONSE_FILE_FOR_LIBRARIES 1) - SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_LIBRARIES 1) - SET(CMAKE_C_USE_RESPONSE_FILE_FOR_OBJECTS 1) - SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_OBJECTS 1) - SET(CMAKE_C_USE_RESPONSE_FILE_FOR_INCLUDES 1) - SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_INCLUDES 1) - SET(CMAKE_C_RESPONSE_FILE_LINK_FLAG "@") - SET(CMAKE_CXX_RESPONSE_FILE_LINK_FLAG "@") - - # Specify the program to use when building static libraries - SET(CMAKE_C_CREATE_STATIC_LIBRARY " lib ") - SET(CMAKE_CXX_CREATE_STATIC_LIBRARY " lib ") - - # set defination for the dll export - if (NOT MSVC) - message(FATAL "Windows build only support msvc. Which was binded by the nvcc compiler of NVIDIA.") - endif(NOT MSVC) -endif(WIN32) - -if(WITH_PSLIB) - add_definitions(-DPADDLE_WITH_PSLIB) -endif() - if(LITE_WITH_CUDA) add_definitions(-DLITE_WITH_CUDA) add_definitions(-DEIGEN_USE_GPU) @@ -154,6 +127,10 @@ if (LITE_WITH_NPU) add_definitions("-DLITE_WITH_NPU") endif() +if (LITE_WITH_XPU) + add_definitions("-DLITE_WITH_XPU") +endif() + if (LITE_WITH_OPENCL) add_definitions("-DLITE_WITH_OPENCL") endif() @@ -180,3 +157,8 @@ endif() if (LITE_ON_TINY_PUBLISH) add_definitions("-DLITE_ON_TINY_PUBLISH") endif() + +if (LITE_ON_MODEL_OPTIMIZE_TOOL) + add_definitions("-DLITE_ON_MODEL_OPTIMIZE_TOOL") +endif(LITE_ON_MODEL_OPTIMIZE_TOOL) + diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake index 11a803ff031706a10f282f21024915be68444546..4fc59ccd62671c5862a298832b1ec03d4e96d05a 100644 --- a/cmake/cross_compiling/android.cmake +++ b/cmake/cross_compiling/android.cmake @@ -18,6 +18,7 @@ endif() set(ANDROID TRUE) add_definitions(-DLITE_WITH_LINUX) +add_definitions(-DLITE_WITH_ANDROID) if(NOT DEFINED ANDROID_NDK) set(ANDROID_NDK $ENV{NDK_ROOT}) @@ -32,7 +33,10 @@ if(ARM_TARGET_LANG STREQUAL "gcc") endif() if(NOT DEFINED ANDROID_API_LEVEL) - set(ANDROID_API_LEVEL "22") + set(ANDROID_API_LEVEL "23") + if(ARM_TARGET_ARCH_ABI STREQUAL "armv7") + set(ANDROID_API_LEVEL "22") + endif() endif() # then check input arm abi diff --git a/cmake/cross_compiling/npu.cmake b/cmake/cross_compiling/npu.cmake index 863200986c93ea09d3fa3049fe684b32c2fb52dd..25aa4d2bc8c1c145e7a103c9164e1c9e231a8f9e 100644 --- a/cmake/cross_compiling/npu.cmake +++ b/cmake/cross_compiling/npu.cmake @@ -50,9 +50,6 @@ find_library(NPU_DDK_IR_FILE NAMES hiai_ir find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH}) -find_library(NPU_DDK_PROTO_FILE NAMES protobuf-lite - PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH}) - if(NOT NPU_DDK_HIAI_FILE) message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}") else() @@ -77,14 +74,8 @@ else() set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE}) endif() -if(NOT NPU_DDK_PROTO_FILE) - message(FATAL_ERROR "Can not find NPU_DDK_PROTO_FILE in ${NPU_DDK_ROOT}") -else() - message(STATUS "Found NPU_DDK Protobuf Library: ${NPU_DDK_PROTO_FILE}") - add_library(npu_ddk_proto SHARED IMPORTED GLOBAL) - set_property(TARGET npu_ddk_proto PROPERTY IMPORTED_LOCATION ${NPU_DDK_PROTO_FILE}) -endif() +set(npu_runtime_libs npu_ddk_hiai CACHE INTERNAL "npu ddk runtime libs") +set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs") -set(npu_ddk_libs npu_ddk_hiai npu_ddk_ir npu_ddk_ir_build npu_ddk_proto CACHE INTERNAL "npu ddk libs") diff --git a/cmake/cross_compiling/postproject.cmake b/cmake/cross_compiling/postproject.cmake index 33254df03c43c2648fb33effe491e5956edf60a9..88ac3e101a686cb49ef5a4c3b1879c15b8f7b57b 100644 --- a/cmake/cross_compiling/postproject.cmake +++ b/cmake/cross_compiling/postproject.cmake @@ -26,6 +26,8 @@ if(ANDROID) endif() if(ARMLINUX) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") if(ARMLINUX_ARCH_ABI STREQUAL "armv8") set(CMAKE_CXX_FLAGS "-march=armv8-a ${CMAKE_CXX_FLAGS}") set(CMAKE_C_FLAGS "-march=armv8-a ${CMAKE_C_FLAGS}") @@ -57,7 +59,10 @@ function(check_linker_flag) endfunction() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") if (LITE_ON_TINY_PUBLISH) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fno-exceptions -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables") + if(NOT LITE_WITH_PYTHON) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto -fvisibility=hidden -fvisibility-inlines-hidden -fdata-sections -ffunction-sections") check_linker_flag(-Wl,--gc-sections) endif() diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 1e6f34a62129e2ca0a717ceb489d98b56b78d47a..9ff908a4c87d55e87468a06ae0e6085ac165a1b1 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -4,9 +4,9 @@ endif() set(paddle_known_gpu_archs "30 35 50 52 60 61 70") set(paddle_known_gpu_archs7 "30 35 50 52") -set(paddle_known_gpu_archs8 "30 35 50 52 60 61") -set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") -set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") +set(paddle_known_gpu_archs8 "30 35 50 52 53 60 61 62") +set(paddle_known_gpu_archs9 "30 35 50 52 53 60 61 62 70") +set(paddle_known_gpu_archs10 "30 35 50 52 53 60 61 62 70 72 75") ###################################################################################### # A function for automatic detection of GPUs installed (if autodetection is enabled) @@ -174,6 +174,16 @@ if(NOT WITH_DSO) endif(WIN32) endif(NOT WITH_DSO) +get_filename_component(CUDA_LIB_PATH ${CUDA_curand_LIBRARY} DIRECTORY) +function(import_static_library alias path) + add_library(${alias} STATIC IMPORTED GLOBAL) + set_property(TARGET ${alias} PROPERTY IMPORTED_LOCATION ${path}) +endfunction() +import_static_library(cudart_static ${CUDA_LIB_PATH}/libcudart_static.a) +import_static_library(cublas_static ${CUDA_LIB_PATH}/libcublas_static.a) +import_static_library(curand_static ${CUDA_LIB_PATH}/libcurand_static.a) +import_static_library(culibos_static ${CUDA_LIB_PATH}/libculibos.a) + # setting nvcc arch flags select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 3775d6cc2bdaa617f225b4cff9a03092bd9a19cc..842b94d47e75b4bab577a1150cb3d198eb42ebaf 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -34,6 +34,14 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 ) + +if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0)) + find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH) + set(CUBLAS_LIBRARIES ${CUBLAS_LIBRARY}) +else() + set(CUBLAS_LIBRARIES ${CUDA_CUBLAS_LIBRARIES}) +endif() + set(CUDNN_LIB_NAME "libcudnn.so") if(WIN32) @@ -45,11 +53,10 @@ if(APPLE) set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so") endif(APPLE) -find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a +find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist} NO_DEFAULT_PATH - DOC "Path to cuDNN library.") - + DOC "Path to cuDNN dynamic library.") if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) set(CUDNN_FOUND ON) @@ -61,6 +68,9 @@ if(CUDNN_FOUND) file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY) + add_library(cudnn_static STATIC IMPORTED GLOBAL) + set_property(TARGET cudnn_static PROPERTY IMPORTED_LOCATION + "${CUDNN_LIB_PATH}/libcudnn_static.a") string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)" CUDNN_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 84be88226f4144c30840fe5a37d35d54b357630c..76cc7b21deab41a40869a68df3a4dce359177c21 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -109,8 +109,7 @@ macro(PROMPT_PROTOBUF_LIB) ADD_LIBRARY(protobuf ${protobuf_LIBTYPE} IMPORTED GLOBAL) SET_PROPERTY(TARGET protobuf PROPERTY IMPORTED_LOCATION ${PROTOBUF_LIBRARY}) - - ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) +ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) @@ -185,6 +184,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) SET(SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/protobuf-host") IF(BUILD_FOR_HOST) + # set for server compile. + if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + set(HOST_C_COMPILER "${CMAKE_C_COMPILER}") + set(HOST_CXX_COMPILER "${CMAKE_CXX_COMPILER}") + endif() + SET(OPTIONAL_ARGS "-DCMAKE_C_COMPILER=${HOST_C_COMPILER}" "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}" @@ -247,6 +252,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) GIT_REPOSITORY "" GIT_TAG ${PROTOBUF_TAG} SOURCE_DIR ${SOURCE_DIR} + BUILD_ALWAYS 1 CONFIGURE_COMMAND ${CMAKE_COMMAND} ${SOURCE_DIR}/cmake ${OPTIONAL_ARGS} -Dprotobuf_BUILD_TESTS=OFF @@ -276,7 +282,11 @@ IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ENDIF() IF(NOT PROTOBUF_FOUND) - build_protobuf(extern_protobuf FALSE) + if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + build_protobuf(extern_protobuf FALSE) + else() + build_protobuf(extern_protobuf TRUE) + endif() SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR} CACHE PATH "protobuf include directory." FORCE) diff --git a/cmake/external/pybind11.cmake b/cmake/external/pybind11.cmake new file mode 100644 index 0000000000000000000000000000000000000000..df8562dff531bc7effbc3978a97fcaabacdce02b --- /dev/null +++ b/cmake/external/pybind11.cmake @@ -0,0 +1,46 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT LITE_WITH_PYTHON) + return() +endif() + +include(ExternalProject) + +set(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind) + +include_directories(${PYBIND_SOURCE_DIR}/src/extern_pybind/include) + +ExternalProject_Add( + extern_pybind + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/pybind/pybind11.git" + GIT_TAG "v2.2.4" + PREFIX ${PYBIND_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) + +if(${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/pybind_dummy.c) + file(WRITE ${dummyfile} "const char * dummy_pybind = \"${dummyfile}\";") + add_library(pybind STATIC ${dummyfile}) +else() + add_library(pybind INTERFACE) +endif() + +add_dependencies(pybind extern_pybind) diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake new file mode 100644 index 0000000000000000000000000000000000000000..ae99f4df9a3676ae8f5b2c4c01305ead9b7a8254 --- /dev/null +++ b/cmake/external/python.cmake @@ -0,0 +1,83 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IF(NOT LITE_WITH_PYTHON) + return() +ENDIF() + +INCLUDE(python_module) + +FIND_PACKAGE(PythonInterp ${PY_VERSION} REQUIRED) +FIND_PACKAGE(PythonLibs ${PY_VERSION} REQUIRED) + +if(WIN32) + execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" +"from distutils import sysconfig as s;import sys;import struct; +print(sys.prefix); +print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION')); +" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE _PYTHON_VALUES + ERROR_VARIABLE _PYTHON_ERROR_VALUE) + + if(NOT _PYTHON_SUCCESS MATCHES 0) + set(PYTHONLIBS_FOUND FALSE) + return() + endif() + + # Convert the process output into a list + string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES}) + string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES}) + list(GET _PYTHON_VALUES 0 PYTHON_PREFIX) + list(GET _PYTHON_VALUES 1 PYTHON_LIBRARY_SUFFIX) + + # Make sure all directory separators are '/' + string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX}) + + set(PYTHON_LIBRARY + "${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib") + + # when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the + # original python installation. They may be found relative to PYTHON_INCLUDE_DIR. + if(NOT EXISTS "${PYTHON_LIBRARY}") + get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY) + set(PYTHON_LIBRARY + "${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib") + endif() + + # raise an error if the python libs are still not found. + if(NOT EXISTS "${PYTHON_LIBRARY}") + message(FATAL_ERROR "Python libraries not found") + endif() + SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}") +endif(WIN32) + +# Fixme: Maybe find a static library. Get SHARED/STATIC by FIND_PACKAGE. +ADD_LIBRARY(python SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET python PROPERTY IMPORTED_LOCATION ${PYTHON_LIBRARIES}) + +SET(py_env "") +IF(PYTHONINTERP_FOUND) + find_python_module(pip REQUIRED) + find_python_module(numpy REQUIRED) + #find_python_module(wheel REQUIRED) + #find_python_module(google.protobuf REQUIRED) + FIND_PACKAGE(NumPy REQUIRED) + #IF(${PY_GOOGLE.PROTOBUF_VERSION} AND ${PY_GOOGLE.PROTOBUF_VERSION} VERSION_LESS "3.0.0") + # MESSAGE(FATAL_ERROR "Found Python Protobuf ${PY_GOOGLE.PROTOBUF_VERSION} < 3.0.0, " + # "please use pip to upgrade protobuf. pip install -U protobuf") + #ENDIF() +ENDIF(PYTHONINTERP_FOUND) +INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) +INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR}) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 36b533aa4f7815896fb48c33fefad892b8d0d29c..903c70fbbff285bc90697281f9703b544fd00186 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -146,8 +146,11 @@ set(GPU_COMMON_FLAGS -Wno-error=unused-local-typedefs -Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=array-bounds # Warnings in Eigen::array + -gencode arch=compute_62,code=sm_62 ) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +if(NOT LITE_WITH_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +endif() endif(NOT WIN32) if (APPLE) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 9a4a278668fe6f83520b86d771f67dd7acac44ec..415eb451a986cd7e59829b9a8f2c744ecf464bd6 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -105,8 +105,8 @@ set_property(GLOBAL PROPERTY FLUID_MODULES "") function(find_fluid_modules TARGET_NAME) get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path}) - string(FIND "${__target_path}" "fluid" pos) - if(pos GREATER 1) + string(FIND "${__target_path}" "lite" pos) + if((pos GREATER 0) OR (pos EQUAL 0)) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) set(fluid_modules ${fluid_modules} ${TARGET_NAME}) set_property(GLOBAL PROPERTY FLUID_MODULES "${fluid_modules}") @@ -303,10 +303,12 @@ function(cc_library TARGET_NAME) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) endif() - if(${source_file} MATCHES "framework.pb.cc") + if(${source_file} MATCHES "__generated_code__.cc") list(APPEND full_path_src ${source_file}) else() - list(APPEND full_path_src ${CMAKE_CURRENT_SOURCE_DIR}/${source_file}) + if(NOT ${source_file} MATCHES "framework.pb.cc" AND NOT ${source_file} MATCHES "__generated_code__.cc") + list(APPEND full_path_src ${CMAKE_CURRENT_SOURCE_DIR}/${source_file}) + endif() endif() endforeach() set(__lite_cc_files ${__lite_cc_files} ${full_path_src} CACHE INTERNAL "") @@ -371,6 +373,7 @@ function(cc_binary TARGET_NAME) endif() get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(${TARGET_NAME} ${os_dependency_modules}) + find_fluid_modules(${TARGET_NAME}) endfunction(cc_binary) function(cc_test TARGET_NAME) @@ -503,17 +506,14 @@ function(nv_test TARGET_NAME) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) - target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog ${os_dependency_modules}) - add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog) + target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest +gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} ) + add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog) common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) if (nv_test_SERIAL) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) endif() - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) endif() endfunction(nv_test) diff --git a/cmake/lite.cmake b/cmake/lite.cmake index 2c839d36e27429672b1098bae4d5cbed16731115..9b6fab3f6261ff13361bda35cfa9cd681075c77d 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -22,7 +22,7 @@ endfunction() function (lite_deps TARGET) set(options "") set(oneValueArgs "") - set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS NPU_DEPS ARGS) + set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS ARGS) cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(deps ${lite_deps_DEPS}) @@ -83,6 +83,12 @@ function (lite_deps TARGET) endforeach(var) endif() + if (LITE_WITH_XPU) + foreach(var ${lite_deps_XPU_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + set(${TARGET} ${deps} PARENT_SCOPE) endfunction() @@ -107,7 +113,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean function(lite_cc_library TARGET) set(options SHARED shared STATIC static MODULE module) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS XPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -118,6 +124,7 @@ function(lite_cc_library TARGET) CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} @@ -126,12 +133,12 @@ function(lite_cc_library TARGET) ) if (args_SHARED OR ARGS_shared) - cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS} SHARED) + cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} SHARED) elseif (args_MODULE OR ARGS_module) add_library(${TARGET} MODULE ${args_SRCS}) add_dependencies(${TARGET} ${deps} ${args_DEPS}) else() - cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) + cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps}) endif() target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) @@ -163,8 +170,17 @@ function(lite_cc_binary TARGET) LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} ) - cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) + cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps}) target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) + if (NOT APPLE) + # strip binary target to reduce size + if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + add_custom_command(TARGET ${TARGET} POST_BUILD + COMMAND "${CMAKE_STRIP}" -s + "${TARGET}" + COMMENT "Strip debug symbols done on final executable file.") + endif() + endif() # collect targets need to compile for lite if (NOT args_EXCLUDE_COMPILE_DEPS) add_dependencies(lite_compile_deps ${TARGET}) @@ -207,6 +223,13 @@ function(lite_cc_test TARGET) HVY_DEPS ${args_HVY_DEPS} ) _lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS}) + # strip binary target to reduce size + if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + add_custom_command(TARGET ${TARGET} POST_BUILD + COMMAND "${CMAKE_STRIP}" -s + "${TARGET}" + COMMENT "Strip debug symbols done on final executable file.") + endif() target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) file(APPEND ${offline_test_registry_file} "${TARGET}\n") @@ -220,11 +243,16 @@ set(arm_kernels CACHE INTERNAL "arm kernels") set(x86_kernels CACHE INTERNAL "x86 kernels") set(fpga_kernels CACHE INTERNAL "fpga kernels") set(npu_kernels CACHE INTERNAL "npu kernels") +set(xpu_kernels CACHE INTERNAL "xpu kernels") set(opencl_kernels CACHE INTERNAL "opencl kernels") set(host_kernels CACHE INTERNAL "host kernels") set(kernels_src_list "${CMAKE_BINARY_DIR}/kernels_src_list.txt") file(WRITE ${kernels_src_list} "") # clean +if(LITE_BUILD_TAILOR) + set(tailored_kernels_list_path "${LITE_OPTMODEL_DIR}/.tailored_kernels_source_list") + file(STRINGS ${tailored_kernels_list_path} tailored_kernels_list) +endif() # add a kernel for some specific device # device: one of (Host, ARM, X86, NPU, FPGA, OPENCL, CUDA) # level: one of (basic, extra) @@ -236,10 +264,34 @@ function(add_kernel TARGET device level) ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if(LITE_BUILD_TAILOR) + foreach(src ${args_SRCS}) + list (FIND tailored_kernels_list ${src} _index) + if (${_index} EQUAL -1) + return() + endif() + endforeach() + endif() + if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if (LITE_ON_MODEL_OPTIMIZE_TOOL) + # the source list will collect for model_optimize_tool to fake kernel generation. + foreach(src ${args_SRCS}) + file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endforeach() + return() + endif() + + # when compiling the model_optimize_tool, a source file with all the fake kernel definitions will be generated, + # no need to continue the compilation of the true kernel source. + if (LITE_ON_MODEL_OPTIMIZE_TOOL) + return() + endif(LITE_ON_MODEL_OPTIMIZE_TOOL) + + if ("${device}" STREQUAL "Host") set(host_kernels "${host_kernels};${TARGET}" CACHE INTERNAL "") endif() @@ -261,6 +313,12 @@ function(add_kernel TARGET device level) endif() set(npu_kernels "${npu_kernels};${TARGET}" CACHE INTERNAL "") endif() + if ("${device}" STREQUAL "XPU") + if (NOT LITE_WITH_XPU) + return() + endif() + set(xpu_kernels "${xpu_kernels};${TARGET}" CACHE INTERNAL "") + endif() if ("${device}" STREQUAL "FPGA") if (NOT LITE_WITH_FPGA) return() @@ -274,6 +332,19 @@ function(add_kernel TARGET device level) set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "") endif() + if ("${device}" STREQUAL "CUDA") + if (NOT LITE_WITH_CUDA) + return() + endif() + set(cuda_kernels "${cuda_kernels};${TARGET}" CACHE INTERNAL "") + foreach(src ${args_SRCS}) + file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endforeach() + nv_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS}) + return() + endif() + + # the source list will collect for paddle_use_kernel.h code generation. foreach(src ${args_SRCS}) file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") endforeach() @@ -281,6 +352,7 @@ function(add_kernel TARGET device level) lite_cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS} X86_DEPS ${args_X86_DEPS} + XPU_DEPS ${args_XPU_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} @@ -294,6 +366,10 @@ endfunction() set(ops CACHE INTERNAL "ops") set(ops_src_list "${CMAKE_BINARY_DIR}/ops_src_list.txt") file(WRITE ${ops_src_list} "") # clean +if(LITE_BUILD_TAILOR) + set(tailored_ops_list_path "${LITE_OPTMODEL_DIR}/.tailored_ops_source_list") + file(STRINGS ${tailored_ops_list_path} tailored_ops_list) +endif() # add an operator # level: one of (basic, extra) function(add_operator TARGET level) @@ -304,19 +380,28 @@ function(add_operator TARGET level) ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() - set(ops "${ops};${TARGET}" CACHE INTERNAL "source") foreach(src ${args_SRCS}) + if(LITE_BUILD_TAILOR) + list(FIND tailored_ops_list ${src} _index) + if (${_index} EQUAL -1) + return() + endif() + endif() file(APPEND ${ops_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") endforeach() + set(ops "${ops};${TARGET}" CACHE INTERNAL "source") + lite_cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS} X86_DEPS ${args_X86_DEPS} + XPU_DEPS ${args_XPU_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} @@ -331,6 +416,8 @@ endfunction() # Bundle several static libraries into one. function(bundle_static_library tgt_name bundled_tgt_name fake_target) list(APPEND static_libs ${tgt_name}) +# for x86 + add_dependencies(lite_compile_deps ${fake_target}) function(_recursively_collect_dependencies input_target) set(_input_link_libraries LINK_LIBRARIES) diff --git a/cmake/python_module.cmake b/cmake/python_module.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1412b7f7f20600acf95a4a899f5e6529c3b67a35 --- /dev/null +++ b/cmake/python_module.cmake @@ -0,0 +1,43 @@ +# Find if a Python module is installed +# Found at http://www.cmake.org/pipermail/cmake/2011-January/041666.html +# To use do: find_python_module(PyQt4 REQUIRED) +function(find_python_module module) + string(TOUPPER ${module} module_upper) + if(NOT PY_${module_upper}) + if(ARGC GREATER 1 AND ARGV1 STREQUAL "REQUIRED") + set(${module}_FIND_REQUIRED TRUE) + else() + set(${module}_FIND_REQUIRED FALSE) + endif() + # A module's location is usually a directory, but for binary modules + # it's a .so file. + execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" + "import re, ${module}; print(re.compile('/__init__.py.*').sub('',${module}.__file__))" + RESULT_VARIABLE _${module}_status + OUTPUT_VARIABLE _${module}_location + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _${module}_status) + set(PY_${module_upper} ${_${module}_location} CACHE STRING + "Location of Python module ${module}") + endif(NOT _${module}_status) + endif(NOT PY_${module_upper}) + find_package_handle_standard_args(PY_${module} DEFAULT_MSG PY_${module_upper}) + if(NOT PY_${module_upper}_FOUND AND ${module}_FIND_REQUIRED) + message(FATAL_ERROR "python module ${module} is not found") + endif() + + execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" + "import sys, ${module}; sys.stdout.write(${module}.__version__)" + OUTPUT_VARIABLE _${module}_version + RESULT_VARIABLE _${module}_status + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _${module}_status) + set(PY_${module_upper}_VERSION ${_${module}_version} CACHE STRING + "Version of Python module ${module}") + endif(NOT _${module}_status) + + set(PY_${module_upper}_FOUND ${PY_${module_upper}_FOUND} PARENT_SCOPE) + set(PY_${module_upper}_VERSION ${PY_${module_upper}_VERSION} PARENT_SCOPE) +endfunction(find_python_module) diff --git a/cmake/xpu.cmake b/cmake/xpu.cmake new file mode 100644 index 0000000000000000000000000000000000000000..8d99343c3041351102820cb20890031fa3f5807e --- /dev/null +++ b/cmake/xpu.cmake @@ -0,0 +1,105 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT LITE_WITH_XPU) + return() +endif() + +if(NOT DEFINED XPU_SDK_ROOT) + set(XPU_SDK_ROOT $ENV{XPU_SDK_ROOT}) + if(NOT XPU_SDK_ROOT) + message(FATAL_ERROR "Must set XPU_SDK_ROOT or env XPU_SDK_ROOT when LITE_WITH_XPU=ON") + endif() +endif() + +message(STATUS "XPU_SDK_ROOT: ${XPU_SDK_ROOT}") +find_path(XPU_SDK_INC NAMES xtcl.h + PATHS ${XPU_SDK_ROOT}/XTCL/include/xtcl NO_DEFAULT_PATH) +if(NOT XPU_SDK_INC) + message(FATAL_ERROR "Can not find xtcl.h in ${XPU_SDK_ROOT}/include") +endif() + +include_directories("${XPU_SDK_ROOT}/XTCL/include") +include_directories("${XPU_SDK_ROOT}/XTDK/include") + +find_library(XPU_SDK_XTCL_FILE NAMES xtcl + PATHS ${XPU_SDK_ROOT}/XTCL/so) + +if(NOT XPU_SDK_XTCL_FILE) + message(FATAL_ERROR "Can not find XPU XTCL Library in ${XPU_SDK_ROOT}") +else() + message(STATUS "Found XPU XTCL Library: ${XPU_SDK_XTCL_FILE}") + add_library(xpu_sdk_xtcl SHARED IMPORTED GLOBAL) + set_property(TARGET xpu_sdk_xtcl PROPERTY IMPORTED_LOCATION ${XPU_SDK_XTCL_FILE}) +endif() + +find_library(XPU_SDK_TVM_FILE NAMES tvm + PATHS ${XPU_SDK_ROOT}/XTCL/so) + +if(NOT XPU_SDK_TVM_FILE) + message(FATAL_ERROR "Can not find XPU TVM Library in ${XPU_SDK_ROOT}") +else() + message(STATUS "Found XPU TVM Library: ${XPU_SDK_TVM_FILE}") + add_library(xpu_sdk_tvm SHARED IMPORTED GLOBAL) + set_property(TARGET xpu_sdk_tvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_TVM_FILE}) +endif() + +find_library(XPU_SDK_XPU_API_FILE NAMES xpuapi + PATHS ${XPU_SDK_ROOT}/XTDK/shlib) + +if(NOT XPU_SDK_XPU_API_FILE) + message(FATAL_ERROR "Can not find XPU API Library in ${XPU_SDK_ROOT}") +else() + message(STATUS "Found XPU API Library: ${XPU_SDK_XPU_API_FILE}") + add_library(xpu_sdk_xpu_api SHARED IMPORTED GLOBAL) + set_property(TARGET xpu_sdk_xpu_api PROPERTY IMPORTED_LOCATION ${XPU_SDK_XPU_API_FILE}) +endif() + +find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt + PATHS ${XPU_SDK_ROOT}/XTDK/shlib) + +if(NOT XPU_SDK_XPU_RT_FILE) + message(FATAL_ERROR "Can not find XPU RT Library in ${XPU_SDK_ROOT}") +else() + message(STATUS "Found XPU RT Library: ${XPU_SDK_XPU_RT_FILE}") + add_library(xpu_sdk_xpu_rt SHARED IMPORTED GLOBAL) + set_property(TARGET xpu_sdk_xpu_rt PROPERTY IMPORTED_LOCATION ${XPU_SDK_XPU_RT_FILE}) +endif() + +find_library(XPU_SDK_XPU_JITC_FILE NAMES xpujitc + PATHS ${XPU_SDK_ROOT}/XTDK/shlib) + +if(NOT XPU_SDK_XPU_JITC_FILE) + message(FATAL_ERROR "Can not find XPU JITC Library in ${XPU_SDK_ROOT}") +else() + message(STATUS "Found XPU JITC Library: ${XPU_SDK_XPU_JITC_FILE}") + add_library(xpu_sdk_xpu_jitc SHARED IMPORTED GLOBAL) + set_property(TARGET xpu_sdk_xpu_jitc PROPERTY IMPORTED_LOCATION ${XPU_SDK_XPU_JITC_FILE}) +endif() + +find_library(XPU_SDK_LLVM_FILE NAMES LLVM-8 + PATHS ${XPU_SDK_ROOT}/XTDK/shlib) + +if(NOT XPU_SDK_LLVM_FILE) + message(FATAL_ERROR "Can not find LLVM Library in ${XPU_SDK_ROOT}") +else() + message(STATUS "Found XPU LLVM Library: ${XPU_SDK_LLVM_FILE}") + add_library(xpu_sdk_llvm SHARED IMPORTED GLOBAL) + set_property(TARGET xpu_sdk_llvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_LLVM_FILE}) +endif() + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_GLOG=1") + +set(xpu_runtime_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu runtime libs") +set(xpu_builder_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu builder libs") diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index cc958f1b59d439e57e1b0ec093ffad9345687476..fa55e27255fcd82a72ac1489741e9e69db1fe933 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -6,12 +6,14 @@ message(STATUS "LITE_WITH_CUDA:\t${LITE_WITH_CUDA}") message(STATUS "LITE_WITH_X86:\t${LITE_WITH_X86}") message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") message(STATUS "LITE_WITH_NPU:\t${LITE_WITH_NPU}") +message(STATUS "LITE_WITH_XPU:\t${LITE_WITH_XPU}") message(STATUS "LITE_WITH_FPGA:\t${LITE_WITH_FPGA}") message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") set(LITE_ON_MOBILE ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}) +add_subdirectory(backends) add_subdirectory(utils) add_subdirectory(operators) add_subdirectory(kernels) @@ -19,7 +21,6 @@ add_subdirectory(core) add_subdirectory(model_parser) add_subdirectory(api) add_subdirectory(fluid) -add_subdirectory(backends) if (NOT LITE_ON_TINY_PUBLISH) add_subdirectory(tests) @@ -44,9 +45,13 @@ if (WITH_TESTING) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz") endif() endif() +# ----------------------------- PUBLISH ----------------------------- +# The final target for publish lite lib +add_custom_target(publish_inference) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) # for publish set(INFER_LITE_PUBLISH_ROOT "${CMAKE_BINARY_DIR}/inference_lite_lib.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}") @@ -56,10 +61,62 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) if (LITE_WITH_NPU) set(INFER_LITE_PUBLISH_ROOT "${INFER_LITE_PUBLISH_ROOT}.npu") endif(LITE_WITH_NPU) - message(STATUS "publish inference lib to ${INFER_LITE_PUBLISH_ROOT}") + if (LITE_WITH_FPGA) + set(INFER_LITE_PUBLISH_ROOT "${INFER_LITE_PUBLISH_ROOT}.fpga") + endif(LITE_WITH_FPGA) +else() + set(INFER_LITE_PUBLISH_ROOT "${CMAKE_BINARY_DIR}/inference_lite_lib") +endif() +message(STATUS "publish inference lib to ${INFER_LITE_PUBLISH_ROOT}") + +# add python lib +if (LITE_WITH_PYTHON) + add_custom_target(publish_inference_python_lib ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/python/lib" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/python/pybind/liblite_pybind.so" "${INFER_LITE_PUBLISH_ROOT}/python/lib/lite_core.so") + add_custom_target(publish_inference_python_light_demo ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/python" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/python/mobilenetv1_light_api.py" "${INFER_LITE_PUBLISH_ROOT}/demo/python/") + if (NOT LITE_ON_TINY_PUBLISH) + add_custom_target(publish_inference_python_full_demo ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/python" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/python/mobilenetv1_full_api.py" "${INFER_LITE_PUBLISH_ROOT}/demo/python/") + add_dependencies(publish_inference publish_inference_python_full_demo) + endif() + add_dependencies(publish_inference_python_lib lite_pybind) + add_dependencies(publish_inference publish_inference_python_lib) + add_dependencies(publish_inference publish_inference_python_light_demo) +endif() + +if (LITE_WITH_X86) + add_custom_target(publish_inference_x86_cxx_lib ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/bin" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_full_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/*.so" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin" + ) + add_dependencies(publish_inference_x86_cxx_lib bundle_full_api) + add_dependencies(publish_inference_x86_cxx_lib bundle_light_api) + add_dependencies(publish_inference_x86_cxx_lib test_model_bin) + add_dependencies(publish_inference_x86_cxx_lib paddle_full_api_shared) + add_dependencies(publish_inference_x86_cxx_lib paddle_light_api_shared) + add_dependencies(publish_inference publish_inference_x86_cxx_lib) + + add_custom_target(publish_inference_x86_cxx_demos ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/third_party" + COMMAND cp -r "${CMAKE_BINARY_DIR}/third_party/install/*" "${INFER_LITE_PUBLISH_ROOT}/third_party" + COMMAND cp -r "${CMAKE_BINARY_DIR}/third_party/eigen3" "${INFER_LITE_PUBLISH_ROOT}/third_party" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + ) + add_dependencies(publish_inference_x86_cxx_lib publish_inference_x86_cxx_demos) + add_dependencies(publish_inference_x86_cxx_demos paddle_full_api_shared eigen3) +endif() - # The final target for publish lite lib - add_custom_target(publish_inference) +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) if (NOT LITE_ON_TINY_PUBLISH) # add cxx lib add_custom_target(publish_inference_cxx_lib ${TARGET} @@ -69,22 +126,28 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_full_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" - COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/model_optimize_tool" "${INFER_LITE_PUBLISH_ROOT}/bin" + #COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/model_optimize_tool" "${INFER_LITE_PUBLISH_ROOT}/bin" COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin" COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin" ) if(NOT IOS) - add_dependencies(publish_inference_cxx_lib model_optimize_tool) + #add_dependencies(publish_inference_cxx_lib model_optimize_tool) add_dependencies(publish_inference_cxx_lib paddle_code_generator) add_dependencies(publish_inference_cxx_lib bundle_full_api) add_dependencies(publish_inference_cxx_lib bundle_light_api) add_dependencies(publish_inference_cxx_lib test_model_bin) + if (ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux") + add_dependencies(publish_inference_cxx_lib paddle_full_api_shared) + add_dependencies(publish_inference paddle_light_api_shared) + add_custom_command(TARGET publish_inference_cxx_lib + COMMAND cp ${CMAKE_BINARY_DIR}/lite/api/*.so ${INFER_LITE_PUBLISH_ROOT}/cxx/lib) + endif() add_dependencies(publish_inference publish_inference_cxx_lib) add_custom_command(TARGET publish_inference_cxx_lib POST_BUILD COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.a) endif() else() - if (IOS OR (ARM_TARGET_OS STREQUAL "armlinux")) + if (IOS) add_custom_target(tiny_publish_lib ${TARGET} COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/lib" COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/include" @@ -93,6 +156,18 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) ) add_dependencies(tiny_publish_lib bundle_light_api) add_dependencies(publish_inference tiny_publish_lib) + else() + if ((ARM_TARGET_OS STREQUAL "android") OR (ARM_TARGET_OS STREQUAL "armlinux")) + add_custom_target(tiny_publish_cxx_lib ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/libpaddle_light_api_shared.so" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + ) + add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared) + add_dependencies(publish_inference tiny_publish_cxx_lib) + endif() endif() endif() @@ -130,6 +205,16 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) ) add_dependencies(publish_inference_android_cxx_demos logging gflags) add_dependencies(publish_inference_cxx_lib publish_inference_android_cxx_demos) + else() + # copy + add_custom_target(publish_inference_android_cxx_demos ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/Makefile.def" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/README.md" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_light" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_light/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_light/Makefile" + ) + add_dependencies(tiny_publish_cxx_lib publish_inference_android_cxx_demos) endif() if (LITE_WITH_JAVA) diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index dc31164c0eed754c6599abd25a46a1b8c83eaea6..bf930ed0e20bf0c1a2e313fd33ad7d87b734c42c 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -4,12 +4,53 @@ else() lite_cc_library(place SRCS paddle_place.cc DEPS glog) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +if (LITE_ON_TINY_PUBLISH) + set(CMAKE_CXX_FLAGS_RELEASE "-Os -DNDEBUG") + set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG") +endif() +set(light_lib_DEPS light_api paddle_api paddle_api_light optimizer) +if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")) + #full api dynamic library + add_library(paddle_full_api_shared SHARED "") + target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc) + add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) + target_link_libraries(paddle_full_api_shared framework_proto) + if(LITE_WITH_X86) + add_dependencies(paddle_full_api_shared xxhash) + target_link_libraries(paddle_full_api_shared xxhash) + endif() + + #light api dynamic library + lite_cc_library(paddle_light_api_shared MODULE + SRCS light_api_shared.cc + DEPS ${light_lib_DEPS} + ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels}) + target_link_libraries(paddle_light_api_shared ${light_lib_DEPS} ${arm_kernels} ${npu_kernels}) + if (LITE_WITH_NPU) + # Strips the symbols of our protobuf functions to fix the conflicts during + # loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so) + set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set_target_properties(paddle_light_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + endif() +else() + if ((ARM_TARGET_OS STREQUAL "android") OR (ARM_TARGET_OS STREQUAL "armlinux")) + add_library(paddle_light_api_shared SHARED "") + target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc) + add_dependencies(paddle_light_api_shared op_list_h kernel_list_h) + if (LITE_WITH_NPU) + # Need to add HIAI runtime libs (libhiai.so) dependency + target_link_libraries(paddle_light_api_shared ${npu_runtime_libs}) + endif() + endif() +endif() + if (WITH_TESTING) lite_cc_library(lite_api_test_helper SRCS lite_api_test_helper.cc DEPS scope optimizer target_wrapper_host model_parser program ${ops} ${host_kernels} CUDA_DEPS ${cuda_kernels} - X86_DEPS ${x86_kernels}) + X86_DEPS ${x86_kernels} + XPU_DEPS ${xpu_kernels}) endif() if(LITE_WITH_FPGA) set(light_api_deps ${light_api_deps} ${fpga_deps}) @@ -21,6 +62,7 @@ message(STATUS "get X86 kernels ${x86_kernels}") message(STATUS "get Host kernels ${host_kernels}") message(STATUS "get ARM kernels ${arm_kernels}") message(STATUS "get NPU kernels ${npu_kernels}") +message(STATUS "get XPU kernels ${xpu_kernels}") message(STATUS "get FPGA kernels ${fpga_kernels}") # for full api @@ -33,6 +75,7 @@ if (NOT LITE_ON_TINY_PUBLISH) X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass + XPU_DEPS ${xpu_kernels} ${xpu_bridges} xpu_pass CL_DEPS ${opencl_kenrels} FPGA_DEPS ${fpga_kenrels}) endif() @@ -42,6 +85,8 @@ set(light_api_deps scope target_wrapper_host model_parser program) if(LITE_WITH_CUDA) set(light_api_deps ${light_api_deps} target_wrapper_cuda) + set(cuda_static_deps cudart_static cublas_static curand_static + cudnn_static culibos_static) endif() lite_cc_library(light_api SRCS light_api.cc DEPS scope target_wrapper_host model_parser @@ -49,7 +94,8 @@ lite_cc_library(light_api SRCS light_api.cc CUDA_DEPS ${cuda_kernels} X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} - NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass + NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kenrels} FPGA_DEPS ${fpga_kenrels}) @@ -64,6 +110,7 @@ if(WITH_TESTING) X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} EXCLUDE_COMPILE_DEPS "ON" @@ -72,25 +119,35 @@ if(WITH_TESTING) add_dependencies(test_cxx_api extern_lite_download_lite_naive_model_tar_gz) if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) lite_cc_test(test_googlenet SRCS test_googlenet_lite.cc - DEPS cxx_api mir_passes lite_api_test_helper + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils ${ops} ${host_kernels} ${x86_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/googlenet) add_dependencies(test_googlenet extern_lite_download_GoogleNet_inference_tar_gz) lite_cc_test(test_mobilenetv1_lite_x86 SRCS test_mobilenetv1_lite_x86.cc - DEPS cxx_api mir_passes lite_api_test_helper + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils ${ops} ${host_kernels} ${x86_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1) add_dependencies(test_mobilenetv1_lite_x86 extern_lite_download_mobilenet_v1_tar_gz) lite_cc_test(test_mobilenetv2_lite_x86 SRCS test_mobilenetv2_lite_x86.cc - DEPS cxx_api mir_passes lite_api_test_helper + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils ${ops} ${host_kernels} ${x86_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v2_relu) add_dependencies(test_mobilenetv2_lite_x86 extern_lite_download_mobilenet_v2_relu_tar_gz) lite_cc_test(test_inceptionv4_lite_x86 SRCS test_inceptionv4_lite_x86.cc - DEPS cxx_api mir_passes lite_api_test_helper + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils ${ops} ${host_kernels} ${x86_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/inception_v4_simple) add_dependencies(test_inceptionv4_lite_x86 extern_lite_download_inception_v4_simple_tar_gz) + lite_cc_test(test_resnet50_lite_x86 SRCS test_resnet50_lite_x86.cc + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) + add_dependencies(test_resnet50_lite_x86 extern_lite_download_resnet50_tar_gz) + lite_cc_test(test_step_rnn_lite_x86 SRCS test_step_rnn_lite_x86.cc + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/step_rnn) + add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz) endif() endif() @@ -150,23 +207,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) # FPGA_DEPS ${fpga_kernels}) endif() -# These tests needs CLI arguments, and is not supported in ARM CI. -# TODO(Superjomn) support latter. -lite_cc_test(test_light_api SRCS light_api_test.cc - DEPS light_api program mir_passes - CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels} - ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) - -lite_cc_test(test_apis SRCS apis_test.cc - DEPS cxx_api light_api ${ops} - CL_DEPS ${opencl_kernels} - X86_DEPS ${x86_kernels} - FPGA_DEPS ${fpga_kernels} - ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model - --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) - -lite_cc_library(paddle_api SRCS paddle_api.cc DEPS op_params tensor) +lite_cc_library(paddle_api SRCS paddle_api.cc DEPS op_params tensor device_info) #----------------------------------------------------------------------------------------------------- # The final inference library for both CxxConfig and MobileConfig. @@ -184,21 +225,53 @@ if (NOT LITE_ON_TINY_PUBLISH) FPGA_DEPS ${fpga_kernels}) # The final inference library for just MobileConfig. bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api) + get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) + cc_library(api_full_static SRCS DEPS paddle_api_full cxx_api paddle_api light_api ${cxx_api_deps} ${ops} ${host_kernels} ${cuda_kernels} program tensor memory naive_buffer types ${fluid_modules} protobuf ${cuda_static_deps}) endif() bundle_static_library(paddle_api_light paddle_api_light_bundled bundle_light_api) #----------------------------------------------------------------------------------------------------- +# These tests needs CLI arguments, and is not supported in ARM CI. +# TODO(Superjomn) support latter. +lite_cc_test(test_light_api SRCS light_api_test.cc + DEPS light_api program mir_passes paddle_api_light + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + +lite_cc_test(test_apis SRCS apis_test.cc + DEPS cxx_api light_api ${ops} paddle_api_light + CL_DEPS ${opencl_kernels} + X86_DEPS ${x86_kernels} + XPU_DEPS ${xpu_kernels} + FPGA_DEPS ${fpga_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model + --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + if (LITE_WITH_JAVA AND LITE_WITH_ARM) add_subdirectory(android) endif() +if (LITE_WITH_PYTHON) + add_subdirectory(python) +endif() + if (LITE_ON_TINY_PUBLISH) return() endif() + +if (LITE_ON_MODEL_OPTIMIZE_TOOL) + message(STATUS "Compiling model_optimize_tool") + lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc + DEPS gflags kernel op optimizer mir_passes utils) + add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc) +endif(LITE_ON_MODEL_OPTIMIZE_TOOL) + lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light ${ops} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} X86_DEPS ${x86_kernels} FPGA_DEPS ${fpga_kernels} @@ -209,17 +282,19 @@ endif() # Some bins if(NOT IOS) - lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags - ${ops} + lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} X86_DEPS ${x86_kernels}) - lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags - ${ops} + lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} X86_DEPS ${x86_kernels}) @@ -229,7 +304,3 @@ endif() #X86_DEPS operator #DEPS light_api model_parser target_wrapper_host mir_passes #ARM_DEPS ${arm_kernels}) NPU_DEPS ${npu_kernels}) - -lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc - DEPS paddle_api_full gflags - CL_DEPS ${opencl_kernels}) diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h deleted file mode 100644 index d22a1ae75e6883e2a147902e30773cb9f797ef58..0000000000000000000000000000000000000000 --- a/lite/api/_paddle_use_kernels.h +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * ATTENTION this header file can only include in .cc file. - */ - -#pragma once -#include "paddle_lite_factory_helper.h" // NOLINT -#ifndef LITE_WITH_FPGA -USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); -USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); -USE_LITE_KERNEL(flatten, kHost, kAny, kAny, def); -USE_LITE_KERNEL(flatten2, kHost, kAny, kAny, def); -#else -USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); -#endif - -// host kernels -USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); -USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); -USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def); - -#ifdef LITE_WITH_ARM -USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(matmul, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(lrn, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(decode_bboxes, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(box_coder, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(elementwise_mul, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(elementwise_max, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(elementwise_div, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(fusion_elementwise_div_activation, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(fusion_elementwise_add_activation, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(fusion_elementwise_mul_activation, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(fusion_elementwise_max_activation, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(split, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(relu6, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(power, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(shuffle_channel, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(yolo_box, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(argmax, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(axpy, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(leaky_relu, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(relu_clipped, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(prelu, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(sigmoid, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(tanh, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(swish, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(log, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(exp, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(conv2d_transpose, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(pad2d, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(prior_box, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(density_prior_box, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(negative, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(crop, kARM, kFloat, kNCHW, def); - -USE_LITE_KERNEL(norm, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(sequence_softmax, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(im2sequence, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(bilinear_interp, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(nearest_interp, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(logical_xor, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(logical_and, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(less_than, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(top_k, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(increment, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(write_to_array, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(read_from_array, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(sequence_expand, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(sequence_pool, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(shape, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(fill_constant, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(cast, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(affine_channel, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(anchor_generator, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle -USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle -USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle -USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(box_clip, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(stack, kARM, kFloat, kNCHW, def) - -USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8); -USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32); -USE_LITE_KERNEL(calib_once, kARM, kInt8, kNCHW, fp32_to_int8); -USE_LITE_KERNEL(calib_once, kARM, kInt8, kNCHW, int8_to_fp32); -USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out); -USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out); -USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out); -USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out); -USE_LITE_KERNEL(gru_unit, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(gru, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(beam_search_decode, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(beam_search, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(while, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(lod_reset, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(is_empty, kARM, kFloat, kNCHW, def) -USE_LITE_KERNEL(assign, kARM, kFloat, kNCHW, def); -#endif - -#ifdef LITE_WITH_X86 -// NOTE all the X86 kernels are disabled temporarily for kernel are changed. -// USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); -USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def); -// USE_LITE_KERNEL(batch_norm, kX86, kFloat, kNCHW, def); -#endif - -#ifdef LITE_WITH_CUDA -USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); -USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); -USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); -USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); -USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); -USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); -USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def); -USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def); -#endif - -#ifdef LITE_WITH_OPENCL -USE_LITE_KERNEL(io_copy, kOpenCL, kAny, kAny, host_to_device); -USE_LITE_KERNEL(io_copy, kOpenCL, kAny, kAny, device_to_host); -USE_LITE_KERNEL(io_copy_once, kOpenCL, kAny, kAny, host_to_device); -USE_LITE_KERNEL(io_copy_once, kOpenCL, kAny, kAny, device_to_host); - -USE_LITE_KERNEL(fc, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(mul, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(fusion_elementwise_add_activation, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(relu, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(conv2d, kOpenCL, kFloat, kNCHW, def); -#endif - -#ifdef LITE_WITH_NPU -USE_LITE_KERNEL(graph_op, kNPU, kFloat, kNCHW, def); -#endif -#ifdef LITE_WITH_FPGA -USE_LITE_KERNEL(relu, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(conv2d, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(elementwise_add, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(fusion_elementwise_add_activation, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(fc, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(pool2d, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(scale, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(softmax, kFPGA, kFP16, kNHWC, def); -USE_LITE_KERNEL(io_copy, kFPGA, kAny, kAny, host_to_device); -USE_LITE_KERNEL(io_copy, kFPGA, kAny, kAny, device_to_host); -USE_LITE_KERNEL(io_copy_once, kFPGA, kAny, kAny, host_to_device_once); -USE_LITE_KERNEL(io_copy_once, kFPGA, kAny, kAny, device_to_host_once); -USE_LITE_KERNEL(calib, kFPGA, kFP16, kNHWC, fp32_to_fp16_fpga); -USE_LITE_KERNEL(calib, kFPGA, kFP16, kNHWC, fp16_to_fp32_fpga); -USE_LITE_KERNEL(calib_once, kFPGA, kFP16, kNHWC, fp32_to_fp16_fpga); -USE_LITE_KERNEL(calib_once, kFPGA, kFP16, kNHWC, fp16_to_fp32_fpga); -USE_LITE_KERNEL(layout, kFPGA, kAny, kNHWC, hwc_to_chw_fpga_fp16); -USE_LITE_KERNEL(layout, kFPGA, kAny, kNHWC, chw_to_hwc_fpga_fp16); -USE_LITE_KERNEL(layout_once, kFPGA, kAny, kNHWC, hwc_to_chw_fpga_fp16); -USE_LITE_KERNEL(layout_once, kFPGA, kAny, kNHWC, chw_to_hwc_fpga_fp16); -#endif diff --git a/lite/api/_paddle_use_ops.h b/lite/api/_paddle_use_ops.h index 94971618963a6e7e3c49b16d71b7eea3c148b424..bdccfab5df67e485b9fef110dc6cc1e9d74b21c3 100644 --- a/lite/api/_paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -21,6 +21,7 @@ USE_LITE_OP(mul); USE_LITE_OP(matmul); USE_LITE_OP(fc); +USE_LITE_OP(assign); USE_LITE_OP(relu); USE_LITE_OP(relu6); USE_LITE_OP(scale); @@ -51,7 +52,7 @@ USE_LITE_OP(batch_norm) USE_LITE_OP(fusion_elementwise_sub_activation) USE_LITE_OP(transpose) USE_LITE_OP(transpose2) -USE_LITE_OP(argmax) +USE_LITE_OP(arg_max) USE_LITE_OP(axpy) USE_LITE_OP(leaky_relu) USE_LITE_OP(relu_clipped) @@ -118,8 +119,13 @@ USE_LITE_OP(cast) USE_LITE_OP(affine_channel) USE_LITE_OP(anchor_generator) USE_LITE_OP(generate_proposals) -USE_LITE_OP(squeeze) // for x2paddle -USE_LITE_OP(squeeze2) // for x2paddle -USE_LITE_OP(expand) // for x2paddle +USE_LITE_OP(squeeze) // for x2paddle +USE_LITE_OP(squeeze2) // for x2paddle +USE_LITE_OP(unsqueeze) // for x2paddle +USE_LITE_OP(unsqueeze2) // for x2paddle +USE_LITE_OP(expand) // for x2paddle USE_LITE_OP(roi_align) USE_LITE_OP(box_clip) +USE_LITE_OP(assign_value) +USE_LITE_OP(hard_sigmoid) +USE_LITE_OP(rsqrt) diff --git a/lite/api/android/jni/native/CMakeLists.txt b/lite/api/android/jni/native/CMakeLists.txt index afe051a437f4de83931bdaa3f2d03427b78d13ad..3efa980332f25d786d5c880fab9b3ba5af0a1013 100644 --- a/lite/api/android/jni/native/CMakeLists.txt +++ b/lite/api/android/jni/native/CMakeLists.txt @@ -17,10 +17,20 @@ if (NOT LITE_ON_TINY_PUBLISH) # Unlike static library, module library has to link target to be able to work # as a single .so lib. target_link_libraries(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels}) + if (LITE_WITH_NPU) + # Strips the symbols of our protobuf functions to fix the conflicts during + # loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so) + set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set_target_properties(paddle_lite_jni PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + endif() else() add_library(paddle_lite_jni SHARED "") target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc) add_dependencies(paddle_lite_jni op_list_h kernel_list_h) + if (LITE_WITH_NPU) + # Need to add HIAI runtime libs (libhiai.so) dependency + target_link_libraries(paddle_lite_jni ${npu_runtime_libs}) + endif() endif() if (APPLE) diff --git a/lite/api/android/jni/native/convert_util_jni.h b/lite/api/android/jni/native/convert_util_jni.h index ae987c330dd0ad415a2da783366483c58789c56e..5e5d3723e43eb311f64b85f7507a12497d724109 100644 --- a/lite/api/android/jni/native/convert_util_jni.h +++ b/lite/api/android/jni/native/convert_util_jni.h @@ -49,6 +49,27 @@ inline std::string jstring_to_cpp_string(JNIEnv *env, jstring jstr) { return ret; } +inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) { + auto *data = str.c_str(); + jclass strClass = env->FindClass("java/lang/String"); + jmethodID strClassInitMethodID = + env->GetMethodID(strClass, "", "([BLjava/lang/String;)V"); + + jbyteArray bytes = env->NewByteArray(strlen(data)); + env->SetByteArrayRegion( + bytes, 0, strlen(data), reinterpret_cast(data)); + + jstring encoding = env->NewStringUTF("UTF-8"); + jstring res = (jstring)( + env->NewObject(strClass, strClassInitMethodID, bytes, encoding)); + + env->DeleteLocalRef(strClass); + env->DeleteLocalRef(encoding); + env->DeleteLocalRef(bytes); + + return res; +} + inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, const float *buf, int64_t len) { @@ -124,8 +145,6 @@ inline CxxConfig jcxxconfig_to_cpp_cxxconfig(JNIEnv *env, jobject jcxxconfig) { jmethodID model_dir_method = env->GetMethodID(cxxconfig_jclazz, "getModelDir", "()Ljava/lang/String;"); - jmethodID preferred_place_method = env->GetMethodID( - cxxconfig_jclazz, "getPreferredPlace", "()Lcom/baidu/paddle/lite/Place;"); jmethodID valid_places_method = env->GetMethodID( cxxconfig_jclazz, "getValidPlaces", "()[Lcom/baidu/paddle/lite/Place;"); @@ -138,13 +157,6 @@ inline CxxConfig jcxxconfig_to_cpp_cxxconfig(JNIEnv *env, jobject jcxxconfig) { config.set_model_dir(cpp_model_dir); } - jobject java_preferred_place = - env->CallObjectMethod(jcxxconfig, preferred_place_method); - if (java_preferred_place != nullptr) { - Place cpp_preferred_place = jplace_to_cpp_place(env, java_preferred_place); - config.set_preferred_place(cpp_preferred_place); - } - jobject object_valid_places = env->CallObjectMethod(jcxxconfig, valid_places_method); jobjectArray *java_valid_places = diff --git a/lite/api/android/jni/native/paddle_lite_jni.cc b/lite/api/android/jni/native/paddle_lite_jni.cc index aa4ece68189f002c9e183a042510021fcb602f75..d0d2d603a9f1c7f9c8308a5540c96b551b4845b7 100644 --- a/lite/api/android/jni/native/paddle_lite_jni.cc +++ b/lite/api/android/jni/native/paddle_lite_jni.cc @@ -50,6 +50,16 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_PaddlePredictor_run( return JNI_TRUE; } +JNIEXPORT jstring JNICALL Java_com_baidu_paddle_lite_PaddlePredictor_getVersion( + JNIEnv *env, jobject jpaddle_predictor) { + std::shared_ptr *predictor = + getPaddlePredictorPointer(env, jpaddle_predictor); + if (predictor == nullptr || (*predictor == nullptr)) { + return cpp_string_to_jstring(env, ""); + } + return cpp_string_to_jstring(env, (*predictor)->GetVersion()); +} + JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_PaddlePredictor_saveOptimizedModel( JNIEnv *env, jobject jpaddle_predictor, jstring model_dir) { diff --git a/lite/api/android/jni/native/paddle_lite_jni.h b/lite/api/android/jni/native/paddle_lite_jni.h index 913e9a4c3a87ca3e649b86d020c3a4a8fd458a0b..f447ce105a1ca7b2d94a00287d2b699f920a09af 100644 --- a/lite/api/android/jni/native/paddle_lite_jni.h +++ b/lite/api/android/jni/native/paddle_lite_jni.h @@ -37,6 +37,14 @@ namespace lite_api { JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_PaddlePredictor_run(JNIEnv *, jobject); +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: getVersion + * Signature: ()Z + */ +JNIEXPORT jstring JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getVersion(JNIEnv *, jobject); + /* * Class: com_baidu_paddle_lite_PaddlePredictor * Method: saveOptimizedModel diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java b/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java index 906293c92fe379caf7e05c805cbbf9a55f0896bd..3f68ef89228d44e41f8d1d5a0ba65791484bb0aa 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java @@ -18,17 +18,8 @@ package com.baidu.paddle.lite; */ public class CxxConfig extends ConfigBase { - protected Place preferredPlace; protected Place[] validPlaces; - public Place getPreferredPlace() { - return preferredPlace; - } - - public void setPreferredPlace(Place preferredPlace) { - this.preferredPlace = preferredPlace; - } - public Place[] getValidPlaces() { return validPlaces; } diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java b/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java index d022fd7d61816e3cc0e01dbac227210e1061099e..efd35d23a1207c8920a8ed3d33af6abf6ba97d5a 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java @@ -82,6 +82,13 @@ public class PaddlePredictor { */ public native boolean run(); + /** + * Get c++ lib's version information. + * + * @return C++ lib's version information. + */ + public native String getVersion(); + /** * Saves the optimized model. It is available only for {@link CxxConfig} * diff --git a/lite/api/apis_test.cc b/lite/api/apis_test.cc index 3dc02240846ed4fc6dc310e3a27725792463da6e..ac2c385d53ea0a1785393cd488d115d20c4264f1 100644 --- a/lite/api/apis_test.cc +++ b/lite/api/apis_test.cc @@ -51,17 +51,12 @@ bool CompareTensors(const std::string& name, TEST(CXXApi_LightApi, optim_model) { lite::Predictor cxx_api; std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kX86), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, // Both works on X86 and ARM }); // On ARM devices, the preferred X86 target not works, but it can still // select ARM kernels. - cxx_api.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kX86), PRECISION(kFloat)}, - valid_places); + cxx_api.Build(FLAGS_model_dir, "", "", valid_places); cxx_api.SaveModel(FLAGS_optimized_model); } @@ -72,17 +67,12 @@ TEST(CXXApi_LightApi, save_and_load_model) { // CXXAPi { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kX86), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, // Both works on X86 and ARM }); // On ARM devices, the preferred X86 target not works, but it can still // select ARM kernels. - cxx_api.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kX86), PRECISION(kFloat)}, - valid_places); + cxx_api.Build(FLAGS_model_dir, "", "", valid_places); auto* x = cxx_api.GetInput(0); SetConstInput(x); diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index ca7bfe7fe6cb57a0f10ad6ca0ed1909a5d82eac1..462a5e2381acf3cc86ca81002a282933f01ee049 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -32,7 +32,9 @@ DEFINE_string(input_shape, DEFINE_string(result_filename, "", "save test result"); DEFINE_bool(run_model_optimize, false, - "apply model_optimize_tool to model, use optimized model to test"); + "if set true, apply model_optimize_tool to model, use optimized " + "model to test"); +DEFINE_bool(is_quantized_model, false, "if set true, test the quantized model"); namespace paddle { namespace lite_api { @@ -42,11 +44,14 @@ void OutputOptModel(const std::string& load_model_dir, const std::vector>& input_shapes) { lite_api::CxxConfig config; config.set_model_dir(load_model_dir); - config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)}); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, - }); + std::vector vaild_places = {Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kOpenCL), PRECISION(kFloat)}}; + if (FLAGS_is_quantized_model) { + vaild_places.insert(vaild_places.begin(), + Place{TARGET(kARM), PRECISION(kInt8)}); + } + config.set_valid_places(vaild_places); auto predictor = lite_api::CreatePaddlePredictor(config); int ret = system( @@ -70,11 +75,7 @@ void Run(const std::vector>& input_shapes, const std::string model_name) { lite_api::MobileConfig config; config.set_threads(thread_num); - if (thread_num == 1) { - config.set_power_mode(LITE_POWER_HIGH); - } else { - config.set_power_mode(LITE_POWER_NO_BIND); - } + config.set_power_mode(LITE_POWER_NO_BIND); config.set_model_dir(model_dir); auto predictor = lite_api::CreatePaddlePredictor(config); diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index eeba68630146870fd43bac3cd7eeaa1d9c576eac..a2b538aa77e0603f439b6b23aab875103fdbbff0 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -13,20 +13,27 @@ // limitations under the License. #include "lite/api/cxx_api.h" +#include #include +#include #include #include #include #include "lite/utils/io.h" -#ifdef LITE_WITH_NPU -#include "lite/backends/npu/npu_helper.h" -#endif namespace paddle { namespace lite { +static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] = + ".tailored_ops_source_list"; +static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list"; +static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] = + ".tailored_kernels_source_list"; +static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; + void Predictor::SaveModel(const std::string &dir, - lite_api::LiteModelType model_type) { + lite_api::LiteModelType model_type, + bool record_info) { if (!program_) { GenRuntimeProgram(); } @@ -42,41 +49,142 @@ void Predictor::SaveModel(const std::string &dir, default: LOG(FATAL) << "Unknown model type"; } -#ifdef LITE_WITH_NPU - for (auto name : npu::DeviceInfo::Global().AllClientNames()) { - // the npu offline model is saved in current dir - // so just copy to dst dir - CHECK_EQ( - system(string_format("cp -r %s %s", name.c_str(), dir.c_str()).c_str()), - 0) - << "Failed copy NPU model to " << dir; + if (record_info) { + SaveOpKernelInfo(dir); } -#endif +} + +void Predictor::SaveOpKernelInfo(const std::string &model_dir) { + std::set ops_info; + std::set kernels_info; + const auto &instructions_ = program_->instructions(); + for (auto &node : instructions_) { + // parse op type infomation + auto op = node.op()->op_info(); + ops_info.insert(op->Type()); + // parse kernel type information + std::string kernel_type_str = + node.kernel()->op_type() + "," + TargetRepr(node.kernel()->target()) + + "," + PrecisionRepr(node.kernel()->precision()) + "," + + DataLayoutRepr(node.kernel()->layout()) + "," + node.kernel()->alias(); + kernels_info.insert(kernel_type_str); + } + + // get souce_file name from op type and kernel type + auto op2pathmap = OpKernelInfoCollector::Global().GetOp2PathDict(); + auto kernel2pathmap = OpKernelInfoCollector::Global().GetKernel2PathDict(); + + // write used op and kernel info into files + std::string opf_path = model_dir + "/" + TAILORD_OPS_LIST_NAME; + std::string opf_source_path = + model_dir + "/" + TAILORD_OPS_SOURCE_LIST_FILENAME; + std::string kpf_path = model_dir + "/" + TAILORD_KERNELS_LIST_NAME; + std::string kpf_source_path = + model_dir + "/" + TAILORD_KERNELS_SOURCE_LIST_FILENAME; + std::map op2path; + + std::FILE *opf = std::fopen(opf_path.c_str(), "w"); + std::FILE *opf_source = std::fopen(opf_source_path.c_str(), "w"); + std::FILE *kpf = std::fopen(kpf_path.c_str(), "w"); + std::FILE *kpf_source = std::fopen(kpf_source_path.c_str(), "w"); + std::vector opcompile; + std::vector kernelcompile; + + if (nullptr == opf || nullptr == opf_source || nullptr == opf || + nullptr == kpf_source) { + LOG(FATAL) << "failed to create info file into: " << model_dir; + } + for (auto op_info = ops_info.begin(); op_info != ops_info.end(); op_info++) { + fputs(op_info->c_str(), opf); + fputc('\n', opf); + std::string op_path = op2pathmap[*op_info]; + fputs(op_path.c_str(), opf_source); + fputc('\n', opf_source); + } + std::fclose(opf_source); + std::fclose(opf); + LOG(INFO) << "operators information of tailored model is stored into: " + << opf_path; + + // write Kernel_type and Kernel_path into file + for (auto kernel_info = kernels_info.begin(); + kernel_info != kernels_info.end(); + kernel_info++) { + fputs(kernel_info->c_str(), kpf); + fputc('\n', kpf); + std::string kernel_path = kernel2pathmap[*kernel_info]; + fputs(kernel_path.c_str(), kpf_source); + fputc('\n', kpf_source); + if (kernel_path == "conv_compute.cc") { + fputs( + "conv_depthwise.cc\nconv_direct.cc\nconv_gemmlike.cc\nconv_" + "winograd.cc\n", + kpf_source); + } + } + std::fclose(kpf_source); + std::fclose(kpf); + LOG(INFO) << "kernels information of tailored model is stored into: " + << kpf_path; } lite::Tensor *Predictor::GetInput(size_t offset) { - auto *_feed_list = exec_scope_->FindVar("feed"); - CHECK(_feed_list) << "no feed variable in exec_scope"; - auto *feed_list = _feed_list->GetMutable>(); - if (offset >= feed_list->size()) { - feed_list->resize(offset + 1); + CHECK(input_names_.size() > offset) + << "The network has " << input_names_.size() << " inputs" + << ", the offset should be less than this."; + auto *in_var = exec_scope_->FindVar(input_names_[offset]); + CHECK(in_var) << "no fatch variable " << input_names_[offset] + << " in exec_scope"; + return in_var->GetMutable(); +} + +// get inputs names +std::vector Predictor::GetInputNames() { return input_names_; } +// get outputnames +std::vector Predictor::GetOutputNames() { return output_names_; } +// append the names of inputs and outputs into input_names_ and output_names_ +void Predictor::PrepareFeedFetch() { + auto current_block = program_desc_.GetBlock(0); + std::vector feeds; + std::vector fetchs; + for (size_t i = 0; i < current_block->OpsSize(); i++) { + auto op = current_block->GetOp(i); + if (op->Type() == "feed") { + feeds.push_back(op); + } else if (op->Type() == "fetch") { + fetchs.push_back(op); + } + } + input_names_.resize(feeds.size()); + output_names_.resize(fetchs.size()); + for (size_t i = 0; i < feeds.size(); i++) { + input_names_[feeds[i]->GetAttr("col")] = + feeds[i]->Output("Out").front(); + } + for (size_t i = 0; i < fetchs.size(); i++) { + output_names_[fetchs[i]->GetAttr("col")] = + fetchs[i]->Input("X").front(); } - return &feed_list->at(offset); } const lite::Tensor *Predictor::GetOutput(size_t offset) const { - auto *_fetch_list = exec_scope_->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto &fetch_list = *_fetch_list->GetMutable>(); - CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; - return &fetch_list.at(offset); + CHECK(output_names_.size() > offset) + << "The network has " << output_names_.size() << " outputs" + << ", the offset should be less than this."; + const std::string name = output_names_.at(offset); + auto *out_var = exec_scope_->FindVar(name); + CHECK(out_var) << "no fatch variable " << name << " in exec_scope"; + return out_var->GetMutable(); } -const std::vector *Predictor::GetOutputs() const { - auto *_fetch_list = exec_scope_->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto &fetch_list = *_fetch_list->GetMutable>(); - return &fetch_list; +std::vector Predictor::GetOutputs() const { + std::vector outputs; + size_t out_size = output_names_.size(); + for (size_t i = 0; i < out_size; i++) { + const std::string name = output_names_.at(i); + outputs.push_back(GetTensor(name)); + } + return outputs; } const cpp::ProgramDesc &Predictor::program_desc() const { @@ -91,14 +199,12 @@ void Predictor::Build(const lite_api::CxxConfig &config, const std::string &model_path = config.model_dir(); const std::string &model_file = config.model_file(); const std::string ¶m_file = config.param_file(); - const Place prefer_place = config.preferred_place(); const bool model_from_memory = config.model_from_memory(); LOG(INFO) << "load from memory " << model_from_memory; Build(model_path, model_file, param_file, - prefer_place, valid_places, passes, model_type, @@ -107,7 +213,6 @@ void Predictor::Build(const lite_api::CxxConfig &config, void Predictor::Build(const std::string &model_path, const std::string &model_file, const std::string ¶m_file, - const Place &prefer_place, const std::vector &valid_places, const std::vector &passes, lite_api::LiteModelType model_type, @@ -134,21 +239,26 @@ void Predictor::Build(const std::string &model_path, default: LOG(FATAL) << "Unknown model type"; } - Build(program_desc_, prefer_place, valid_places, passes); + Build(program_desc_, valid_places, passes); } void Predictor::Build(const cpp::ProgramDesc &desc, - const Place &prefer_place, const std::vector &valid_places, const std::vector &passes) { program_desc_ = desc; - Program program(desc, scope_, valid_places); - optimizer_.KernelPickPreferPlace(prefer_place); + std::vector inner_places = valid_places; + inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); + inner_places.emplace_back( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + Program program(desc, scope_, inner_places); + /// The first place in valid_places is core::KernelPickFactor factor; factor.ConsiderTarget(); factor.ConsiderPrecision(); - optimizer_.Run(std::move(program), valid_places, factor, passes); + factor.ConsiderDataLayout(); + optimizer_.Run(std::move(program), inner_places, factor, passes); exec_scope_ = optimizer_.exec_scope(); + PrepareFeedFetch(); } void Predictor::GenRuntimeProgram() { @@ -161,6 +271,21 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const { auto *var = exec_scope_->FindVar(name); return &var->Get(); } +// get input by name +lite::Tensor *Predictor::GetInputByName(const std::string &name) { + auto element = std::find(input_names_.begin(), input_names_.end(), name); + if (element == input_names_.end()) { + LOG(ERROR) << "Model do not have input named with: [" << name + << "], model's inputs include:"; + for (size_t i = 0; i < input_names_.size(); i++) { + LOG(ERROR) << "[" << input_names_[i] << "]"; + } + return nullptr; + } else { + int position = std::distance(input_names_.begin(), element); + return GetInput(position); + } +} #ifdef LITE_WITH_TRAIN void Predictor::FeedVars(const std::vector &tensors) { diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 2506ae47b0ddbce683d8f4b12e000bb3ea19d497..502ce812e1f4a7f520e89e6eaff020c5853f5308 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -13,7 +13,9 @@ // limitations under the License. #pragma once +#include #include +#include //NOLINT #include #include #include @@ -49,14 +51,12 @@ class LITE_API Predictor { const std::string& model_path, const std::string& model_file_path, const std::string& param_file_path, - const Place& prefer_place, const std::vector& valid_places, const std::vector& passes = {}, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, bool memory_from_memory = false); void Build(const cpp::ProgramDesc& desc, - const Place& prefer_place, const std::vector& valid_places, const std::vector& passes = {}); @@ -68,15 +68,20 @@ class LITE_API Predictor { GenRuntimeProgram(); } program_->Run(); - LOG(INFO) << "running"; } // Get offset-th col of feed inputs. lite::Tensor* GetInput(size_t offset); + // get input by name. + lite::Tensor* GetInputByName(const std::string& name); + // get inputnames and get outputnames. + std::vector GetInputNames(); + std::vector GetOutputNames(); + void PrepareFeedFetch(); // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; - const std::vector* GetOutputs() const; + std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; @@ -85,7 +90,9 @@ class LITE_API Predictor { // This method is disabled in mobile, for unnecessary dependencies required. void SaveModel( const std::string& dir, - lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, + bool record_info = false); + void SaveOpKernelInfo(const std::string& model_dir); #ifdef LITE_WITH_TRAIN void Run(const std::vector& tensors) { @@ -103,6 +110,47 @@ class LITE_API Predictor { const Scope* exec_scope_; std::unique_ptr program_; bool program_generated_{false}; + std::vector input_names_; + std::vector output_names_; +}; + +class CxxPaddleApiImpl : public lite_api::PaddlePredictor { + public: + CxxPaddleApiImpl() {} + + /// Create a new predictor from a config. + void Init(const lite_api::CxxConfig& config); + + std::unique_ptr GetInput(int i) override; + + std::unique_ptr GetOutput(int i) const override; + + void Run() override; + + std::shared_ptr Clone() override; + + std::string GetVersion() const override; + + // get inputs names and get outputs names + std::vector GetInputNames() override; + std::vector GetOutputNames() override; + + std::unique_ptr GetTensor( + const std::string& name) const override; + + // Get InputTebsor by name + std::unique_ptr GetInputByName( + const std::string& name) override; + + void SaveOptimizedModel( + const std::string& model_dir, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, + bool record_info = false) override; + + private: + Predictor raw_predictor_; + lite_api::CxxConfig config_; + std::mutex mutex_; }; /* @@ -123,10 +171,8 @@ class LITE_API Predictor { class LITE_API CXXTrainer { public: CXXTrainer(const std::shared_ptr& root_scope, - const Place& preferred_place, const std::vector& valid_places) : scope_(root_scope), - preferred_place_(preferred_place), valid_places_(valid_places), main_program_executor_(Predictor(scope_)) {} @@ -135,7 +181,7 @@ class LITE_API CXXTrainer { // NOTE Just support to execute the 0-th block currently. Predictor& BuildMainProgramExecutor(const framework::proto::ProgramDesc& desc, int block_id = 0) { - main_program_executor_.Build(desc, preferred_place_, valid_places_); + main_program_executor_.Build(desc, valid_places_); return main_program_executor_; } @@ -153,14 +199,12 @@ class LITE_API CXXTrainer { void RunStartupProgram(const framework::proto::ProgramDesc& desc, int block_id = 0) { Predictor exe(scope_); - exe.Build(desc, preferred_place_, valid_places_); + exe.Build(desc, valid_places_); exe.Run(); } private: std::shared_ptr scope_; - - Place preferred_place_; std::vector valid_places_; // The training program. diff --git a/lite/api/cxx_api_bin.cc b/lite/api/cxx_api_bin.cc index 000e94307ca4acaa3a57597f4a7b0e44a57e0031..8c929e9c8700a65c868e2facd763b0ec36719e23 100644 --- a/lite/api/cxx_api_bin.cc +++ b/lite/api/cxx_api_bin.cc @@ -35,13 +35,11 @@ void Run(const char* model_dir, int repeat) { #endif lite::Predictor predictor; std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kInt8)}, + Place{TARGET(kARM), PRECISION(kFloat)}, }); - predictor.Build( - model_dir, "", "", Place{TARGET(kARM), PRECISION(kInt8)}, valid_places); + predictor.Build(model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index b8c92a8f96afefa7a2de6b844980f9c0f769f6a9..6fa400db6da9f029c38b496cd70d593a876628c9 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -13,41 +13,26 @@ // limitations under the License. #include "lite/api/cxx_api.h" +#include +#include //NOLINT +#include #include "lite/api/paddle_api.h" +#include "lite/core/device_info.h" +#include "lite/core/version.h" namespace paddle { namespace lite { -class CxxPaddleApiImpl : public lite_api::PaddlePredictor { - public: - CxxPaddleApiImpl(); - - /// Create a new predictor from a config. - void Init(const lite_api::CxxConfig &config); - - std::unique_ptr GetInput(int i) override; - - std::unique_ptr GetOutput(int i) const override; - - void Run() override; - - std::unique_ptr GetTensor( - const std::string &name) const override; - - void SaveOptimizedModel(const std::string &model_dir, - lite_api::LiteModelType model_type = - lite_api::LiteModelType::kProtobuf) override; - - private: - Predictor raw_predictor_; -}; - -CxxPaddleApiImpl::CxxPaddleApiImpl() {} - void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { + config_ = config; +#ifdef LITE_WITH_CUDA + Env::Init(); +#endif auto places = config.valid_places(); - places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); raw_predictor_.Build(config, places); + + mode_ = config.power_mode(); + threads_ = config.threads(); } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { @@ -61,7 +46,29 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( return std::unique_ptr(new lite_api::Tensor(x)); } -void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); } +std::vector CxxPaddleApiImpl::GetInputNames() { + return raw_predictor_.GetInputNames(); +} + +std::vector CxxPaddleApiImpl::GetOutputNames() { + return raw_predictor_.GetOutputNames(); +} + +void CxxPaddleApiImpl::Run() { +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Global().SetRunMode(mode_, threads_); +#endif + raw_predictor_.Run(); +} + +std::shared_ptr CxxPaddleApiImpl::Clone() { + std::lock_guard lock(mutex_); + auto predictor = std::make_shared(); + predictor->Init(config_); + return predictor; +} + +std::string CxxPaddleApiImpl::GetVersion() const { return version(); } std::unique_ptr CxxPaddleApiImpl::GetTensor( const std::string &name) const { @@ -69,9 +76,16 @@ std::unique_ptr CxxPaddleApiImpl::GetTensor( return std::unique_ptr(new lite_api::Tensor(x)); } +std::unique_ptr CxxPaddleApiImpl::GetInputByName( + const std::string &name) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_.GetInputByName(name))); +} + void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir, - lite_api::LiteModelType model_type) { - raw_predictor_.SaveModel(model_dir, model_type); + lite_api::LiteModelType model_type, + bool record_info) { + raw_predictor_.SaveModel(model_dir, model_type, record_info); } } // namespace lite diff --git a/lite/api/cxx_api_test.cc b/lite/api/cxx_api_test.cc index c562b9f0801c55630bb8f4108a27e7b927c62514..4d711302cb5880247f4a7b7082185c500b9ad6e9 100644 --- a/lite/api/cxx_api_test.cc +++ b/lite/api/cxx_api_test.cc @@ -43,13 +43,8 @@ TEST(CXXApi, test) { TEST(CXXApi, save_model) { lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kCUDA), PRECISION(kFloat)}, - valid_places); + std::vector valid_places({Place{TARGET(kX86), PRECISION(kFloat)}}); + predictor.Build(FLAGS_model_dir, "", "", valid_places); LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; predictor.SaveModel(FLAGS_optimized_model, @@ -59,11 +54,11 @@ TEST(CXXApi, save_model) { } /*TEST(CXXTrainer, train) { - Place prefer_place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}); - std::vector valid_places({prefer_place}); + Place place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}); + std::vector valid_places({place}); auto scope = std::make_shared(); - CXXTrainer trainer(scope, prefer_place, valid_places); + CXXTrainer trainer(scope, valid_places); std::string main_program_pb, startup_program_pb; ReadBinaryFile(FLAGS_main_program_path, &main_program_pb); @@ -94,13 +89,8 @@ TEST(CXXApi, save_model) { #ifdef LITE_WITH_ARM TEST(CXXApi, save_model) { lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kARM), PRECISION(kFloat)}, - valid_places); + std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); + predictor.Build(FLAGS_model_dir, "", "", valid_places); LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; predictor.SaveModel(FLAGS_optimized_model); @@ -110,12 +100,10 @@ TEST(CXXApi, save_model) { TEST(CXXApi, load_model_naive) { lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); predictor.Build(FLAGS_optimized_model + ".naive", "", "", - Place{TARGET(kARM), PRECISION(kFloat)}, valid_places, {}, lite_api::LiteModelType::kNaiveBuffer); diff --git a/lite/api/detection_model_test.cc b/lite/api/detection_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c14acbac411aad526cf9271c22891cf7279f3ade --- /dev/null +++ b/lite/api/detection_model_test.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +DEFINE_bool(is_run_model_optimize, + false, + "apply model_optimize_tool to model, use optimized model to test"); + +namespace paddle { +namespace lite_api { + +void OutputOptModel(const std::string& load_model_dir, + const std::string& save_optimized_model_dir) { + lite_api::CxxConfig config; + config.set_model_dir(load_model_dir); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + auto predictor = lite_api::CreatePaddlePredictor(config); + + int ret = system( + paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str()) + .c_str()); + if (ret == 0) { + LOG(INFO) << "delete old optimized model " << save_optimized_model_dir; + } + predictor->SaveOptimizedModel(save_optimized_model_dir, + LiteModelType::kNaiveBuffer); + LOG(INFO) << "Load model from " << load_model_dir; + LOG(INFO) << "Save optimized model to " << save_optimized_model_dir; +} + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +void Run(const std::string& model_dir, + const int repeat, + const int warmup_times, + const int thread_num) { + // set config and create predictor + lite_api::MobileConfig config; + config.set_model_dir(model_dir); + config.set_threads(thread_num); + if (thread_num == 1) { + config.set_power_mode(LITE_POWER_HIGH); + } else { + config.set_power_mode(LITE_POWER_NO_BIND); + } + + auto predictor = lite_api::CreatePaddlePredictor(config); + + // set input + auto input_image = predictor->GetInput(0); + input_image->Resize({1, 3, 300, 300}); + auto input_image_data = input_image->mutable_data(); + std::ifstream read_file("/data/local/tmp/pjc/ssd_img.txt"); + if (!read_file.is_open()) { + LOG(INFO) << "read image file fail"; + return; + } + auto input_shape = input_image->shape(); + int64_t input_image_size = 1; + for (auto t : input_shape) { + input_image_size *= t; + } + for (int i = 0; i < input_image_size; i++) { + read_file >> input_image_data[i]; + } + + // warmup and run + for (int i = 0; i < warmup_times; ++i) { + predictor->Run(); + } + + auto start = lite::GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + predictor->Run(); + } + + // show result + auto end = lite::GetCurrentUS(); + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (end - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + auto out = predictor->GetOutput(0); + auto out_data = out->data(); + LOG(INFO) << "output shape:"; + auto out_shape = out->shape(); + for (auto t : out_shape) { + LOG(INFO) << t; + } + LOG(INFO) << "output data:"; + int output_len = 20; + for (int i = 0; i < output_len; i++) { + LOG(INFO) << out_data[i]; + } +} +#endif + +} // namespace lite_api +} // namespace paddle + +TEST(Faster_RCNN, test_arm) { + std::string save_optimized_model_dir; + if (FLAGS_is_run_model_optimize) { + save_optimized_model_dir = FLAGS_model_dir + "opt"; + paddle::lite_api::OutputOptModel(FLAGS_model_dir, save_optimized_model_dir); + } + std::string run_model_dir = + FLAGS_is_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir; + paddle::lite_api::Run( + run_model_dir, FLAGS_repeats, FLAGS_threads, FLAGS_warmup); +} diff --git a/lite/api/efficientnet_b0_test.cc b/lite/api/efficientnet_b0_test.cc index fa16a6be817f2a6160fd2eaf8fd48d9fa9e1aa1a..61d74eb35412291398d4491057013c514ff5e1de 100644 --- a/lite/api/efficientnet_b0_test.cc +++ b/lite/api/efficientnet_b0_test.cc @@ -25,13 +25,12 @@ namespace paddle { namespace lite { -void TestModel(const std::vector &valid_places, - const Place &preferred_place) { +void TestModel(const std::vector &valid_places) { DeviceInfo::Init(); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto *input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); @@ -80,22 +79,20 @@ void TestModel(const std::vector &valid_places, TEST(EfficientNetB0, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, // Place{TARGET(kOpenCL), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } TEST(EfficientNetB0, test_opencl) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kOpenCL), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); + TestModel(valid_places); } } // namespace lite diff --git a/lite/api/faster_rcnn_test.cc b/lite/api/faster_rcnn_test.cc deleted file mode 100644 index ac5ced0dec5b81c899b35eef60ecd5b756283848..0000000000000000000000000000000000000000 --- a/lite/api/faster_rcnn_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include "lite/api/cxx_api.h" -#include "lite/api/paddle_use_kernels.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/api/paddle_use_passes.h" -#include "lite/api/test_helper.h" -#include "lite/core/op_registry.h" - -namespace paddle { -namespace lite { - -#ifdef LITE_WITH_ARM -void TestModel(const std::vector& valid_places, - const Place& preferred_place) { - DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); - lite::Predictor predictor; - - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); - - auto* input_image = predictor.GetInput(0); - input_image->Resize({1, 3, 1333, 800}); - auto* input_image_data = input_image->mutable_data(); - std::ifstream read_file("/data/local/tmp/pjc/faster_rcnn_img.txt"); - for (int i = 0; i < input_image->numel(); i++) { - read_file >> input_image_data[i]; - } - read_file.close(); - LOG(INFO) << "image data:" << input_image_data[0] << " " - << input_image_data[input_image->numel() - 1]; - - auto* im_info = predictor.GetInput(1); - im_info->Resize({1, 3}); - auto* im_info_data = im_info->mutable_data(); - im_info_data[0] = 1333; - im_info_data[1] = 800; - im_info_data[2] = 1; - - auto* im_shape = predictor.GetInput(2); - im_shape->Resize({1, 3}); - auto* im_shape_data = im_shape->mutable_data(); - im_shape_data[0] = 1333; - im_shape_data[1] = 800; - im_shape_data[2] = 1; - - for (int i = 0; i < FLAGS_warmup; ++i) { - predictor.Run(); - } - - auto start = GetCurrentUS(); - for (int i = 0; i < FLAGS_repeats; ++i) { - predictor.Run(); - } - - LOG(INFO) << "================== Speed Report ==================="; - LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads - << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats - << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 - << " ms in average."; - - auto* out = predictor.GetOutput(0); - auto* out_data = out->data(); - LOG(INFO) << "==========output data==============="; - LOG(INFO) << out->dims(); - for (int i = 0; i < out->numel(); i++) { - LOG(INFO) << out_data[i]; - } -} - -TEST(Faster_RCNN, test_arm) { - std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, - }); - - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); -} - -#endif // LITE_WITH_ARM - -} // namespace lite -} // namespace paddle diff --git a/lite/api/inceptionv4_test.cc b/lite/api/inceptionv4_test.cc index ae772dbba560b855f7f835f7513451713f1099b8..95ad5121caafd70b6b0111bab9c2e76bce75c742 100644 --- a/lite/api/inceptionv4_test.cc +++ b/lite/api/inceptionv4_test.cc @@ -30,14 +30,9 @@ TEST(InceptionV4, test) { DeviceInfo::Init(); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kARM), PRECISION(kFloat)}, - valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 98b79e58aa349436ad64dcc0a54256d5d9ead3df..a0c4b7e5e375d9d004de63345ba5013ee6c252b9 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/api/light_api.h" +#include namespace paddle { namespace lite { @@ -22,44 +23,94 @@ void LightPredictor::Build(const std::string& model_dir, const std::string& param_buffer, lite_api::LiteModelType model_type, bool model_from_memory) { - cpp::ProgramDesc desc; switch (model_type) { #ifndef LITE_ON_TINY_PUBLISH case lite_api::LiteModelType::kProtobuf: - LoadModelPb(model_dir, "", "", scope_.get(), &desc); + LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_); break; #endif case lite_api::LiteModelType::kNaiveBuffer: { if (model_from_memory) { LoadModelNaiveFromMemory( - model_buffer, param_buffer, scope_.get(), &desc); + model_buffer, param_buffer, scope_.get(), &cpp_program_desc_); } else { - LoadModelNaive(model_dir, scope_.get(), &desc); + LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_); } break; } default: LOG(FATAL) << "Unknown model type"; } - BuildRuntimeProgram(desc); + BuildRuntimeProgram(cpp_program_desc_); + PrepareFeedFetch(); } Tensor* LightPredictor::GetInput(size_t offset) { - auto* _feed_list = program_->exec_scope()->FindVar("feed"); - CHECK(_feed_list) << "no feed variable in exec_scope"; - auto* feed_list = _feed_list->GetMutable>(); - if (offset >= feed_list->size()) { - feed_list->resize(offset + 1); + CHECK(input_names_.size() > offset) + << "The network has " << input_names_.size() << " inputs" + << ", the offset should be less than this."; + auto* in_var = program_->exec_scope()->FindVar(input_names_[offset]); + CHECK(in_var) << "no fatch variable " << input_names_[offset] + << " in exec_scope"; + return in_var->GetMutable(); +} + +// get input by name +Tensor* LightPredictor::GetInputByName(const std::string& name) { + auto element = std::find(input_names_.begin(), input_names_.end(), name); + if (element == input_names_.end()) { + LOG(ERROR) << "Model do not have input named with: [" << name + << "], model's inputs include:"; + for (int i = 0; i < input_names_.size(); i++) { + LOG(ERROR) << "[" << input_names_[i] << "]"; + } + return nullptr; + } else { + int position = std::distance(input_names_.begin(), element); + return GetInput(position); } - return &feed_list->at(offset); } const Tensor* LightPredictor::GetOutput(size_t offset) { - auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto& fetch_list = *_fetch_list->GetMutable>(); - CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; - return &fetch_list.at(offset); + CHECK(output_names_.size() > offset) + << "The network has " << output_names_.size() << " outputs" + << ", the offset should be less than this."; + auto* out_var = program_->exec_scope()->FindVar(output_names_.at(offset)); + CHECK(out_var) << "no fatch variable " << output_names_.at(offset) + << " in exec_scope"; + return out_var->GetMutable(); +} +// get inputs names +std::vector LightPredictor::GetInputNames() { + return input_names_; +} +// get outputnames +std::vector LightPredictor::GetOutputNames() { + return output_names_; +} +// append the names of inputs and outputs into input_names_ and output_names_ +void LightPredictor::PrepareFeedFetch() { + auto current_block = cpp_program_desc_.GetBlock(0); + std::vector feeds; + std::vector fetchs; + for (int i = 0; i < current_block->OpsSize(); i++) { + auto op = current_block->GetOp(i); + if (op->Type() == "feed") { + feeds.push_back(op); + } else if (op->Type() == "fetch") { + fetchs.push_back(op); + } + } + input_names_.resize(feeds.size()); + output_names_.resize(fetchs.size()); + for (int i = 0; i < feeds.size(); i++) { + input_names_[feeds[i]->GetAttr("col")] = + feeds[i]->Output("Out").front(); + } + for (int i = 0; i < fetchs.size(); i++) { + output_names_[fetchs[i]->GetAttr("col")] = + fetchs[i]->Input("X").front(); + } } void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { @@ -84,9 +135,11 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { }); CHECK(it != kernels.end()); (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); + insts.emplace_back(op, std::move(*it)); } program_.reset(new RuntimeProgram(std::move(insts))); + CHECK(program.exec_scope()); program_->set_exec_scope(program.exec_scope()); } diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 241540174489d40c0688cba2ce7911f11b5b5832..3781bc4d674db5d2e8794edaf33f00627b9977bb 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -18,6 +18,7 @@ */ #pragma once +#include #include #include #include @@ -52,7 +53,8 @@ class LITE_API LightPredictor { // Get offset-th col of feed inputs. Tensor* GetInput(size_t offset); - + // get input by name. + Tensor* GetInputByName(const std::string& name); // Get offset-th col of fetch outputs. const Tensor* GetOutput(size_t offset); @@ -61,6 +63,11 @@ class LITE_API LightPredictor { return &var->Get(); } + // get inputnames and get outputnames. + std::vector GetInputNames(); + std::vector GetOutputNames(); + void PrepareFeedFetch(); + private: void Build( const std::string& model_dir, @@ -74,6 +81,37 @@ class LITE_API LightPredictor { private: std::shared_ptr scope_; std::unique_ptr program_; + cpp::ProgramDesc cpp_program_desc_; + std::vector input_names_; + std::vector output_names_; +}; + +class LightPredictorImpl : public lite_api::PaddlePredictor { + public: + LightPredictorImpl() = default; + + std::unique_ptr GetInput(int i) override; + + std::unique_ptr GetOutput(int i) const override; + + void Run() override; + + std::shared_ptr Clone() override; + + std::string GetVersion() const override; + std::vector GetInputNames() override; + std::vector GetOutputNames() override; + + std::unique_ptr GetTensor( + const std::string& name) const override; + // Get InputTebsor by name + std::unique_ptr GetInputByName( + const std::string& name) override; + + void Init(const lite_api::MobileConfig& config); + + private: + std::unique_ptr raw_predictor_; }; } // namespace lite diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 6075f1a36f6803b7e5090697802dcb47fafa0d0d..a0ae28df0958403237114a3d4b94031829019339 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -13,64 +13,78 @@ // limitations under the License. #include "lite/api/light_api.h" +#include #include "lite/api/paddle_api.h" +#include "lite/core/version.h" #include "lite/model_parser/model_parser.h" namespace paddle { -namespace lite_api { - -class LightPredictorImpl : public PaddlePredictor { - public: - LightPredictorImpl() = default; - - std::unique_ptr GetInput(int i) override; - - std::unique_ptr GetOutput(int i) const override; - - void Run() override; - - std::unique_ptr GetTensor( - const std::string& name) const override; +namespace lite { + +void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { + // LightPredictor Only support NaiveBuffer backend in publish lib + raw_predictor_.reset( + new LightPredictor(config.model_dir(), + config.model_buffer(), + config.param_buffer(), + config.model_from_memory(), + lite_api::LiteModelType::kNaiveBuffer)); + + mode_ = config.power_mode(); + threads_ = config.threads(); +} - void Init(const MobileConfig& config); +std::unique_ptr LightPredictorImpl::GetInput(int i) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_->GetInput(i))); +} - private: - std::unique_ptr raw_predictor_; -}; +std::unique_ptr LightPredictorImpl::GetOutput( + int i) const { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_->GetOutput(i))); +} -void LightPredictorImpl::Init(const MobileConfig& config) { -// LightPredictor Only support NaiveBuffer backend in publish lib +void LightPredictorImpl::Run() { #ifdef LITE_WITH_ARM - lite::DeviceInfo::Init(); - lite::DeviceInfo::Global().SetRunMode(config.power_mode(), config.threads()); + lite::DeviceInfo::Global().SetRunMode(mode_, threads_); #endif - raw_predictor_.reset(new lite::LightPredictor(config.model_dir(), - config.model_buffer(), - config.param_buffer(), - config.model_from_memory(), - LiteModelType::kNaiveBuffer)); + raw_predictor_->Run(); } -std::unique_ptr LightPredictorImpl::GetInput(int i) { - return std::unique_ptr(new Tensor(raw_predictor_->GetInput(i))); +std::shared_ptr LightPredictorImpl::Clone() { + LOG(FATAL) << "The Clone API is not supported in LigthPredictor"; } -std::unique_ptr LightPredictorImpl::GetOutput(int i) const { - return std::unique_ptr(new Tensor(raw_predictor_->GetOutput(i))); +std::string LightPredictorImpl::GetVersion() const { return lite::version(); } + +std::unique_ptr LightPredictorImpl::GetTensor( + const std::string& name) const { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_->GetTensor(name))); +} +std::unique_ptr LightPredictorImpl::GetInputByName( + const std::string& name) { + return std::unique_ptr( + new lite_api::Tensor(raw_predictor_->GetInputByName(name))); } -void LightPredictorImpl::Run() { raw_predictor_->Run(); } +std::vector LightPredictorImpl::GetInputNames() { + return raw_predictor_->GetInputNames(); +} -std::unique_ptr LightPredictorImpl::GetTensor( - const std::string& name) const { - return std::unique_ptr( - new Tensor(raw_predictor_->GetTensor(name))); +std::vector LightPredictorImpl::GetOutputNames() { + return raw_predictor_->GetOutputNames(); } +} // namespace lite + +namespace lite_api { + template <> std::shared_ptr CreatePaddlePredictor( const MobileConfig& config) { - auto x = std::make_shared(); + auto x = std::make_shared(); x->Init(config); return x; } diff --git a/lite/api/light_api_shared.cc b/lite/api/light_api_shared.cc new file mode 100644 index 0000000000000000000000000000000000000000..557804bfa56787fa8a83bfbfc3046df08be010f8 --- /dev/null +++ b/lite/api/light_api_shared.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#ifndef LITE_ON_TINY_PUBLISH +#include "lite/api/paddle_use_passes.h" +#endif + +namespace paddle { +namespace lite_api { + +void RunModel() { + // 1. Set MobileConfig + MobileConfig mobile_config; + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr mobile_predictor = + CreatePaddlePredictor(mobile_config); +} + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 8e2fc420bc3be91e35047b823e628b80f2175496..7d322530f624c43737018d8ece98fb24d48bc16a 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -36,6 +36,18 @@ TEST(LightAPI, load) { data[i] = i; } + predictor.PrepareFeedFetch(); + const std::vector inputs = predictor.GetInputNames(); + + LOG(INFO) << "input size: " << inputs.size(); + for (int i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputnames: " << inputs[i]; + } + const std::vector outputs = predictor.GetOutputNames(); + for (int i = 0; i < outputs.size(); i++) { + LOG(INFO) << "outputnames: " << outputs[i]; + } + predictor.Run(); const auto* output = predictor.GetOutput(0); diff --git a/lite/api/lite_api_test_helper.cc b/lite/api/lite_api_test_helper.cc index cd576998d3472a8a8c08a77765a03adce7490827..802f6d4b52082ea45867c63a544256ae4b567040 100644 --- a/lite/api/lite_api_test_helper.cc +++ b/lite/api/lite_api_test_helper.cc @@ -24,24 +24,16 @@ namespace lite { const lite::Tensor* RunHvyModel() { lite::Predictor predictor; #ifndef LITE_WITH_CUDA - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kX86), PRECISION(kFloat)}}); #else std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, - Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, - Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, }); #endif - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda - valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({100, 100}))); diff --git a/lite/api/mobilenetv1_int8_test.cc b/lite/api/mobilenetv1_int8_test.cc index 769f195d19a0e4bff1d4a6da515afcda6cc366cc..fb4a98084c7f7a5935a5ca655af4ddff13152460 100644 --- a/lite/api/mobilenetv1_int8_test.cc +++ b/lite/api/mobilenetv1_int8_test.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include "lite/api/cxx_api.h" #include "lite/api/paddle_use_kernels.h" @@ -22,23 +23,36 @@ #include "lite/api/test_helper.h" #include "lite/core/op_registry.h" +DEFINE_string(input_img_txt_path, + "", + "if set input_img_txt_path, read the img filename as input."); + namespace paddle { namespace lite { -void TestModel(const std::vector& valid_places, - const Place& preferred_place) { +void TestModel(const std::vector& valid_places) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); auto* data = input_tensor->mutable_data(); auto item_size = input_tensor->dims().production(); - for (int i = 0; i < item_size; i++) { - data[i] = 1; + if (FLAGS_input_img_txt_path.empty()) { + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + } else { + std::fstream fs(FLAGS_input_img_txt_path, std::ios::in); + if (!fs.is_open()) { + LOG(FATAL) << "open input_img_txt error."; + } + for (int i = 0; i < item_size; i++) { + fs >> data[i]; + } } for (int i = 0; i < FLAGS_warmup; ++i) { @@ -58,8 +72,9 @@ void TestModel(const std::vector& valid_places, std::vector> results; // i = 1 + // ground truth result from fluid results.emplace_back(std::vector( - {0.000227548, 0.000262385, 0.000260347, 0.000293865, 0.00025008})); + {0.0002451055, 0.0002585023, 0.0002659616, 0.0002823})); auto* out = predictor.GetOutput(0); ASSERT_EQ(out->dims().size(), 2); ASSERT_EQ(out->dims()[0], 1); @@ -73,16 +88,30 @@ void TestModel(const std::vector& valid_places, 1e-6); } } + + auto* out_data = out->data(); + LOG(INFO) << "output data:"; + for (int i = 0; i < out->numel(); i += step) { + LOG(INFO) << out_data[i]; + } + float max_val = out_data[0]; + int max_val_arg = 0; + for (int i = 1; i < out->numel(); i++) { + if (max_val < out_data[i]) { + max_val = out_data[i]; + max_val_arg = i; + } + } + LOG(INFO) << "max val:" << max_val << ", max_val_arg:" << max_val_arg; } TEST(MobileNetV1, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kInt8)}, + Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kInt8)})); + TestModel(valid_places); } } // namespace lite diff --git a/lite/api/mobilenetv1_ssd_test.cc b/lite/api/mobilenetv1_ssd_test.cc index e37e180f9b27424c59d6549515af0d6c8e929eea..8eacbe2619c6c55594fd8a280bb1ab2901f24c51 100644 --- a/lite/api/mobilenetv1_ssd_test.cc +++ b/lite/api/mobilenetv1_ssd_test.cc @@ -26,13 +26,12 @@ namespace paddle { namespace lite { #ifdef LITE_WITH_ARM -void TestModel(const std::vector& valid_places, - const Place& preferred_place) { +void TestModel(const std::vector& valid_places) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 300, 300}))); @@ -99,7 +98,6 @@ void TestModel(const std::vector& valid_places, TEST(MobileNetV1_SSD, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); diff --git a/lite/api/mobilenetv1_test.cc b/lite/api/mobilenetv1_test.cc index 91d1828a94dfb943f10e55054baf7d8038525a13..63a401745b325654f81c3af93402703395264c0d 100644 --- a/lite/api/mobilenetv1_test.cc +++ b/lite/api/mobilenetv1_test.cc @@ -28,14 +28,13 @@ namespace paddle { namespace lite { void TestModel(const std::vector& valid_places, - const Place& preferred_place, const std::string& model_dir = FLAGS_model_dir, bool save_model = false) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; - predictor.Build(model_dir, "", "", preferred_place, valid_places); + predictor.Build(model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); @@ -103,41 +102,32 @@ void TestModel(const std::vector& valid_places, #ifdef LITE_WITH_NPU TEST(MobileNetV1, test_npu) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kNPU), PRECISION(kFloat)}, }); - TestModel(valid_places, - Place({TARGET(kARM), PRECISION(kFloat)}), - FLAGS_model_dir, - true /* save_model*/); + TestModel(valid_places, FLAGS_model_dir, true /* save_model*/); - TestModel(valid_places, - Place({TARGET(kARM), PRECISION(kFloat)}), - FLAGS_optimized_model, - false /* save model */); + TestModel(valid_places, FLAGS_optimized_model, false /* save model */); } #endif // LITE_WITH_NPU TEST(MobileNetV1, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } #ifdef LITE_WITH_OPENCL TEST(MobileNetV1, test_opencl) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kOpenCL), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); + TestModel(valid_places); } #endif // LITE_WITH_OPENCL diff --git a/lite/api/mobilenetv1_yolov3_test.cc b/lite/api/mobilenetv1_yolov3_test.cc index 3a12203b710fb42b910cdcd381095958175cd280..09f9b6d11a10fb8eb66e939716aaea4ceaf7f418 100644 --- a/lite/api/mobilenetv1_yolov3_test.cc +++ b/lite/api/mobilenetv1_yolov3_test.cc @@ -26,13 +26,12 @@ namespace paddle { namespace lite { #ifdef LITE_WITH_ARM -void TestModel(const std::vector& valid_places, - const Place& preferred_place) { +void TestModel(const std::vector& valid_places) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 608, 608}))); @@ -106,11 +105,10 @@ void TestModel(const std::vector& valid_places, TEST(MobileNetV1_YoloV3, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } #endif // LITE_WITH_ARM diff --git a/lite/api/mobilenetv2_test.cc b/lite/api/mobilenetv2_test.cc index ca36943cb9056ffd87c5862b32845f8962cee3df..84bd27e352f549d619cfa51f9127f973023e6d45 100644 --- a/lite/api/mobilenetv2_test.cc +++ b/lite/api/mobilenetv2_test.cc @@ -29,14 +29,13 @@ namespace lite { #ifdef LITE_WITH_ARM void TestModel(const std::vector& valid_places, - const Place& preferred_place, const std::string& model_dir = FLAGS_model_dir, bool save_model = false) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; - predictor.Build(model_dir, "", "", preferred_place, valid_places); + predictor.Build(model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); @@ -103,41 +102,32 @@ void TestModel(const std::vector& valid_places, #ifdef LITE_WITH_NPU TEST(MobileNetV2, test_npu) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kNPU), PRECISION(kFloat)}, }); - TestModel(valid_places, - Place({TARGET(kARM), PRECISION(kFloat)}), - FLAGS_model_dir, - true /* save_model*/); + TestModel(valid_places, FLAGS_model_dir, true /* save_model*/); - TestModel(valid_places, - Place({TARGET(kARM), PRECISION(kFloat)}), - FLAGS_optimized_model, - false /* save model */); + TestModel(valid_places, FLAGS_optimized_model, false /* save model */); } #endif // LITE_WITH_NPU TEST(MobileNetV2, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } #ifdef LITE_WITH_OPENCL TEST(MobileNetV2, test_opencl) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kOpenCL), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); + TestModel(valid_places); } #endif // LITE_WITH_OPENCL diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc index 6286a3398f2d5ca585f1ee02a456f0d901156b0b..1aef522b2a6bb95f895449469f3c13e4a713179a 100644 --- a/lite/api/model_optimize_tool.cc +++ b/lite/api/model_optimize_tool.cc @@ -16,10 +16,14 @@ #ifdef PADDLE_WITH_TESTING #include #endif +// "all_kernel_faked.cc" and "kernel_src_map.h" are created automatically during +// model_optimize_tool's compiling period +#include "all_kernel_faked.cc" // NOLINT +#include "kernel_src_map.h" // NOLINT #include "lite/api/paddle_api.h" -#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" @@ -33,6 +37,12 @@ DEFINE_string( optimize_out_type, "protobuf", "store type of the output optimized model. protobuf/naive_buffer"); +DEFINE_bool(display_kernels, false, "Display kernel information"); +DEFINE_bool(record_tailoring_info, + false, + "Record kernels and operators information of the optimized model " + "for tailoring compiling, information are stored into optimized " + "model path as hidden files"); DEFINE_string(optimize_out, "", "path of the output optimized model"); DEFINE_string(valid_targets, "arm", @@ -43,12 +53,22 @@ DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); namespace paddle { namespace lite_api { +//! Display the kernel information. +void DisplayKernels() { + LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); +} + void Main() { if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) { LOG(WARNING) << "Load combined-param model. Option model_dir will be ignored"; } + if (FLAGS_display_kernels) { + DisplayKernels(); + exit(0); + } + lite_api::CxxConfig config; config.set_model_dir(FLAGS_model_dir); config.set_model_file(FLAGS_model_file); @@ -74,10 +94,11 @@ void Main() { CHECK(!valid_places.empty()) << "At least one target should be set, should set the " "command argument 'valid_targets'"; + if (FLAGS_prefer_int8_kernel) { LOG(WARNING) << "Int8 mode is only support by ARM target"; - valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)}); - config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)}); + valid_places.insert(valid_places.begin(), + Place{TARGET(kARM), PRECISION(kInt8)}); } config.set_valid_places(valid_places); @@ -91,8 +112,14 @@ void Main() { } else { LOG(FATAL) << "Unsupported Model type :" << FLAGS_optimize_out_type; } + OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map); - predictor->SaveOptimizedModel(FLAGS_optimize_out, model_type); + predictor->SaveOptimizedModel( + FLAGS_optimize_out, model_type, FLAGS_record_tailoring_info); + if (FLAGS_record_tailoring_info) { + LOG(INFO) << "Record the information of tailored model into :" + << FLAGS_optimize_out; + } } } // namespace lite_api diff --git a/lite/api/model_run_test_image.cc b/lite/api/model_run_test_image.cc index 099a74ed7fbf54da2d632150c4438f9ad894bb1d..72f6212445a7c3f016e3c67d00d8485ca7087692 100644 --- a/lite/api/model_run_test_image.cc +++ b/lite/api/model_run_test_image.cc @@ -28,18 +28,16 @@ namespace lite { TEST(model, test) { #ifdef LITE_WITH_ARM DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, + std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kInt8)}}); auto precision = PRECISION(kFloat); if (FLAGS_int8) { precision = PRECISION(kInt8); } - predictor.Build( - FLAGS_model_dir, "", "", Place{TARGET(kARM), precision}, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); int im_width = FLAGS_im_width; int im_height = FLAGS_im_height; auto* input_tensor = predictor.GetInput(0); @@ -60,11 +58,11 @@ TEST(model, test) { for (int i = 0; i < FLAGS_repeats; ++i) { predictor.Run(); } - auto* output_tensors = predictor.GetOutputs(); + auto output_tensors = predictor.GetOutputs(); LOG(INFO) << "======output:========"; - for (auto t : *output_tensors) { - LOG(INFO) << t; + for (auto* t : output_tensors) { + LOG(INFO) << *t; } LOG(INFO) << "=====RUN_finished!!============= Speed Report ==================="; diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc index 6e0a249a81c8c2476a9a0685ab6492da3d4013a6..1358267000991c81b80453669cf46638449b8a7b 100644 --- a/lite/api/model_test.cc +++ b/lite/api/model_test.cc @@ -21,13 +21,23 @@ #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" #include "lite/core/device_info.h" +#include "lite/tests/utils/timer.h" #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/basic_profiler.h" +#endif // LITE_WITH_PROFILE + +using paddle::lite::Timer; DEFINE_string(input_shape, "1,3,224,224", "input shapes, separated by colon and comma"); +DEFINE_bool(use_optimize_nb, + false, + "optimized & naive buffer model for mobile devices"); + namespace paddle { namespace lite_api { @@ -36,7 +46,6 @@ void OutputOptModel(const std::string& load_model_dir, const std::vector>& input_shapes) { lite_api::CxxConfig config; config.set_model_dir(load_model_dir); - config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)}); config.set_valid_places({ Place{TARGET(kX86), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, @@ -59,15 +68,18 @@ void OutputOptModel(const std::string& load_model_dir, #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK void Run(const std::vector>& input_shapes, const std::string& model_dir, - const int repeat, + const PowerMode power_mode, const int thread_num, + const int repeat, const int warmup_times = 0) { -#ifdef LITE_WITH_ARM - lite::DeviceInfo::Init(); - lite::DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, thread_num); +#ifdef LITE_WITH_PROFILE + lite::profile::BasicProfiler::Global().SetWarmup( + warmup_times); #endif lite_api::MobileConfig config; config.set_model_dir(model_dir); + config.set_power_mode(power_mode); + config.set_threads(thread_num); auto predictor = lite_api::CreatePaddlePredictor(config); @@ -88,17 +100,22 @@ void Run(const std::vector>& input_shapes, predictor->Run(); } - auto start = lite::GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { + Timer ti; + for (int j = 0; j < repeat; ++j) { + ti.start(); predictor->Run(); + ti.end(); + LOG(INFO) << "iter: " << j << ", time: " << ti.latest_time() << " ms"; } - auto end = lite::GetCurrentUS(); LOG(INFO) << "================== Speed Report ==================="; - LOG(INFO) << "Model: " << model_dir << ", threads num " << thread_num - << ", warmup: " << warmup_times << ", repeats: " << repeat - << ", spend " << (end - start) / repeat / 1000.0 - << " ms in average."; + LOG(INFO) << "Model: " << model_dir + << ", power_mode: " << static_cast(power_mode) + << ", threads num " << thread_num << ", warmup: " << warmup_times + << ", repeats: " << repeat << ", avg time: " << ti.get_average_ms() + << " ms" + << ", min time: " << ti.get_min_time() << " ms" + << ", max time: " << ti.get_max_time() << " ms."; auto output = predictor->GetOutput(0); auto out = output->data(); @@ -123,7 +140,12 @@ int main(int argc, char** argv) { << "--model_dir /path/to/your/model"; exit(0); } - std::string save_optimized_model_dir = FLAGS_model_dir + "opt2"; + std::string save_optimized_model_dir = ""; + if (FLAGS_use_optimize_nb) { + save_optimized_model_dir = FLAGS_model_dir; + } else { + save_optimized_model_dir = FLAGS_model_dir + "opt2"; + } auto split_string = [](const std::string& str_in) -> std::vector { @@ -165,17 +187,21 @@ int main(int argc, char** argv) { input_shapes.push_back(get_shape(str_input_shapes[i])); } - // Output optimized model - paddle::lite_api::OutputOptModel( - FLAGS_model_dir, save_optimized_model_dir, input_shapes); + if (!FLAGS_use_optimize_nb) { + // Output optimized model + paddle::lite_api::OutputOptModel( + FLAGS_model_dir, save_optimized_model_dir, input_shapes); + } #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK // Run inference using optimized model - paddle::lite_api::Run(input_shapes, - save_optimized_model_dir, - FLAGS_repeats, - FLAGS_threads, - FLAGS_warmup); + paddle::lite_api::Run( + input_shapes, + save_optimized_model_dir, + static_cast(FLAGS_power_mode), + FLAGS_threads, + FLAGS_repeats, + FLAGS_warmup); #endif return 0; } diff --git a/lite/api/ocr_attention_test.cc b/lite/api/ocr_attention_test.cc index 89cf6a3e8d3fa29b25d617afdec3df3980755424..5e39c5437c18990be9c6414695a94c6f2c9fcf20 100644 --- a/lite/api/ocr_attention_test.cc +++ b/lite/api/ocr_attention_test.cc @@ -25,14 +25,12 @@ namespace paddle { namespace lite { -void TestModel(const std::vector& valid_places, - const Place& preferred_place, - bool use_npu = false) { +void TestModel(const std::vector& valid_places, bool use_npu = false) { DeviceInfo::Init(); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 1, 48, 512}))); @@ -104,11 +102,10 @@ void TestModel(const std::vector& valid_places, TEST(OcrAttention, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } } // namespace lite diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index fee4ebf6dcea418f66068da53e49004e28752ad5..f148096bb69a3a249521bcb847d5beae3f8297f9 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -13,8 +13,14 @@ // limitations under the License. #include "lite/api/paddle_api.h" +#include "lite/core/device_info.h" +#include "lite/core/target_wrapper.h" #include "lite/core/tensor.h" +#ifdef LITE_WITH_CUDA +#include "lite/backends/cuda/target_wrapper.h" +#endif + namespace paddle { namespace lite_api { @@ -40,26 +46,115 @@ template <> const int8_t *Tensor::data() const { return ctensor(raw_tensor_)->data(); } +template <> +const int64_t *Tensor::data() const { + return ctensor(raw_tensor_)->data(); +} template <> -float *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +const int32_t *Tensor::data() const { + return ctensor(raw_tensor_)->data(); } + template <> -int8_t *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +int *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); +} +template <> +float *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); +} +template <> +int8_t *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); +} +template <> +int64_t *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); +} + +template +void Tensor::CopyFromCpu(const T *src_data) { + T *data = tensor(raw_tensor_)->mutable_data(type); + int64_t num = tensor(raw_tensor_)->numel(); + CHECK(num > 0) << "You should call Resize interface first"; + if (type == TargetType::kHost || type == TargetType::kARM) { + lite::TargetWrapperHost::MemcpySync( + data, src_data, num * sizeof(T), lite::IoDirection::HtoH); + } else if (type == TargetType::kCUDA) { +#ifdef LITE_WITH_CUDA + lite::TargetWrapperCuda::MemcpySync( + data, src_data, num * sizeof(T), lite::IoDirection::HtoD); +#else + LOG(FATAL) << "Please compile the lib with CUDA."; +#endif + } else { + LOG(FATAL) << "The CopyFromCpu interface just support kHost, kARM, kCUDA"; + } +} +template +void Tensor::CopyToCpu(T *data) { + const T *src_data = tensor(raw_tensor_)->data(); + int64_t num = tensor(raw_tensor_)->numel(); + CHECK(num > 0) << "You should call Resize interface first"; + auto type = tensor(raw_tensor_)->target(); + if (type == TargetType::kHost || type == TargetType::kARM) { + lite::TargetWrapperHost::MemcpySync( + data, src_data, num * sizeof(T), lite::IoDirection::HtoH); + } else if (type == TargetType::kCUDA) { +#ifdef LITE_WITH_CUDA + lite::TargetWrapperCuda::MemcpySync( + data, src_data, num * sizeof(T), lite::IoDirection::DtoH); +#else + LOG(FATAL) << "Please compile the lib with CUDA."; +#endif + } else { + LOG(FATAL) << "The CopyToCpu interface just support kHost, kARM, kCUDA"; + } } +template void Tensor::CopyFromCpu(const int *); +template void Tensor::CopyFromCpu(const float *); +template void Tensor::CopyFromCpu(const int8_t *); + +template void Tensor::CopyFromCpu(const int *); +template void Tensor::CopyFromCpu(const float *); +template void Tensor::CopyFromCpu(const int8_t *); +template void Tensor::CopyFromCpu(const int *); +template void Tensor::CopyFromCpu(const float *); +template void Tensor::CopyFromCpu(const int8_t *); + +template void Tensor::CopyToCpu(int8_t *); +template void Tensor::CopyToCpu(float *); +template void Tensor::CopyToCpu(int *); + shape_t Tensor::shape() const { return ctensor(raw_tensor_)->dims().Vectorize(); } +TargetType Tensor::target() const { + auto type = ctensor(raw_tensor_)->target(); + if (type == TargetType::kUnk) { + CHECK(false) << "This tensor was not initialized."; + } + return type; +} + +PrecisionType Tensor::precision() const { + auto precision = ctensor(raw_tensor_)->precision(); + if (precision == PrecisionType::kUnk) { + CHECK(false) << "This tensor was not initialized."; + } + return precision; +} + lod_t Tensor::lod() const { return ctensor(raw_tensor_)->lod(); } void Tensor::SetLoD(const lod_t &lod) { tensor(raw_tensor_)->set_lod(lod); } void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir, - LiteModelType model_type) { + LiteModelType model_type, + bool record_info) { LOG(FATAL) << "The SaveOptimizedModel API is only supported by CxxConfig predictor."; } @@ -69,5 +164,30 @@ std::shared_ptr CreatePaddlePredictor(const ConfigT &) { return std::shared_ptr(); } +ConfigBase::ConfigBase(PowerMode mode, int threads) { +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Init(); + lite::DeviceInfo::Global().SetRunMode(mode, threads); + mode_ = lite::DeviceInfo::Global().mode(); + threads_ = lite::DeviceInfo::Global().threads(); +#endif +} + +void ConfigBase::set_power_mode(paddle::lite_api::PowerMode mode) { +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Global().SetRunMode(mode, threads_); + mode_ = lite::DeviceInfo::Global().mode(); + threads_ = lite::DeviceInfo::Global().threads(); +#endif +} + +void ConfigBase::set_threads(int threads) { +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Global().SetRunMode(mode_, threads); + mode_ = lite::DeviceInfo::Global().mode(); + threads_ = lite::DeviceInfo::Global().threads(); +#endif +} + } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index b1a8b21935bfbab603c7f27e233cc6115414dc7e..42b455da811fe1a21277d38f2e1237000276b1ff 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -43,10 +43,17 @@ struct LITE_API Tensor { const T* data() const; template - T* mutable_data() const; + T* mutable_data(TargetType type = TargetType::kHost) const; + template + void CopyFromCpu(const T* data); + + template + void CopyToCpu(T* data); /// Shape of the tensor. shape_t shape() const; + TargetType target() const; + PrecisionType precision() const; // LoD of the tensor lod_t lod() const; @@ -71,6 +78,17 @@ class LITE_API PaddlePredictor { virtual std::unique_ptr GetOutput(int i) const = 0; virtual void Run() = 0; + virtual std::shared_ptr Clone() = 0; + + virtual std::string GetVersion() const = 0; + + // Get input names + virtual std::vector GetInputNames() = 0; + // Get output names + virtual std::vector GetOutputNames() = 0; + + // Get Input by name + virtual std::unique_ptr GetInputByName(const std::string& name) = 0; /// Get a readonly tensor, return null if no one called `name` exists. virtual std::unique_ptr GetTensor( @@ -80,31 +98,43 @@ class LITE_API PaddlePredictor { /// CxxConfig, and the persisted model can be reused for MobileConfig. virtual void SaveOptimizedModel( const std::string& model_dir, - LiteModelType model_type = LiteModelType::kProtobuf); + LiteModelType model_type = LiteModelType::kProtobuf, + bool record_info = false); virtual ~PaddlePredictor() = default; + + protected: + int threads_{1}; + lite_api::PowerMode mode_{lite_api::LITE_POWER_NO_BIND}; }; /// Base class for all the configs. class LITE_API ConfigBase { std::string model_dir_; + int threads_{1}; + PowerMode mode_{LITE_POWER_NO_BIND}; public: + explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1); + // set Model_dir void set_model_dir(const std::string& x) { model_dir_ = x; } - const std::string& model_dir() const { return model_dir_; } + // set Power_mode + void set_power_mode(PowerMode mode); + PowerMode power_mode() const { return mode_; } + // set Thread + void set_threads(int threads); + int threads() const { return threads_; } }; /// CxxConfig is the config for the Full feature predictor. class LITE_API CxxConfig : public ConfigBase { - Place preferred_place_; std::vector valid_places_; std::string model_file_; std::string param_file_; bool model_from_memory_{false}; public: - void set_preferred_place(const Place& x) { preferred_place_ = x; } void set_valid_places(const std::vector& x) { valid_places_ = x; } void set_model_file(const std::string& path) { model_file_ = path; } void set_param_file(const std::string& path) { param_file_ = path; } @@ -117,7 +147,6 @@ class LITE_API CxxConfig : public ConfigBase { model_from_memory_ = true; } - const Place& preferred_place() const { return preferred_place_; } const std::vector& valid_places() const { return valid_places_; } std::string model_file() const { return model_file_; } std::string param_file() const { return param_file_; } @@ -127,21 +156,11 @@ class LITE_API CxxConfig : public ConfigBase { /// MobileConfig is the config for the light weight predictor, it will skip /// IR optimization or other unnecessary stages. class LITE_API MobileConfig : public ConfigBase { - PowerMode mode_{LITE_POWER_HIGH}; - int threads_{1}; std::string model_buffer_; std::string param_buffer_; bool model_from_memory_{false}; public: - MobileConfig(Place preferred_place = Place(TARGET(kARM), - PRECISION(kFloat), - DATALAYOUT(kNCHW)), - PowerMode mode = LITE_POWER_HIGH, - int threads = 1) - : mode_(mode), threads_(threads) {} - void set_power_mode(PowerMode mode) { mode_ = mode; } - void set_threads(int threads) { threads_ = threads; } void set_model_buffer(const char* model_buffer, size_t model_buffer_size, const char* param_buffer, @@ -151,8 +170,6 @@ class LITE_API MobileConfig : public ConfigBase { model_from_memory_ = true; } - PowerMode power_mode() const { return mode_; } - int threads() const { return threads_; } bool model_from_memory() const { return model_from_memory_; } const std::string& model_buffer() const { return model_buffer_; } const std::string& param_buffer() const { return param_buffer_; } diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index 02502ff9c80f3ee3c5a23f8ef6909353d839ea9e..69d544c3decac9f312bc9eb03cdc6c3702c5032b 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -28,7 +28,6 @@ namespace lite_api { TEST(CxxApi, run) { lite_api::CxxConfig config; config.set_model_dir(FLAGS_model_dir); - config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)}); config.set_valid_places({ Place{TARGET(kX86), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, @@ -36,7 +35,18 @@ TEST(CxxApi, run) { auto predictor = lite_api::CreatePaddlePredictor(config); - auto input_tensor = predictor->GetInput(0); + LOG(INFO) << "Version: " << predictor->GetVersion(); + + auto inputs = predictor->GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); + for (int i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputnames: " << inputs[i]; + } + auto outputs = predictor->GetOutputNames(); + for (int i = 0; i < outputs.size(); i++) { + LOG(INFO) << "outputnames: " << outputs[i]; + } + auto input_tensor = predictor->GetInputByName(inputs[0]); input_tensor->Resize(std::vector({100, 100})); auto* data = input_tensor->mutable_data(); for (int i = 0; i < 100 * 100; i++) { @@ -45,7 +55,7 @@ TEST(CxxApi, run) { predictor->Run(); - auto output = predictor->GetOutput(0); + auto output = predictor->GetTensor(outputs[0]); auto* out = output->data(); LOG(INFO) << out[0]; LOG(INFO) << out[1]; @@ -54,8 +64,8 @@ TEST(CxxApi, run) { EXPECT_NEAR(out[1], -28.8729, 1e-3); predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2"); - predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2.naive", - LiteModelType::kNaiveBuffer); + predictor->SaveOptimizedModel( + FLAGS_model_dir + ".opt2.naive", LiteModelType::kNaiveBuffer, true); } // Demo1 for Mobile Devices :Load model from file and run @@ -66,6 +76,18 @@ TEST(LightApi, run) { auto predictor = lite_api::CreatePaddlePredictor(config); + auto inputs = predictor->GetInputNames(); + LOG(INFO) << "input size: " << inputs.size(); + for (int i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputnames: " << inputs.at(i); + } + auto outputs = predictor->GetOutputNames(); + for (int i = 0; i < outputs.size(); i++) { + LOG(INFO) << "outputnames: " << outputs.at(i); + } + + LOG(INFO) << "Version: " << predictor->GetVersion(); + auto input_tensor = predictor->GetInput(0); input_tensor->Resize(std::vector({100, 100})); auto* data = input_tensor->mutable_data(); diff --git a/lite/api/paddle_lite_factory_helper.h b/lite/api/paddle_lite_factory_helper.h index 544cd0e313034ef4a8c378298f4e86c9597d6a98..e99127e233bc4adf159a6a567dfb15f6fd784a27 100644 --- a/lite/api/paddle_lite_factory_helper.h +++ b/lite/api/paddle_lite_factory_helper.h @@ -25,7 +25,7 @@ #define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \ extern int touch_##op_type__##target__##precision__##layout__##alias__(); \ - int op_type__##target__##precision__##layout__##alias__ \ + int op_type__##target__##precision__##layout__##alias__##__use_lite_kernel \ __attribute__((unused)) = \ touch_##op_type__##target__##precision__##layout__##alias__(); diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index dbdf9ff269b372cd3dcd59769b15526b7631a5e5..894d839185ea9e1b6b47b87c398f249f044c2b51 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -46,8 +46,16 @@ std::string Place::DebugString() const { } const std::string& TargetToStr(TargetType target) { - static const std::string target2string[] = { - "unk", "host", "x86", "cuda", "arm", "opencl", "any", "fpga", "npu"}; + static const std::string target2string[] = {"unk", + "host", + "x86", + "cuda", + "arm", + "opencl", + "any", + "fpga", + "npu", + "xpu"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; @@ -84,7 +92,8 @@ const std::string& TargetRepr(TargetType target) { "kOpenCL", "kAny", "kFPGA", - "kNPU"}; + "kNPU", + "kXPU"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; @@ -113,5 +122,37 @@ const std::string& DataLayoutRepr(DataLayoutType layout) { return datalayout2string[x]; } +std::set ExpandValidTargets(TargetType target) { + static const std::set valid_set({TARGET(kX86), + TARGET(kCUDA), + TARGET(kARM), + TARGET(kOpenCL), + TARGET(kNPU), + TARGET(kXPU), + TARGET(kFPGA)}); + if (target == TARGET(kAny)) { + return valid_set; + } + return std::set({target}); +} + +std::set ExpandValidPrecisions(PrecisionType precision) { + static const std::set valid_set( + {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); + if (precision == PRECISION(kAny)) { + return valid_set; + } + return std::set({precision}); +} + +std::set ExpandValidLayouts(DataLayoutType layout) { + static const std::set valid_set( + {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + if (layout == DATALAYOUT(kAny)) { + return valid_set; + } + return std::set({layout}); +} + } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 5e4f2ed21c8298ac15a912672e3d15633d0a3ecb..07284be095c05e5dfa069b0973d5982cf1f07c8a 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include // Generic helper definitions for shared library support @@ -50,8 +51,9 @@ enum class TargetType : int { kOpenCL = 5, kFPGA = 7, kNPU = 8, + kXPU = 9, kAny = 6, // any target - NUM = 9, // number of fields. + NUM = 10, // number of fields. }; enum class PrecisionType : int { kUnk = 0, @@ -101,6 +103,8 @@ static size_t PrecisionTypeLength(PrecisionType type) { return 1; case PrecisionType::kInt32: return 4; + case PrecisionType::kInt64: + return 8; case PrecisionType::kFP16: return 2; default: @@ -124,6 +128,17 @@ const std::string& PrecisionRepr(PrecisionType precision); const std::string& DataLayoutRepr(DataLayoutType layout); +// Get a set of all the elements represented by the target. +std::set ExpandValidTargets(TargetType target = TARGET(kAny)); + +// Get a set of all the elements represented by the precision. +std::set ExpandValidPrecisions( + PrecisionType precision = PRECISION(kAny)); + +// Get a set of all the elements represented by the layout. +std::set ExpandValidLayouts( + DataLayoutType layout = DATALAYOUT(kAny)); + /* * Place specifies the execution context of a Kernel or input/output for a * kernel. It is used to make the analysis of the MIR more clear and accurate. diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index bc2e59f38727a8a418da2b8829a9c3d8937884a3..70355fdf890eb63cd5bedd5bab42a2dd69af0927 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -31,6 +31,7 @@ USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass); USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); +USE_MIR_PASS(lite_interpolate_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); @@ -38,3 +39,4 @@ USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); +USE_MIR_PASS(memory_optimize_pass); diff --git a/lite/api/python/CMakeLists.txt b/lite/api/python/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..43178a37c663bb09acb7c025e021cbc91bf0cc5d --- /dev/null +++ b/lite/api/python/CMakeLists.txt @@ -0,0 +1,7 @@ +if (NOT LITE_WITH_PYTHON) + return() +endif() + + +add_subdirectory(pybind) +#add_subdirectory(interface) diff --git a/lite/api/python/pybind/CMakeLists.txt b/lite/api/python/pybind/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..178f167e6a1627d01df13b2e105e0af36b20601a --- /dev/null +++ b/lite/api/python/pybind/CMakeLists.txt @@ -0,0 +1,6 @@ +set(PYBIND_DEPS pybind python paddle_api_light paddle_api) +if (NOT LITE_ON_TINY_PUBLISH) + set(PYBIND_DEPS ${PYBIND_DEPS} paddle_api_full) +endif() + +lite_cc_library(lite_pybind SHARED SRCS pybind.cc DEPS ${PYBIND_DEPS}) diff --git a/lite/api/python/pybind/pybind.cc b/lite/api/python/pybind/pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..2df2e8f8f8aa56bb71b0e1cb293df2ecbbafd0bb --- /dev/null +++ b/lite/api/python/pybind/pybind.cc @@ -0,0 +1,262 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/python/pybind/pybind.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef LITE_ON_TINY_PUBLISH +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_passes.h" +#endif + +#include "lite/api/light_api.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/tensor.h" + +namespace py = pybind11; + +namespace paddle { +namespace lite { +namespace pybind { + +using lite_api::Tensor; +using lite_api::CxxConfig; +using lite_api::MobileConfig; +using lite_api::PowerMode; +using lite_api::TargetType; +using lite_api::PrecisionType; +using lite_api::DataLayoutType; +using lite_api::Place; +using lite::LightPredictorImpl; + +#ifndef LITE_ON_TINY_PUBLISH +using lite::CxxPaddleApiImpl; +static void BindLiteCxxPredictor(py::module *m); +#endif +static void BindLiteLightPredictor(py::module *m); +static void BindLiteCxxConfig(py::module *m); +static void BindLiteMobileConfig(py::module *m); +static void BindLitePowerMode(py::module *m); +static void BindLitePlace(py::module *m); +static void BindLiteTensor(py::module *m); + +void BindLiteApi(py::module *m) { + BindLiteCxxConfig(m); + BindLiteMobileConfig(m); + BindLitePowerMode(m); + BindLitePlace(m); + BindLiteTensor(m); +#ifndef LITE_ON_TINY_PUBLISH + BindLiteCxxPredictor(m); +#endif + BindLiteLightPredictor(m); +// Global helper methods +#ifndef LITE_ON_TINY_PUBLISH + m->def("create_paddle_predictor", + [](const CxxConfig &config) -> std::unique_ptr { + auto x = std::unique_ptr(new CxxPaddleApiImpl()); + x->Init(config); + return std::move(x); + }); +#endif + m->def("create_paddle_predictor", + [](const MobileConfig &config) -> std::unique_ptr { + auto x = + std::unique_ptr(new LightPredictorImpl()); + x->Init(config); + return std::move(x); + }); +} + +void BindLiteCxxConfig(py::module *m) { + py::class_ cxx_config(*m, "CxxConfig"); + + cxx_config.def(py::init<>()) + .def("set_model_dir", &CxxConfig::set_model_dir) + .def("model_dir", &CxxConfig::model_dir) + .def("set_model_file", &CxxConfig::set_model_file) + .def("model_file", &CxxConfig::model_file) + .def("set_param_file", &CxxConfig::set_param_file) + .def("param_file", &CxxConfig::param_file) + .def("set_valid_places", &CxxConfig::set_valid_places) + .def("set_model_buffer", &CxxConfig::set_model_buffer) + .def("model_from_memory", &CxxConfig::model_from_memory); +#ifdef LITE_WITH_ARM + cxx_config.def("set_threads", &CxxConfig::set_threads) + .def("threads", &CxxConfig::threads) + .def("set_power_mode", &CxxConfig::set_power_mode) + .def("power_mode", &CxxConfig::power_mode); +#endif +} + +// TODO(sangoly): Should MobileConfig be renamed to LightConfig ?? +void BindLiteMobileConfig(py::module *m) { + py::class_ mobile_config(*m, "MobileConfig"); + + mobile_config.def(py::init<>()) + .def("set_model_dir", &MobileConfig::set_model_dir) + .def("model_dir", &MobileConfig::model_dir) + .def("set_model_buffer", &MobileConfig::set_model_buffer) + .def("model_from_memory", &MobileConfig::model_from_memory); +#ifdef LITE_WITH_ARM + mobile_config.def("set_threads", &MobileConfig::set_threads) + .def("threads", &MobileConfig::threads) + .def("set_power_mode", &MobileConfig::set_power_mode) + .def("power_mode", &MobileConfig::power_mode); +#endif +} + +void BindLitePowerMode(py::module *m) { + py::enum_(*m, "PowerMode") + .value("LITE_POWER_HIGH", PowerMode::LITE_POWER_HIGH) + .value("LITE_POWER_LOW", PowerMode::LITE_POWER_LOW) + .value("LITE_POWER_FULL", PowerMode::LITE_POWER_FULL) + .value("LITE_POWER_NO_BIND", PowerMode::LITE_POWER_NO_BIND) + .value("LITE_POWER_RAND_HIGH", PowerMode::LITE_POWER_RAND_HIGH) + .value("LITE_POWER_RAND_LOW", PowerMode::LITE_POWER_RAND_LOW); +} + +void BindLitePlace(py::module *m) { + // TargetType + py::enum_(*m, "TargetType") + .value("Host", TargetType::kHost) + .value("X86", TargetType::kX86) + .value("CUDA", TargetType::kCUDA) + .value("ARM", TargetType::kARM) + .value("OpenCL", TargetType::kOpenCL) + .value("FPGA", TargetType::kFPGA) + .value("NPU", TargetType::kNPU) + .value("Any", TargetType::kAny); + + // PrecisionType + py::enum_(*m, "PrecisionType") + .value("FP16", PrecisionType::kFP16) + .value("FP32", PrecisionType::kFloat) + .value("INT8", PrecisionType::kInt8) + .value("INT16", PrecisionType::kInt16) + .value("INT32", PrecisionType::kInt32) + .value("INT64", PrecisionType::kInt64) + .value("BOOL", PrecisionType::kBool) + .value("Any", PrecisionType::kAny); + + // DataLayoutType + py::enum_(*m, "DataLayoutType") + .value("NCHW", DataLayoutType::kNCHW) + .value("NHWC", DataLayoutType::kNHWC) + .value("Any", DataLayoutType::kAny); + + // Place + py::class_(*m, "Place") + .def(py::init(), + py::arg("target"), + py::arg("percision") = PrecisionType::kFloat, + py::arg("layout") = DataLayoutType::kNCHW, + py::arg("device") = 0) + .def("is_valid", &Place::is_valid); +} + +void BindLiteTensor(py::module *m) { + auto data_size_func = [](const std::vector &shape) -> int64_t { + int64_t res = 1; + for (size_t i = 0; i < shape.size(); i++) { + res *= shape[i]; + } + return res; + }; + + py::class_ tensor(*m, "Tensor"); + + tensor.def("resize", &Tensor::Resize) + .def("shape", &Tensor::shape) + .def("target", &Tensor::target) + .def("precision", &Tensor::precision) + .def("lod", &Tensor::lod) + .def("set_lod", &Tensor::SetLoD); + +#define DO_GETTER_ONCE(data_type__, name__) \ + tensor.def(#name__, [=](Tensor &self) -> std::vector { \ + std::vector data; \ + auto shape = self.shape(); \ + int64_t num = data_size_func(shape); \ + data.resize(num); \ + self.CopyToCpu(data.data()); \ + return data; \ + }); + +#define DO_SETTER_ONCE(data_type__, name__) \ + tensor.def( \ + #name__, \ + [](Tensor &self, \ + const std::vector &data, \ + TargetType type = TargetType::kHost) { \ + if (type == TargetType::kHost || type == TargetType::kARM) { \ + self.CopyFromCpu(data.data()); \ + } else if (type == TargetType::kCUDA) { \ + self.CopyFromCpu(data.data()); \ + } \ + }, \ + py::arg("data"), \ + py::arg("type") = TargetType::kHost); + +#define DATA_GETTER_SETTER_ONCE(data_type__, name__) \ + DO_SETTER_ONCE(data_type__, set_##name__##_data) \ + DO_GETTER_ONCE(data_type__, name__##_data) + + DATA_GETTER_SETTER_ONCE(int8_t, int8); + DATA_GETTER_SETTER_ONCE(int32_t, int32); + DATA_GETTER_SETTER_ONCE(float, float); +#undef DO_GETTER_ONCE +#undef DO_SETTER_ONCE +#undef DATA_GETTER_SETTER_ONCE +} + +#ifndef LITE_ON_TINY_PUBLISH +void BindLiteCxxPredictor(py::module *m) { + py::class_(*m, "CxxPredictor") + .def(py::init<>()) + .def("get_input", &CxxPaddleApiImpl::GetInput) + .def("get_output", &CxxPaddleApiImpl::GetOutput) + .def("run", &CxxPaddleApiImpl::Run) + .def("get_version", &CxxPaddleApiImpl::GetVersion) + .def("save_optimized_model", + [](CxxPaddleApiImpl &self, const std::string &output_dir) { + self.SaveOptimizedModel(output_dir, + lite_api::LiteModelType::kNaiveBuffer); + }); +} +#endif + +void BindLiteLightPredictor(py::module *m) { + py::class_(*m, "LightPredictor") + .def(py::init<>()) + .def("get_input", &LightPredictorImpl::GetInput) + .def("get_output", &LightPredictorImpl::GetOutput) + .def("run", &LightPredictorImpl::Run) + .def("get_version", &LightPredictorImpl::GetVersion); +} + +} // namespace pybind +} // namespace lite +} // namespace paddle diff --git a/lite/api/python/pybind/pybind.h b/lite/api/python/pybind/pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..ca05f24b32fd0b0418d9cf595fe6134b34fa725f --- /dev/null +++ b/lite/api/python/pybind/pybind.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace lite { +namespace pybind { + +void BindLiteApi(pybind11::module *m); + +PYBIND11_MODULE(lite_core, m) { + m.doc() = "C++ core of Paddle-Lite"; + + BindLiteApi(&m); +} + +} // namespace pybind +} // namespace lite +} // namespace paddle diff --git a/lite/api/resnet18_test.cc b/lite/api/resnet18_test.cc index c003dc1dba6500e37d4b0d6b724d743c45ebeebf..5a50367006a8c3eeea0cfa6fe46f393463763ca9 100644 --- a/lite/api/resnet18_test.cc +++ b/lite/api/resnet18_test.cc @@ -28,14 +28,9 @@ namespace lite { #ifdef LITE_WITH_ARM TEST(ResNet18, test) { lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kARM), PRECISION(kFloat)}, - valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); diff --git a/lite/api/resnet50_test.cc b/lite/api/resnet50_test.cc index 6e78d12be07b0887ec9942e8b8c1d2c530b6fc35..3e5a725b9001da760670976666ef624e5dac416b 100644 --- a/lite/api/resnet50_test.cc +++ b/lite/api/resnet50_test.cc @@ -26,13 +26,12 @@ namespace paddle { namespace lite { #ifdef LITE_WITH_ARM -void TestModel(const std::vector& valid_places, - const Place& preferred_place) { +void TestModel(const std::vector& valid_places) { DeviceInfo::Init(); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); @@ -82,22 +81,20 @@ void TestModel(const std::vector& valid_places, TEST(ResNet50, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } #ifdef LITE_WITH_OPENCL TEST(ResNet50, test_opencl) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kOpenCL), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); + TestModel(valid_places); } #endif // LITE_WITH_OPENCL diff --git a/lite/api/resnet50_test_fpga.cc b/lite/api/resnet50_test_fpga.cc index 7ea81cc746411c86e6f7a882e3f040cfab98503c..ab647f96998f1c0e73476369611218d0a7930c57 100644 --- a/lite/api/resnet50_test_fpga.cc +++ b/lite/api/resnet50_test_fpga.cc @@ -29,8 +29,7 @@ namespace lite { TEST(ResNet50, test) { lite::Predictor predictor; std::vector valid_places( - {Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, - Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNHWC)}}); + {Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}}); predictor.Build(FLAGS_model_dir, "", diff --git a/lite/api/shufflenetv2_test.cc b/lite/api/shufflenetv2_test.cc index f67bc8c6cfcc5ad545c43f2ee91a799c295e5838..2c1247997c2dcaa33e5c11af37996cab1e287fa4 100644 --- a/lite/api/shufflenetv2_test.cc +++ b/lite/api/shufflenetv2_test.cc @@ -25,13 +25,12 @@ namespace paddle { namespace lite { -void TestModel(const std::vector& valid_places, - const Place& preferred_place) { +void TestModel(const std::vector& valid_places) { DeviceInfo::Init(); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; - predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim((std::vector({1, 3, 224, 224})))); @@ -80,12 +79,11 @@ void TestModel(const std::vector& valid_places, TEST(ShuffleNetV2, test_arm) { std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, // Place{TARGET(kOpenCL), PRECISION(kFloat)}, }); - TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); + TestModel(valid_places); } } // namespace lite diff --git a/lite/api/test_googlenet_lite.cc b/lite/api/test_googlenet_lite.cc index 4c9ecd90c6962ac390dc6db2f37710615b2c60d8..8ff7a49af9cbce09d205bb8633a913410beb91c3 100644 --- a/lite/api/test_googlenet_lite.cc +++ b/lite/api/test_googlenet_lite.cc @@ -12,56 +12,54 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #include #include #include -#include "lite/api/cxx_api.h" #include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" -#include "lite/core/op_registry.h" -#include "lite/core/tensor.h" - -// for googlenet -DEFINE_string(model_dir, "", ""); +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" namespace paddle { namespace lite { #ifdef LITE_WITH_X86 TEST(CXXApi, test_lite_googlenet) { - lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); - - // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"<Resize(DDim(std::vector({1, 3, 224, 224}))); + auto input_tensor = predictor->GetInput(0); + std::vector input_shape{1, 3, 224, 224}; + input_tensor->Resize(input_shape); auto* data = input_tensor->mutable_data(); - for (int i = 0; i < input_tensor->dims().production(); i++) { + int input_num = 1; + for (int i = 0; i < input_shape.size(); ++i) { + input_num *= input_shape[i]; + } + for (int i = 0; i < input_num; i++) { data[i] = 1; } - predictor.Run(); - auto* out = predictor.GetOutput(0); + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor->Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor->Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + auto out = predictor->GetOutput(0); std::vector results( {0.00034298553, 0.0008200012, 0.0005046297, 0.000839279, 0.00052616704, 0.0003447803, 0.0010877076, 0.00081762316, @@ -71,9 +69,9 @@ TEST(CXXApi, test_lite_googlenet) { for (size_t i = 0; i < results.size(); ++i) { EXPECT_NEAR(out->data()[i * 51], results[i], 1e-5); } - ASSERT_EQ(out->dims().size(), 2); - ASSERT_EQ(out->dims()[0], 1); - ASSERT_EQ(out->dims()[1], 1000); + ASSERT_EQ(out->shape().size(), 2); + ASSERT_EQ(out->shape()[0], 1); + ASSERT_EQ(out->shape()[1], 1000); } #endif } // namespace lite diff --git a/lite/api/test_helper.h b/lite/api/test_helper.h index d835c030f03a3c95575217020cd298dabbf1a15a..71752c942bb53e7f2ed289ac0d965ae1d1007c55 100644 --- a/lite/api/test_helper.h +++ b/lite/api/test_helper.h @@ -22,6 +22,13 @@ DEFINE_string(model_dir, "", "model dir"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_int32(power_mode, + 3, + "arm power mode: " + "0 for big cluster, " + "1 for little cluster, " + "2 for all cores, " + "3 for no bind"); DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(im_width, 224, "image width"); DEFINE_int32(im_height, 224, "image height"); diff --git a/lite/api/test_inceptionv4_lite_x86.cc b/lite/api/test_inceptionv4_lite_x86.cc index 5d1dbbe1448433eb5bdde0818229d5e1793ae39c..e986784809951390889e17f766302fc5ea459465 100644 --- a/lite/api/test_inceptionv4_lite_x86.cc +++ b/lite/api/test_inceptionv4_lite_x86.cc @@ -12,70 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #include #include #include -#include "lite/api/cxx_api.h" #include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" -#include "lite/core/op_registry.h" -#include "lite/core/tensor.h" +#include "lite/utils/cp_logging.h" namespace paddle { namespace lite { TEST(InceptionV4, test_inceptionv4_lite_x86) { - lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); - - // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"< passes({"static_kernel_pick_pass", - "variable_place_inference_pass", - "type_target_cast_pass", - "variable_place_inference_pass", - "io_copy_kernel_pick_pass", - "variable_place_inference_pass", - "runtime_context_assign_pass"}); - predictor.Build(model_dir, - "", - "", - Place{TARGET(kX86), PRECISION(kFloat)}, - valid_places, - passes); + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto predictor = lite_api::CreatePaddlePredictor(config); - auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto input_tensor = predictor->GetInput(0); + std::vector input_shape{1, 3, 224, 224}; + input_tensor->Resize(input_shape); auto* data = input_tensor->mutable_data(); - for (int i = 0; i < input_tensor->dims().production(); i++) { + int input_num = 1; + for (int i = 0; i < input_shape.size(); ++i) { + input_num *= input_shape[i]; + } + for (int i = 0; i < input_num; i++) { data[i] = 1; } for (int i = 0; i < FLAGS_warmup; ++i) { - predictor.Run(); + predictor->Run(); } auto start = GetCurrentUS(); for (int i = 0; i < FLAGS_repeats; ++i) { - predictor.Run(); + predictor->Run(); } LOG(INFO) << "================== Speed Report ==================="; @@ -83,7 +59,6 @@ TEST(InceptionV4, test_inceptionv4_lite_x86) { << ", repeats: " << FLAGS_repeats << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; - std::vector> results; // i = 1 results.emplace_back(std::vector( @@ -93,15 +68,15 @@ TEST(InceptionV4, test_inceptionv4_lite_x86) { 0.0009782845, 0.0009230255, 0.0010548076, 0.0010974824, 0.0010612885, 0.00089107914, 0.0010112736, 0.00097655767})); - auto* out = predictor.GetOutput(0); - ASSERT_EQ(out->dims().size(), 2); - ASSERT_EQ(out->dims()[0], 1); - ASSERT_EQ(out->dims()[1], 1000); + auto out = predictor->GetOutput(0); + ASSERT_EQ(out->shape().size(), 2); + ASSERT_EQ(out->shape()[0], 1); + ASSERT_EQ(out->shape()[1], 1000); int step = 50; for (int i = 0; i < results.size(); ++i) { for (int j = 0; j < results[i].size(); ++j) { - EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + EXPECT_NEAR(out->data()[j * step + (out->shape()[1] * i)], results[i][j], 1e-6); } diff --git a/lite/api/test_mobilenetv1_lite_x86.cc b/lite/api/test_mobilenetv1_lite_x86.cc index d755410b6a8816cee1de60504e93e1eae5eedd4b..67dc1b2436988c7d0d853c945fecce27ef2d329f 100644 --- a/lite/api/test_mobilenetv1_lite_x86.cc +++ b/lite/api/test_mobilenetv1_lite_x86.cc @@ -12,68 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #include #include #include -#include "lite/api/cxx_api.h" #include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" -#include "lite/core/op_registry.h" -#include "lite/core/tensor.h" +#include "lite/utils/cp_logging.h" namespace paddle { namespace lite { TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) { - lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto predictor = lite_api::CreatePaddlePredictor(config); - std::string model_dir = FLAGS_model_dir; - std::vector passes({"static_kernel_pick_pass", - "variable_place_inference_pass", - "type_target_cast_pass", - "variable_place_inference_pass", - "io_copy_kernel_pick_pass", - "variable_place_inference_pass", - "runtime_context_assign_pass"}); - predictor.Build(model_dir, - "", - "", - Place{TARGET(kX86), PRECISION(kFloat)}, - valid_places, - passes); - auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto input_tensor = predictor->GetInput(0); + std::vector input_shape{1, 3, 224, 224}; + input_tensor->Resize(input_shape); auto* data = input_tensor->mutable_data(); - for (int i = 0; i < input_tensor->dims().production(); i++) { + int input_num = 1; + for (int i = 0; i < input_shape.size(); ++i) { + input_num *= input_shape[i]; + } + for (int i = 0; i < input_num; i++) { data[i] = 1; } for (int i = 0; i < FLAGS_warmup; ++i) { - predictor.Run(); + predictor->Run(); } auto start = GetCurrentUS(); for (int i = 0; i < FLAGS_repeats; ++i) { - predictor.Run(); + predictor->Run(); } LOG(INFO) << "================== Speed Report ==================="; @@ -81,7 +59,6 @@ TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) { << ", repeats: " << FLAGS_repeats << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; - std::vector> results; // i = 1 results.emplace_back(std::vector( @@ -90,15 +67,15 @@ TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) { 0.0010323516, 0.00010079765, 0.00011006987, 0.0017364529, 0.0048292773, 0.0013995157, 0.0018453331, 0.0002428986, 0.00020211363, 0.00013668182, 0.0005855956, 0.00025901722})); - auto* out = predictor.GetOutput(0); - ASSERT_EQ(out->dims().size(), 2); - ASSERT_EQ(out->dims()[0], 1); - ASSERT_EQ(out->dims()[1], 1000); + auto out = predictor->GetOutput(0); + ASSERT_EQ(out->shape().size(), 2); + ASSERT_EQ(out->shape()[0], 1); + ASSERT_EQ(out->shape()[1], 1000); int step = 50; for (int i = 0; i < results.size(); ++i) { for (int j = 0; j < results[i].size(); ++j) { - EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + EXPECT_NEAR(out->data()[j * step + (out->shape()[1] * i)], results[i][j], 1e-6); } diff --git a/lite/api/test_mobilenetv2_lite_x86.cc b/lite/api/test_mobilenetv2_lite_x86.cc index b1090cc6f260ba1b67c5cca8730a2915900f695f..95e88abcc8e59c6808ea2dc44cf7d1bdd53ac9d0 100644 --- a/lite/api/test_mobilenetv2_lite_x86.cc +++ b/lite/api/test_mobilenetv2_lite_x86.cc @@ -12,71 +12,47 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #include #include #include -#include "lite/api/cxx_api.h" #include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" -#include "lite/core/op_registry.h" -#include "lite/core/tensor.h" +#include "lite/utils/cp_logging.h" // for googlenet namespace paddle { namespace lite { TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) { - lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); - - // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"< passes({"static_kernel_pick_pass", - "variable_place_inference_pass", - "type_target_cast_pass", - "variable_place_inference_pass", - "io_copy_kernel_pick_pass", - "variable_place_inference_pass", - "runtime_context_assign_pass"}); - predictor.Build(model_dir, - "", - "", - Place{TARGET(kX86), PRECISION(kFloat)}, - valid_places, - passes); + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto predictor = lite_api::CreatePaddlePredictor(config); - auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto input_tensor = predictor->GetInput(0); + std::vector input_shape{1, 3, 224, 224}; + input_tensor->Resize(input_shape); auto* data = input_tensor->mutable_data(); - for (int i = 0; i < input_tensor->dims().production(); i++) { + int input_num = 1; + for (int i = 0; i < input_shape.size(); ++i) { + input_num *= input_shape[i]; + } + for (int i = 0; i < input_num; i++) { data[i] = 1; } for (int i = 0; i < FLAGS_warmup; ++i) { - predictor.Run(); + predictor->Run(); } auto start = GetCurrentUS(); for (int i = 0; i < FLAGS_repeats; ++i) { - predictor.Run(); + predictor->Run(); } LOG(INFO) << "================== Speed Report ==================="; @@ -84,7 +60,6 @@ TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) { << ", repeats: " << FLAGS_repeats << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; - std::vector> results; // i = 1 results.emplace_back(std::vector( @@ -93,15 +68,15 @@ TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) { 0.0009059976, 9.5378724e-05, 5.386537e-05, 0.0006427285, 0.0070957416, 0.0016094646, 0.0018807327, 0.00010506048, 6.823785e-05, 0.00012269315, 0.0007806194, 0.00022354358})); - auto* out = predictor.GetOutput(0); - ASSERT_EQ(out->dims().size(), 2); - ASSERT_EQ(out->dims()[0], 1); - ASSERT_EQ(out->dims()[1], 1000); + auto out = predictor->GetOutput(0); + ASSERT_EQ(out->shape().size(), 2); + ASSERT_EQ(out->shape()[0], 1); + ASSERT_EQ(out->shape()[1], 1000); int step = 50; for (int i = 0; i < results.size(); ++i) { for (int j = 0; j < results[i].size(); ++j) { - EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + EXPECT_NEAR(out->data()[j * step + (out->shape()[1] * i)], results[i][j], 1e-6); } diff --git a/lite/api/test_resnet50_lite_x86.cc b/lite/api/test_resnet50_lite_x86.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f9b59d714de611ef0a84cfc3b283d0dddd5c294 --- /dev/null +++ b/lite/api/test_resnet50_lite_x86.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +TEST(Resnet50, test_resnet50_lite_x86) { + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto predictor = lite_api::CreatePaddlePredictor(config); + + auto input_tensor = predictor->GetInput(0); + std::vector input_shape{1, 3, 224, 224}; + input_tensor->Resize(input_shape); + auto* data = input_tensor->mutable_data(); + int input_num = 1; + for (int i = 0; i < input_shape.size(); ++i) { + input_num *= input_shape[i]; + } + for (int i = 0; i < input_num; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor->Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor->Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.00024139918, 0.00020566184, 0.00022418296, 0.00041731037, + 0.0005366107, 0.00016948722, 0.00028638865, 0.0009257241, + 0.00072681636, 8.531815e-05, 0.0002129998, 0.0021168243, + 0.006387163, 0.0037145028, 0.0012812682, 0.00045948103, + 0.00013535398, 0.0002483765, 0.00076759676, 0.0002773295})); + auto out = predictor->GetOutput(0); + ASSERT_EQ(out->shape().size(), 2); + ASSERT_EQ(out->shape()[0], 1); + ASSERT_EQ(out->shape()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->shape()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc new file mode 100644 index 0000000000000000000000000000000000000000..c483373dc745f6520d51ece3936448ada71990d3 --- /dev/null +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +TEST(Step_rnn, test_step_rnn_lite_x86) { + std::string model_dir = FLAGS_model_dir; + lite_api::CxxConfig config; + config.set_model_dir(model_dir); + config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + auto predictor = lite_api::CreatePaddlePredictor(config); + + std::vector target_names = {"item_type_id", + "mthid_id", + "source_id_id", + "layout_id", + "mark_id", + "category_id", + "subcategory_id", + "score_segment_id", + "item_attention_id", + "queue_num_id", + "micro_video_id", + "vertical_type_id"}; + + for (int i = 0; i < target_names.size(); ++i) { + auto input_tensor = predictor->GetInput(i); + int size = 0; + if (i == 6 || i == 8) { + input_tensor->Resize(std::vector{5, 1}); + input_tensor->SetLoD({{0, 5}}); + size = 5; + } else { + input_tensor->Resize(std::vector{1, 1}); + input_tensor->SetLoD({{0, 1}}); + size = 1; + } + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < size; i++) data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor->Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor->Run(); + } + + // LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector({0.5030127, 0.496987})); + auto out = predictor->GetOutput(0); + + std::vector out_shape = out->shape(); + + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR( + out->data()[j + (out_shape[1] * i)], results[i][j], 1e-6); + } + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/transform_test.cc b/lite/api/transform_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e51f3778d30ba9fcfde493c3e27ecc973e66a59 --- /dev/null +++ b/lite/api/transform_test.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +DEFINE_string(input, "", "input_data"); +DEFINE_int32(batch, 1, "batch"); + +namespace paddle { +namespace lite { +namespace test_transformer { + +std::vector inputed_lines; + +void LoadInputLines(const char* filename) { + static const int max_line_buf_size = 100 * 1024 * 1024; + char* line_buffer = (char*)calloc(max_line_buf_size, sizeof(char)); // NOLINT + FILE* input_file = fopen(filename, "r"); + + while (fgets(line_buffer, max_line_buf_size, input_file)) { + // trim newline at end + char* pos = NULL; + if ((pos = strchr(line_buffer, '\n')) != NULL) { + *pos = 0; + } + inputed_lines.push_back(line_buffer); + } + free(line_buffer); + line_buffer = NULL; + fclose(input_file); +} +void Split2(const std::string& main_str, + std::vector& str_list, // NOLINT + const std::string& delimiter) { + size_t pre_pos = 0; + size_t position = 0; + std::string tmp_str; + + str_list.clear(); + if (main_str.empty()) { + return; + } + + while ((position = main_str.find(delimiter, pre_pos)) != std::string::npos) { + tmp_str.assign(main_str, pre_pos, position - pre_pos); + str_list.push_back(tmp_str); + pre_pos = position + 1; + } + + tmp_str.assign(main_str, pre_pos, main_str.length() - pre_pos); + + if (!tmp_str.empty()) { + str_list.push_back(tmp_str); + } +} +} // NOLINT + +void PadBatchInput(std::vector& input_lines, // NOLINT + int pad_idx, + int n_head, + Tensor* src_word, + Tensor* src_pos, + Tensor* src_attn_bias, + Tensor* trg_word, + Tensor* init_scores, + Tensor* init_idx, + Tensor* trg_bias, + int line_start, + int batch_size, + int bos_idx) { + int max_len = 0; + int max_line = input_lines.size(); + + std::vector> batch_lines; + for (int i = line_start; i < line_start + batch_size; ++i) { + int i_index = i % max_line; + std::string cur_line = input_lines[i_index]; + + std::vector split_str; + + test_transformer::Split2(cur_line, split_str, " "); + + batch_lines.push_back(split_str); + max_len = max_len >= split_str.size() ? max_len : split_str.size(); + } + + src_word->Resize(std::vector({batch_size, max_len, 1})); + src_pos->Resize(std::vector({batch_size, max_len, 1})); + src_attn_bias->Resize( + std::vector({batch_size, n_head, max_len, max_len})); + trg_bias->Resize( + std::vector({batch_size, n_head, 1, max_len})); + float* src_word_data = src_word->mutable_data(); + float* src_pos_data = src_pos->mutable_data(); + float* src_bias_data = src_attn_bias->mutable_data(); + float* trg_bias_data = trg_bias->mutable_data(); + for (int i = 0; i < batch_size; ++i) { + std::vector cur_words = batch_lines[i]; + int fill_len = cur_words.size(); + int src_bias_start = i * n_head * max_len * max_len; + int trg_bias_start = i * n_head * max_len; + for (int j = 0; j < fill_len; ++j) { + src_word_data[i * max_len + j] = (atoi(cur_words[j].c_str())); + src_pos_data[i * max_len + j] = j; + src_bias_data[src_bias_start + j] = 0; + trg_bias_data[trg_bias_start + j] = 0; + } + for (int j = fill_len; j < max_len; ++j) { + src_word_data[i * max_len + j] = pad_idx; + src_pos_data[i * max_len + j] = 0; + src_bias_data[src_bias_start + j] = -1000000000; + trg_bias_data[trg_bias_start + j] = -1000000000; + } + for (int j = src_bias_start; + j < src_bias_start + n_head * max_len * max_len; + ++j) { + int value_ind = j % max_len + src_bias_start; + src_bias_data[j] = src_bias_data[value_ind]; + } + for (int j = trg_bias_start; j < trg_bias_start + n_head * max_len; ++j) { + int value_ind = j % max_len + trg_bias_start; + trg_bias_data[j] = trg_bias_data[value_ind]; + } + } + + trg_word->Resize(std::vector({batch_size, 1, 1})); + auto* trg_word_data = trg_word->mutable_data(); + for (int i = 0; i < batch_size; ++i) { + trg_word_data[i] = bos_idx; + } + + init_scores->Resize(std::vector({batch_size, 1})); + init_idx->Resize(std::vector({batch_size})); + float* score_data = init_scores->mutable_data(); + float* idx_data = init_idx->mutable_data(); + for (int i = 0; i < init_scores->numel(); ++i) { + score_data[i] = 0; + } + std::vector> lod_s; + lod_s.resize(2); + for (int i = 0; i < batch_size; ++i) { + lod_s[0].push_back(i); + lod_s[1].push_back(i); + idx_data[i] = i; + } + lod_s[0].push_back(batch_size); + lod_s[1].push_back(batch_size); + auto score_lod = init_scores->mutable_lod(); + *score_lod = lod_s; + + auto trg_word_lod = trg_word->mutable_lod(); + *trg_word_lod = lod_s; +} + +void TestModel(const std::vector& valid_places, + const Place& preferred_place, + bool use_npu = false) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::string test_data_path = FLAGS_input; + + predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); + + int n_head = 8; + int batch_size = FLAGS_batch; + int bos_idx = 0; + int eos_idx = 1; + LOG(INFO) << "reading"; + + test_transformer::LoadInputLines(test_data_path.c_str()); + LOG(INFO) << "reading finished"; + + auto* trg_bias = predictor.GetInput(6); + auto* src_word = predictor.GetInput(0); + auto* src_pos = predictor.GetInput(1); + auto* src_bias = predictor.GetInput(2); + auto* trg_word = predictor.GetInput(3); + auto* init_score = predictor.GetInput(4); + auto* init_idx = predictor.GetInput(5); + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + auto start_i = GetCurrentUS(); + PadBatchInput(test_transformer::inputed_lines, + eos_idx, + n_head, + src_word, // src_word + src_pos, // src_pos + src_bias, // src_bias + trg_word, // trg_word + init_score, // init_score + init_idx, // init_idx + trg_bias, // trg_bias + i * batch_size, + batch_size, + bos_idx); + LOG(INFO) << "src_word:" << src_word->dims(); + auto start_ii = GetCurrentUS(); + LOG(INFO) << i << "->ii:" << (start_ii - start_i) / 1000.0; + predictor.Run(); + auto start_iii = GetCurrentUS(); + LOG(INFO) << i << "->iii:" << (start_iii - start_ii) / 1000.0; + auto* outs = predictor.GetOutputs(); + LOG(INFO) << "out:" << (*outs)[0].dims(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + auto* outs = predictor.GetOutputs(); + for (auto out : *outs) { + LOG(INFO) << "======" + << "here"; + LOG(INFO) << out; + } + LOG(INFO) << "======" + << "hereggg"; +} + +TEST(OcrAttention, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/unet_test.cc b/lite/api/unet_test.cc index aae5f493eb0f67e3d09c7b48eb823dda8b343159..697280f28883138d2603f796c1952c655cd085d8 100644 --- a/lite/api/unet_test.cc +++ b/lite/api/unet_test.cc @@ -30,14 +30,9 @@ TEST(unet, test) { DeviceInfo::Init(); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kARM), PRECISION(kFloat)}, - valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 512, 512}))); diff --git a/lite/backends/CMakeLists.txt b/lite/backends/CMakeLists.txt index 80dc574de894280575837584dadd8024660c6dc6..dec63e6efa0e4c4548646ebdd6f6de24f046d6d0 100644 --- a/lite/backends/CMakeLists.txt +++ b/lite/backends/CMakeLists.txt @@ -1,7 +1,8 @@ +add_subdirectory(opencl) add_subdirectory(arm) add_subdirectory(x86) add_subdirectory(cuda) add_subdirectory(fpga) add_subdirectory(host) -add_subdirectory(opencl) add_subdirectory(npu) +add_subdirectory(xpu) diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index f17928cc2935089fd4d9925d9791f0190f1e5c85..cbbcf49a5fd55dabd6b072bc6b3b2e3f9bb91a13 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -6,6 +6,17 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) return() endif() +set(script_dir ${CMAKE_CURRENT_SOURCE_DIR}/../../../tools/) +message(STATUS "generating arm dotprod code") +find_package(PythonInterp REQUIRED) +execute_process(COMMAND ${PYTHON_EXECUTABLE} ${script_dir}/convert_arm_sdot_to_machine_code.py + "--input_file=${CMAKE_CURRENT_SOURCE_DIR}/dotprod/__gemm_sdot_meta__.h" + "--output_file=${CMAKE_CURRENT_SOURCE_DIR}/dotprod/gemm_sdot.h" + RESULT_VARIABLE gen_code_ret) +if (NOT ${gen_code_ret} STREQUAL "0") + message(FATAL_ERROR "generating dotprod code quit with error: ${gen_code_ret}") +endif () + set(HAS_ARM_MATH_LIB_DIR OFF) # will search name as "libmath_arm.${os}.${abi}.${lang}.a" if(ARM_MATH_LIB_DIR AND EXISTS "${ARM_MATH_LIB_DIR}") @@ -50,6 +61,27 @@ if (NOT HAS_ARM_MATH_LIB_DIR) funcs.cc packed_sgemm.cc sgemm.cc + gemm_prepacked_int8.cc + gemm_s8.cc + sgemv.cc + gemv_arm_int8.cc + conv3x3s1_direct_fp32.cc + conv3x3s2_direct_fp32.cc + conv3x3s1_depthwise_fp32.cc + conv3x3s2_depthwise_fp32.cc + conv3x3s1_direct_int8.cc + conv3x3s2_direct_int8.cc + conv3x3s1_depthwise_int8.cc + conv3x3s2_depthwise_int8.cc + conv5x5s1_depthwise_int8.cc + conv5x5s1_depthwise_fp32.cc + conv5x5s2_depthwise_fp32.cc + conv_depthwise_3x3p0.cc + conv_depthwise_3x3p1.cc + conv_depthwise_3x3s1.cc + conv_depthwise_3x3s2.cc + conv_winograd_3x3.cc + conv_impl.cc softmax.cc scale.cc pooling.cc @@ -57,32 +89,13 @@ if (NOT HAS_ARM_MATH_LIB_DIR) lrn.cc decode_bboxes.cc concat.cc - sgemv.cc type_trans.cc box_coder.cc - conv_impl.cc - conv_direct_3x3s1.cc - conv_direct_3x3s2.cc - conv_direct.cc - conv_depthwise_3x3_int8.cc - conv_depthwise_5x5s1_int8.cc - conv_depthwise_3x3p0.cc - conv_depthwise_3x3p1.cc - conv_depthwise_5x5s1.cc - conv_depthwise_5x5s2.cc - conv_depthwise.cc - conv_gemmlike.cc - conv_winograd_3x3.cc - conv_winograd.cc split.cc shuffle_channel.cc activation.cc yolo_box.cc dropout.cc - gemm_prepacked_int8.cc - gemv_arm_int8.cc - conv3x3s1_direct_int8.cc - conv3x3s2_direct_int8.cc power.cc interpolate.cc argmax.cc @@ -104,8 +117,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR) slice.cc reduce_mean.cc stack.cc - affine_channel.cc - anchor_generator.cc - DEPS ${lite_kernel_deps}) + affine_channel.cc + anchor_generator.cc + DEPS ${lite_kernel_deps} context tensor) endif() diff --git a/lite/backends/arm/math/activation.cc b/lite/backends/arm/math/activation.cc index 938057ce9d8bb52614477689fba0d819c3ea3eda..634021cc3ce82bbb5fba72123b38457ab0c7ac06 100644 --- a/lite/backends/arm/math/activation.cc +++ b/lite/backends/arm/math/activation.cc @@ -471,7 +471,7 @@ void act_prelu(const float* din, } template <> -void act_sigmoid(const float* din, float* dout, int size, int threads) { +void act_sigmoid(const float* din, float* dout, int size, int threads) { int nums_per_thread = size / threads; int remain = size - threads * nums_per_thread; int neon_loop_cnt_dim4 = nums_per_thread >> 2; @@ -595,15 +595,11 @@ void act_swish( } template <> -void act_log(const float* din, float* dout, int size, int threads) { +void act_log(const float* din, float* dout, int size, int threads) { int nums_per_thread = size / threads; int remain = size - threads * nums_per_thread; int neon_loop_cnt_dim4 = nums_per_thread >> 2; int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); - LOG(INFO) << "nums_per_thread" << nums_per_thread; - LOG(INFO) << "remain" << remain; - LOG(INFO) << "neon_loop_cnt_dim4" << neon_loop_cnt_dim4; - LOG(INFO) << "neon_loop_remian_dim4" << neon_loop_remain_dim4; float32x4_t vzero = vdupq_n_f32(0.f); #pragma omp parallel for @@ -633,7 +629,7 @@ void act_log(const float* din, float* dout, int size, int threads) { } template <> -void act_exp(const float* din, float* dout, int size, int threads) { +void act_exp(const float* din, float* dout, int size, int threads) { int nums_per_thread = size / threads; int remain = size - threads * nums_per_thread; int neon_loop_cnt_dim4 = nums_per_thread >> 2; @@ -677,6 +673,33 @@ void act_floor(const float* din, float* dout, int size, int threads) { } } +template <> +void act_hard_sigmoid(const float* din, + float* dout, + const int64_t size, + const float slope, + const float offset, + int threads) { + for (int64_t i = 0; i < size; ++i) { + dout[0] = din[0] * slope + offset; + dout[0] = dout[0] < 1.0f ? dout[0] : 1.0f; + dout[0] = dout[0] > 0.0f ? dout[0] : 0.0f; + ++din; + ++dout; + } +} + +template <> +void act_rsqrt(const float* din, float* dout, int size, int threads) { + const float* ptr_in = din; + float* ptr_out = dout; + for (int i = 0; i < size; ++i) { + ptr_out[0] = 1.0 / sqrtf(ptr_in[0]); + ptr_in++; + ptr_out++; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/activation.h b/lite/backends/arm/math/activation.h index e115777d0242a42773dd5db65f869b42fab48876..bb8189eef0d81a92caf2aaf73e401e20d9c80155 100644 --- a/lite/backends/arm/math/activation.h +++ b/lite/backends/arm/math/activation.h @@ -58,6 +58,17 @@ void act_exp(const T* din, T* dout, int size, int threads); template void act_floor(const T* din, T* dout, int size, int threads); +template +void act_hard_sigmoid(const T* din, + T* dout, + const int64_t size, + const float slope, + const float offset, + int threads); + +template +void act_rsqrt(const T* din, T* dout, int size, int threads); + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..99aeea8bdea2a50795dcdca18464a196ee877291 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc @@ -0,0 +1,538 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void conv_3x3s1_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + + /// get workspace + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: + outc10 = ptr_write; + outc11 = ptr_write; + case 2: + outc20 = ptr_write; + outc21 = ptr_write; + case 1: + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(flag_relu)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ + "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ + "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ + "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ + /* r0, r1, mul w0, get out r0, r1 */ + "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ + "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ + "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ + "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ + "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ + "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ + "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ + "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ + /* r0, r1, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ + "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ + "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ + "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ + "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ + "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ + "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ + "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ + /* r0, r1, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ + "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ + "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ + "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ + "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ + "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ + "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ + "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ + /* r1, r2, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ + "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ + "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ + "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ + "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ + "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ + /* r1, r2, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ + "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ + "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ + "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ + "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ + "ldp x0, x1, [%[outl]] \n" + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ + "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ + "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ + "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ + /* r1, r2, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ + "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ + "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ + "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ + "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ + "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ + "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ + "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ + /* r2, r3, mul w6, get out r0, r1 */ + "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ + "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ + "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ + "ldp x2, x3, [%[outl], #16] \n" + "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ + "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ + "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ + "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ + /* r2, r3, mul w7, get out r0, r1 */ + "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ + "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ + "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ + "ldp x4, x5, [%[outl], #32] \n" + "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ + "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ + "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ + "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ + /* r2, r3, mul w8, get out r0, r1 */ + "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ + "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ + "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ + "ldp x6, x7, [%[outl], #48] \n" + "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ + "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ + "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ + "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ + + "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ + "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ + "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ + "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ + "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ + "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ + "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ + "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ + + /* transpose */ + "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ + "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ + "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ + "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ + "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ + "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ + "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ + "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ + "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ + "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ + "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ + "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ + "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ + "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ + "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ + "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ + + "cbz %w[flag_relu], 0f\n" /* skip relu*/ + "movi v0.4s, #0\n" /* for relu */ + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "0:\n" + "cbnz %w[flag_mask], 1f\n" + "str q15, [x0]\n" /* save outc00 */ + "str q16, [x4]\n" /* save outc01 */ + "str q17, [x1]\n" /* save outc10 */ + "str q18, [x5]\n" /* save outc11 */ + "str q19, [x2]\n" /* save outc20 */ + "str q20, [x6]\n" /* save outc21 */ + "str q21, [x3]\n" /* save outc30 */ + "str q22, [x7]\n" /* save outc31 */ + "b 2f\n" + "1:\n" + "str q15, [%[out]], #16 \n" /* save remain to pre_out */ + "str q17, [%[out]], #16 \n" /* save remain to pre_out */ + "str q19, [%[out]], #16 \n" /* save remain to pre_out */ + "str q21, [%[out]], #16 \n" /* save remain to pre_out */ + "str q16, [%[out]], #16 \n" /* save remain to pre_out */ + "str q18, [%[out]], #16 \n" /* save remain to pre_out */ + "str q20, [%[out]], #16 \n" /* save remain to pre_out */ + "str q22, [%[out]], #16 \n" /* save remain to pre_out */ + "2:\n" + :[inr0] "+r"(inr0), [inr1] "+r"(inr1), + [inr2] "+r"(inr2), [inr3] "+r"(inr3), + [out]"+r"(out0) + :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), + [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), + [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), + [vbias]"w" (vbias), [outl] "r" (outl_ptr), + [flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu) + : "cc", "memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8", "v9", "v10", "v11", "v15", + "v16","v17","v18","v19","v20","v21","v22", + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7" + ); +#else + asm volatile( + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" + /* load r0, r1 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" + "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" + /* main loop */ + "0: @ main loop\n" + /* mul r0 with w0, w1, w2, get out r0 */ + "vmul.f32 q8, q5, q0 @ w0 * inr00\n" + "vmul.f32 q9, q5, q1 @ w0 * inr01\n" + "vmul.f32 q10, q5, q2 @ w0 * inr02\n" + "vmul.f32 q11, q5, q3 @ w0 * inr03\n" + "vmla.f32 q8, q6, q1 @ w1 * inr01\n" + "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" + "vmla.f32 q9, q6, q2 @ w1 * inr02\n" + "vmla.f32 q10, q6, q3 @ w1 * inr03\n" + "vmla.f32 q11, q6, q0 @ w1 * inr04\n" + "vmla.f32 q8, q7, q2 @ w2 * inr02\n" + "vmla.f32 q9, q7, q3 @ w2 * inr03\n" + "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" + "vmla.f32 q10, q7, q0 @ w2 * inr04\n" + "vmla.f32 q11, q7, q1 @ w2 * inr05\n" + "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" + "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" + /* mul r1 with w0-w5, get out r0, r1 */ + "vmul.f32 q12, q5, q2 @ w0 * inr10\n" + "vmul.f32 q13, q5, q3 @ w0 * inr11\n" + "vmul.f32 q14, q5, q0 @ w0 * inr12\n" + "vmul.f32 q15, q5, q1 @ w0 * inr13\n" + "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" + "vmla.f32 q8, q4, q2 @ w3 * inr10\n" + "vmla.f32 q9, q4, q3 @ w3 * inr11\n" + "vmla.f32 q10, q4, q0 @ w3 * inr12\n" + "vmla.f32 q11, q4, q1 @ w3 * inr13\n" + /* mul r1 with w1, w4, get out r1, r0 */ + "vmla.f32 q8, q5, q3 @ w4 * inr11\n" + "vmla.f32 q12, q6, q3 @ w1 * inr11\n" + "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" + "vmla.f32 q9, q5, q0 @ w4 * inr12\n" + "vmla.f32 q13, q6, q0 @ w1 * inr12\n" + "vmla.f32 q10, q5, q1 @ w4 * inr13\n" + "vmla.f32 q14, q6, q1 @ w1 * inr13\n" + "vmla.f32 q11, q5, q2 @ w4 * inr14\n" + "vmla.f32 q15, q6, q2 @ w1 * inr14\n" + "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" + /* mul r1 with w2, w5, get out r1, r0 */ + "vmla.f32 q12, q7, q0 @ w2 * inr12\n" + "vmla.f32 q13, q7, q1 @ w2 * inr13\n" + "vmla.f32 q8, q6, q0 @ w5 * inr12\n" + "vmla.f32 q9, q6, q1 @ w5 * inr13\n" + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" + "vmla.f32 q14, q7, q2 @ w2 * inr14\n" + "vmla.f32 q15, q7, q3 @ w2 * inr15\n" + "vmla.f32 q10, q6, q2 @ w5 * inr14\n" + "vmla.f32 q11, q6, q3 @ w5 * inr15\n" + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" + /* mul r2 with w3-w8, get out r0, r1 */ + "vmla.f32 q12, q4, q0 @ w3 * inr20\n" + "vmla.f32 q13, q4, q1 @ w3 * inr21\n" + "vmla.f32 q14, q4, q2 @ w3 * inr22\n" + "vmla.f32 q15, q4, q3 @ w3 * inr23\n" + "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" + "vmla.f32 q8, q7, q0 @ w6 * inr20\n" + "vmla.f32 q9, q7, q1 @ w6 * inr21\n" + "vmla.f32 q10, q7, q2 @ w6 * inr22\n" + "vmla.f32 q11, q7, q3 @ w6 * inr23\n" + /* mul r2 with w4, w7, get out r1, r0 */ + "vmla.f32 q8, q4, q1 @ w7 * inr21\n" + "vmla.f32 q12, q5, q1 @ w4 * inr21\n" + "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" + "vmla.f32 q9, q4, q2 @ w7 * inr22\n" + "vmla.f32 q13, q5, q2 @ w4 * inr22\n" + "vmla.f32 q10, q4, q3 @ w7 * inr23\n" + "vmla.f32 q14, q5, q3 @ w4 * inr23\n" + "vmla.f32 q11, q4, q0 @ w7 * inr24\n" + "vmla.f32 q15, q5, q0 @ w4 * inr24\n" + "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" + /* mul r1 with w5, w8, get out r1, r0 */ + "vmla.f32 q12, q6, q2 @ w5 * inr22\n" + "vmla.f32 q13, q6, q3 @ w5 * inr23\n" + "vmla.f32 q8, q5, q2 @ w8 * inr22\n" + "vmla.f32 q9, q5, q3 @ w8 * inr23\n" + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" + "ldr r4, [%[outl], #32] @ load bias addr to r4\n" + "vmla.f32 q14, q6, q0 @ w5 * inr24\n" + "vmla.f32 q15, q6, q1 @ w5 * inr25\n" + "vmla.f32 q10, q5, q0 @ w8 * inr24\n" + "vmla.f32 q11, q5, q1 @ w8 * inr25\n" + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" + /* mul r3 with w6, w7, w8, get out r1 */ + "vmla.f32 q12, q7, q2 @ w6 * inr30\n" + "vmla.f32 q13, q7, q3 @ w6 * inr31\n" + "vmla.f32 q14, q7, q0 @ w6 * inr32\n" + "vmla.f32 q15, q7, q1 @ w6 * inr33\n" + "vmla.f32 q12, q4, q3 @ w7 * inr31\n" + "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" + "vld1.32 {d12-d13}, [r4] @ load bias\n" + "vmla.f32 q13, q4, q0 @ w7 * inr32\n" + "vmla.f32 q14, q4, q1 @ w7 * inr33\n" + "vmla.f32 q15, q4, q2 @ w7 * inr34\n" + "ldr r0, [%[outl]] @ load outc00 to r0\n" + "vmla.f32 q12, q5, q0 @ w8 * inr32\n" + "vmla.f32 q13, q5, q1 @ w8 * inr33\n" + "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" + "vmla.f32 q14, q5, q2 @ w8 * inr34\n" + "vmla.f32 q15, q5, q3 @ w8 * inr35\n" + "ldr r1, [%[outl], #4] @ load outc10 to r1\n" + "vadd.f32 q8, q8, q6 @ r00 add bias\n" + "vadd.f32 q9, q9, q6 @ r01 add bias\n" + "vadd.f32 q10, q10, q6 @ r02 add bias\n" + "vadd.f32 q11, q11, q6 @ r03 add bias\n" + "ldr r2, [%[outl], #8] @ load outc20 to r2\n" + "vadd.f32 q12, q12, q6 @ r10 add bias\n" + "vadd.f32 q13, q13, q6 @ r11 add bias\n" + "vadd.f32 q14, q14, q6 @ r12 add bias\n" + "vadd.f32 q15, q15, q6 @ r13 add bias\n" + "ldr r3, [%[outl], #12] @ load outc30 to r3\n" + "vmov.u32 q7, #0 @ mov zero to q7\n" + "cmp r5, #0 @ cmp flag relu\n" + "beq 1f @ skip relu\n" + "vmax.f32 q8, q8, q7 @ r00 relu\n" + "vmax.f32 q9, q9, q7 @ r01 relu\n" + "vmax.f32 q10, q10, q7 @ r02 relu\n" + "vmax.f32 q11, q11, q7 @ r03 relu\n" + "vmax.f32 q12, q12, q7 @ r10 relu\n" + "vmax.f32 q13, q13, q7 @ r11 relu\n" + "vmax.f32 q14, q14, q7 @ r12 relu\n" + "vmax.f32 q15, q15, q7 @ r13 relu\n" + "1:\n" + "ldr r4, [%[outl], #16] @ load outc01 to r4\n" + "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" + "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" + "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" + "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" + "ldr r5, [%[outl], #20] @ load outc11 to r5\n" + "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" + "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" + "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" + "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" + "cmp %[flag_mask], #0 @ cmp flag mask\n" + "bne 2f\n" + "vst1.32 {d16-d17}, [r0] @ save outc00\n" + "vst1.32 {d18-d19}, [r1] @ save outc10\n" + "vst1.32 {d20-d21}, [r2] @ save outc20\n" + "vst1.32 {d22-d23}, [r3] @ save outc30\n" + "vst1.32 {d24-d25}, [r4] @ save outc01\n" + "vst1.32 {d26-d27}, [r5] @ save outc11\n" + "ldr r0, [%[outl], #24] @ load outc21 to r0\n" + "ldr r1, [%[outl], #28] @ load outc31 to r1\n" + "vst1.32 {d28-d29}, [r0] @ save outc21\n" + "vst1.32 {d30-d31}, [r1] @ save outc31\n" + "b 3f @ branch end\n" + "2: \n" + "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" + "3: \n" + : [r0] "+r"(inr0), [r1] "+r"(inr1), + [r2] "+r"(inr2), [r3] "+r"(inr3), + [out0] "+r"(out0), [wc0] "+r"(weight_c) + : [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr) + : "cc", "memory", + "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5" + ); +#endif // __arch64__ + // clang-format on + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc b/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc2097b9286dbce4430739a0784f2691c62d37a1 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc @@ -0,0 +1,483 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +void conv_depthwise_3x3s1_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx) { + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + + const int hout_c_block = 8; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 2; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! win_round = wout_round + 2 + //! hin_r_block = hout_r_block + 2 + int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) / + (win_round * threads * hout_c_block + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + auto tmp_work_space = ctx->workspace_data(); + int8_t ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(int8_t) * win_round); + Dtype ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + int8_t* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 9; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; + + for (int n = 0; n < num; ++n) { + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = reinterpret_cast(dout) + + n * chout * size_out_channel * sizeof(Dtype); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 2; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int8_t* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4); + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size); +#else + int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); + auto pre_din = tmp_din; +#endif + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); + + const int8_t* block_inr0 = pre_din; + const int8_t* block_inr1 = block_inr0 + in_len; + const int8_t* block_inr2 = block_inr1 + in_len; + + const int8_t* weight_c = weights + c * w_stride; + float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + bias_local[4] = bias[c + 4]; + bias_local[5] = bias[c + 5]; + bias_local[6] = bias[c + 6]; + bias_local[7] = bias[c + 7]; + } +#ifdef __aarch64__ + int8x8_t vw0 = vld1_s8(weight_c); + int8x8_t vw1 = vld1_s8(weight_c + 8); + int8x8_t vw2 = vld1_s8(weight_c + 16); + int8x8_t vw3 = vld1_s8(weight_c + 24); + int8x8_t vw4 = vld1_s8(weight_c + 32); + int8x8_t vw5 = vld1_s8(weight_c + 40); + int8x8_t vw6 = vld1_s8(weight_c + 48); + int8x8_t vw7 = vld1_s8(weight_c + 56); + int8x8_t vw8 = vld1_s8(weight_c + 64); +#endif + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + int32_t* ptr_out0 = pre_out + hk * out_row_stride; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" + "1:\n" + /* inr0 -> outr0 */ + "ldp d4, d5, [%[r0]]\n" /* load r0, 4 */ + "smull v20.8h, v0.8b, %[w0].8b\n" /* int16, out0 */ + "smull v21.8h, v1.8b, %[w0].8b\n" /* int16, out1 */ + "smull v22.8h, v2.8b, %[w0].8b\n" /* int16, out2 */ + "smull v23.8h, v3.8b, %[w0].8b\n" /* int16, out3 */ + "smlal v20.8h, v1.8b, %[w1].8b\n" /* int16, out0 */ + "smlal v21.8h, v2.8b, %[w1].8b\n" /* int16, out1 */ + "smlal v22.8h, v3.8b, %[w1].8b\n" /* int16, out2 */ + "smlal v23.8h, v4.8b, %[w1].8b\n" /* int16, out3 */ + "ldp d0, d1, [%[r1]], #16\n" /* load r1, 0,1 */ + "sxtl v24.4s, v20.4h\n" + "sxtl2 v25.4s, v20.8h\n" + "sxtl v26.4s, v21.4h\n" + "sxtl2 v27.4s, v21.8h\n" + "sxtl v28.4s, v22.4h\n" + "sxtl2 v29.4s, v22.8h\n" + "sxtl v30.4s, v23.4h\n" + "sxtl2 v31.4s, v23.8h\n" + "smull v20.8h, v2.8b, %[w2].8b\n" /* int16, out0 */ + "smull v21.8h, v3.8b, %[w2].8b\n" /* int16, out1 */ + "smull v22.8h, v4.8b, %[w2].8b\n" /* int16, out2 */ + "smull v23.8h, v5.8b, %[w2].8b\n" /* int16, out3 */ + "ldp d2, d3, [%[r1]], #16\n" /* load r1, 2,3 */ + "smlal v20.8h, v0.8b, %[w3].8b\n" /* int16, out0 */ + "smlal v21.8h, v1.8b, %[w3].8b\n" /* int16, out1 */ + "smlal v22.8h, v2.8b, %[w3].8b\n" /* int16, out2 */ + "smlal v23.8h, v3.8b, %[w3].8b\n" /* int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "ldp d4, d5, [%[r1]]\n" /* load r1, 4,5 */ + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "smull v20.8h, v1.8b, %[w4].8b\n" /* int16, out0 */ + "smull v21.8h, v2.8b, %[w4].8b\n" /* int16, out1 */ + "smull v22.8h, v3.8b, %[w4].8b\n" /* int16, out1 */ + "smull v23.8h, v4.8b, %[w4].8b\n" /* int16, out1 */ + "ldp d0, d1, [%[r2]], #16\n" /* load r2, 0,1 */ + "smlal v20.8h, v2.8b, %[w5].8b\n" /* int16, out0 */ + "smlal v21.8h, v3.8b, %[w5].8b\n" /* int16, out1 */ + "smlal v22.8h, v4.8b, %[w5].8b\n" /* int16, out2 */ + "smlal v23.8h, v5.8b, %[w5].8b\n" /* int16, out3 */ + "ldp d2, d3, [%[r2]], #16\n" /* load r2, 2,3 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "ldp d4, d5, [%[r2]]\n" /* load r2 */ + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "smull v20.8h, v0.8b, %[w6].8b\n" /* int16, out0 */ + "smull v21.8h, v1.8b, %[w6].8b\n" /* int16, out1 */ + "smull v22.8h, v2.8b, %[w6].8b\n" /* int16, out1 */ + "smull v23.8h, v3.8b, %[w6].8b\n" /* int16, out1 */ + "smlal v20.8h, v1.8b, %[w7].8b\n" /* int16, out0 */ + "smlal v21.8h, v2.8b, %[w7].8b\n" /* int16, out1 */ + "smlal v22.8h, v3.8b, %[w7].8b\n" /* int16, out1 */ + "smlal v23.8h, v4.8b, %[w7].8b\n" /* int16, out1 */ + "ldp d0, d1, [%[r0]], #16\n" /* load r0, 0,1 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "smull v20.8h, v2.8b, %[w8].8b\n" /* int16, out0 */ + "smull v21.8h, v3.8b, %[w8].8b\n" /* int16, out1 */ + "smull v22.8h, v4.8b, %[w8].8b\n" /* int16, out1 */ + "smull v23.8h, v5.8b, %[w8].8b\n" /* int16, out1 */ + "ldp d2, d3, [%[r0]], #16\n" /* load r0, 2,3 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "stp q24, q25, [%[ptr_out0]], #32\n" + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "stp q26, q27, [%[ptr_out0]], #32\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "subs %w[cnt], %w[cnt], #1\n" + "stp q28, q29, [%[ptr_out0]], #32\n" + "stp q30, q31, [%[ptr_out0]], #32\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [ptr_out0] "+r"(ptr_out0) + : [w0] "w"(vw0), + [w1] "w"(vw1), + [w2] "w"(vw2), + [w3] "w"(vw3), + [w4] "w"(vw4), + [w5] "w"(vw5), + [w6] "w"(vw6), + [w7] "w"(vw7), + [w8] "w"(vw8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31" + + ); +#else + auto wptr = weight_c; + asm volatile( + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-4 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "1:\n" + /* inr0 -> outr0 */ + "vld1.32 {d4-d5}, [%[r0]]\n" /* load r0, 5-6 */ + "vmull.s8 q4, d0, d6\n" /* int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* int16, out3 */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w2 */ + "vmlal.s8 q4, d1, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* int16, out3 */ + "vld1.32 {d7}, [%[wptr]]!\n" /* load w3 */ + "vmovl.s16 q8, d8\n" + "vmovl.s16 q9, d9\n" + "vmovl.s16 q10, d10\n" + "vmovl.s16 q11, d11\n" + "vld1.32 {d0-d1}, [%[r1]]!\n" /* load r1, 0-1 */ + "vmovl.s16 q12, d12\n" + "vmovl.s16 q13, d13\n" + "vmovl.s16 q14, d14\n" + "vmovl.s16 q15, d15\n" + "vmull.s8 q4, d2, d6\n" /* int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* int16, out1 */ + "vld1.32 {d2-d3}, [%[r1]]!\n" /* load r1, 2-3 */ + "vmull.s8 q6, d4, d6\n" /* int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* int16, out3 */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w4 */ + /* inr1 -> outr0 */ + "vmlal.s8 q4, d0, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d1, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* int16, out3 */ + "vld1.32 {d4-d5}, [%[r1]]\n" /* load r1, 4-5 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vld1.32 {d7}, [%[wptr]]!\n" /* load w5 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "vmull.s8 q4, d1, d6\n" /* int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* int16, out1 */ + "vmull.s8 q6, d3, d6\n" /* int16, out2 */ + "vmull.s8 q7, d4, d6\n" /* int16, out3 */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w6 */ + "vld1.32 {d0-d1}, [%[r2]]!\n" /* load r2, 0-1 */ + "vmlal.s8 q4, d2, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d3, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* int16, out2 */ + "vmlal.s8 q7, d5, d7\n" /* int16, out3 */ + "vld1.32 {d7}, [%[wptr]]!\n" /* load w7 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vld1.32 {d2-d3}, [%[r2]]!\n" /* load r2, 2-3 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "vld1.32 {d4-d5}, [%[r2]]\n" /* load r2, 4-5 */ + /* inr2 -> outr0 */ + "vmull.s8 q4, d0, d6\n" /* int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* int16, out3 */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w8 */ + "vmlal.s8 q4, d1, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* int16, out3 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vld1.32 {d0-d1}, [%[r0]]!\n" /* load r0, 0-1 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "sub %[wptr], %[wptr], #72\n" + "vmull.s8 q4, d2, d6\n" /* int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* int16, out3 */ + "vld1.32 {d2-d3}, [%[r0]]!\n" /* load r0, 2-3 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]!\n" + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]!\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "subs %[cnt], #1\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]!\n" + "vst1.32 {d28-d31}, [%[ptr_out0]]!\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [ptr_out0] "+r"(ptr_out0), + [wptr] "+r"(wptr) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + block_inr0 = block_inr1; + block_inr1 = block_inr2; + block_inr2 = block_inr1 + in_len; + } + write_int32_nchwc8_to_nchw(pre_out, + reinterpret_cast(dout_batch), + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + bias_local, + flag_bias, + ptr_write, + scale + c); + } + } + } +} + +template void conv_depthwise_3x3s1_int8(int8_t* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + +template void conv_depthwise_3x3s1_int8(float* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a1fa37681585883280625a22c15aec43c6554af --- /dev/null +++ b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc @@ -0,0 +1,790 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +const int OUT_C_BLOCK = 4; +const int OUT_H_BLOCK = 2; +const int OUT_W_BLOCK = 4; + +size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx) { + auto dim_in = param.x->dims(); + auto dim_out = param.output->dims(); + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / sizeof(float); + const int pad_w = param.paddings[1]; + const int pad_h = param.paddings[0]; + int ow = dim_out[3]; + int oh = dim_out[2]; + int ic = dim_in[1]; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); + const int win_round = wout_round + 2; + + int hout_r_block = (llc_size - 2 * win_round * ic) / + (win_round * ic + OUT_C_BLOCK * wout_round * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + + return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size); +} + +void conv_3x3s1_direct_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int threads = ctx->threads(); + int l2_size = ctx->llc_size() / sizeof(float); + + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); + const int win_round = wout_round + 2; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + + int hout_r_block = (l2_size - 2 * win_round * ic) / + (win_round * ic + OUT_C_BLOCK * wout_round * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + + float* pre_din = tmp_work_space; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + int w_stride = ic * 9; // kernel_w * kernel_h; + int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h * + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; + int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; + + int out_row_stride = OUT_C_BLOCK * wout_round; + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; + for (int h = 0; h < oh; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > oh) { + h_kernel = oh - h; + } + int hs = h - pad_h; + int he = hs + h_kernel + 2; + prepack_input_nxw( + din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc - (OUT_C_BLOCK - 1); c += OUT_C_BLOCK) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + + const float* weight_c = weights + c * w_stride; + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c; + } + fill_packed_biasc4( + pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { + const float* wc0 = weight_c; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + + float* pre_out0 = pre_out + hk * out_row_stride; + float* pre_out1 = pre_out0 + out_row_stride; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + float32x4_t w0 = vld1q_f32(wc0); // w0, v23 + float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + + int cnt = w_loop; + // clang-format off + asm volatile( + "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00,outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "2: \n" /* main loop*/ + /* r0, r1, mul w0, get out r0, r1 */ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ + "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ + "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ + "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ + "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ + "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ + "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ + /* r0, r1, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ + "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ + "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ + "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ + "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ + "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ + "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + /* r0, r1, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ + "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ + "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ + "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ + "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ + "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ + "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ + /* r1, r2, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ + "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ + "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ + "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ + "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ + "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ + "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ + /* r1, r2, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ + "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ + "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ + "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ + "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ + "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ + "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + /* r1, r2, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ + "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ + "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ + "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ + "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ + "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ + "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ + /* r2, r3, mul w6, get out r0, r1 */ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ + "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ + "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ + "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ + "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ + "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ + "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ + /* r2, r3, mul w7, get out r0, r1 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ + "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ + "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ + "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ + "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ + "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ + "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + /* r2, r3, mul w8, get out r0, r1 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ + "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ + "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), + [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), + [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) + : "cc","memory","v0","v1","v2","v3", + "v4","v5","v6","v7","v15","v16", + "v17","v18","v19","v20","v21","v22" + ); + // clang-format on + + wc0 += 9 * OUT_C_BLOCK; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + const float* wc0 = weight_c + i * w_stride_chin; + + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + + int cnt = w_loop; + // clang-format off + asm volatile( + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + /* load r0, r1 */ + "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n" + "vld1.32 {d2}, [%[r0]] @ load r0\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" + /* main loop */ + "0: @ main loop\n" + /* mul r0 with w0, w1, w2, get out r0 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n" + "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n" + "vmla.f32 q9, q5, d0[1] @ w0 * inr01\n" + "vmla.f32 q10, q5, d1[0] @ w0 * inr02\n" + "vmla.f32 q11, q5, d1[1] @ w0 * inr03\n" + "vld1.32 {d3-d4}, [%[r1]]! @ load r1\n" + "vmla.f32 q8, q6, d0[1] @ w1 * inr01\n" + "vmla.f32 q9, q6, d1[0] @ w1 * inr02\n" + "vmla.f32 q10, q6, d1[1] @ w1 * inr03\n" + "vmla.f32 q11, q6, d2[0] @ w1 * inr04\n" + "vld1.32 {d5}, [%[r1]] @ load r0\n" + "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n" + "vmla.f32 q9, q7, d1[1] @ w2 * inr03\n" + "vmla.f32 q10, q7, d2[0] @ w2 * inr04\n" + "vmla.f32 q11, q7, d2[1] @ w2 * inr05\n" + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n" + /* mul r1 with w0, w1, w2, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w0 * inr10\n" + "vmla.f32 q13, q5, d3[1] @ w0 * inr11\n" + "vmla.f32 q14, q5, d4[0] @ w0 * inr12\n" + "vmla.f32 q15, q5, d4[1] @ w0 * inr13\n" + "vmla.f32 q12, q6, d3[1] @ w1 * inr11\n" + "vmla.f32 q13, q6, d4[0] @ w1 * inr12\n" + "vmla.f32 q14, q6, d4[1] @ w1 * inr13\n" + "vmla.f32 q15, q6, d5[0] @ w1 * inr14\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4\n" + "vmla.f32 q12, q7, d4[0] @ w2 * inr12\n" + "vmla.f32 q13, q7, d4[1] @ w2 * inr13\n" + "vmla.f32 q14, q7, d5[0] @ w2 * inr14\n" + "vmla.f32 q15, q7, d5[1] @ w2 * inr15\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5\n" + /* mul r1 with w3, w4, w5, get out r0 */ + "vmla.f32 q8, q5, d3[0] @ w3 * inr10\n" + "vmla.f32 q9, q5, d3[1] @ w3 * inr11\n" + "vmla.f32 q10, q5, d4[0] @ w3 * inr12\n" + "vmla.f32 q11, q5, d4[1] @ w3 * inr13\n" + "vld1.32 {d0-d1}, [%[r2]]! @ load r2\n" + "vmla.f32 q8, q6, d3[1] @ w4 * inr11\n" + "vmla.f32 q9, q6, d4[0] @ w4 * inr12\n" + "vmla.f32 q10, q6, d4[1] @ w4 * inr13\n" + "vmla.f32 q11, q6, d5[0] @ w4 * inr14\n" + "vld1.32 {d2}, [%[r2]] @ load r2\n" + "vmla.f32 q8, q7, d4[0] @ w5 * inr12\n" + "vmla.f32 q9, q7, d4[1] @ w5 * inr13\n" + "vmla.f32 q10, q7, d5[0] @ w5 * inr14\n" + "vmla.f32 q11, q7, d5[1] @ w5 * inr15\n" + /* mul r2 with w3, w4, w5, get out r1 */ + "vmla.f32 q12, q5, d0[0] @ w3 * inr20\n" + "vmla.f32 q13, q5, d0[1] @ w3 * inr21\n" + "vmla.f32 q14, q5, d1[0] @ w3 * inr22\n" + "vmla.f32 q15, q5, d1[1] @ w3 * inr23\n" + "vmla.f32 q12, q6, d0[1] @ w4 * inr21\n" + "vmla.f32 q13, q6, d1[0] @ w4 * inr22\n" + "vmla.f32 q14, q6, d1[1] @ w4 * inr23\n" + "vmla.f32 q15, q6, d2[0] @ w4 * inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n" + "vmla.f32 q12, q7, d1[0] @ w5 * inr22\n" + "vmla.f32 q13, q7, d1[1] @ w5 * inr23\n" + "vmla.f32 q14, q7, d2[0] @ w5 * inr24\n" + "vmla.f32 q15, q7, d2[1] @ w5 * inr25\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144\n" + /* mul r2 with w6, w7, w8, get out r0 */ + "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n" + "vmla.f32 q9, q5, d0[1] @ w6 * inr21\n" + "vld1.32 {d3-d4}, [%[r3]]! @ load r3\n" + "vmla.f32 q10, q5, d1[0] @ w6 * inr22\n" + "vmla.f32 q11, q5, d1[1] @ w6 * inr23\n" + "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n" + "vmla.f32 q9, q6, d1[0] @ w7 * inr22\n" + "vld1.32 {d5}, [%[r3]] @ load r3\n" + "vmla.f32 q10, q6, d1[1] @ w7 * inr23\n" + "vmla.f32 q11, q6, d2[0] @ w7 * inr24\n" + "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n" + "vmla.f32 q9, q7, d1[1] @ w8 * inr23\n" + "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n" + "vmla.f32 q10, q7, d2[0] @ w8 * inr24\n" + "vmla.f32 q11, q7, d2[1] @ w8 * inr25\n" + "vld1.32 {d2}, [%[r0]] @ load r0\n" + /* mul r3 with w6, w7, w8, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w6 * inr20\n" + "vmla.f32 q13, q5, d3[1] @ w6 * inr21\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n" + "vmla.f32 q14, q5, d4[0] @ w6 * inr22\n" + "vmla.f32 q15, q5, d4[1] @ w6 * inr23\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n" + "vmla.f32 q12, q6, d3[1] @ w7 * inr21\n" + "vmla.f32 q13, q6, d4[0] @ w7 * inr22\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vmla.f32 q14, q6, d4[1] @ w7 * inr23\n" + "vmla.f32 q15, q6, d5[0] @ w7 * inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vmla.f32 q12, q7, d4[0] @ w8 * inr22\n" + "vmla.f32 q13, q7, d4[1] @ w8 * inr23\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + "vmla.f32 q14, q7, d5[0] @ w8 * inr24\n" + "vmla.f32 q15, q7, d5[1] @ w8 * inr25\n" + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" + "subs %[cnt], #1 @ loop count--\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(wc0) + : + : "cc","memory","q0","q1","q2","q3", + "q4","q5","q6","q7","q8","q9", + "q10","q11","q12","q13","q14","q15"); + // clang-format on + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + } + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + OUT_C_BLOCK, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + const float* weight_remain_ptr = weights + c_round_down * w_stride; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < c_remain; ++c) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + + int c_idx = c_round_down + c; + + int h_kernel = hout_r_block; + if (h + hout_r_block > oh) { + h_kernel = oh - h; + } + + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c_idx; + } + fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { + const float* wc0 = weight_remain_ptr; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + + float* pre_out0 = pre_out + hk * wout_round; + float* pre_out1 = pre_out0 + wout_round; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 + float32x4_t w1 = vdupq_n_f32(wc0[4 + c]); // w1, v24 + float32x4_t w2 = vdupq_n_f32(wc0[8 + c]); // w2, v25 + float32x4_t w3 = vdupq_n_f32(wc0[12 + c]); // w3, v26 + float32x4_t w4 = vdupq_n_f32(wc0[16 + c]); // w4, v27 + float32x4_t w5 = vdupq_n_f32(wc0[20 + c]); // w5, v28 + float32x4_t w6 = vdupq_n_f32(wc0[24 + c]); // w6, v29 + float32x4_t w7 = vdupq_n_f32(wc0[28 + c]); // w7, v30 + float32x4_t w8 = vdupq_n_f32(wc0[32 + c]); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + + int cnt = w_loop; + // clang-format off + asm volatile( + "ldr q21, [%[ptr_out0]]\n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + "2: \n" /* main loop*/ + "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ + "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ + "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ + "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ + "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ + "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ + "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ + "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ + "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ + "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ + "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ + "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ + "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ + "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ + "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ + "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ + "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ + "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ + "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ + "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ + "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ + "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ + "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ + "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), + [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), + [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) + : "cc","memory","v0", + "v1","v2","v3","v4","v5","v6", + "v7","v8","v9","v10","v11","v12", + "v13","v14","v15","v21","v22" + ); + // clang-format on + wc0 += 9 * OUT_C_BLOCK; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float w_tmp[10] = {wc0[c], + wc0[c + 4], + wc0[c + 8], + wc0[c + 12], + wc0[c + 16], + wc0[c + 20], + wc0[c + 24], + wc0[c + 28], + wc0[c + 32], + 0.f}; + float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 + float32x4_t w1 = vld1q_f32(w_tmp + 3); // w3, w4, w5, q1 + float32x4_t w2 = vld1q_f32(w_tmp + 6); // w6, w7, w8, q2 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + int cnt = w_loop / 2; + if (cnt > 0) { + // clang-format off + asm volatile( + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r0\n" + "vld1.32 {d10}, [%[r0]] @ load r0\n" + /* main loop */ + "0: @ main loop\n" + /* r0 * w0, w1, w2, get out r0*/ + "vld1.32 {d28-d31}, [%[ptr_out1]]@ load or10 or11\n" + "vext.32 q8, q3, q4, #1 @ r0, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r0, shift left 1\n" + "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0\n" + "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0\n" + "vext.32 q10, q3, q4, #2 @ r0, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r0, shift left 2\n" + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n" + "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0\n" + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8\n" + "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0\n" + "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0\n" + "vld1.32 {d10}, [%[r1]] @ load r1\n" + /* r1 * w3, w4, w5, get out r0*/ + /* r1 * w0, w1, w2, get out r1*/ + "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1\n" + "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1\n" + "vext.32 q8, q3, q4, #1 @ r1, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r1, shift left 1\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1\n" + "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1\n" + "vext.32 q10, q3, q4, #2 @ r1, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r1, shift left 2\n" + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n" + "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1\n" + "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1\n" + "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1\n" + "vld1.32 {d6-d9}, [%[r2]]! @ load r2\n" + "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1\n" + "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1\n" + "vld1.32 {d10}, [%[r2]] @ load r2\n" + /* r2 * w6, w7, w8, get out r0*/ + /* r2 * w3, w4, w5, get out r1*/ + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n" + "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2\n" + "vext.32 q8, q3, q4, #1 @ r2, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r2, shift left 1\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2\n" + "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2\n" + "vext.32 q10, q3, q4, #2 @ r2, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r2, shift left 2\n" + "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2\n" + "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2\n" + "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2\n" + "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2\n" + "vld1.32 {d6-d9}, [%[r3]]! @ load r3\n" + "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2\n" + "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2\n" + "vld1.32 {d10}, [%[r3]] @ load r3\n" + /* r3 * w6, w7, w8, get out r1*/ + "vext.32 q8, q3, q4, #1 @ r3, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r3, shift left 1\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3\n" + "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, or01\n" + "vext.32 q10, q3, q4, #2 @ r3, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r3, shift left 2\n" + "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3\n" + "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00,or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r3\n" + "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3\n" + "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3\n" + "vld1.32 {d10}, [%[r0]] @ load r0\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, or11\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc","memory","q3","q4", + "q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15" + ); + // clang-format on + r0 -= 8; + } + //! deal with remain ow + if (w_loop & 1) { + ptr_out0[0] += + r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + + r1[0] * w_tmp[3] + r1[1] * w_tmp[4] + r1[2] * w_tmp[5] + + r2[0] * w_tmp[6] + r2[1] * w_tmp[7] + r2[2] * w_tmp[8]; + + ptr_out0[1] += + r0[1] * w_tmp[0] + r0[2] * w_tmp[1] + r0[3] * w_tmp[2] + + r1[1] * w_tmp[3] + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + + r2[1] * w_tmp[6] + r2[2] * w_tmp[7] + r2[3] * w_tmp[8]; + + ptr_out0[2] += + r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + + r1[2] * w_tmp[3] + r1[3] * w_tmp[4] + r1[4] * w_tmp[5] + + r2[2] * w_tmp[6] + r2[3] * w_tmp[7] + r2[4] * w_tmp[8]; + + ptr_out0[3] += + r0[3] * w_tmp[0] + r0[4] * w_tmp[1] + r0[5] * w_tmp[2] + + r1[3] * w_tmp[3] + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + + r2[3] * w_tmp[6] + r2[4] * w_tmp[7] + r2[5] * w_tmp[8]; + + ptr_out1[0] += + r1[0] * w_tmp[0] + r1[1] * w_tmp[1] + r1[2] * w_tmp[2] + + r2[0] * w_tmp[3] + r2[1] * w_tmp[4] + r2[2] * w_tmp[5] + + r3[0] * w_tmp[6] + r3[1] * w_tmp[7] + r3[2] * w_tmp[8]; + + ptr_out1[1] += + r1[1] * w_tmp[0] + r1[2] * w_tmp[1] + r1[3] * w_tmp[2] + + r2[1] * w_tmp[3] + r2[2] * w_tmp[4] + r2[3] * w_tmp[5] + + r3[1] * w_tmp[6] + r3[2] * w_tmp[7] + r3[3] * w_tmp[8]; + + ptr_out1[2] += + r1[2] * w_tmp[0] + r1[3] * w_tmp[1] + r1[4] * w_tmp[2] + + r2[2] * w_tmp[3] + r2[3] * w_tmp[4] + r2[4] * w_tmp[5] + + r3[2] * w_tmp[6] + r3[3] * w_tmp[7] + r3[4] * w_tmp[8]; + + ptr_out1[3] += + r1[3] * w_tmp[0] + r1[4] * w_tmp[1] + r1[5] * w_tmp[2] + + r2[3] * w_tmp[3] + r2[4] * w_tmp[4] + r2[5] * w_tmp[5] + + r3[3] * w_tmp[6] + r3[4] * w_tmp[7] + r3[5] * w_tmp[8]; + } + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + } + write_to_output_c1_fp32(pre_out, + dout_batch, + c_idx, + c_idx + 1, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s1_direct_int8.cc b/lite/backends/arm/math/conv3x3s1_direct_int8.cc index d44d911131dd7c096acfe9b59003b00655505ec0..f966313e118acf3f74124aca1d16aa3c50009bb8 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_int8.cc @@ -26,9 +26,9 @@ namespace lite { namespace arm { namespace math { -#ifdef __aarch64__ +template void conv_3x3s1_direct_int8(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -37,62 +37,74 @@ void conv_3x3s1_direct_int8(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, Context* ctx, - PrecisionType out_type, const float* scale) { - const int hin_r_block = 4; - const int hout_c_block = 4; // 8; - const int hout_r_block = 2; - - int stride_w = param.strides[1]; - int pad_w = param.paddings[1]; - int pad_h = param.paddings[0]; bool flag_relu = param.fuse_relu; - bool flag_bias = (param.bias != nullptr); - - int wout_round = ((wout + 3) / 4) * 4; - int win_round = wout_round * stride_w + 4; - - int threads = ctx->threads(); + bool flag_bias = param.bias; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; - int* tmp_work_space = ctx->workspace_data(); - int* ptr_zero = tmp_work_space; - memset(ptr_zero, 0, sizeof(int) * win_round); - int* ptr_write = ptr_zero + win_round; + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 2; + + //! get h block + //! llc_size = win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * + //! hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! win_round = wout_round + 2 + //! hin_r_block = hout_r_block + 2 + int hout_r_block = + (llc_size - 2 * win_round * chin) / + (win_round * chin + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + auto tmp_work_space = ctx->workspace_data(); + int8_t ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(int8_t) * win_round); + Dtype ptr_write[wout_round]; // NOLINT int in_len = win_round * chin; int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); int pre_out_size = hout_c_block * hout_r_block * wout_round; - signed char* pre_din = reinterpret_cast(ptr_write + wout_round); + int8_t* pre_din = tmp_work_space; int size_in_channel = win * hin; int size_out_channel = wout * hout; - int w_stride = chin * 9; + int w_stride = chin * 9; // kernel_w * kernel_h; + int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * int ws = -pad_w; int we = ws + win_round; int w_loop = wout_round / 4; - int size_out = wout_round * hout_c_block; + int out_row_stride = hout_c_block * wout_round; - // printf("win_round: %d, wout_round: %d, ws: %d, we: %d\n", win_round, - // wout_round, ws, we); - // here for (int n = 0; n < num; ++n) { - const signed char* din_batch = - static_cast(din) + n * chin * size_in_channel; - signed char* dout_batch = - reinterpret_cast(dout) + - n * chout * size_out_channel * PrecisionTypeLength(out_type); + const int8_t* din_batch = din + n * chin * size_in_channel; + Dtype* dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } - for (int h = 0; h < hout; h += 2) { int hs = h - pad_h; - int he = hs + 4; - // printf("hs: %d, he: %d, chin: %d, hin: %d, win: %d \n", hs, he, chin, - // hin, win); + int he = hs + h_kernel + 2; + prepack_input_nxw(din_batch, pre_din, 0, @@ -104,701 +116,370 @@ void conv_3x3s1_direct_int8(const int8_t* din, chin, win, hin, - (signed char*)ptr_zero); + ptr_zero); #pragma omp parallel for num_threads(threads) for (int c = 0; c < chout; c += hout_c_block) { #ifdef ARM_WITH_OMP - int* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + - omp_get_thread_num() * pre_out_size; + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size) + + omp_get_thread_num() * pre_out_size; #else - int* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); + auto pre_out = reinterpret_cast(pre_din + pre_in_size); #endif - // printf("ptr_zero_int: %x, ptr_zero: %x, ptr_write: %x, pre_din: %x, - // pre_out: %x \n", ptr_zero_int, ptr_zero, ptr_write, pre_din, - // pre_out); - const signed char* inr0 = pre_din; - const signed char* inr1 = inr0 + in_len; - const signed char* inr2 = inr1 + in_len; - const signed char* inr3 = inr2 + in_len; - - const signed char* wc0 = - static_cast(weights) + c * w_stride; + const int8_t* block_inr0 = pre_din; + const int8_t* block_inr1 = block_inr0 + in_len; + const int8_t* block_inr2 = block_inr1 + in_len; + const int8_t* block_inr3 = block_inr2 + in_len; - const int* bias_ptr = ptr_zero; + const int8_t* weight_c = weights + c * w_stride; + float bias_local[4] = {0, 0, 0, 0}; if (flag_bias) { - bias_ptr = static_cast(bias) + c; - } - // hout_r_block * wout_round * hout_c_block - fill_packed_bias_nxmw_int8( - bias_ptr, pre_out, hout_c_block, hout_r_block, wout_round); - - for (int i = 0; i < chin; ++i) { - const signed char* r0 = inr0; - const signed char* r1 = inr1; - const signed char* r2 = inr2; - const signed char* r3 = inr3; - - int* ptr_out0 = pre_out; - int* ptr_out1 = pre_out + size_out; - - int cnt = w_loop; - const signed char* ptr_wc0 = wc0; - - asm volatile( - "ldp q4, q5, [%[wc0]] \n" /* w4 w5 w6 w7 */ - "ldr q6, [%[wc0], #32] \n" /* w8 */ - "SXTL v11.8h, v4.8b \n" /* w to int16 */ - "SXTL2 v12.8h, v4.16b \n" /* w to int16 */ - "SXTL v13.8h, v5.8b \n" /* to int16 */ - "SXTL2 v14.8h, v5.16b \n" /* to int16 */ - "SXTL v15.8h, v6.8b \n" /* to int16 */ - "1: \n" /* main loop*/ - "ldr d0, [%[r0]] \n" /* load data din0-dinn7*/ - "SXTL v1.8h, v0.8b \n" /* to int16 */ - - /*output 1st row*/ - "smull v16.4s, v11.4h, v1.h[0] \n" /* */ - "smull v17.4s, v11.4h, v1.h[1] \n" /* */ - "smull v18.4s, v11.4h, v1.h[2] \n" /* */ - "smull v19.4s, v11.4h, v1.h[3] \n" /* */ - - "add %[r0], %[r0], #4\n" - - /*output 1st row*/ - "smlal2 v16.4s, v11.8h, v1.h[1] \n" /* */ - "smlal2 v17.4s, v11.8h, v1.h[2] \n" /* */ - "smlal2 v18.4s, v11.8h, v1.h[3] \n" /* */ - "smlal2 v19.4s, v11.8h, v1.h[4] \n" /* */ - - "ldr d0, [%[r1]] \n" /* load data */ - - /*output 1st row*/ - "smlal v16.4s, v12.4h, v1.h[2] \n" /* */ - "smlal v17.4s, v12.4h, v1.h[3] \n" /* */ - "SXTL v2.8h, v0.8b \n" /* to int16 */ - "smlal v18.4s, v12.4h, v1.h[4] \n" /* */ - "smlal v19.4s, v12.4h, v1.h[5] \n" /* */ - - "add %[r1], %[r1], #4 \n" - - /*output 1st row*/ - "smlal2 v16.4s, v12.8h, v2.h[0] \n" /* */ - "smlal2 v17.4s, v12.8h, v2.h[1] \n" /* */ - "smlal2 v18.4s, v12.8h, v2.h[2] \n" /* */ - "smlal2 v19.4s, v12.8h, v2.h[3] \n" /* */ - - /*output 1st row*/ - "smlal v16.4s, v13.4h, v2.h[1] \n" /* */ - "smlal v17.4s, v13.4h, v2.h[2] \n" /* */ - "smlal v18.4s, v13.4h, v2.h[3] \n" /* */ - "smlal v19.4s, v13.4h, v2.h[4] \n" /* */ - - /*output 1st row*/ - "smlal2 v16.4s, v13.8h, v2.h[2] \n" /* */ - "smlal2 v17.4s, v13.8h, v2.h[3] \n" /* */ - "smlal2 v18.4s, v13.8h, v2.h[4] \n" /* */ - "smlal2 v19.4s, v13.8h, v2.h[5] \n" /* */ - - /*output 2rd row*/ - "smull v24.4s, v11.4h, v2.h[0] \n" /* */ - "smull v25.4s, v11.4h, v2.h[1] \n" /* */ - "smull v26.4s, v11.4h, v2.h[2] \n" /* */ - "smull v27.4s, v11.4h, v2.h[3] \n" /* */ - - /*output 2rd row*/ - "smlal2 v24.4s, v11.8h, v2.h[1] \n" /* */ - "smlal2 v25.4s, v11.8h, v2.h[2] \n" /* */ - "smlal2 v26.4s, v11.8h, v2.h[3] \n" /* */ - "smlal2 v27.4s, v11.8h, v2.h[4] \n" /* */ - - "ldr d0, [%[r2]] \n" /* load data */ - - /*output 2rd row*/ - "smlal v24.4s, v12.4h, v2.h[2] \n" /* */ - "smlal v25.4s, v12.4h, v2.h[3] \n" /* */ - "SXTL v1.8h, v0.8b \n" /* to int16 */ - "smlal v26.4s, v12.4h, v2.h[4] \n" /* */ - "smlal v27.4s, v12.4h, v2.h[5] \n" /* */ - - /*output 1st row*/ - "smlal v16.4s, v14.4h, v1.h[0] \n" /* */ - "smlal v17.4s, v14.4h, v1.h[1] \n" /* */ - "smlal v18.4s, v14.4h, v1.h[2] \n" /* */ - "smlal v19.4s, v14.4h, v1.h[3] \n" /* */ - - "add %[r2], %[r2], #4 \n" - - /*output 1st row*/ - "smlal2 v16.4s, v14.8h, v1.h[1] \n" /* */ - "smlal2 v17.4s, v14.8h, v1.h[2] \n" /* */ - "smlal2 v18.4s, v14.8h, v1.h[3] \n" /* */ - "smlal2 v19.4s, v14.8h, v1.h[4] \n" /* */ - - "ldp q3, q4, [%[ptr_out0]] \n" - "ldp q5, q6, [%[ptr_out0], #32] \n" - - /*output 1st row*/ - "smlal v16.4s, v15.4h, v1.h[2] \n" /* */ - "smlal v17.4s, v15.4h, v1.h[3] \n" /* */ - "smlal v18.4s, v15.4h, v1.h[4] \n" /* */ - "smlal v19.4s, v15.4h, v1.h[5] \n" /* */ - - "ADD v3.4s, v16.4s, v3.4s \n" - "ADD v4.4s, v17.4s, v4.4s \n" - "ADD v5.4s, v18.4s, v5.4s \n" - "ADD v6.4s, v19.4s, v6.4s \n" - - "stp q3, q4, [%[ptr_out0]], #32 \n" /* save to - output*/ - "stp q5, q6, [%[ptr_out0]], #32 \n" /* save to - output*/ - - /*output 2rd row*/ - "smlal2 v24.4s, v12.8h, v1.h[0] \n" /* */ - "smlal2 v25.4s, v12.8h, v1.h[1] \n" /* */ - "smlal2 v26.4s, v12.8h, v1.h[2] \n" /* */ - "smlal2 v27.4s, v12.8h, v1.h[3] \n" /* */ - - /*output 2rd row*/ - "smlal v24.4s, v13.4h, v1.h[1] \n" /* */ - "smlal v25.4s, v13.4h, v1.h[2] \n" /* */ - "smlal v26.4s, v13.4h, v1.h[3] \n" /* */ - "smlal v27.4s, v13.4h, v1.h[4] \n" /* */ - - "ldr d0, [%[r3]] \n" /* load data */ - - /*output 2rd row*/ - "smlal2 v24.4s, v13.8h, v1.h[2] \n" /* */ - "smlal2 v25.4s, v13.8h, v1.h[3] \n" /* */ - "SXTL v2.8h, v0.8b \n" /* to int16 */ - "smlal2 v26.4s, v13.8h, v1.h[4] \n" /* */ - "smlal2 v27.4s, v13.8h, v1.h[5] \n" /* */ - - /*output 2rd row*/ - "smlal v24.4s, v14.4h, v2.h[0] \n" /* */ - "smlal v25.4s, v14.4h, v2.h[1] \n" /* */ - "smlal v26.4s, v14.4h, v2.h[2] \n" /* */ - "smlal v27.4s, v14.4h, v2.h[3] \n" /* */ - - "add %[r3], %[r3], #4 \n" - - /*output 2rd row*/ - "smlal2 v24.4s, v14.8h, v2.h[1] \n" /* */ - "smlal2 v25.4s, v14.8h, v2.h[2] \n" /* */ - "smlal2 v26.4s, v14.8h, v2.h[3] \n" /* */ - "smlal2 v27.4s, v14.8h, v2.h[4] \n" /* */ - - "ldp q3, q4, [%[ptr_out1]] \n" - "ldp q5, q6, [%[ptr_out1], #32] \n" - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ - - /*output 2rd row*/ - "smlal v24.4s, v15.4h, v2.h[2] \n" /* */ - "smlal v25.4s, v15.4h, v2.h[3] \n" /* */ - "smlal v26.4s, v15.4h, v2.h[4] \n" /* */ - "smlal v27.4s, v15.4h, v2.h[5] \n" /* */ - - "ADD v3.4s, v24.4s, v3.4s \n" - "ADD v4.4s, v25.4s, v4.4s \n" - "ADD v5.4s, v26.4s, v5.4s \n" - "ADD v6.4s, v27.4s, v6.4s \n" - - "stp q3, q4, [%[ptr_out1]], #32 \n" /* save to output*/ - "stp q5, q6, [%[ptr_out1]], #32 \n" /* save to output*/ - - "bne 1b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [wc0] "+r"(ptr_wc0), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v24", - "v25", - "v26", - "v27" - - ); - - wc0 += 9 * hout_c_block; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - } - if (out_type == PRECISION(kFloat)) { - write_to_output_c4_int32_1(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - hout_r_block, - c, - c + 4, - h, - h + 2, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else if (out_type == PRECISION(kInt8)) { - write_to_output_c4_int32_1(pre_out, - dout_batch, - hout_c_block, - hout_r_block, - c, - c + 4, - h, - h + 2, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else { // int32 - write_to_output_c4_int32(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - hout_r_block, - c, - c + 4, - h, - h + 2, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - ptr_write); + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; } - } - } - } -} - -#else - -void conv_3x3s1_direct_int8(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale) { - // printf("conv2_3x3s1_direct_int8 \n"); - - const int hin_r_block = 4; - const int hout_c_block = 4; // 8 - const int hout_r_block = 2; - - int stride_w = param.strides[1]; - int pad_w = param.paddings[1]; - int pad_h = param.paddings[0]; - bool flag_relu = param.fuse_relu; - bool flag_bias = (param.bias != nullptr); - - int wout_round = ((wout + 3) / 4) * 4; - int win_round = wout_round * stride_w + 4; - - int threads = ctx->threads(); - - int* tmp_work_space = ctx->workspace_data(); - int* ptr_zero = tmp_work_space; - memset(ptr_zero, 0, sizeof(int) * win_round); - int* ptr_write = ptr_zero + win_round; - - int in_len = win_round * chin; - int pre_in_size = hin_r_block * in_len; - int pre_out_size = hout_c_block * hout_r_block * wout_round; - - signed char* pre_din = reinterpret_cast(ptr_write + wout_round); - - int size_in_channel = win * hin; - int size_out_channel = wout * hout; - int w_stride = chin * 9; - - int ws = -pad_w; - int we = ws + win_round; - int w_loop = wout_round / 4; - - int size_out = wout_round * hout_c_block; - - // printf("win_round: %d, wout_round: %d, ws: %d, we: %d\n", win_round, - // wout_round, ws, we); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = - static_cast(din) + n * chin * size_in_channel; - signed char* dout_batch = - reinterpret_cast(dout) + - n * chout * size_out_channel * PrecisionTypeLength(out_type); - - for (int h = 0; h < hout; h += 2) { - int hs = h - pad_h; - int he = hs + 4; - // printf("hs: %d, he: %d, chin: %d, hin: %d, win: %d \n", hs, he, chin, - // hin, win); - prepack_input_nxw(din_batch, - pre_din, - 0, - chin, - hs, - he, - ws, - we, - chin, - win, - hin, - (signed char*)ptr_zero); + memset(pre_out, 0, pre_out_size * sizeof(int32_t)); + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const int8_t* wc0 = weight_c; + + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + const int8_t* inr3 = block_inr3; + + int32_t* pre_out0 = pre_out + hk * out_row_stride; + int32_t* pre_out1 = pre_out0 + out_row_stride; + + for (int i = 0; i < chin; ++i) { + int32_t* ptr_out0 = pre_out0; + int32_t* ptr_out1 = pre_out1; + + const signed char* r0 = inr0; + const signed char* r1 = inr1; + const signed char* r2 = inr2; + const signed char* r3 = inr3; + + int cnt = w_loop; + const int8_t* ptr_wc0 = wc0; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "ldp q4, q5, [%[wc0]]\n" + "ldr d6, [%[wc0], #32]\n" + "sxtl v11.8h, v4.8b\n" + "sxtl2 v12.8h, v4.16b\n" + "sxtl v13.8h, v5.8b\n" + "sxtl2 v14.8h, v5.16b\n" + "sxtl v15.8h, v6.8b\n" + "ldp q16, q17, [%[ptr_out0]]\n" + "ldp q18, q19, [%[ptr_out0], #32]\n" + "ldr d0, [%[r1]], #4\n" /* load r1 */ + "ldr d1, [%[r2]], #4\n" /* load r2 */ + "sxtl v2.8h, v0.8b\n" /* r1, cvt to int16 */ + "sxtl v3.8h, v1.8b\n" /* r2, cvt to int16 */ + "1:\n" + /* inr1 -> outr0, outr1 */ + "ldp q20, q21, [%[ptr_out1]]\n" + "ldr d0, [%[r0]], #4\n" /* load r0 */ + "smlal2 v16.4s, v12.8h, v2.h[0]\n" /* out00, w10 * r10 */ + "smlal2 v17.4s, v12.8h, v2.h[1]\n" /* out01, w10 * r11 */ + "smlal2 v18.4s, v12.8h, v2.h[2]\n" /* out02, w10 * r12 */ + "smlal2 v19.4s, v12.8h, v2.h[3]\n" /* out03, w10 * r13 */ + "ldp q22, q23, [%[ptr_out1], #32]\n" + "smlal v16.4s, v13.4h, v2.h[1]\n" /* out00, w11 * r11 */ + "smlal v17.4s, v13.4h, v2.h[2]\n" /* out01, w11 * r12 */ + "smlal v18.4s, v13.4h, v2.h[3]\n" /* out02, w11 * r13 */ + "smlal v19.4s, v13.4h, v2.h[4]\n" /* out03, w11 * r14 */ + "smlal2 v16.4s, v13.8h, v2.h[2]\n" /* out00, w12 * r12 */ + "smlal2 v17.4s, v13.8h, v2.h[3]\n" /* out01, w12 * r13 */ + "smlal2 v18.4s, v13.8h, v2.h[4]\n" /* out02, w12 * r14 */ + "smlal2 v19.4s, v13.8h, v2.h[5]\n" /* out03, w12 * r15 */ + "smlal v20.4s, v11.4h, v2.h[0]\n" /* out10, w00 * r10 */ + "smlal v21.4s, v11.4h, v2.h[1]\n" /* out11, w00 * r11 */ + "smlal v22.4s, v11.4h, v2.h[2]\n" /* out12, w00 * r12 */ + "smlal v23.4s, v11.4h, v2.h[3]\n" /* out13, w00 * r13 */ + "smlal2 v20.4s, v11.8h, v2.h[1]\n" /* out10, w01 * r11 */ + "smlal2 v21.4s, v11.8h, v2.h[2]\n" /* out11, w01 * r12 */ + "smlal2 v22.4s, v11.8h, v2.h[3]\n" /* out12, w01 * r13 */ + "smlal2 v23.4s, v11.8h, v2.h[4]\n" /* out13, w01 * r14 */ + "smlal v20.4s, v12.4h, v2.h[2]\n" /* out10, w02 * r12 */ + "smlal v21.4s, v12.4h, v2.h[3]\n" /* out11, w02 * r13 */ + "smlal v22.4s, v12.4h, v2.h[4]\n" /* out12, w02 * r14 */ + "smlal v23.4s, v12.4h, v2.h[5]\n" /* out13, w02 * r15 */ + "sxtl v2.8h, v0.8b\n" /* r0, cvt to int16 */ + /* inr2 -> outr0, outr1 */ + "ldr d1, [%[r3]], #4\n" /* load r3 */ + "smlal v16.4s, v14.4h, v3.h[0]\n" /* out00, w20 * r20 */ + "smlal v17.4s, v14.4h, v3.h[1]\n" /* out01, w20 * r21 */ + "smlal v18.4s, v14.4h, v3.h[2]\n" /* out02, w20 * r22 */ + "smlal v19.4s, v14.4h, v3.h[3]\n" /* out03, w20 * r23 */ + "smlal2 v20.4s, v12.8h, v3.h[0]\n" /* out10, w10 * r20 */ + "smlal2 v21.4s, v12.8h, v3.h[1]\n" /* out11, w10 * r21 */ + "smlal2 v22.4s, v12.8h, v3.h[2]\n" /* out12, w10 * r22 */ + "smlal2 v23.4s, v12.8h, v3.h[3]\n" /* out13, w10 * r23 */ + "smlal2 v16.4s, v14.8h, v3.h[1]\n" /* out00, w21 * r21 */ + "smlal2 v17.4s, v14.8h, v3.h[2]\n" /* out01, w21 * r22 */ + "smlal2 v18.4s, v14.8h, v3.h[3]\n" /* out02, w21 * r23 */ + "smlal2 v19.4s, v14.8h, v3.h[4]\n" /* out03, w21 * r24 */ + "smlal v20.4s, v13.4h, v3.h[1]\n" /* out10, w11 * r21 */ + "smlal v21.4s, v13.4h, v3.h[2]\n" /* out11, w11 * r22 */ + "smlal v22.4s, v13.4h, v3.h[3]\n" /* out12, w11 * r23 */ + "smlal v23.4s, v13.4h, v3.h[4]\n" /* out13, w11 * r24 */ + "smlal v16.4s, v15.4h, v3.h[2]\n" /* out00, w22 * r22 */ + "smlal v17.4s, v15.4h, v3.h[3]\n" /* out01, w22 * r23 */ + "smlal v18.4s, v15.4h, v3.h[4]\n" /* out02, w22 * r24 */ + "smlal v19.4s, v15.4h, v3.h[5]\n" /* out03, w22 * r25 */ + "smlal2 v20.4s, v13.8h, v3.h[2]\n" /* out10, w12 * r22 */ + "smlal2 v21.4s, v13.8h, v3.h[3]\n" /* out11, w12 * r23 */ + "smlal2 v22.4s, v13.8h, v3.h[4]\n" /* out12, w12 * r24 */ + "smlal2 v23.4s, v13.8h, v3.h[5]\n" /* out13, w12 * r25 */ + "sxtl v3.8h, v1.8b\n" /* r0, cvt to int16 */ + /* inr0 -> outr0 */ + "ldr d0, [%[r1]], #4\n" /* load r1 */ + "smlal v16.4s, v11.4h, v2.h[0]\n" /* out00, w00 * r00 */ + "smlal v17.4s, v11.4h, v2.h[1]\n" /* out01, w00 * r01 */ + "smlal v18.4s, v11.4h, v2.h[2]\n" /* out02, w00 * r02 */ + "smlal v19.4s, v11.4h, v2.h[3]\n" /* out03, w00 * r03 */ + "smlal2 v16.4s, v11.8h, v2.h[1]\n" /* out00, w01 * r01 */ + "smlal2 v17.4s, v11.8h, v2.h[2]\n" /* out01, w01 * r02 */ + "smlal2 v18.4s, v11.8h, v2.h[3]\n" /* out02, w01 * r03 */ + "smlal2 v19.4s, v11.8h, v2.h[4]\n" /* out03, w01 * r04 */ + "smlal v16.4s, v12.4h, v2.h[2]\n" /* out00, w02 * r02 */ + "smlal v17.4s, v12.4h, v2.h[3]\n" /* out01, w02 * r03 */ + "smlal v18.4s, v12.4h, v2.h[4]\n" /* out02, w02 * r04 */ + "smlal v19.4s, v12.4h, v2.h[5]\n" /* out03, w02 * r05 */ + "sxtl v2.8h, v0.8b\n" /* r0, cvt to int16 */ + /* inr3 -> outr1 */ + "ldr d1, [%[r2]], #4\n" /* load r2 */ + "stp q16, q17, [%[ptr_out0]], #32\n" + "smlal v20.4s, v14.4h, v3.h[0]\n" /* out10, w20 * r30 */ + "smlal v21.4s, v14.4h, v3.h[1]\n" /* out11, w20 * r31 */ + "smlal v22.4s, v14.4h, v3.h[2]\n" /* out12, w20 * r32 */ + "smlal v23.4s, v14.4h, v3.h[3]\n" /* out13, w20 * r33 */ + "stp q18, q19, [%[ptr_out0]], #32\n" + "ldp q16, q17, [%[ptr_out0]]\n" + "smlal2 v20.4s, v14.8h, v3.h[1]\n" /* out10, w21 * r31 */ + "smlal2 v21.4s, v14.8h, v3.h[2]\n" /* out11, w21 * r32 */ + "smlal2 v22.4s, v14.8h, v3.h[3]\n" /* out12, w21 * r33 */ + "smlal2 v23.4s, v14.8h, v3.h[4]\n" /* out13, w21 * r34 */ + "ldp q18, q19, [%[ptr_out0], #32]\n" + "smlal v20.4s, v15.4h, v3.h[2]\n" /* out10, w22 * r32 */ + "smlal v21.4s, v15.4h, v3.h[3]\n" /* out11, w22 * r33 */ + "smlal v22.4s, v15.4h, v3.h[4]\n" /* out12, w22 * r34 */ + "smlal v23.4s, v15.4h, v3.h[5]\n" /* out13, w22 * r35 */ + "sxtl v3.8h, v1.8b\n" /* r0, cvt to int16 */ + "subs %w[cnt], %w[cnt], #1\n" + "stp q20, q21, [%[ptr_out1]], #32\n" + "stp q22, q23, [%[ptr_out1]], #32\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [wc0] "+r"(ptr_wc0), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", + "v5", "v6", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20","v21", "v22", "v23" + + ); -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < chout; c += hout_c_block) { // 4 -#ifdef ARM_WITH_OMP - int* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + - omp_get_thread_num() * pre_out_size; #else - int* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); -#endif - // printf("ptr_zero_int: %x, ptr_zero: %x, ptr_write: %x, pre_din: %x, - // pre_out: %x \n", ptr_zero_int, ptr_zero, ptr_write, pre_din, - // pre_out); - const signed char* inr0 = pre_din; - const signed char* inr1 = inr0 + in_len; - const signed char* inr2 = inr1 + in_len; - const signed char* inr3 = inr2 + in_len; - - const signed char* wc0 = - static_cast(weights) + c * w_stride; - - const int* bias_ptr = ptr_zero; - if (flag_bias) { - bias_ptr = static_cast(bias) + c; - } - // hout_r_block * wout_round * hout_c_block - fill_packed_bias_nxmw_int8( - bias_ptr, pre_out, hout_c_block, hout_r_block, wout_round); - - for (int i = 0; i < chin; ++i) { - const signed char* r0 = inr0; - const signed char* r1 = inr1; - const signed char* r2 = inr2; - const signed char* r3 = inr3; - - int* ptr_out0 = pre_out; - int* ptr_out1 = pre_out + size_out; - - int cnt = w_loop; - const signed char* ptr_wc = wc0; - - asm volatile( - "vld1.s8 {d0-d3}, [%[wc0]]! \n" /* wc0, wc1, wc2, wc3, wc4, - wc5, wc6, wc7*/ - "vld1.s8 {d4}, [%[wc0]]! \n" /* wc8 */ - "vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ - "vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ - "vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ - "vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ - "vmovl.s8 q7, d4 \n" /* q7 = w8 */ - - "1: \n" /* main loop*/ - "vld1.s32 {d0}, [%[r0]] \n" /* load data din0-dinn7*/ - "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ - /*output 1st row*/ - "vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ - "vmull.s16 q9, d6, d0[1] \n" /* q9 = w0 * r0[2] */ - "vmull.s16 q10, d6, d0[2] \n" /* q10 = w0 * r0[4] */ - "vmull.s16 q11, d6, d0[3] \n" /* q11 = w0 * r0[6] */ - - "add %[r0], #4 \n" - - /*output 1st row*/ - "vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ - "vmlal.s16 q9, d7, d0[2] \n" /* q9 = w1 * r0[2] */ - "vmlal.s16 q10, d7, d0[3] \n" /* q10 = w1 * r0[3] */ - "vmlal.s16 q11, d7, d1[0] \n" /* q11 = w1 * r0[4] */ - - "vld1.s32 {d2}, [%[r1]] \n" /* load input r1 -> d2 */ - "vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ - - /*output 1st row*/ - "vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ - "vmlal.s16 q9, d8, d0[3] \n" /* q9 = w2 * r0[3] */ - "vmlal.s16 q10, d8, d1[0] \n" /* q10 = w2 * r0[4] */ - "vmlal.s16 q11, d8, d1[1] \n" /* q11 = w2 * r0[5] */ - - /*output 1st row*/ - "vmlal.s16 q8, d9, d2[0] \n" /* */ - "vmlal.s16 q9, d9, d2[1] \n" /* */ - "vmlal.s16 q10, d9, d2[2] \n" /* */ - "vmlal.s16 q11, d9, d2[3] \n" /* */ - - "add %[r1], #4 \n" - - /*output 1st row*/ - "vmlal.s16 q8, d10, d2[1] \n" /* */ - "vmlal.s16 q9, d10, d2[2] \n" /* */ - "vmlal.s16 q10, d10, d2[3] \n" /* */ - "vmlal.s16 q11, d10, d3[0] \n" /* */ - - /*output 1st row*/ - "vmlal.s16 q8, d11, d2[2] \n" /* */ - "vmlal.s16 q9, d11, d2[3] \n" /* */ - "vmlal.s16 q10, d11, d3[0] \n" /* */ - "vmlal.s16 q11, d11, d3[1] \n" /* */ - - /*output 2rd row*/ - "vmull.s16 q12, d6, d2[0] \n" /* */ - "vmull.s16 q13, d6, d2[1] \n" /* */ - "vmull.s16 q14, d6, d2[2] \n" /* */ - "vmull.s16 q15, d6, d2[3] \n" /* */ - - "vld1.s32 {d0}, [%[r2]] \n" /* load input r2 -> d2 */ - "vmovl.s8 q0, d0 \n" /* movl d2 -> q1 */ - - /*output 2rd row*/ - "vmlal.s16 q12, d7, d2[1] \n" /* */ - "vmlal.s16 q13, d7, d2[2] \n" /* */ - "vmlal.s16 q14, d7, d2[3] \n" /* */ - "vmlal.s16 q15, d7, d3[0] \n" /* */ - - /*output 2rd row*/ - "vmlal.s16 q12, d8, d2[2] \n" /* */ - "vmlal.s16 q13, d8, d2[3] \n" /* */ - "vmlal.s16 q14, d8, d3[0] \n" /* */ - "vmlal.s16 q15, d8, d3[1] \n" /* */ - - "add %[r2], #4 \n" - - /*output 1st row*/ - "vmlal.s16 q8, d12, d0[0] \n" /* */ - "vmlal.s16 q9, d12, d0[1] \n" /* */ - "vmlal.s16 q10, d12, d0[2] \n" /* */ - "vmlal.s16 q11, d12, d0[3] \n" /* */ - - /*output 1st row*/ - "vmlal.s16 q8, d13, d0[1] \n" /* */ - "vmlal.s16 q9, d13, d0[2] \n" /* */ - "vmlal.s16 q10, d13, d0[3] \n" /* */ - "vmlal.s16 q11, d13, d1[0] \n" /* */ - - "vld1.32 {d2-d5}, [%[ptr_out0]] \n" /* load ptr_out -> q, q - */ - - /*output 1st row*/ - "vmlal.s16 q8, d14, d0[2] \n" /* */ - "vmlal.s16 q9, d14, d0[3] \n" /* */ - "vmlal.s16 q10, d14, d1[0] \n" /* */ - "vmlal.s16 q11, d14, d1[1] \n" /* */ - - /*load & store output 1st row*/ - "vadd.s32 q1, q8, q1 \n" /* out[0] += q8 */ - "vadd.s32 q2, q9, q2 \n" /* out[0] += q8 */ - "vst1.s32 {d2-d5}, [%[ptr_out0]]! \n" - - /*output 2rd row*/ - "vmlal.s16 q12, d9, d0[0] \n" /* */ - "vmlal.s16 q13, d9, d0[1] \n" /* */ - "vmlal.s16 q14, d9, d0[2] \n" /* */ - "vmlal.s16 q15, d9, d0[3] \n" /* */ - - "vld1.32 {d2-d5}, [%[ptr_out0]] \n" /* load ptr_out -> q2, q3 - */ - - /*output 2rd row */ - "vmlal.s16 q12, d10, d0[1] \n" /* */ - "vmlal.s16 q13, d10, d0[2] \n" /* */ - "vadd.s32 q1, q10, q1 \n" /* out[0] += q */ - "vadd.s32 q2, q11, q2 \n" /* out[1] += q */ - - "vmlal.s16 q14, d10, d0[3] \n" /* */ - "vst1.s32 {d2-d5}, [%[ptr_out0]]! \n" - "vmlal.s16 q15, d10, d1[0] \n" /* */ - - /*output 2rd row */ - "vmlal.s16 q12, d11, d0[2] \n" /* */ - "vmlal.s16 q13, d11, d0[3] \n" /* */ - - "vld1.s32 {d4}, [%[r3]] \n" /* load input r2 -> d2 - */ - "vmovl.s8 q2, d4 \n" /* movl d2 -> q2 */ - - "vmlal.s16 q14, d11, d1[0] \n" /* */ - "vmlal.s16 q15, d11, d1[1] \n" /* */ - - "add %[r3], #4 \n" - - /*output 2rd row */ - "vmlal.s16 q12, d12, d4[0] \n" /* */ - "vmlal.s16 q13, d12, d4[1] \n" /* */ - "vmlal.s16 q14, d12, d4[2] \n" /* */ - "vmlal.s16 q15, d12, d4[3] \n" /* */ - - "vld1.32 {d0-d3}, [%[ptr_out1]] \n" /* */ - - /*output 2rd row */ - "vmlal.s16 q12, d13, d4[1] \n" /* */ - "vmlal.s16 q13, d13, d4[2] \n" /* */ - "vmlal.s16 q14, d13, d4[3] \n" /* */ - "vmlal.s16 q15, d13, d5[0] \n" /* */ - - "subs %[cnt], #1 \n" - - /*output 2rd row */ - "vmlal.s16 q12, d14, d4[2] \n" /* */ - "vmlal.s16 q13, d14, d4[3] \n" /* */ - "vmlal.s16 q14, d14, d5[0] \n" /* */ - "vmlal.s16 q15, d14, d5[1] \n" /* */ - - /*output 2rd row*/ - "vadd.s32 q0, q12, q0 \n" /* */ - "vadd.s32 q1, q13, q1 \n" /* */ - "vst1.s32 {d0-d3}, [%[ptr_out1]]! \n" - - "vld1.32 {d0-d3}, [%[ptr_out1]] \n" /* */ - "vadd.s32 q0, q14, q0 \n" /* */ - "vadd.s32 q1, q15, q1 \n" /* */ - "vst1.s32 {d0-d3}, [%[ptr_out1]]! \n" - - "bne 1b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1), - [wc0] "+r"(ptr_wc) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - wc0 += 9 * hout_c_block; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; + asm volatile( + "vld1.32 {d0-d3}, [%[wc0]]!\n" + "vld1.32 {d4}, [%[wc0]]!\n" + "vmovl.s8 q3, d0\n" /* q3 = w0, w1 */ + "vmovl.s8 q4, d1\n" /* q4 = w2 ,w3 */ + "vmovl.s8 q5, d2\n" /* q5 = w4, w5 */ + "vmovl.s8 q6, d3\n" /* q6 = w6, w7 */ + "vmovl.s8 q7, d4\n" /* q7 = w8 */ + "vld1.32 d0, [%[r1]]\n" + "vld1.32 d1, [%[r2]]\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]!\n" + "vld1.32 {d20-d23}, [%[ptr_out0]]\n" + "vmovl.s8 q1, d0\n" + "vmovl.s8 q2, d1\n" + "1:\n" + /* inr1 -> outr0, outr1 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]!\n" + "vld1.32 d0, [%[r0]]\n" /* load r0 */ + "vmlal.s16 q8, d9, d2[0]\n" /* out00, w10 * r10 */ + "vmlal.s16 q9, d9, d2[1]\n" /* out01, w10 * r11 */ + "vmlal.s16 q10, d9, d2[2]\n" /* out02, w10 * r12 */ + "vmlal.s16 q11, d9, d2[3]\n" /* out03, w10 * r13 */ + "vld1.32 {d28-d31}, [%[ptr_out1]]\n" + "vmlal.s16 q8, d10, d2[1]\n" /* out00, w11 * r11 */ + "vmlal.s16 q9, d10, d2[2]\n" /* out01, w11 * r12 */ + "vmlal.s16 q10, d10, d2[3]\n" /* out02, w11 * r13 */ + "vmlal.s16 q11, d10, d3[0]\n" /* out03, w11 * r14 */ + "sub %[ptr_out0], %[ptr_out0], #32\n" + "vmlal.s16 q8, d11, d2[2]\n" /* out00, w12 * r12 */ + "vmlal.s16 q9, d11, d2[3]\n" /* out01, w12 * r13 */ + "vmlal.s16 q10, d11, d3[0]\n" /* out02, w12 * r14 */ + "vmlal.s16 q11, d11, d3[1]\n" /* out03, w12 * r15 */ + "vmlal.s16 q12, d6, d2[0]\n" /* out10, w00 * r10 */ + "vmlal.s16 q13, d6, d2[1]\n" /* out11, w00 * r11 */ + "vmlal.s16 q14, d6, d2[2]\n" /* out12, w00 * r12 */ + "vmlal.s16 q15, d6, d2[3]\n" /* out13, w00 * r13 */ + "add %[r1], %[r1], #4\n" + "vmlal.s16 q12, d7, d2[1]\n" /* out10, w01 * r11 */ + "vmlal.s16 q13, d7, d2[2]\n" /* out11, w01 * r12 */ + "vmlal.s16 q14, d7, d2[3]\n" /* out12, w01 * r13 */ + "vmlal.s16 q15, d7, d3[0]\n" /* out13, w01 * r14 */ + "sub %[ptr_out1], %[ptr_out1], #32\n" + "vmlal.s16 q12, d8, d2[2]\n" /* out10, w02 * r12 */ + "vmlal.s16 q13, d8, d2[3]\n" /* out11, w02 * r13 */ + "vmlal.s16 q14, d8, d3[0]\n" /* out12, w02 * r14 */ + "vmlal.s16 q15, d8, d3[1]\n" /* out13, w02 * r15 */ + "vmovl.s8 q1, d0\n" /* r0, cvt to int16 */ + /* inr2 -> outr0, outr1 */ + "vld1.32 d1, [%[r3]]\n" /* load r3 */ + "vmlal.s16 q8, d12, d4[0]\n" /* out00, w20 * r20 */ + "vmlal.s16 q9, d12, d4[1]\n" /* out01, w20 * r21 */ + "vmlal.s16 q10, d12, d4[2]\n" /* out02, w20 * r22 */ + "vmlal.s16 q11, d12, d4[3]\n" /* out03, w20 * r23 */ + "add %[r2], %[r2], #4\n" + "vmlal.s16 q12, d9, d4[0]\n" /* out10, w10 * r20 */ + "vmlal.s16 q13, d9, d4[1]\n" /* out11, w10 * r21 */ + "vmlal.s16 q14, d9, d4[2]\n" /* out12, w10 * r22 */ + "vmlal.s16 q15, d9, d4[3]\n" /* out13, w10 * r23 */ + "vmlal.s16 q8, d13, d4[1]\n" /* out00, w21 * r21 */ + "vmlal.s16 q9, d13, d4[2]\n" /* out01, w21 * r22 */ + "vmlal.s16 q10, d13, d4[3]\n" /* out02, w21 * r23 */ + "vmlal.s16 q11, d13, d5[0]\n" /* out03, w21 * r24 */ + "add %[r0], %[r0], #4\n" + "vmlal.s16 q12, d10, d4[1]\n" /* out10, w11 * r21 */ + "vmlal.s16 q13, d10, d4[2]\n" /* out11, w11 * r22 */ + "vmlal.s16 q14, d10, d4[3]\n" /* out12, w11 * r23 */ + "vmlal.s16 q15, d10, d5[0]\n" /* out13, w11 * r24 */ + "vmlal.s16 q8, d14, d4[2]\n" /* out00, w22 * r22 */ + "vmlal.s16 q9, d14, d4[3]\n" /* out01, w22 * r23 */ + "vmlal.s16 q10, d14, d5[0]\n" /* out02, w22 * r24 */ + "vmlal.s16 q11, d14, d5[1]\n" /* out03, w22 * r25 */ + "add %[r3], %[r3], #4\n" + "vmlal.s16 q12, d11, d4[2]\n" /* out10, w12 * r22 */ + "vmlal.s16 q13, d11, d4[3]\n" /* out11, w12 * r23 */ + "vmlal.s16 q14, d11, d5[0]\n" /* out12, w12 * r24 */ + "vmlal.s16 q15, d11, d5[1]\n" /* out13, w12 * r25 */ + "vmovl.s8 q2, d1\n" /* r3, cvt to int16 */ + /* inr0 -> outr0 */ + "vld1.32 d0, [%[r1]]\n" /* load r1 */ + "vmlal.s16 q8, d6, d2[0]\n" /* out00, w00 * r00 */ + "vmlal.s16 q9, d6, d2[1]\n" /* out01, w00 * r01 */ + "vmlal.s16 q10, d6, d2[2]\n" /* out02, w00 * r02 */ + "vmlal.s16 q11, d6, d2[3]\n" /* out03, w00 * r03 */ + "vmlal.s16 q8, d7, d2[1]\n" /* out00, w01 * r01 */ + "vmlal.s16 q9, d7, d2[2]\n" /* out01, w01 * r02 */ + "vmlal.s16 q10, d7, d2[3]\n" /* out02, w01 * r03 */ + "vmlal.s16 q11, d7, d3[0]\n" /* out03, w01 * r04 */ + "vmlal.s16 q8, d8, d2[2]\n" /* out00, w02 * r02 */ + "vmlal.s16 q9, d8, d2[3]\n" /* out01, w02 * r03 */ + "vmlal.s16 q10, d8, d3[0]\n" /* out02, w02 * r04 */ + "vmlal.s16 q11, d8, d3[1]\n" /* out03, w02 * r05 */ + "vmovl.s8 q1, d0\n" /* r1, cvt to int16 */ + /* inr3 -> outr1 */ + "vld1.32 {d1}, [%[r2]]\n" /* load r2 */ + "vst1.32 {d16-d19}, [%[ptr_out0]]!\n" + "vmlal.s16 q12, d12, d4[0]\n" /* out10, w20 * r30 */ + "vmlal.s16 q13, d12, d4[1]\n" /* out11, w20 * r31 */ + "vmlal.s16 q14, d12, d4[2]\n" /* out12, w20 * r32 */ + "vmlal.s16 q15, d12, d4[3]\n" /* out13, w20 * r33 */ + "vst1.32 {d20-d23}, [%[ptr_out0]]!\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]!\n" + "vmlal.s16 q12, d13, d4[1]\n" /* out10, w21 * r31 */ + "vmlal.s16 q13, d13, d4[2]\n" /* out11, w21 * r32 */ + "vmlal.s16 q14, d13, d4[3]\n" /* out12, w21 * r33 */ + "vmlal.s16 q15, d13, d5[0]\n" /* out13, w21 * r34 */ + "vld1.32 {d20-d23}, [%[ptr_out0]]\n" + "vmlal.s16 q12, d14, d4[2]\n" /* out10, w22 * r32 */ + "vmlal.s16 q13, d14, d4[3]\n" /* out11, w22 * r33 */ + "vmlal.s16 q14, d14, d5[0]\n" /* out12, w22 * r34 */ + "vmlal.s16 q15, d14, d5[1]\n" /* out13, w22 * r35 */ + "vmovl.s8 q2, d1\n" /* r2, cvt to int16 */ + "subs %[cnt], #1\n" + "vst1.32 {d24-d27}, [%[ptr_out1]]!\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]!\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(ptr_wc0) + : + : "cc", "memory", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15" + ); +#endif // __aarch64__ + // clang-format on + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; } - - if (out_type == PRECISION(kFloat)) { - write_to_output_c4_int32_1(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - hout_r_block, - c, - c + 4, - h, - h + 2, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else if (out_type == PRECISION(kInt8)) { - write_to_output_c4_int32_1(pre_out, - dout_batch, - hout_c_block, - hout_r_block, - c, - c + 4, - h, - h + 2, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else { // int32 - write_to_output_c4_int32(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - hout_r_block, + write_int32_nchwc4_to_nchw(pre_out, + dout_batch, c, c + 4, h, - h + 2, + h + hout_r_block, 0, wout_round, chout, hout, wout, flag_relu, - ptr_write); - } + bias_local, + flag_bias, + ptr_write, + scale + c); } } } } -#endif // __aarch64__ +template void conv_3x3s1_direct_int8(const int8_t* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx, + const float* scale); + +template void conv_3x3s1_direct_int8(const int8_t* din, + int8_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx, + const float* scale); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d75323a9677f1cfbed726a1a28920dd77131688 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc @@ -0,0 +1,361 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_3x3s2_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int out_c_block = 4; + const int out_h_kernel = 1; + const int out_w_kernel = 4; + const int win_ext = ow * 2 + 1; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh * 2 + 1; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + + /// get workspace + auto ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc0 = dout_c00 + h * ow; + float* outc1 = outc0 + size_out_channel; + float* outc2 = outc1 + size_out_channel; + float* outc3 = outc2 + size_out_channel; + const float* inr0 = pre_din + h * 2 * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: + outc1 = ptr_write; + case 2: + outc2 = ptr_write; + case 1: + outc3 = ptr_write; + default: + break; + } + } + auto c0 = outc0; + auto c1 = outc1; + auto c2 = outc2; + auto c3 = outc3; + float pre_out[16]; + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + if (flag_mask) { + c0 = outc0; + c1 = outc1; + c2 = outc2; + c3 = outc3; + outc0 = pre_out; + outc1 = pre_out + 4; + outc2 = pre_out + 8; + outc3 = pre_out + 12; + } +// clang-format off +#ifdef __aarch64__ + asm volatile( + "ldr q8, [%[bias]]\n" /* load bias */ + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ + "and v19.16b, v8.16b, v8.16b\n" + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ + "and v20.16b, v8.16b, v8.16b\n" + "ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ + "and v21.16b, v8.16b, v8.16b\n" + "ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ + "and v22.16b, v8.16b, v8.16b\n" + "ldr q8, [%[inr0]]\n" /* load input r0*/ + /* r0 mul w0-w2, get out */ + "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ + "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ + "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ + "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ + "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ + "ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ + "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ + "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ + "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ + "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ + "ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ + "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ + "ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ + "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ + "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ + "ldr q8, [%[inr1]]\n" /* load input r1*/ + /* r1, mul w3-w5, get out */ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ + "fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ + "fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ + "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ + "fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ + "fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ + "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ + "fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ + "ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ + "fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ + "ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ + "fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ + "ldr q8, [%[inr2]]\n" /* load input r2*/ + /* r2, mul w6-w8, get out r0, r1 */ + "fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ + "fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ + "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ + "fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ + "fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ + "fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ + "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ + "fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ + "fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ + "fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ + "fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ + "fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ + /* transpose */ + "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ + "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ + "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ + "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ + "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ + "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ + "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ + "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ + /* relu */ + "cbz %w[flag_relu], 0f\n" /* skip relu*/ + "movi v0.4s, #0\n" /* for relu */ + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + /* save result */ + "0:\n" + "str q19, [%[outc0]], #16\n" + "str q20, [%[outc1]], #16\n" + "str q21, [%[outc2]], #16\n" + "str q22, [%[outc3]], #16\n" + :[inr0] "+r"(inr0), [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0]"+r"(outc0), [outc1]"+r"(outc1), + [outc2]"+r"(outc2), [outc3]"+r"(outc3) + :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), + [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), + [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), + [bias] "r" (bias_local), [flag_relu]"r"(flag_relu) + : "cc", "memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8", "v19","v20","v21","v22" + ); +#else + asm volatile( + /* fill with bias */ + "vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ + /* load weights */ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ + "vand.i32 q12, q8, q8\n" + "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ + "vand.i32 q13, q8, q8\n" + "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ + "vand.i32 q14, q8, q8\n" + "vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ + "vand.i32 q15, q8, q8\n" + "vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ + /* mul r0 with w0, w1, w2 */ + "vmla.f32 q12, q9, q0 @ w0 * inr0\n" + "vmla.f32 q13, q9, q2 @ w0 * inr2\n" + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ + "vmla.f32 q14, q9, q4 @ w0 * inr4\n" + "vmla.f32 q15, q9, q6 @ w0 * inr6\n" + "vmla.f32 q12, q10, q1 @ w1 * inr1\n" + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" + "vmla.f32 q13, q10, q3 @ w1 * inr3\n" + "vmla.f32 q14, q10, q5 @ w1 * inr5\n" + "vmla.f32 q15, q10, q7 @ w1 * inr7\n" + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ + "vmla.f32 q12, q11, q2 @ w2 * inr2\n" + "vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" + "vmla.f32 q13, q11, q4 @ w2 * inr4\n" + "vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" + "vmla.f32 q14, q11, q6 @ w2 * inr6\n" + "vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" + "vmla.f32 q15, q11, q8 @ w2 * inr8\n" + /* mul r1 with w3, w4, w5 */ + "vmla.f32 q12, q9, q0 @ w3 * inr0\n" + "vmla.f32 q13, q9, q2 @ w3 * inr2\n" + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ + "vmla.f32 q14, q9, q4 @ w3 * inr4\n" + "vmla.f32 q15, q9, q6 @ w3 * inr6\n" + "vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ + "vmla.f32 q12, q10, q1 @ w4 * inr1\n" + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" + "vmla.f32 q13, q10, q3 @ w4 * inr3\n" + "vmla.f32 q14, q10, q5 @ w4 * inr5\n" + "vmla.f32 q15, q10, q7 @ w4 * inr7\n" + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ + "vmla.f32 q12, q11, q2 @ w5 * inr2\n" + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" + "vmla.f32 q13, q11, q4 @ w5 * inr4\n" + "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" + "vmla.f32 q14, q11, q6 @ w5 * inr6\n" + "vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" + "vmla.f32 q15, q11, q8 @ w5 * inr8\n" + /* mul r2 with w6, w7, w8 */ + "vmla.f32 q12, q9, q0 @ w6 * inr0\n" + "vmla.f32 q13, q9, q2 @ w6 * inr2\n" + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ + "vmla.f32 q14, q9, q4 @ w6 * inr4\n" + "vmla.f32 q15, q9, q6 @ w6 * inr6\n" + "vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ + "vmla.f32 q12, q10, q1 @ w7 * inr1\n" + "vmla.f32 q13, q10, q3 @ w7 * inr3\n" + "vmla.f32 q14, q10, q5 @ w7 * inr5\n" + "vmla.f32 q15, q10, q7 @ w7 * inr7\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" + "vmla.f32 q12, q11, q2 @ w8 * inr2\n" + "vmla.f32 q13, q11, q4 @ w8 * inr4\n" + "vmla.f32 q14, q11, q6 @ w8 * inr6\n" + "vmla.f32 q15, q11, q8 @ w8 * inr8\n" + /* transpose */ + "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ + "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ + "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ + "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ + "cmp %[flag_relu], #0\n" + "beq 0f\n" /* skip relu*/ + "vmov.u32 q0, #0\n" + "vmax.f32 q12, q12, q0\n" + "vmax.f32 q13, q13, q0\n" + "vmax.f32 q14, q14, q0\n" + "vmax.f32 q15, q15, q0\n" + "0:\n" + "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ + "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ + "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ + "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ + :[r0] "+r"(inr0), [r1] "+r"(inr1), + [r2] "+r"(inr2), [wc0] "+r" (weight_c), + [outc0]"+r"(outc0), [outc1]"+r"(outc1), + [outc2]"+r"(outc2), [outc3]"+r"(outc3) + :[bias] "r" (bias_local), + [flag_relu]"r"(flag_relu) + :"cc", "memory", + "q0","q1","q2","q3","q4","q5","q6","q7", + "q8", "q9","q10","q11","q12","q13","q14","q15" + ); +#endif // __arch64__ + // clang-format off + if (flag_mask) { + for (int i = 0; i < remain; ++i) { + c0[i] = pre_out[i]; + c1[i] = pre_out[i + 4]; + c2[i] = pre_out[i + 8]; + c3[i] = pre_out[i + 12]; + } + } + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc b/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e475fc6067cf52962038fc4bf18c99909e4bafd --- /dev/null +++ b/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc @@ -0,0 +1,497 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +void conv_depthwise_3x3s2_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx) { + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + + const int hout_c_block = 8; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ROUNDUP(wout, wout_block); + const int win_round = wout_round * 2 /*stride*/ + 1; + + //! get h block + //! llc_size = threads * win_round * hin_r_block * hout_c_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! win_round = wout_round + 2 + //! hin_r_block = hout_r_block + 2 + int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) / + (2 * win_round * threads * hout_c_block + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block * 2 /*stride*/ + 1; + + auto tmp_work_space = ctx->workspace_data(); + int8_t ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(int8_t) * win_round); + Dtype ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + int8_t* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 9; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; + + for (int n = 0; n < num; ++n) { + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = reinterpret_cast(dout) + + n * chout * size_out_channel * sizeof(Dtype); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h * 2 /*stride*/ - padh; + int he = hs + h_kernel * 2 /*stride*/ + 1; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int8_t* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4); + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size); +#else + int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); + auto pre_din = tmp_din; +#endif + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); + const int8_t* block_inr0 = pre_din; + const int8_t* block_inr1 = block_inr0 + in_len; + const int8_t* block_inr2 = block_inr1 + in_len; + + const int8_t* weight_c = weights + c * w_stride; + float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + bias_local[4] = bias[c + 4]; + bias_local[5] = bias[c + 5]; + bias_local[6] = bias[c + 6]; + bias_local[7] = bias[c + 7]; + } +#ifdef __aarch64__ + int8x8_t vw0 = vld1_s8(weight_c); + int8x8_t vw1 = vld1_s8(weight_c + 8); + int8x8_t vw2 = vld1_s8(weight_c + 16); + int8x8_t vw3 = vld1_s8(weight_c + 24); + int8x8_t vw4 = vld1_s8(weight_c + 32); + int8x8_t vw5 = vld1_s8(weight_c + 40); + int8x8_t vw6 = vld1_s8(weight_c + 48); + int8x8_t vw7 = vld1_s8(weight_c + 56); + int8x8_t vw8 = vld1_s8(weight_c + 64); +#endif + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + int32_t* ptr_out0 = pre_out + hk * out_row_stride; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]], #32\n" + "1:\n" + /* inr0 -> outr0 */ + "smull v20.8h, v0.8b, %[w0].8b\n" /* int16, out0 */ + "smull v21.8h, v2.8b, %[w0].8b\n" /* int16, out1 */ + "smull v22.8h, v4.8b, %[w0].8b\n" /* int16, out2 */ + "smull v23.8h, v6.8b, %[w0].8b\n" /* int16, out3 */ + "smlal v20.8h, v1.8b, %[w1].8b\n" /* int16, out0 */ + "smlal v21.8h, v3.8b, %[w1].8b\n" /* int16, out1 */ + "smlal v22.8h, v5.8b, %[w1].8b\n" /* int16, out2 */ + "smlal v23.8h, v7.8b, %[w1].8b\n" /* int16, out3 */ + "ldr d8, [%[r0]]\n" /* load r0, 8 */ + "ldp d0, d1, [%[r1]], #16\n" /* load r1, 0,1 */ + "sxtl v24.4s, v20.4h\n" + "sxtl2 v25.4s, v20.8h\n" + "smull v20.8h, v2.8b, %[w2].8b\n" /* int16, out0 */ + "ldp d2, d3, [%[r1]], #16\n" /* load r1, 2,3 */ + "sxtl v26.4s, v21.4h\n" + "sxtl2 v27.4s, v21.8h\n" + "smull v21.8h, v4.8b, %[w2].8b\n" /* int16, out1 */ + "ldp d4, d5, [%[r1]], #16\n" /* load r1, 4,5 */ + "sxtl v28.4s, v22.4h\n" + "sxtl2 v29.4s, v22.8h\n" + "smull v22.8h, v6.8b, %[w2].8b\n" /* int16, out2 */ + "ldp d6, d7, [%[r1]], #16\n" /* load r1, 6,7 */ + "sxtl v30.4s, v23.4h\n" + "sxtl2 v31.4s, v23.8h\n" + "smull v23.8h, v8.8b, %[w2].8b\n" /* int16, out3 */ + "smlal v20.8h, v0.8b, %[w3].8b\n" /* int16, out0 */ + "smlal v21.8h, v2.8b, %[w3].8b\n" /* int16, out1 */ + "smlal v22.8h, v4.8b, %[w3].8b\n" /* int16, out2 */ + "smlal v23.8h, v6.8b, %[w3].8b\n" /* int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "ldr d8, [%[r1]]\n" /* load r1, 8 */ + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "smull v20.8h, v1.8b, %[w4].8b\n" /* int16, out0 */ + "smull v21.8h, v3.8b, %[w4].8b\n" /* int16, out1 */ + "smull v22.8h, v5.8b, %[w4].8b\n" /* int16, out1 */ + "smull v23.8h, v7.8b, %[w4].8b\n" /* int16, out1 */ + "ldp d0, d1, [%[r2]], #16\n" /* load r2, 0,1 */ + "smlal v20.8h, v2.8b, %[w5].8b\n" /* int16, out0 */ + "smlal v21.8h, v4.8b, %[w5].8b\n" /* int16, out1 */ + "ldp d2, d3, [%[r2]], #16\n" /* load r2, 2,3 */ + "smlal v22.8h, v6.8b, %[w5].8b\n" /* int16, out2 */ + "smlal v23.8h, v8.8b, %[w5].8b\n" /* int16, out3 */ + "ldp d4, d5, [%[r2]], #16\n" /* load r2, 4,5 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "ldp d6, d7, [%[r2]], #16\n" /* load r2, 6,7 */ + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "smull v20.8h, v0.8b, %[w6].8b\n" /* int16, out0 */ + "smull v21.8h, v2.8b, %[w6].8b\n" /* int16, out1 */ + "smull v22.8h, v4.8b, %[w6].8b\n" /* int16, out1 */ + "smull v23.8h, v6.8b, %[w6].8b\n" /* int16, out1 */ + "smlal v20.8h, v1.8b, %[w7].8b\n" /* int16, out0 */ + "smlal v21.8h, v3.8b, %[w7].8b\n" /* int16, out1 */ + "smlal v22.8h, v5.8b, %[w7].8b\n" /* int16, out1 */ + "smlal v23.8h, v7.8b, %[w7].8b\n" /* int16, out1 */ + "ldp d0, d1, [%[r0]], #16\n" /* load r0, 0,1 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "ldr d8, [%[r2]]\n" /* load r2 */ + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "smull v20.8h, v2.8b, %[w8].8b\n" /* int16, out0 */ + "smull v21.8h, v4.8b, %[w8].8b\n" /* int16, out1 */ + "ldp d2, d3, [%[r0]], #16\n" /* load r0, 2,3 */ + "smull v22.8h, v6.8b, %[w8].8b\n" /* int16, out1 */ + "smull v23.8h, v8.8b, %[w8].8b\n" /* int16, out1 */ + "ldp d4, d5, [%[r0]], #16\n" /* load r0, 5 */ + "saddw v24.4s, v24.4s, v20.4h\n" + "saddw2 v25.4s, v25.4s, v20.8h\n" + "saddw v26.4s, v26.4s, v21.4h\n" + "saddw2 v27.4s, v27.4s, v21.8h\n" + "ldp d6, d7, [%[r0]], #16\n" /* load r0, 6 */ + "stp q24, q25, [%[ptr_out0]], #32\n" + "saddw v28.4s, v28.4s, v22.4h\n" + "saddw2 v29.4s, v29.4s, v22.8h\n" + "stp q26, q27, [%[ptr_out0]], #32\n" + "saddw v30.4s, v30.4s, v23.4h\n" + "saddw2 v31.4s, v31.4s, v23.8h\n" + "subs %w[cnt], %w[cnt], #1\n" + "stp q28, q29, [%[ptr_out0]], #32\n" + "stp q30, q31, [%[ptr_out0]], #32\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [ptr_out0] "+r"(ptr_out0) + : [w0] "w"(vw0), + [w1] "w"(vw1), + [w2] "w"(vw2), + [w3] "w"(vw3), + [w4] "w"(vw4), + [w5] "w"(vw5), + [w6] "w"(vw6), + [w7] "w"(vw7), + [w8] "w"(vw8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31" + + ); +#else + auto wptr = weight_c; + asm volatile( + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */ + "1:\n" + /* inr0 -> outr0 */ + "vmull.s8 q4, d1, d7\n" /* int16, out0 */ + "vld1.32 {d1}, [%[r0]]!\n" /* load r0, 6 */ + "vmull.s8 q5, d3, d7\n" /* int16, out1 */ + "vld1.32 {d3}, [%[r0]]!\n" /* load r0, 7 */ + "vmull.s8 q6, d5, d7\n" /* int16, out2 */ + "vld1.32 {d5}, [%[r0]]\n" /* load r0, 8 */ + "vmull.s8 q7, d1, d6\n" /* int16, out0 */ + "vmlal.s8 q4, d0, d6\n" /* int16, out3 */ + "vmlal.s8 q5, d2, d6\n" /* int16, out1 */ + "vmlal.s8 q6, d4, d6\n" /* int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* int16, out3 */ + "vmovl.s16 q8, d8\n" + "vmovl.s16 q9, d9\n" + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */ + "vmovl.s16 q10, d10\n" + "vmovl.s16 q11, d11\n" + "vmovl.s16 q12, d12\n" + "vmovl.s16 q13, d13\n" + "vmovl.s16 q14, d14\n" + "vmovl.s16 q15, d15\n" + "vmull.s8 q4, d2, d6\n" /* int16, out0 */ + "vmull.s8 q6, d1, d6\n" /* int16, out2 */ + "vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */ + "vmull.s8 q5, d4, d6\n" /* int16, out1 */ + "vmull.s8 q7, d5, d6\n" /* int16, out3 */ + "vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4,5 */ + /* inr1 -> outr0 */ + "vmlal.s8 q4, d0, d7\n" /* int16, out0 */ + "vld1.32 {d0}, [%[r1]]!\n" /* load r1, 6 */ + "vmlal.s8 q5, d2, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* int16, out2 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vmlal.s8 q7, d0, d7\n" /* int16, out3 */ + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "vmull.s8 q4, d1, d6\n" /* int16, out0 */ + "vld1.32 {d1}, [%[r1]]!\n" /* load r1, 7 */ + "vmull.s8 q5, d3, d6\n" /* int16, out1 */ + "vld1.32 {d3}, [%[r1]]\n" /* load r1, 8 */ + "vmull.s8 q6, d5, d6\n" /* int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* int16, out3 */ + "vmlal.s8 q4, d2, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* int16, out2 */ + "vmlal.s8 q6, d0, d7\n" /* int16, out1 */ + "vmlal.s8 q7, d3, d7\n" /* int16, out3 */ + "vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */ + /* inr2 -> outr0 */ + "vmull.s8 q4, d1, d7\n" /* int16, out0 */ + "vld1.32 {d1}, [%[r2]]!\n" /* load r2, 6 */ + "vmull.s8 q5, d3, d7\n" /* int16, out1 */ + "vld1.32 {d3}, [%[r2]]!\n" /* load r2, 7 */ + "vmull.s8 q6, d5, d7\n" /* int16, out2 */ + "vld1.32 {d5}, [%[r2]]\n" /* load r2, 8 */ + "vmull.s8 q7, d1, d6\n" /* int16, out3 */ + "vmlal.s8 q4, d0, d6\n" /* int16, out0 */ + "vmlal.s8 q5, d2, d6\n" /* int16, out1 */ + "vmlal.s8 q6, d4, d6\n" /* int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* int16, out3 */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w8 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "sub %[wptr], %[wptr], #72\n" + "vmull.s8 q4, d2, d6\n" /* int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* int16, out1 */ + "vmull.s8 q6, d1, d6\n" /* int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* int16, out3 */ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vaddw.s16 q8, q8, d8\n" + "vaddw.s16 q9, q9, d9\n" + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */ + "vaddw.s16 q10, q10, d10\n" + "vaddw.s16 q11, q11, d11\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]!\n" + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "vaddw.s16 q12, q12, d12\n" + "vaddw.s16 q13, q13, d13\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]!\n" + "vaddw.s16 q14, q14, d14\n" + "vaddw.s16 q15, q15, d15\n" + "subs %[cnt], #1\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]!\n" + "vst1.32 {d28-d31}, [%[ptr_out0]]!\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [ptr_out0] "+r"(ptr_out0), + [wptr] "+r"(wptr) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + block_inr0 = block_inr2; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + } + write_int32_nchwc8_to_nchw(pre_out, + reinterpret_cast(dout_batch), + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + bias_local, + flag_bias, + ptr_write, + scale + c); + } + } + } +} + +template void conv_depthwise_3x3s2_int8(int8_t* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + +template void conv_depthwise_3x3s2_int8(float* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..8260718a50f8e2fa8497d41d958e82a45ea0480d --- /dev/null +++ b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc @@ -0,0 +1,849 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +const int OUT_C_BLOCK = 4; +const int OUT_H_BLOCK = 2; +const int OUT_W_BLOCK = 4; + +size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx) { + auto dim_in = param.x->dims(); + auto dim_out = param.output->dims(); + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / sizeof(float); + const int pad_w = param.paddings[1]; + const int pad_h = param.paddings[0]; + int ow = dim_out[3]; + int oh = dim_out[2]; + int ic = dim_in[1]; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); + const int win_round = wout_round * 2 /*stride_w*/ + 1; + const int hin_r_block = OUT_H_BLOCK * 2 /*stride_h*/ + 1; + + int hout_r_block = + (llc_size - 2 * wout_round * ic - ic) / + ((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + + return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size); +} + +void conv_3x3s2_direct_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + //! 3x3s2 convolution, implemented by direct algorithm + //! prepack input to tmp buffer + //! write output to tmp buffer + const int threads = ctx->threads(); + int l2_size = ctx->llc_size() / sizeof(float); + const int pad_w = param.paddings[1]; + const int pad_h = param.paddings[0]; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); + const int win_round = wout_round * 2 /*stride_w*/ + 1; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + + //! get h block + //! win_round * ic * hin_r_block + wout_round * OUT_C_BLOCK * hout_r_block + //! * threads = l2_size + //! win_round = 2 * wout_round + 1 + //! hin_r_block = 2 * hout_r_block + 1 + int hout_r_block = + (l2_size - 2 * wout_round * ic - ic) / + ((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; + + const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1; + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + //! l2_cache start + float* pre_din = tmp_work_space; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + int w_stride = ic * 9; /*kernel_w * kernel_h*/ + int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h * + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; + int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; + + int out_row_stride = OUT_C_BLOCK * wout_round; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; + for (int h = 0; h < oh; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > oh) { + h_kernel = oh - h; + } + + int hs = h * 2 /*stride_h*/ - pad_h; + int he = hs + h_kernel * 2 /*stride_h*/ + 1; + + prepack_input_nxw( + din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); + + const float* cblock_inr0 = pre_din; + const float* cblock_inr1 = cblock_inr0 + in_len; + const float* cblock_inr2 = cblock_inr1 + in_len; + const float* cblock_inr3 = cblock_inr2 + in_len; + const float* cblock_inr4 = cblock_inr3 + in_len; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < c_round_down; c += OUT_C_BLOCK) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + const float* block_inr0 = cblock_inr0; + const float* block_inr1 = cblock_inr1; + const float* block_inr2 = cblock_inr2; + const float* block_inr3 = cblock_inr3; + const float* block_inr4 = cblock_inr4; + + const float* weight_c = weights + c * w_stride; + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c; + } + fill_packed_biasc4( + pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { + const float* wc0 = weight_c; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + + float* pre_out0 = pre_out + hk * out_row_stride; + float* pre_out1 = pre_out0 + out_row_stride; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + float32x4_t w0 = vld1q_f32(wc0); // w0, v23 + float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop; + // clang-format off + asm volatile( + "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00, outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "ldp q0, q1, [%[r0]], #32\n" /* load input r0*/ + "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/ + "ldp q4, q5, [%[r2]], #32\n" /* load input r2*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "2:\n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ + "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ + "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ + "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ + "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ + "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ + "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ + "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ + /* r2 mul w6, get out r0*/ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ + "fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ + "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ + "ldr d11, [%[r1]]\n" /* load input r1, 9th element*/ + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ + "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ + "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ + "fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ + "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ + "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ + "fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ + "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ + /* r2 mul w7, get out r0 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ + "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ + "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ + "ldr d13, [%[r3]]\n" /* load input r3, 9th element*/ + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ + "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ + "fmla v18.4s , %[w2].4s, v10.s[0]\n"/* outr03 = w2 * r0[8]*/ + "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ + "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ + "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ + "fmla v22.4s , %[w2].4s, v12.s[0]\n"/* outr13 = w2 * r2[8]*/ + "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ + /* r2, mul w8, get out r0 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ + "fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ + "fmla v18.4s , %[w8].4s, v12.s[0]\n"/* outr03 = w8 * r2[8]*/ + "ldr d14, [%[r4]]\n" /* load input r4, 9th element*/ + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ + "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ + "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ + "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ + "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ + "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ + "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ + "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ + "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ + "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ + "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ + "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ + "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ + "fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ + "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/ + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ + "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ + "fmla v18.4s , %[w5].4s, v11.s[0]\n"/* outr03 = w5 * r1[8]*/ + "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ + "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ + "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ + "fmla v22.4s , %[w5].4s, v13.s[0]\n"/* outr13 = w5 * r3[8]*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + /* r4, mul w6, get out r1 */ + "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/ + "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/ + "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/ + "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/ + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + /* r4, mul w7, get out r1 */ + "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/ + "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/ + "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/ + "fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + /* r4, mul w8, get out r1 */ + "fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/ + "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/ + "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/ + "fmla v22.4s , %[w8].4s, v14.s[0]\n"/* outr13 = w8 * r4[8]*/ + "subs %w[cnt], %w[cnt], #1\n" /*loop count -1*/ + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), + [w1] "w"(w1), [w2] "w"(w2), + [w3] "w"(w3), [w4] "w"(w4), + [w5] "w"(w5), [w6] "w"(w6), + [w7] "w"(w7), [w8] "w"(w8) + : "cc","memory","v0","v1","v2","v3","v4", + "v5","v6","v7","v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19","v20","v21","v22"); + // clang-format on + wc0 += 9 * OUT_C_BLOCK; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + const float* wc0 = weight_c + i * w_stride_chin; + + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop; + // clang-format off + asm volatile( + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + /* load r0, r2 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0\n" + "vld1.32 {d8}, [%[r0]] @ load r0\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 -32\n" + /* main loop */ + "0: @ main loop\n" + /* mul r0, with w0, w1, w2 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n" + "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n" + "vmla.f32 q9, q5, d1[0] @ w0 * inr02\n" + "vmla.f32 q10, q5, d2[0] @ w0 * inr04\n" + "vmla.f32 q11, q5, d3[0] @ w0 * inr06\n" + "vld1.32 {d4-d7}, [%[r2]]! @ load r2\n" + "vmla.f32 q8, q6, d0[1] @ w1 * inr01\n" + "vmla.f32 q9, q6, d1[1] @ w1 * inr03\n" + "vmla.f32 q10, q6, d2[1] @ w1 * inr05\n" + "vmla.f32 q11, q6, d3[1] @ w1 * inr07\n" + "vld1.32 {d9}, [%[r2]] @ load r2, 9th float\n" + "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n" + "vmla.f32 q9, q7, d2[0] @ w2 * inr04\n" + "vmla.f32 q10, q7, d3[0] @ w2 * inr06\n" + "vmla.f32 q11, q7, d8[0] @ w2 * inr08\n" + "sub %[r2], %[r2], #32 @ r2 - 32\n" + /* mul r2, with w0, w1, w2 */ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1\n" + "vmla.f32 q12, q5, d4[0] @ w0 * inr20\n" + "vmla.f32 q13, q5, d5[0] @ w0 * inr22\n" + "vmla.f32 q14, q5, d6[0] @ w0 * inr24\n" + "vmla.f32 q15, q5, d7[0] @ w0 * inr26\n" + "vld1.32 {d8}, [%[r1]] @ load r1, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w1 * inr21\n" + "vmla.f32 q13, q6, d5[1] @ w1 * inr23\n" + "vmla.f32 q14, q6, d6[1] @ w1 * inr25\n" + "vmla.f32 q15, q6, d7[1] @ w1 * inr27\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w2 * inr22\n" + "vmla.f32 q13, q7, d6[0] @ w2 * inr24\n" + "vmla.f32 q14, q7, d7[0] @ w2 * inr26\n" + "vmla.f32 q15, q7, d9[0] @ w2 * inr28\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, to q7\n" + /* mul r1, with w3, w4, w5 */ + "vmla.f32 q8, q5, d0[0] @ w3 * inr10\n" + "vmla.f32 q9, q5, d1[0] @ w3 * inr12\n" + "vmla.f32 q10, q5, d2[0] @ w3 * inr14\n" + "vmla.f32 q11, q5, d3[0] @ w3 * inr16\n" + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 8 float\n" + "vmla.f32 q8, q6, d0[1] @ w4 * inr11\n" + "vmla.f32 q9, q6, d1[1] @ w4 * inr13\n" + "vmla.f32 q10, q6, d2[1] @ w4 * inr15\n" + "vmla.f32 q11, q6, d3[1] @ w4 * inr17\n" + "vld1.32 {d9}, [%[r3]] @ load r3, 9th float\n" + "vmla.f32 q8, q7, d1[0] @ w5 * inr12\n" + "vmla.f32 q9, q7, d2[0] @ w5 * inr14\n" + "vmla.f32 q10, q7, d3[0] @ w5 * inr16\n" + "vmla.f32 q11, q7, d8[0] @ w5 * inr18\n" + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n" + /* mul r3, with w3, w4, w5 */ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2\n" + "vmla.f32 q12, q5, d4[0] @ w3 * inr30\n" + "vmla.f32 q13, q5, d5[0] @ w3 * inr32\n" + "vmla.f32 q14, q5, d6[0] @ w3 * inr34\n" + "vmla.f32 q15, q5, d7[0] @ w3 * inr36\n" + "vld1.32 {d8}, [%[r2]] @ load r2, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * inr31\n" + "vmla.f32 q13, q6, d5[1] @ w4 * inr33\n" + "vmla.f32 q14, q6, d6[1] @ w4 * inr35\n" + "vmla.f32 q15, q6, d7[1] @ w4 * inr37\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n" + "vmla.f32 q12, q7, d5[0] @ w5 * inr32\n" + "vmla.f32 q13, q7, d6[0] @ w5 * inr34\n" + "vmla.f32 q14, q7, d7[0] @ w5 * inr36\n" + "vmla.f32 q15, q7, d9[0] @ w5 * inr38\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n" + /* mul r2, with w6, w7, w8 */ + "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n" + "vmla.f32 q9, q5, d1[0] @ w6 * inr22\n" + "vmla.f32 q10, q5, d2[0] @ w6 * inr24\n" + "vmla.f32 q11, q5, d3[0] @ w6 * inr26\n" + "vld1.32 {d4-d7}, [%[r4]]! @ load r4\n" + "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n" + "vmla.f32 q9, q6, d1[1] @ w7 * inr23\n" + "vmla.f32 q10, q6, d2[1] @ w7 * inr25\n" + "vmla.f32 q11, q6, d3[1] @ w7 * inr27\n" + "vld1.32 {d9}, [%[r4]] @ load r4, 9th float\n" + "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n" + "vmla.f32 q9, q7, d2[0] @ w8 * inr24\n" + "vmla.f32 q10, q7, d3[0] @ w8 * inr26\n" + "vmla.f32 q11, q7, d8[0] @ w8 * inr28\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144\n" + /* mul r4, with w6, w7, w8 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0\n" + "vmla.f32 q12, q5, d4[0] @ w3 * inr40\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n" + "vmla.f32 q13, q5, d5[0] @ w3 * inr42\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n" + "vmla.f32 q14, q5, d6[0] @ w3 * inr44\n" + "vmla.f32 q15, q5, d7[0] @ w3 * inr46\n" + "vld1.32 {d8}, [%[r0]] @ load r0, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * inr41\n" + "vmla.f32 q13, q6, d5[1] @ w4 * inr43\n" + "vmla.f32 q14, q6, d6[1] @ w4 * inr45\n" + "vmla.f32 q15, q6, d7[1] @ w4 * inr47\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vmla.f32 q12, q7, d5[0] @ w5 * inr42\n" + "vmla.f32 q13, q7, d6[0] @ w5 * inr44\n" + "vmla.f32 q14, q7, d7[0] @ w5 * inr46\n" + "vmla.f32 q15, q7, d9[0] @ w5 * inr48\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" + "subs %[cnt], #1 @ loop count--\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(wc0) + : + : "cc","memory","q0","q1","q2","q3","q4", + "q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15" + ); + // clang-format on + + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + OUT_C_BLOCK, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < c_remain; ++c) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + + const float* block_inr0 = cblock_inr0; + const float* block_inr1 = cblock_inr1; + const float* block_inr2 = cblock_inr2; + const float* block_inr3 = cblock_inr3; + const float* block_inr4 = cblock_inr4; + + //! get weights ptr of remained + const float* weight_c = weights + c_round_down * w_stride; + + //! fill bias to one channel + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c_round_down + c; + } + fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { + const float* wc0 = weight_c; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + + float* pre_out0 = pre_out + hk * wout_round; + float* pre_out1 = pre_out0 + wout_round; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 + float32x4_t w1 = vdupq_n_f32(wc0[c + 4]); // w1, v24 + float32x4_t w2 = vdupq_n_f32(wc0[c + 8]); // w2, v25 + float32x4_t w3 = vdupq_n_f32(wc0[c + 12]); // w3, v26 + float32x4_t w4 = vdupq_n_f32(wc0[c + 16]); // w4, v27 + float32x4_t w5 = vdupq_n_f32(wc0[c + 20]); // w5, v28 + float32x4_t w6 = vdupq_n_f32(wc0[c + 24]); // w6, v29 + float32x4_t w7 = vdupq_n_f32(wc0[c + 28]); // w7, v30 + float32x4_t w8 = vdupq_n_f32(wc0[c + 32]); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop; + // clang-format off + asm volatile( + "ldr q21, [%[ptr_out0]]\n" /* load outr00-outr03*/ + "ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/ + "ldr d10, [%[r0]]\n"/* load input r0, 9th element*/ + "ld2 {v4.4s, v5.4s}, [%[r2]], #32\n" /* load input r2*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "2:\n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldr q22, [%[ptr_out1]]\n" /* load outr10 - outr13*/ + "fmla v21.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0*/ + "fmla v22.4s , %[w0].4s, v4.4s\n" /* outr1 = w0 * r2*/ + "ld2 {v2.4s, v3.4s}, [%[r1]], #32\n" /* load input r1*/ + /* r2 mul w6, get out r0*/ + "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr0 = w6 * r2*/ + "ldr d11, [%[r1]]\n" /* load input r1, 9th element*/ + /* shift left 1 */ + "ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/ + "ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/ + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v21.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0*/ + "fmla v22.4s , %[w1].4s, v5.4s\n" /* outr1 = w1 * r2*/ + "ld2 {v6.4s, v7.4s}, [%[r3]], #32\n" /* load input r3*/ + /* r2 mul w7, get out r0 */ + "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr00 = w7 * r2*/ + "ldr d13, [%[r3]]\n" /* load input r3, 9th element*/ + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v21.4s , %[w2].4s, v15.4s\n" /* outr0 = w2 * r0*/ + "fmla v22.4s , %[w2].4s, v16.4s\n" /* outr1 = w2 * r2*/ + "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/ + /* r2, mul w8, get out r0 */ + "fmla v21.4s , %[w8].4s, v16.4s\n" /* outr00 = w8 * r2*/ + "ldr d14, [%[r4]]\n" /* load input r4, 9th element*/ + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr0 = w3 * r1*/ + "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr1 = w3 * r3*/ + /* shift left 1 */ + "ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/ + "ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/ + "ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/ + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr0 = w4 * r1*/ + "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr1 = w4 * r3*/ + "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/ + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v21.4s , %[w5].4s, v15.4s\n" /* outr0 = w5 * r1[2]*/ + "fmla v22.4s , %[w5].4s, v16.4s\n" /* outr1 = w5 * r1[4]*/ + "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "str q21, [%[ptr_out0]], #16\n" /* save outr00, outr01*/ + /* r4, mul w6, get out r1 */ + "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4*/ + "ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/ + "ldr q21, [%[ptr_out0]] \n" /* load outr0*/ + /* r4, mul w7, get out r1 */ + "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4*/ + /* r4, mul w8, get out r1 */ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4*/ + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + "str q22, [%[ptr_out1]], #16 \n" /* save outr1*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), + [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), + [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) + : "cc","memory","v0","v1","v2","v3", + "v4","v5","v6","v7","v8","v9","v10","v11", + "v12","v13","v14","v15","v16","v21","v22"); + // clang-format on + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float w_tmp[12] = {wc0[c], + wc0[c + 4], + wc0[c + 8], + 0.f, + wc0[c + 12], + wc0[c + 16], + wc0[c + 20], + 0.f, + wc0[c + 24], + wc0[c + 28], + wc0[c + 32], + 0.f}; + float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 + float32x4_t w1 = vld1q_f32(w_tmp + 4); // w3, w4, w5, q1 + float32x4_t w2 = vld1q_f32(w_tmp + 8); // w6, w7, w8, q2 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop / 2; + if (cnt > 0) { + // clang-format off + asm volatile( + /* main loop */ + "0: @ main loop\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, or11\n" + "vld2.32 {d6-d9}, [%[r2]]! @ load r2\n" + "vld2.32 {d10-d13}, [%[r2]]! @ load r2\n" + "vld1.32 {d22}, [%[r2]] @ load 16th float\n" + /* r2 * w2, r2 * w0, get or0, or1 */ + "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2\n" + "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2\n " + "vld2.32 {d14-d17}, [%[r0]]! @ load r0\n" + "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2\n" + "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2\n" + "vext.32 q4, q3, q5, #1 @ r2, shift left 1\n" + "vext.32 q6, q5, q11, #1 @ r2, shift left 1\n" + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n" + "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2\n" + "vld2.32 {d18-d21}, [%[r0]]! @ load r0\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2\n" + "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2\n" + "vld1.32 {d22}, [%[r0]] @ load 16th float\n" + "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2\n" + "vld2.32 {d6-d9}, [%[r3]]! @ load r3\n" + "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q15, q6, %f[w0][0] @ w02 * r2\n" + "vld2.32 {d10-d13}, [%[r3]]! @ load r3\n" + /* r0 * w0, get or0, r3 * w1, get or1*/ + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n" + "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0\n" + "vext.32 q8, q7, q9, #1 @ r0, shift left 1\n" + "vext.32 q10, q9, q11, #1 @ r0, shift left 1\n" + "vld1.32 {d22}, [%[r3]] @ load 16th float\n" + "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3\n" + "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3\n" + "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0\n" + "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0\n" + "vext.32 q4, q3, q5, #1 @ r3, shift left 1\n" + "vext.32 q6, q5, q11, #1 @ r3, shift left 1\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3\n" + "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3\n" + "vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, " + "2, 4, 6, 8\n" + "vld2.32 {d14-d17}, [%[r1]]! @ load r1\n" + "vmla.f32 q13, q10,%f[w0][0] @ w02 * r0\n" + "vld2.32 {d18-d21}, [%[r1]]! @ load r1\n" + "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3\n" + "vld2.32 {d6-d9}, [%[r4]]! @ load r4\n" + "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3\n" + "vld2.32 {d10-d13}, [%[r4]]! @ load r4\n" + "vld1.32 {d22}, [%[r1]] @ load 16th float\n" + /* r1 * w1, get or0, r4 * w2, get or1 */ + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n" + "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1\n" + "vext.32 q8, q7, q9, #1 @ r1, shift left 1\n" + "vext.32 q10, q9, q11, #1 @ r1, shift left 1\n" + "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4\n" + "vmla.f32 q15, q6, %e[w2][1] @ w21 * r4\n" + "vld1.32 {d22}, [%[r4]] @ load 16th float\n" + "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1\n" + "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1\n" + "vext.32 q4, q3, q5, #1 @ r1, shift left 1\n" + "vext.32 q6, q5, q11, #1 @ r1, shift left 1\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4\n" + "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4\n" + "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4\n" + "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1),[r2] "+r"(r2), + [r3] "+r"(r3),[r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc","memory","q3","q4", + "q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15" + ); + // clang-format on + } + //! deal with remain ow + if (w_loop & 1) { + ptr_out0[0] += + r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + + r1[0] * w_tmp[4] + r1[1] * w_tmp[5] + r1[2] * w_tmp[6] + + r2[0] * w_tmp[8] + r2[1] * w_tmp[9] + r2[2] * w_tmp[10]; + + ptr_out0[1] += + r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + r1[4] * w_tmp[6] + + r2[2] * w_tmp[8] + r2[3] * w_tmp[9] + r2[4] * w_tmp[10]; + + ptr_out0[2] += + r0[4] * w_tmp[0] + r0[5] * w_tmp[1] + r0[6] * w_tmp[2] + + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + r1[6] * w_tmp[6] + + r2[4] * w_tmp[8] + r2[5] * w_tmp[9] + r2[6] * w_tmp[10]; + + ptr_out0[3] += + r0[6] * w_tmp[0] + r0[7] * w_tmp[1] + r0[8] * w_tmp[2] + + r1[6] * w_tmp[4] + r1[7] * w_tmp[5] + r1[8] * w_tmp[6] + + r2[6] * w_tmp[8] + r2[7] * w_tmp[9] + r2[8] * w_tmp[10]; + + ptr_out1[0] += + r2[0] * w_tmp[0] + r2[1] * w_tmp[1] + r2[2] * w_tmp[2] + + r3[0] * w_tmp[4] + r3[1] * w_tmp[5] + r3[2] * w_tmp[6] + + r4[0] * w_tmp[8] + r4[1] * w_tmp[9] + r4[2] * w_tmp[10]; + + ptr_out1[1] += + r2[2] * w_tmp[0] + r2[3] * w_tmp[1] + r2[4] * w_tmp[2] + + r3[2] * w_tmp[4] + r3[3] * w_tmp[5] + r3[4] * w_tmp[6] + + r4[2] * w_tmp[8] + r4[3] * w_tmp[9] + r4[4] * w_tmp[10]; + + ptr_out1[2] += + r2[4] * w_tmp[0] + r2[5] * w_tmp[1] + r2[6] * w_tmp[2] + + r3[4] * w_tmp[4] + r3[5] * w_tmp[5] + r3[6] * w_tmp[6] + + r4[4] * w_tmp[8] + r4[5] * w_tmp[9] + r4[6] * w_tmp[10]; + + ptr_out1[3] += + r2[6] * w_tmp[0] + r2[7] * w_tmp[1] + r2[8] * w_tmp[2] + + r3[6] * w_tmp[4] + r3[7] * w_tmp[5] + r3[8] * w_tmp[6] + + r4[6] * w_tmp[8] + r4[7] * w_tmp[9] + r4[8] * w_tmp[10]; + } + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + write_to_output_c1_fp32(pre_out, + dout_batch, + c + c_round_down, + c + c_round_down + 1, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2_direct_int8.cc b/lite/backends/arm/math/conv3x3s2_direct_int8.cc index 6169ad5d12f131d93fa1244a70d92ee827cbeb0e..01b7a812ebc05a054bb9952bf53605ce7aed135a 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_int8.cc @@ -28,8 +28,9 @@ namespace math { #ifdef __aarch64__ int conv_3x3s2_direct_int8_c_num() { return 8; } +template void conv_3x3s2_direct_int8(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -38,27 +39,25 @@ void conv_3x3s2_direct_int8(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, Context* ctx, - PrecisionType out_type, const float* scale) { //! 3x3s2 int8 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer - int threads = ctx->threads(); - int stride_w = param.strides[1]; - int pad_w = param.paddings[1]; - int pad_h = param.paddings[0]; bool flag_relu = param.fuse_relu; - bool flag_bias = (param.bias != nullptr); + bool flag_bias = param.bias; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; - //! set 2/3 l2 cache - int l2_size = ctx->llc_size() / 3 * 2; const int hout_c_block = 8; const int hout_r_kernel = 2; const int wout_round = ((wout + 3) / 4) * 4; - const int win_round = wout_round * stride_w + 1; + const int win_round = wout_round * 2 /*stride_w*/ + 1; //! get h block //! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * @@ -66,7 +65,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! win_round = 2 * wout_round + 1 //! hin_r_block = 2 * hout_r_block + 1 int hout_r_block = - (l2_size - 2 * wout_round * chin - chin) / + (llc_size - 2 * wout_round * chin - chin) / ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4); hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; @@ -74,16 +73,15 @@ void conv_3x3s2_direct_int8(const int8_t* din, const int hin_r_block = hout_r_block * 2 + 1; - int8_t* tmp_work_space = ctx->workspace_data(); + auto tmp_work_space = ctx->workspace_data(); int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4; - const int kZeroSize = zero_size; - int32_t ptr_zero[kZeroSize]; + int32_t ptr_zero[zero_size]; // NOLINT memset(ptr_zero, 0, sizeof(int32_t) * zero_size); - const int kWoutRound = wout_round; - int32_t ptr_write[kWoutRound]; + Dtype ptr_write[wout_round]; // NOLINT int in_len = win_round * chin; int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); int pre_out_size = hout_c_block * hout_r_block * wout_round; //! l2_cache start @@ -100,10 +98,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, int out_row_stride = hout_c_block * wout_round; for (int n = 0; n < num; ++n) { - const int8_t* din_batch = din + n * chin * size_in_channel; - int8_t* dout_batch = - reinterpret_cast(dout) + - n * chout * size_out_channel * PrecisionTypeLength(out_type); + auto din_batch = din + n * chin * size_in_channel; + auto dout_batch = dout + n * chout * size_out_channel; for (int h = 0; h < hout; h += hout_r_block) { int h_kernel = hout_r_block; if (h + hout_r_block > hout) { @@ -133,12 +129,10 @@ void conv_3x3s2_direct_int8(const int8_t* din, #pragma omp parallel for num_threads(threads) for (int c = 0; c < chout; c += hout_c_block) { #ifdef ARM_WITH_OMP - int32_t* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + - omp_get_thread_num() * pre_out_size; + auto pre_out = reinterpret_cast(pre_din + pre_in_size) + + omp_get_thread_num() * pre_out_size; #else - int32_t* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); + auto pre_out = reinterpret_cast(pre_din + pre_in_size); #endif const int8_t* block_inr0 = cblock_inr0; const int8_t* block_inr1 = cblock_inr1; @@ -147,12 +141,19 @@ void conv_3x3s2_direct_int8(const int8_t* din, const int8_t* block_inr4 = cblock_inr4; const int8_t* weight_c = weights + c * w_stride; - const int32_t* bias_ptr = ptr_zero; + float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; if (flag_bias) { - bias_ptr = bias + c; + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + bias_local[4] = bias[c + 4]; + bias_local[5] = bias[c + 5]; + bias_local[6] = bias[c + 6]; + bias_local[7] = bias[c + 7]; } - fill_packed_bias_nxmw_int8(bias_ptr, pre_out, 8, h_kernel, wout_round); + memset(pre_out, 0, pre_out_size * sizeof(int32_t)); for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { const int8_t* wc0 = weight_c; @@ -186,490 +187,236 @@ void conv_3x3s2_direct_int8(const int8_t* din, int32_t* ptr_out0 = pre_out0; int32_t* ptr_out1 = pre_out1; int cnt = w_loop; - + // clang-format off asm volatile( - "ldr q0, [%[r0]], #8 \n" /* load input r0 */ - "ldr q1, [%[r2]], #8 \n" /* load input r2 */ - "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ - "sshll v1.8h, v1.8b, #0 \n" /* r1: int8 -> int16*/ - "1: \n" /* main loop */ - - /* r0, r2 mul w00 */ - "smull v4.4s, %[v0].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] - */ - "smull2 v5.4s, %[v0].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] - */ - "smull v6.4s, %[v0].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] - */ - "smull2 v7.4s, %[v0].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smull v8.4s, %[v0].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] - */ - "smull2 v9.4s, %[v0].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] - */ - "smull v10.4s, %[v0].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] - */ - "smull2 v11.4s, %[v0].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] - */ - - "smull v12.4s, %[v0].4h, v1.h[0]\n" /* outr10 = v0 * r2[0] - */ - "smull2 v13.4s, %[v0].8h, v1.h[0]\n" /* outr11 = v0 * r2[2] - */ - "smull v14.4s, %[v0].4h, v1.h[2]\n" /* outr12 = v0 * r2[4] - */ - "smull2 v15.4s, %[v0].8h, v1.h[2]\n" /* outr13 = v0 * r2[6] - */ - "smull v16.4s, %[v0].4h, v1.h[4]\n" /* outr10 = v0 * r2[0] - */ - "smull2 v17.4s, %[v0].8h, v1.h[4]\n" /* outr11 = v0 * r2[2] - */ - "smull v18.4s, %[v0].4h, v1.h[6]\n" /* outr12 = v0 * r2[4] - */ - "smull2 v19.4s, %[v0].8h, v1.h[6]\n" /* outr13 = v0 * r2[6] - */ - - /* r2, mul w06 */ - "smlal v4.4s, %[v6].4h, v1.h[0]\n" /* outr00 = v6 * r2[1] - */ - "smlal2 v5.4s, %[v6].8h, v1.h[0]\n" /* outr01 = v6 * r2[3] - */ - "smlal v6.4s, %[v6].4h, v1.h[2]\n" /* outr02 = v6 * r2[5] - */ - "smlal2 v7.4s, %[v6].8h, v1.h[2]\n" /* outr03 = v6 * r2[7] - */ - "smlal v8.4s, %[v6].4h, v1.h[4]\n" /* outr00 = v6 * r2[1] - */ - "smlal2 v9.4s, %[v6].8h, v1.h[4]\n" /* outr01 = v6 * r2[3] - */ - "smlal v10.4s, %[v6].4h, v1.h[6]\n" /* outr02 = v6 * r2[5] - */ - "smlal2 v11.4s, %[v6].8h, v1.h[6]\n" /* outr03 = v6 * r2[7] - */ - - "ldr q2, [%[r0]] \n" /* load r0, 9th - data,v10.s[0] */ - - /* r0, r2, mul w01 */ - "smlal v4.4s, %[v1].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v5.4s, %[v1].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] - */ - "smlal v6.4s, %[v1].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v7.4s, %[v1].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] - */ - "sshll v2.8h, v2.8b, #0 \n" /* r0: int8 -> int16 */ - "smlal v8.4s, %[v1].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v9.4s, %[v1].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] - */ - "smlal v10.4s, %[v1].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v11.4s, %[v1].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] - */ - - "smlal v12.4s, %[v1].4h, v1.h[1]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v13.4s, %[v1].8h, v1.h[1]\n" /* outr11 = v0 * r2[2] - */ - "smlal v14.4s, %[v1].4h, v1.h[3]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v15.4s, %[v1].8h, v1.h[3]\n" /* outr13 = v0 * r2[6] - */ - "smlal v16.4s, %[v1].4h, v1.h[5]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v17.4s, %[v1].8h, v1.h[5]\n" /* outr11 = v0 * r2[2] - */ - "smlal v18.4s, %[v1].4h, v1.h[7]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v19.4s, %[v1].8h, v1.h[7]\n" /* outr13 = v0 * r2[6] - */ - - /* r2, mul w07 */ - "smlal v4.4s, %[v7].4h, v1.h[1]\n" /* outr00 = v6 * r2[1] - */ - "smlal2 v5.4s, %[v7].8h, v1.h[1]\n" /* outr01 = v6 * r2[3] - */ - "smlal v6.4s, %[v7].4h, v1.h[3]\n" /* outr02 = v6 * r2[5] - */ - "smlal2 v7.4s, %[v7].8h, v1.h[3]\n" /* outr03 = v6 * r2[7] - */ - "smlal v8.4s, %[v7].4h, v1.h[5]\n" /* outr00 = v6 * r2[1] - */ - "smlal2 v9.4s, %[v7].8h, v1.h[5]\n" /* outr01 = v6 * r2[3] - */ - "smlal v10.4s, %[v7].4h, v1.h[7]\n" /* outr02 = v6 * r2[5] - */ - "smlal2 v11.4s, %[v7].8h, v1.h[7]\n" /* outr03 = v6 * r2[7] - */ - - "ldr q3, [%[r2]] \n" /* load r2, 9th - data,v11.s[0] */ - - /* r0, r2, mul w02 */ - "smlal v4.4s, %[v2].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v5.4s, %[v2].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal v6.4s, %[v2].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v7.4s, %[v2].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] - */ - "sshll v3.8h, v3.8b, #0 \n" /* r2: int8 -> int16*/ - "smlal v8.4s, %[v2].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v9.4s, %[v2].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] - */ - "smlal v10.4s, %[v2].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v11.4s, %[v2].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] - */ - - "ldr q0, [%[r1]], #8 \n" /* load input r1 */ - - "smlal v12.4s, %[v2].4h, v1.h[2]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v13.4s, %[v2].8h, v1.h[2]\n" /* outr11 = v0 * r2[2] - */ - "smlal v14.4s, %[v2].4h, v1.h[4]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v15.4s, %[v2].8h, v1.h[4]\n" /* outr13 = v0 * r2[6] - */ - "sshll v0.8h, v0.8b, #0 \n" /* r1 : int8 -> int16 */ - "smlal v16.4s, %[v2].4h, v1.h[6]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v17.4s, %[v2].8h, v1.h[6]\n" /* outr11 = v0 * r2[2] - */ - "smlal v18.4s, %[v2].4h, v3.h[0]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v19.4s, %[v2].8h, v3.h[0]\n" /* outr13 = v0 * r2[6] - */ - - /* r2, mul w08 */ - "smlal v4.4s, %[v8].4h, v1.h[2]\n" /* outr00 = v6 * r2[1] - */ - "smlal2 v5.4s, %[v8].8h, v1.h[2]\n" /* outr01 = v6 * r2[3] - */ - "smlal v6.4s, %[v8].4h, v1.h[4]\n" /* outr02 = v6 * r2[5] - */ - "smlal2 v7.4s, %[v8].8h, v1.h[4]\n" /* outr03 = v6 * r2[7] - */ - "smlal v8.4s, %[v8].4h, v1.h[6]\n" /* outr00 = v6 * r2[1] - */ - "smlal2 v9.4s, %[v8].8h, v1.h[6]\n" /* outr01 = v6 * r2[3] - */ - "smlal v10.4s, %[v8].4h, v3.h[0]\n" /* outr02 = v6 * r2[5] - */ - "smlal2 v11.4s, %[v8].8h, v3.h[0]\n" /* outr03 = v6 * r2[7] - */ - - "ldr q1, [%[r3]], #8 \n" /* load input r3 */ - - /* r1, r3, mul w03 */ - "smlal v4.4s, %[v3].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v5.4s, %[v3].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] - */ - "smlal v6.4s, %[v3].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v7.4s, %[v3].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "sshll v1.8h, v1.8b, #0 \n" /* r3: int8 -> int16 */ - "smlal v8.4s, %[v3].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v9.4s, %[v3].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] - */ - "smlal v10.4s, %[v3].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v11.4s, %[v3].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] - */ - "ldr q2, [%[r1]] \n" /* load r1, 9th - data,v10.s[0] */ - - "smlal v12.4s, %[v3].4h, v1.h[0]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v13.4s, %[v3].8h, v1.h[0]\n" /* outr11 = v0 * r2[2] - */ - "smlal v14.4s, %[v3].4h, v1.h[2]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v15.4s, %[v3].8h, v1.h[2]\n" /* outr13 = v0 * r2[6] - */ - "ldr q3, [%[r3]] \n" /* load r3, 9th - data,v11.s[0] */ - "smlal v16.4s, %[v3].4h, v1.h[4]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v17.4s, %[v3].8h, v1.h[4]\n" /* outr11 = v0 * r2[2] - */ - "smlal v18.4s, %[v3].4h, v1.h[6]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v19.4s, %[v3].8h, v1.h[6]\n" /* outr13 = v0 * r2[6] - */ - "sshll v2.8h, v2.8b, #0 \n" /* r1 : int8 -> int16 */ - - /* r1, r3, mul w05 */ - "smlal v4.4s, %[v5].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v5.4s, %[v5].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal v6.4s, %[v5].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v7.4s, %[v5].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] - */ - "sshll v3.8h, v3.8b, #0 \n" /* r3 : int8 -> int16 */ - "smlal v8.4s, %[v5].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v9.4s, %[v5].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] - */ - "smlal v10.4s, %[v5].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v11.4s, %[v5].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] - */ - - "smlal v12.4s, %[v5].4h, v1.h[2]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v13.4s, %[v5].8h, v1.h[2]\n" /* outr11 = v0 * r2[2] - */ - "smlal v14.4s, %[v5].4h, v1.h[4]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v15.4s, %[v5].8h, v1.h[4]\n" /* outr13 = v0 * r2[6] - */ - "smlal v16.4s, %[v5].4h, v1.h[6]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v17.4s, %[v5].8h, v1.h[6]\n" /* outr11 = v0 * r2[2] - */ - "smlal v18.4s, %[v5].4h, v3.h[0]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v19.4s, %[v5].8h, v3.h[0]\n" /* outr13 = v0 * r2[6] - */ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ - - /* r1, r3, mul w04 */ - "smlal v4.4s, %[v4].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v5.4s, %[v4].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] - */ - "smlal v6.4s, %[v4].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v7.4s, %[v4].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] - */ - "smlal v8.4s, %[v4].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v9.4s, %[v4].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] - */ - "smlal v10.4s, %[v4].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v11.4s, %[v4].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] - */ - - "ldr q0, [%[r4]], #8 \n" /* load input r4 */ - - "smlal v12.4s, %[v4].4h, v1.h[1]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v13.4s, %[v4].8h, v1.h[1]\n" /* outr11 = v0 * r2[2] - */ - "smlal v14.4s, %[v4].4h, v1.h[3]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v15.4s, %[v4].8h, v1.h[3]\n" /* outr13 = v0 * r2[6] - */ - "sshll v0.8h, v0.8b, #0 \n" /* r4 : int8 -> int16 */ - "smlal v16.4s, %[v4].4h, v1.h[5]\n" /* outr10 = v0 * r2[0] - */ - "smlal2 v17.4s, %[v4].8h, v1.h[5]\n" /* outr11 = v0 * r2[2] - */ - "smlal v18.4s, %[v4].4h, v1.h[7]\n" /* outr12 = v0 * r2[4] - */ - "smlal2 v19.4s, %[v4].8h, v1.h[7]\n" /* outr13 = v0 * r2[6] - */ - - "ldr q2, [%[r4]] \n" /* load r4, 9th - data,v10.s[0] */ - "sshll v2.8h, v2.8b, #0 \n" /* r4 : int8 -> int16 */ - - "ldp q1, q3, [%[ptr_out0]] \n" /* load ptr_out + 0 -> - q2, q3 */ - "ldp q20, q21, [%[ptr_out0], #32]\n" /* load ptr_out + 32 -> - q4, q5 */ - - "add v4.4s, v1.4s , v4.4s \n" /* v10 = outr00[0].low - + q2 */ - "add v5.4s, v3.4s , v5.4s \n" /* v11 = outr00[0].high - + q3 */ - "add v6.4s, v20.4s, v6.4s \n" /* v12 = outr01[0].low - + q4 */ - "add v7.4s, v21.4s, v7.4s \n" /* v13 = outr01[0].high - + q5 */ - - "ldp q1 , q3 , [%[ptr_out0], #64]\n" /* load ptr_out + 64 -> - q6, q7 */ - "ldp q20, q21, [%[ptr_out0], #96]\n" /* load ptr_out + 96 -> - q8, q9 */ - - "stp q4, q5 , [%[ptr_out0]], #32\n" /* store q10, q11 -> - ptr_out */ - "stp q6, q7 , [%[ptr_out0]], #32\n" /* store q10, q11 -> - ptr_out */ - - "add v8.4s , v1.4s , v8.4s \n" /* v10 = outr00[0].low - + q2 */ - "add v9.4s , v3.4s , v9.4s \n" /* v11 = outr00[0].high - + q3 */ - "add v10.4s, v20.4s, v10.4s \n" /* v12 = outr01[0].low - + q4 */ - "add v11.4s, v21.4s, v11.4s \n" /* v13 = outr01[0].high - + q5 */ - "stp q8, q9, [%[ptr_out0]], #32\n" /* store q14, q15 -> - ptr_out += 64 */ - "stp q10, q11, [%[ptr_out0]], #32\n" /* store q16, q17 -> - ptr_out += 96 */ - - /* r4, mul w08 */ - "smlal v12.4s, %[v8].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v13.4s, %[v8].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal v14.4s, %[v8].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v15.4s, %[v8].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] - */ - - "smlal v16.4s, %[v8].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v17.4s, %[v8].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] - */ - "smlal v18.4s, %[v8].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v19.4s, %[v8].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] - */ - - /* r4, mul w07 */ - "smlal v12.4s, %[v7].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v13.4s, %[v7].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] - */ - "smlal v14.4s, %[v7].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] - */ - "smlal2 v15.4s, %[v7].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] - */ - - "ldr q1, [%[r2]], #8 \n" /* load input r2 */ - - "smlal v16.4s, %[v7].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v17.4s, %[v7].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] - */ - "smlal v18.4s, %[v7].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v19.4s, %[v7].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] - */ - - "sshll v1.8h, v1.8b, #0 \n" /* r2: int8 -> int16 - */ - - /* r4, mul w06 */ - "ldp q4, q5, [%[ptr_out1]] \n" /* load ptr_out + 0 -> - q2, q3 */ - - "smlal v12.4s, %[v6].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] - */ - "smlal2 v13.4s, %[v6].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] - */ - "smlal v14.4s, %[v6].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] - */ - - "ldp q8, q9, [%[ptr_out1], #64]\n" /* load ptr_out + 64 -> - q6, q7 */ - - "smlal2 v15.4s, %[v6].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] - */ - "smlal v16.4s, %[v6].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] - */ - "smlal2 v17.4s, %[v6].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] - */ - - "ldp q10, q11, [%[ptr_out1], #96]\n" /* load ptr_out + 96 -> - q8, q9 */ - - "smlal v18.4s, %[v6].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] - */ - "smlal2 v19.4s, %[v6].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] - */ - - "ldr q0, [%[r0]], #8 \n" /* load input r2 */ - "ldp q6, q7, [%[ptr_out1], #32]\n" /* load ptr_out + 32 -> - q4, q5 */ - - "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ - - /* store outr1 */ - "add v12.4s, v4.4s , v12.4s\n" /* v10 = outr10[0].low + q2 */ - "add v13.4s, v5.4s , v13.4s\n" /* v11 = outr10[0].high + q3 */ - "add v14.4s, v6.4s , v14.4s\n" /* v12 = outr11[0].low + q4 */ - "add v15.4s, v7.4s , v15.4s\n" /* v13 = outr11[0].high + q5 */ - - "stp q12, q13, [%[ptr_out1]], #32\n" /* store q10, q11 -> - ptr_out */ - - "add v16.4s, v8.4s , v16.4s\n" /* v14 = outr12[0].low + q6 */ - "add v17.4s, v9.4s , v17.4s\n" /* v15 = outr12[0].high + q7 */ - - "stp q14, q15, [%[ptr_out1]], #32\n" /* store q12, q13 -> - ptr_out += 32 */ - - "add v18.4s, v10.4s, v18.4s\n" /* v16 = outr13[0].low + q8 */ - "add v19.4s, v11.4s, v19.4s\n" /* v17 = outr13[0].high + q9 */ - - "stp q16, q17, [%[ptr_out1]], #32\n" /* store q14, q15 -> - ptr_out += 64 */ - "stp q18, q19, [%[ptr_out1]], #32\n" /* store q16, q17 -> - ptr_out += 96 */ - - "bne 1b \n" /* jump to main loop */ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [v0] "w"(v0), - [v1] "w"(v1), - [v2] "w"(v2), - [v3] "w"(v3), - [v4] "w"(v4), - [v5] "w"(v5), - [v6] "w"(v6), - [v7] "w"(v7), - [v8] "w"(v8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); - + "ldr q0, [%[r0]], #8 \n" /* load input r0 */ + "ldr q1, [%[r2]], #8 \n" /* load input r2 */ + "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ + "sshll v1.8h, v1.8b, #0 \n" /* r1: int8 -> int16*/ + "1: \n" /* main loop */ + /* r0, r2 mul w00 */ + "smull v4.4s, %[v0].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smull2 v5.4s, %[v0].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smull v6.4s, %[v0].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]*/ + "smull2 v7.4s, %[v0].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smull v8.4s, %[v0].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]*/ + "smull2 v9.4s, %[v0].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/ + "smull v10.4s, %[v0].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]*/ + "smull2 v11.4s, %[v0].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/ + "smull v12.4s, %[v0].4h, v1.h[0]\n" /* outr10 = v0 * r2[0]*/ + "smull2 v13.4s, %[v0].8h, v1.h[0]\n" /* outr11 = v0 * r2[2]*/ + "smull v14.4s, %[v0].4h, v1.h[2]\n" /* outr12 = v0 * r2[4]*/ + "smull2 v15.4s, %[v0].8h, v1.h[2]\n" /* outr13 = v0 * r2[6]*/ + "smull v16.4s, %[v0].4h, v1.h[4]\n" /* outr10 = v0 * r2[0]*/ + "smull2 v17.4s, %[v0].8h, v1.h[4]\n" /* outr11 = v0 * r2[2]*/ + "smull v18.4s, %[v0].4h, v1.h[6]\n" /* outr12 = v0 * r2[4]*/ + "smull2 v19.4s, %[v0].8h, v1.h[6]\n" /* outr13 = v0 * r2[6]*/ + /* r2, mul w06 */ + "smlal v4.4s, %[v6].4h, v1.h[0]\n" /* outr00 = v6 * r2[1]*/ + "smlal2 v5.4s, %[v6].8h, v1.h[0]\n" /* outr01 = v6 * r2[3]*/ + "smlal v6.4s, %[v6].4h, v1.h[2]\n" /* outr02 = v6 * r2[5]*/ + "smlal2 v7.4s, %[v6].8h, v1.h[2]\n" /* outr03 = v6 * r2[7]*/ + "smlal v8.4s, %[v6].4h, v1.h[4]\n" /* outr00 = v6 * r2[1]*/ + "smlal2 v9.4s, %[v6].8h, v1.h[4]\n" /* outr01 = v6 * r2[3]*/ + "smlal v10.4s, %[v6].4h, v1.h[6]\n" /* outr02 = v6 * r2[5]*/ + "smlal2 v11.4s, %[v6].8h, v1.h[6]\n" /* outr03 = v6 * r2[7]*/ + "ldr q2, [%[r0]]\n" /* load r0, 9th data,v10.s[0] */ + /* r0, r2, mul w01 */ + "smlal v4.4s, %[v1].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v5.4s, %[v1].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/ + "smlal v6.4s, %[v1].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v7.4s, %[v1].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]*/ + "sshll v2.8h, v2.8b, #0 \n" /* r0: int8 -> int16 */ + "smlal v8.4s, %[v1].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v9.4s, %[v1].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]*/ + "smlal v10.4s, %[v1].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v11.4s, %[v1].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]*/ + "smlal v12.4s, %[v1].4h, v1.h[1]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v13.4s, %[v1].8h, v1.h[1]\n" /* outr11 = v0 * r2[2]*/ + "smlal v14.4s, %[v1].4h, v1.h[3]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v15.4s, %[v1].8h, v1.h[3]\n" /* outr13 = v0 * r2[6]*/ + "smlal v16.4s, %[v1].4h, v1.h[5]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v17.4s, %[v1].8h, v1.h[5]\n" /* outr11 = v0 * r2[2]*/ + "smlal v18.4s, %[v1].4h, v1.h[7]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v19.4s, %[v1].8h, v1.h[7]\n" /* outr13 = v0 * r2[6]*/ + /* r2, mul w07 */ + "smlal v4.4s, %[v7].4h, v1.h[1]\n" /* outr00 = v6 * r2[1]*/ + "smlal2 v5.4s, %[v7].8h, v1.h[1]\n" /* outr01 = v6 * r2[3]*/ + "smlal v6.4s, %[v7].4h, v1.h[3]\n" /* outr02 = v6 * r2[5]*/ + "smlal2 v7.4s, %[v7].8h, v1.h[3]\n" /* outr03 = v6 * r2[7]*/ + "smlal v8.4s, %[v7].4h, v1.h[5]\n" /* outr00 = v6 * r2[1]*/ + "smlal2 v9.4s, %[v7].8h, v1.h[5]\n" /* outr01 = v6 * r2[3]*/ + "smlal v10.4s, %[v7].4h, v1.h[7]\n" /* outr02 = v6 * r2[5]*/ + "smlal2 v11.4s, %[v7].8h, v1.h[7]\n" /* outr03 = v6 * r2[7]*/ + "ldr q3, [%[r2]]\n" /* load r2, 9th data,v11.s[0] */ + /* r0, r2, mul w02 */ + "smlal v4.4s, %[v2].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v5.4s, %[v2].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal v6.4s, %[v2].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v7.4s, %[v2].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/ + "sshll v3.8h, v3.8b, #0 \n" /* r2: int8 -> int16*/ + "smlal v8.4s, %[v2].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v9.4s, %[v2].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/ + "smlal v10.4s, %[v2].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v11.4s, %[v2].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]*/ + "ldr q0, [%[r1]], #8 \n" /* load input r1 */ + "smlal v12.4s, %[v2].4h, v1.h[2]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v13.4s, %[v2].8h, v1.h[2]\n" /* outr11 = v0 * r2[2]*/ + "smlal v14.4s, %[v2].4h, v1.h[4]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v15.4s, %[v2].8h, v1.h[4]\n" /* outr13 = v0 * r2[6]*/ + "sshll v0.8h, v0.8b, #0 \n" /* r1 : int8 -> int16 */ + "smlal v16.4s, %[v2].4h, v1.h[6]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v17.4s, %[v2].8h, v1.h[6]\n" /* outr11 = v0 * r2[2]*/ + "smlal v18.4s, %[v2].4h, v3.h[0]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v19.4s, %[v2].8h, v3.h[0]\n" /* outr13 = v0 * r2[6]*/ + /* r2, mul w08 */ + "smlal v4.4s, %[v8].4h, v1.h[2]\n" /* outr00 = v6 * r2[1]*/ + "smlal2 v5.4s, %[v8].8h, v1.h[2]\n" /* outr01 = v6 * r2[3]*/ + "smlal v6.4s, %[v8].4h, v1.h[4]\n" /* outr02 = v6 * r2[5]*/ + "smlal2 v7.4s, %[v8].8h, v1.h[4]\n" /* outr03 = v6 * r2[7]*/ + "smlal v8.4s, %[v8].4h, v1.h[6]\n" /* outr00 = v6 * r2[1]*/ + "smlal2 v9.4s, %[v8].8h, v1.h[6]\n" /* outr01 = v6 * r2[3]*/ + "smlal v10.4s, %[v8].4h, v3.h[0]\n" /* outr02 = v6 * r2[5]*/ + "smlal2 v11.4s, %[v8].8h, v3.h[0]\n" /* outr03 = v6 * r2[7]*/ + "ldr q1, [%[r3]], #8 \n" /* load input r3 */ + /* r1, r3, mul w03 */ + "smlal v4.4s, %[v3].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v5.4s, %[v3].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smlal v6.4s, %[v3].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v7.4s, %[v3].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "sshll v1.8h, v1.8b, #0 \n" /* r3: int8 -> int16 */ + "smlal v8.4s, %[v3].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v9.4s, %[v3].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/ + "smlal v10.4s, %[v3].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v11.4s, %[v3].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/ + "ldr q2, [%[r1]]\n" /* load r1, 9th data,v10.s[0] */ + "smlal v12.4s, %[v3].4h, v1.h[0]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v13.4s, %[v3].8h, v1.h[0]\n" /* outr11 = v0 * r2[2]*/ + "smlal v14.4s, %[v3].4h, v1.h[2]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v15.4s, %[v3].8h, v1.h[2]\n" /* outr13 = v0 * r2[6]*/ + "ldr q3, [%[r3]]\n" /* load r3, 9th data,v11.s[0] */ + "smlal v16.4s, %[v3].4h, v1.h[4]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v17.4s, %[v3].8h, v1.h[4]\n" /* outr11 = v0 * r2[2]*/ + "smlal v18.4s, %[v3].4h, v1.h[6]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v19.4s, %[v3].8h, v1.h[6]\n" /* outr13 = v0 * r2[6]*/ + "sshll v2.8h, v2.8b, #0 \n" /* r1 : int8 -> int16 */ + /* r1, r3, mul w05 */ + "smlal v4.4s, %[v5].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v5.4s, %[v5].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal v6.4s, %[v5].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v7.4s, %[v5].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/ + "sshll v3.8h, v3.8b, #0 \n" /* r3 : int8 -> int16 */ + "smlal v8.4s, %[v5].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v9.4s, %[v5].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/ + "smlal v10.4s, %[v5].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v11.4s, %[v5].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smlal v12.4s, %[v5].4h, v1.h[2]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v13.4s, %[v5].8h, v1.h[2]\n" /* outr11 = v0 * r2[2]*/ + "smlal v14.4s, %[v5].4h, v1.h[4]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v15.4s, %[v5].8h, v1.h[4]\n" /* outr13 = v0 * r2[6]*/ + "smlal v16.4s, %[v5].4h, v1.h[6]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v17.4s, %[v5].8h, v1.h[6]\n" /* outr11 = v0 * r2[2]*/ + "smlal v18.4s, %[v5].4h, v3.h[0]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v19.4s, %[v5].8h, v3.h[0]\n" /* outr13 = v0 * r2[6]*/ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + /* r1, r3, mul w04 */ + "smlal v4.4s, %[v4].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v5.4s, %[v4].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/ + "smlal v6.4s, %[v4].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v7.4s, %[v4].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]*/ + "smlal v8.4s, %[v4].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v9.4s, %[v4].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]*/ + "smlal v10.4s, %[v4].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v11.4s, %[v4].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]*/ + "ldr q0, [%[r4]], #8 \n" /* load input r4 */ + "smlal v12.4s, %[v4].4h, v1.h[1]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v13.4s, %[v4].8h, v1.h[1]\n" /* outr11 = v0 * r2[2]*/ + "smlal v14.4s, %[v4].4h, v1.h[3]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v15.4s, %[v4].8h, v1.h[3]\n" /* outr13 = v0 * r2[6]*/ + "sshll v0.8h, v0.8b, #0 \n" /* r4 : int8 -> int16 */ + "smlal v16.4s, %[v4].4h, v1.h[5]\n" /* outr10 = v0 * r2[0]*/ + "smlal2 v17.4s, %[v4].8h, v1.h[5]\n" /* outr11 = v0 * r2[2]*/ + "smlal v18.4s, %[v4].4h, v1.h[7]\n" /* outr12 = v0 * r2[4]*/ + "smlal2 v19.4s, %[v4].8h, v1.h[7]\n" /* outr13 = v0 * r2[6]*/ + "ldr q2, [%[r4]]\n" /* load r4, 9th data,v10.s[0] */ + "sshll v2.8h, v2.8b, #0\n" /* r4 : int8 -> int16 */ + "ldp q1, q3, [%[ptr_out0]]\n" /* load ptr_out */ + "ldp q20, q21, [%[ptr_out0], #32]\n" /* load ptr_out */ + "add v4.4s, v1.4s , v4.4s\n" /* v10 = outr00[0].low + q2 */ + "add v5.4s, v3.4s , v5.4s\n" /* v11 = outr00[0].high+ q3 */ + "add v6.4s, v20.4s, v6.4s\n" /* v12 = outr01[0].low + q4 */ + "add v7.4s, v21.4s, v7.4s\n" /* v13 = outr01[0].high+ q5 */ + "ldp q1 , q3 , [%[ptr_out0], #64]\n" /* load ptr_out*/ + "ldp q20, q21, [%[ptr_out0], #96]\n" /* load ptr_out*/ + "stp q4, q5 , [%[ptr_out0]], #32\n" /* store q10, q11*/ + "stp q6, q7 , [%[ptr_out0]], #32\n" /* store q10, q11*/ + "add v8.4s , v1.4s , v8.4s\n" /* v10 = outr00[0].low+ q2 */ + "add v9.4s , v3.4s , v9.4s\n" /* v11 = outr00[0].high+q3 */ + "add v10.4s, v20.4s, v10.4s\n" /* v12 = outr01[0].low+q4 */ + "add v11.4s, v21.4s, v11.4s\n" /* v13 = outr01[0].high+q5 */ + "stp q8, q9, [%[ptr_out0]], #32\n" /* store q14, q15*/ + "stp q10, q11, [%[ptr_out0]], #32\n" /* store q16, q17*/ + /* r4, mul w08 */ + "smlal v12.4s, %[v8].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v13.4s, %[v8].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal v14.4s, %[v8].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v15.4s, %[v8].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/ + "smlal v16.4s, %[v8].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v17.4s, %[v8].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/ + "smlal v18.4s, %[v8].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v19.4s, %[v8].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]*/ + /* r4, mul w07 */ + "smlal v12.4s, %[v7].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v13.4s, %[v7].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/ + "smlal v14.4s, %[v7].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]*/ + "smlal2 v15.4s, %[v7].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]*/ + "ldr q1, [%[r2]], #8 \n" /* load input r2 */ + "smlal v16.4s, %[v7].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v17.4s, %[v7].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]*/ + "smlal v18.4s, %[v7].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v19.4s, %[v7].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]*/ + "sshll v1.8h, v1.8b, #0 \n" /* r2: int8 -> int16*/ + /* r4, mul w06 */ + "ldp q4, q5, [%[ptr_out1]] \n" /* load ptr_out*/ + "smlal v12.4s, %[v6].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smlal2 v13.4s, %[v6].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/ + "smlal v14.4s, %[v6].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]*/ + "ldp q8, q9, [%[ptr_out1], #64]\n" /* load ptr_out*/ + "smlal2 v15.4s, %[v6].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/ + "smlal v16.4s, %[v6].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]*/ + "smlal2 v17.4s, %[v6].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/ + "ldp q10, q11, [%[ptr_out1], #96]\n" /* load ptr_out*/ + "smlal v18.4s, %[v6].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]*/ + "smlal2 v19.4s, %[v6].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/ + "ldr q0, [%[r0]], #8 \n" /* load input r2 */ + "ldp q6, q7, [%[ptr_out1], #32]\n" /* load ptr_out*/ + "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ + /* store outr1 */ + "add v12.4s, v4.4s , v12.4s\n" /* v10 = outr10[0].low + q2 */ + "add v13.4s, v5.4s , v13.4s\n" /* v11 = outr10[0].high + q3 */ + "add v14.4s, v6.4s , v14.4s\n" /* v12 = outr11[0].low + q4 */ + "add v15.4s, v7.4s , v15.4s\n" /* v13 = outr11[0].high + q5 */ + "stp q12, q13, [%[ptr_out1]], #32\n" /* store q10, q11*/ + "add v16.4s, v8.4s , v16.4s\n" /* v14 = outr12[0].low + q6 */ + "add v17.4s, v9.4s , v17.4s\n" /* v15 = outr12[0].high + q7 */ + "stp q14, q15, [%[ptr_out1]], #32\n" /* store q12, q13*/ + "add v18.4s, v10.4s, v18.4s\n" /* v16 = outr13[0].low + q8 */ + "add v19.4s, v11.4s, v19.4s\n" /* v17 = outr13[0].high + q9 */ + "stp q16, q17, [%[ptr_out1]], #32\n" /* store q14, q15*/ + "stp q18, q19, [%[ptr_out1]], #32\n" /* store q16, q17*/ + "bne 1b\n" /* jump to main loop */ + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), + [r2] "+r"(r2), [r3] "+r"(r3), [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1) + : [v0] "w"(v0), [v1] "w"(v1), [v2] "w"(v2), + [v3] "w"(v3), [v4] "w"(v4), [v5] "w"(v5), + [v6] "w"(v6), [v7] "w"(v7), [v8] "w"(v8) + : "cc", "memory", "v0", "v1", "v2", "v3", + "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22" + ); + // clang-format on wc0 += 9 * hout_c_block; inr0 += win_round; inr1 += win_round; @@ -683,47 +430,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, block_inr3 = block_inr2 + in_len; block_inr4 = block_inr3 + in_len; } - if (out_type == PRECISION(kFloat)) { - write_to_output_c8_int32_1(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - 2, - c, - c + hout_c_block, - h, - h + h_kernel, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else if (out_type == PRECISION(kInt8)) { - write_to_output_c8_int32_1(pre_out, - dout_batch, - hout_c_block, - 2, - c, - c + hout_c_block, - h, - h + h_kernel, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else { - write_to_output_c8_int32(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - 2, + write_int32_nchwc8_to_nchw(pre_out, + dout_batch, c, c + hout_c_block, h, @@ -734,8 +442,10 @@ void conv_3x3s2_direct_int8(const int8_t* din, hout, wout, flag_relu, - ptr_write); - } + bias_local, + flag_bias, + ptr_write, + scale + c); } } } @@ -743,8 +453,10 @@ void conv_3x3s2_direct_int8(const int8_t* din, #else // __aarch64__ int conv_3x3s2_direct_int8_c_num() { return 4; } + +template void conv_3x3s2_direct_int8(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -753,27 +465,24 @@ void conv_3x3s2_direct_int8(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, Context* ctx, - PrecisionType out_type, const float* scale) { //! 3x3s2 int8 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer - int threads = ctx->threads(); - int stride_w = param.strides[1]; - int pad_w = param.paddings[1]; - int pad_h = param.paddings[0]; bool flag_relu = param.fuse_relu; - bool flag_bias = (param.bias != nullptr); - - //! set 2/3 l2 cache - int l2_size = ctx->llc_size() / 3 * 2; + bool flag_bias = param.bias; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + const int threads = ctx->threads(); + //! set 1/4 l2 cache + int llc_size = ctx->llc_size() / 4; const int hout_c_block = 4; const int hout_r_kernel = 1; const int wout_round = ((wout + 3) / 4) * 4; - const int win_round = wout_round * stride_w + 1; + const int win_round = wout_round * 2 /*stride_w*/ + 1; //! get h block //! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * @@ -781,7 +490,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! win_round = 2 * wout_round + 1 //! hin_r_block = 2 * hout_r_block + 1 int hout_r_block = - (l2_size - 2 * wout_round * chin - chin) / + (llc_size - 2 * wout_round * chin - chin) / ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4); hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; @@ -789,16 +498,15 @@ void conv_3x3s2_direct_int8(const int8_t* din, const int hin_r_block = hout_r_block * 2 + 1; - int8_t* tmp_work_space = ctx->workspace_data(); + auto tmp_work_space = ctx->workspace_data(); int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4; - const int kZeroSize = zero_size; - int32_t ptr_zero[kZeroSize]; + int32_t ptr_zero[zero_size]; // NOLINT memset(ptr_zero, 0, sizeof(int32_t) * zero_size); - const int kWoutRound = wout_round; - int32_t ptr_write[kWoutRound]; + Dtype ptr_write[wout_round]; // NOLINT int in_len = win_round * chin; int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); int pre_out_size = hout_c_block * hout_r_block * wout_round; //! l2_cache start @@ -815,10 +523,9 @@ void conv_3x3s2_direct_int8(const int8_t* din, int out_row_stride = hout_c_block * wout_round; for (int n = 0; n < num; ++n) { - const int8_t* din_batch = din + n * chin * size_in_channel; - int8_t* dout_batch = - reinterpret_cast(dout) + - n * chout * size_out_channel * PrecisionTypeLength(out_type); + const int8_t* din_batch = + static_cast(din) + n * chin * size_in_channel; + auto dout_batch = dout + n * chout * size_out_channel; for (int h = 0; h < hout; h += hout_r_block) { int h_kernel = hout_r_block; if (h + hout_r_block > hout) { @@ -845,24 +552,23 @@ void conv_3x3s2_direct_int8(const int8_t* din, #pragma omp parallel for num_threads(threads) for (int c = 0; c < chout; c += hout_c_block) { #ifdef ARM_WITH_OMP - int32_t* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + - omp_get_thread_num() * pre_out_size; + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size) + + omp_get_thread_num() * pre_out_size; #else - int32_t* pre_out = - reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size); #endif const int8_t* block_inr0 = cblock_inr0; const int8_t* block_inr1 = cblock_inr1; const int8_t* block_inr2 = cblock_inr2; - const int8_t* weight_c = weights + c * w_stride; - const int32_t* bias_ptr = ptr_zero; + float bias_local[4] = {0, 0, 0, 0}; if (flag_bias) { - bias_ptr = bias + c; + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; } - - fill_packed_bias_nxmw_int8(bias_ptr, pre_out, 4, h_kernel, wout_round); + memset(pre_out, 0, pre_out_size * sizeof(int32_t)); for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { const int8_t* wc0 = weight_c; @@ -879,134 +585,97 @@ void conv_3x3s2_direct_int8(const int8_t* din, int32_t* ptr_out0 = pre_out0; const signed char* ptr_wc0 = wc0; int cnt = w_loop; + // clang-format off asm volatile( - "vld1.s32 {d0-d3}, [%[wc0]]! \n" /* w0-w7 */ - "vld1.s32 {d4}, [%[wc0]]! \n" /* w8 */ - "vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ - "vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ - "vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ - "vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ - "vmovl.s8 q7, d4 \n" /* q7 = w8 */ - "vld1.s32 {d0}, [%[r0]]! \n" /* load input r0 -> d0 */ - "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ - "1: \n" /* main loop */ - - /* r0 mul w0 */ - "vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ - "vmull.s16 q9, d6, d0[2] \n" /* q9 = w0 * r0[2] */ - "vmull.s16 q10, d6, d1[0] \n" /* q10 = w0 * r0[4] */ - "vmull.s16 q11, d6, d1[2] \n" /* q11 = w0 * r0[6] */ - - "vld1.s32 {d2}, [%[r1]]! \n" /* load input r1 -> d2 */ - "vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ - - /* r0 mul w1 */ - "vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ - "vmlal.s16 q9, d7, d0[3] \n" /* q9 = w1 * r0[3] */ - "vmlal.s16 q10, d7, d1[1] \n" /* q10 = w1 * r0[5] */ - "vmlal.s16 q11, d7, d1[3] \n" /* q11 = w1 * r0[7] */ - - "vld1.s32 {d4}, [%[r0]] \n" /* load r0[8] -> d4 */ - "vmovl.s8 q2 , d4 \n" /* movl d4 -> q2 */ - - /* r0 mul w2 */ - "vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ - "vmlal.s16 q9, d8, d1[0] \n" /* q9 = w2 * r0[4] */ - "vmlal.s16 q10, d8, d1[2] \n" /* q10 = w2 * r0[6] */ - "vmlal.s16 q11, d8, d4[0] \n" /* q11 = w2 * r0[8] */ - - "subs %[cnt], #1 \n" /* loop count -1 */ - - /* r1 mul w3 */ - "vmlal.s16 q8, d9, d2[0] \n" /* q8 = w3 * r1[0] */ - "vmlal.s16 q9, d9, d2[2] \n" /* q9 = w3 * r1[2] */ - "vmlal.s16 q10, d9, d3[0] \n" /* q10 = w3 * r1[4] */ - "vmlal.s16 q11, d9, d3[2] \n" /* q11 = w3 * r1[6] */ - - "vld1.s32 {d4}, [%[r2]]! \n" /* load input r2 -> d4*/ - "vmovl.s8 q2, d4 \n" /* movl d4 -> q2 */ - - /* r1 mul w4 */ - "vmlal.s16 q8, d10, d2[1] \n" /* q8 = w4 * r1[1] */ - "vmlal.s16 q9, d10, d2[3] \n" /* q9 = w4 * r1[3] */ - "vmlal.s16 q10, d10, d3[1] \n" /* q10 = w4 * r1[5] */ - "vmlal.s16 q11, d10, d3[3] \n" /* q11 = w4 * r1[7] */ - - "vld1.s32 {d0}, [%[r1]] \n" /* load r1[8] -> d0 */ - "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ - - /* r1 mul w5 */ - "vmlal.s16 q8, d11, d2[2] \n" /* q8 = w5 * r1[2] */ - "vmlal.s16 q9, d11, d3[0] \n" /* q9 = w5 * r1[4] */ - "vmlal.s16 q10, d11, d3[2] \n" /* q10 = w5 * r1[6] */ - "vmlal.s16 q11, d11, d0[0] \n" /* q11 = w5 * r1[8] */ - - /* r2 mul w6 */ - "vmlal.s16 q8, d12, d4[0] \n" /* q8 = w6 * r2[0] */ - "vmlal.s16 q9, d12, d4[2] \n" /* q9 = w6 * r2[2] */ - "vmlal.s16 q10, d12, d5[0] \n" /* q10 = w6 * r2[4] */ - "vmlal.s16 q11, d12, d5[2] \n" /* q11 = w6 * r2[6] */ - - "vld1.s32 {d24-d27}, [%[ptr_out0]] \n" /* load output -> q12, - q13 */ - - /* r2 mul w7 */ - "vmlal.s16 q8, d13, d4[1] \n" /* q8 = w7 * r2[1] */ - "vmlal.s16 q9, d13, d4[3] \n" /* q9 = w7 * r2[3] */ - "vmlal.s16 q10, d13, d5[1] \n" /* q10 = w7 * r2[5] */ - "vmlal.s16 q11, d13, d5[3] \n" /* q11 = w7 * r2[7] */ - - "vld1.s32 {d0}, [%[r2]] \n" /* load r2[8] -> d0 */ - "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ - - /* r2 mul w8 */ - "vmlal.s16 q8, d14, d4[2] \n" /* q8 = w8 * r2[2] */ - "vmlal.s16 q9, d14, d5[0] \n" /* q9 = w8 * r2[4] */ - "vmlal.s16 q10, d14, d5[2] \n" /* q10 = w8 * r2[6] */ - "vmlal.s16 q11, d14, d0[0] \n" /* q11 = w8 * r2[8] */ - - "vadd.s32 q12, q8, q12 \n" /* out[0] += q8 */ - "vadd.s32 q13, q9, q13 \n" /* out[1] += q9 */ - "vst1.s32 {d24-d27}, [%[ptr_out0]]! \n" /* store q12, q13 -> - output[0,1] */ - - "vld1.s32 {d0}, [%[r0]]! \n" /* load next input r0 -> d0*/ - "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ - - "vld1.s32 {d28-d31}, [%[ptr_out0]] \n" /* load output[0,1] -> - q14, q15 */ - "vadd.s32 q14, q10, q14 \n" /* out[2] += q10 */ - "vadd.s32 q15, q11, q15 \n" /* out[3] += q11 */ - "vst1.s32 {d28-d31}, [%[ptr_out0]]! \n" /* store q14, q15 -> - output[2,3] */ - - "bne 1b \n" /* jump to main loop */ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [ptr_out0] "+r"(ptr_out0), - [wc0] "+r"(ptr_wc0) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + "vld1.s32 {d0-d3}, [%[wc0]]! \n" /* w0-w7 */ + "vld1.s32 {d4}, [%[wc0]]! \n" /* w8 */ + "vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ + "vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ + "vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ + "vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ + "vmovl.s8 q7, d4 \n" /* q7 = w8 */ + "vld1.s32 {d0}, [%[r0]]! \n" /* load input r0, d0 */ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + "1: \n" /* main loop */ + /* r0 mul w0 */ + "vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ + "vmull.s16 q9, d6, d0[2] \n" /* q9 = w0 * r0[2] */ + "vmull.s16 q10, d6, d1[0] \n" /* q10 = w0 * r0[4] */ + "vmull.s16 q11, d6, d1[2] \n" /* q11 = w0 * r0[6] */ + "vld1.s32 {d2}, [%[r1]]! \n" /* load input r1, d2 */ + "vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ + /* r0 mul w1 */ + "vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ + "vmlal.s16 q9, d7, d0[3] \n" /* q9 = w1 * r0[3] */ + "vmlal.s16 q10, d7, d1[1] \n" /* q10 = w1 * r0[5] */ + "vmlal.s16 q11, d7, d1[3] \n" /* q11 = w1 * r0[7] */ + "vld1.s32 {d4}, [%[r0]] \n" /* load r0[8] -> d4 */ + "vmovl.s8 q2 , d4 \n" /* movl d4 -> q2 */ + /* r0 mul w2 */ + "vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ + "vmlal.s16 q9, d8, d1[0] \n" /* q9 = w2 * r0[4] */ + "vmlal.s16 q10, d8, d1[2] \n" /* q10 = w2 * r0[6] */ + "vmlal.s16 q11, d8, d4[0] \n" /* q11 = w2 * r0[8] */ + "subs %[cnt], #1 \n" /* loop count -1 */ + /* r1 mul w3 */ + "vmlal.s16 q8, d9, d2[0] \n" /* q8 = w3 * r1[0] */ + "vmlal.s16 q9, d9, d2[2] \n" /* q9 = w3 * r1[2] */ + "vmlal.s16 q10, d9, d3[0] \n" /* q10 = w3 * r1[4] */ + "vmlal.s16 q11, d9, d3[2] \n" /* q11 = w3 * r1[6] */ + "vld1.s32 {d4}, [%[r2]]! \n" /* load input r2, d4*/ + "vmovl.s8 q2, d4 \n" /* movl d4 -> q2 */ + /* r1 mul w4 */ + "vmlal.s16 q8, d10, d2[1] \n" /* q8 = w4 * r1[1] */ + "vmlal.s16 q9, d10, d2[3] \n" /* q9 = w4 * r1[3] */ + "vmlal.s16 q10, d10, d3[1] \n" /* q10 = w4 * r1[5] */ + "vmlal.s16 q11, d10, d3[3] \n" /* q11 = w4 * r1[7] */ + "vld1.s32 {d0}, [%[r1]] \n" /* load r1[8] -> d0 */ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + /* r1 mul w5 */ + "vmlal.s16 q8, d11, d2[2] \n" /* q8 = w5 * r1[2] */ + "vmlal.s16 q9, d11, d3[0] \n" /* q9 = w5 * r1[4] */ + "vmlal.s16 q10, d11, d3[2] \n" /* q10 = w5 * r1[6] */ + "vmlal.s16 q11, d11, d0[0] \n" /* q11 = w5 * r1[8] */ + /* r2 mul w6 */ + "vmlal.s16 q8, d12, d4[0] \n" /* q8 = w6 * r2[0] */ + "vmlal.s16 q9, d12, d4[2] \n" /* q9 = w6 * r2[2] */ + "vmlal.s16 q10, d12, d5[0] \n" /* q10 = w6 * r2[4] */ + "vmlal.s16 q11, d12, d5[2] \n" /* q11 = w6 * r2[6] */ + "vld1.s32 {d24-d27}, [%[ptr_out0]] \n" /* load output, q12,q13 */ + /* r2 mul w7 */ + "vmlal.s16 q8, d13, d4[1] \n" /* q8 = w7 * r2[1] */ + "vmlal.s16 q9, d13, d4[3] \n" /* q9 = w7 * r2[3] */ + "vmlal.s16 q10, d13, d5[1] \n" /* q10 = w7 * r2[5] */ + "vmlal.s16 q11, d13, d5[3] \n" /* q11 = w7 * r2[7] */ + "vld1.s32 {d0}, [%[r2]] \n" /* load r2[8] -> d0 */ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + /* r2 mul w8 */ + "vmlal.s16 q8, d14, d4[2] \n" /* q8 = w8 * r2[2] */ + "vmlal.s16 q9, d14, d5[0] \n" /* q9 = w8 * r2[4] */ + "vmlal.s16 q10, d14, d5[2] \n" /* q10 = w8 * r2[6] */ + "vmlal.s16 q11, d14, d0[0] \n" /* q11 = w8 * r2[8] */ + "vadd.s32 q12, q8, q12 \n" /* out[0] += q8 */ + "vadd.s32 q13, q9, q13 \n" /* out[1] += q9 */ + "vst1.s32 {d24-d27}, [%[ptr_out0]]! \n" /* store output[0,1]*/ + "vld1.s32 {d0}, [%[r0]]! \n" /* load next input r0, d0*/ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + "vld1.s32 {d28-d31}, [%[ptr_out0]] \n" /* load output[0,1]*/ + "vadd.s32 q14, q10, q14 \n" /* out[2] += q10 */ + "vadd.s32 q15, q11, q15 \n" /* out[3] += q11 */ + "vst1.s32 {d28-d31}, [%[ptr_out0]]! \n" /* store output[2,3] */ + "bne 1b \n" /* jump to main loop */ + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [ptr_out0] "+r"(ptr_out0), + [wc0] "+r"(ptr_wc0) + : + : "cc", "memory", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15" + ); + // clang-format on wc0 += 9 * hout_c_block; inr0 += win_round; inr1 += win_round; @@ -1016,47 +685,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, block_inr1 = block_inr0 + in_len; block_inr2 = block_inr1 + in_len; } - if (out_type == PRECISION(kFloat)) { - write_to_output_c4_int32_1(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - 1, - c, - c + hout_c_block, - h, - h + h_kernel, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else if (out_type == PRECISION(kInt8)) { - write_to_output_c4_int32_1(pre_out, - dout_batch, - hout_c_block, - 1, - c, - c + hout_c_block, - h, - h + h_kernel, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - &scale[c], - out_type); - } else { - write_to_output_c4_int32(pre_out, - reinterpret_cast(dout_batch), - hout_c_block, - 1, + write_int32_nchwc4_to_nchw(pre_out, + dout_batch, c, c + hout_c_block, h, @@ -1067,14 +697,46 @@ void conv_3x3s2_direct_int8(const int8_t* din, hout, wout, flag_relu, - ptr_write); - } + bias_local, + flag_bias, + ptr_write, + scale + c); } } } } #endif // __aarch64__ +template void conv_3x3s2_direct_int8(const int8_t* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx, + const float* scale); + +template void conv_3x3s2_direct_int8(const int8_t* din, + int8_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx, + const float* scale); + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_depthwise_5x5s1.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc similarity index 99% rename from lite/backends/arm/math/conv_depthwise_5x5s1.cc rename to lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index 2b9744665c504323551b2e9d6ba164eb6b75d0fc..1a2e42e0a9ca4193be84a21247112de8cdc144a1 100644 --- a/lite/backends/arm/math/conv_depthwise_5x5s1.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/arm/math/conv_depthwise.h" #include +#include "lite/backends/arm/math/conv_depthwise.h" namespace paddle { namespace lite { @@ -5073,7 +5073,7 @@ void conv_depthwise_5x5s1_small_impl(const float* din, int w_in_new = w_in + 2 * pad_new; int h_out_new = h_out - 2 * pad_0; int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; + float zero_ptr[w_in_new + w_out]; // NOLINT memset(zero_ptr, 0, w_in_new * sizeof(float)); float* write_ptr = zero_ptr + w_in_new; int pad_cnt = pad_0 >> 2; @@ -5320,7 +5320,7 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din, int pad_0 = pad - pad_new; int h_in_new = h_in + 2 * pad_new; int w_in_new = w_in + 2 * pad_new; - float zero_ptr[w_in_new + w_out]; + float zero_ptr[w_in_new + w_out]; // NOLINT memset(zero_ptr, 0, w_in_new * sizeof(float)); float* write_ptr = zero_ptr + w_in_new; int h_out_new = h_out - 2 * pad_0; @@ -9177,7 +9177,7 @@ void conv_depthwise_5x5s1_small_impl(const float* din, int w_in_new = w_in + 2 * pad_new; int h_out_new = h_out - 2 * pad_0; int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; + float zero_ptr[w_in_new + w_out]; // NOLINT memset(zero_ptr, 0, w_in_new * sizeof(float)); float* write_ptr = zero_ptr + w_in_new; int pad_cnt = pad_0 >> 2; @@ -9359,7 +9359,7 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din, int w_in_new = w_in + 2 * pad_new; int h_out_new = h_out - 2 * pad_0; int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; + float zero_ptr[w_in_new + w_out]; // NOLINT memset(zero_ptr, 0, w_in_new * sizeof(float)); float* write_ptr = zero_ptr + w_in_new; int pad_cnt = pad_0 >> 2; @@ -9523,21 +9523,21 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din, } #endif // __aarch64__ -void conv_depthwise_5x5s1(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { +void conv_depthwise_5x5s1_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { if (win < 4) { if (flag_relu) { conv_depthwise_5x5s1_small_relu_impl(din, diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..802082048c86beeeecfe64a0de09880b1b9b0137 --- /dev/null +++ b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc @@ -0,0 +1,776 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +void conv_depthwise_5x5s1_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx) { + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + + const int hout_c_block = 8; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 4; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! win_round = wout_round + 4 + //! hin_r_block = hout_r_block + 4 + int hout_r_block = (llc_size - 4 * win_round * hout_c_block * threads) / + (win_round * hout_c_block * threads + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 4; + + auto tmp_work_space = ctx->workspace_data(); + int8_t ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(int8_t) * win_round); + Dtype ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + int8_t* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; + for (int n = 0; n < num; ++n) { + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = reinterpret_cast(dout) + + n * chout * size_out_channel * sizeof(Dtype); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 4; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int8_t* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4); + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size); +#else + int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); + auto pre_din = tmp_din; +#endif + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); + + const int8_t* block_inr0 = pre_din; + const int8_t* block_inr1 = block_inr0 + in_len; + const int8_t* block_inr2 = block_inr1 + in_len; + const int8_t* block_inr3 = block_inr2 + in_len; + const int8_t* block_inr4 = block_inr3 + in_len; + + const int8_t* weight_c = weights + c * w_stride; + float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + bias_local[4] = bias[c + 4]; + bias_local[5] = bias[c + 5]; + bias_local[6] = bias[c + 6]; + bias_local[7] = bias[c + 7]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + const int8_t* inr3 = block_inr3; + const int8_t* inr4 = block_inr4; + + int32_t* ptr_out0 = pre_out + hk * out_row_stride; +// clang-format off +#ifdef __aarch64__ + auto wptr = weight_c; + asm volatile( + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */ + "1:\n" + /* in r0 */ + "smull v20.8h, v0.8b, v8.8b\n" /* w0, int16, out0 */ + "smull v21.8h, v1.8b, v8.8b\n" /* w0, int16, out1 */ + "smull v22.8h, v2.8b, v8.8b\n" /* w0, int16, out2 */ + "smull v23.8h, v3.8b, v8.8b\n" /* w0, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]]\n" /* load r0 4-7 */ + "smlal v20.8h, v1.8b, v9.8b\n" /* w1, int16, out0 */ + "smlal v21.8h, v2.8b, v9.8b\n" /* w1, int16, out1 */ + "smlal v22.8h, v3.8b, v9.8b\n" /* w1, int16, out2 */ + "smlal v23.8h, v4.8b, v9.8b\n" /* w1, int16, out3 */ + "sxtl v24.4s, v20.4h\n" /* mov to out0 low */ + "sxtl2 v25.4s, v20.8h\n" /* mov to out0 hig */ + "sxtl v26.4s, v21.4h\n" /* mov to out1 low */ + "sxtl2 v27.4s, v21.8h\n" /* mov to out1 hig */ + "sxtl v28.4s, v22.4h\n" /* mov to out2 low */ + "sxtl2 v29.4s, v22.8h\n" /* mov to out2 hig */ + "sxtl v30.4s, v23.4h\n" /* mov to out3 low */ + "sxtl2 v31.4s, v23.8h\n" /* mov to out3 hig */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 4-7 */ + + "smull v20.8h, v2.8b, v10.8b\n" /* w2, int16, out0 */ + "smull v21.8h, v3.8b, v10.8b\n" /* w2, int16, out1 */ + "smull v22.8h, v4.8b, v10.8b\n" /* w2, int16, out2 */ + "smull v23.8h, v5.8b, v10.8b\n" /* w2, int16, out3 */ + "smlal v20.8h, v3.8b, v11.8b\n" /* w3, int16, out0 */ + "smlal v21.8h, v4.8b, v11.8b\n" /* w3, int16, out1 */ + "smlal v22.8h, v5.8b, v11.8b\n" /* w3, int16, out2 */ + "smlal v23.8h, v6.8b, v11.8b\n" /* w3, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r1]], #32\n" /* load r1 0-3 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v4.8b, v12.8b\n" /* w4, int16, out0 */ + "smull v21.8h, v5.8b, v12.8b\n" /* w4, int16, out1 */ + "smull v22.8h, v6.8b, v12.8b\n" /* w4, int16, out2 */ + "smull v23.8h, v7.8b, v12.8b\n" /* w4, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r1]]\n" /* load r1 4-7 */ + /* in r1 */ + "smlal v20.8h, v0.8b, v13.8b\n" /* w5, int16, out0 */ + "smlal v21.8h, v1.8b, v13.8b\n" /* w5, int16, out1 */ + "smlal v22.8h, v2.8b, v13.8b\n" /* w5, int16, out2 */ + "smlal v23.8h, v3.8b, v13.8b\n" /* w5, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 8-11 */ + + "smull v20.8h, v1.8b, v14.8b\n" /* w6, int16, out0 */ + "smull v21.8h, v2.8b, v14.8b\n" /* w6, int16, out1 */ + "smull v22.8h, v3.8b, v14.8b\n" /* w6, int16, out2 */ + "smull v23.8h, v4.8b, v14.8b\n" /* w6, int16, out3 */ + "smlal v20.8h, v2.8b, v15.8b\n" /* w7, int16, out0 */ + "smlal v21.8h, v3.8b, v15.8b\n" /* w7, int16, out1 */ + "smlal v22.8h, v4.8b, v15.8b\n" /* w7, int16, out2 */ + "smlal v23.8h, v5.8b, v15.8b\n" /* w7, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v3.8b, v16.8b\n" /* w8, int16, out0 */ + "smull v21.8h, v4.8b, v16.8b\n" /* w8, int16, out1 */ + "smull v22.8h, v5.8b, v16.8b\n" /* w8, int16, out2 */ + "smull v23.8h, v6.8b, v16.8b\n" /* w8, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r2]], #32\n" /* load r2 0-3 */ + "smlal v20.8h, v4.8b, v17.8b\n" /* w9, int16, out0 */ + "smlal v21.8h, v5.8b, v17.8b\n" /* w9, int16, out1 */ + "smlal v22.8h, v6.8b, v17.8b\n" /* w9, int16, out2 */ + "smlal v23.8h, v7.8b, v17.8b\n" /* w9, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r2]]\n" /* load r2 4-7 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + /* in r2 */ + "smull v20.8h, v0.8b, v18.8b\n" /* w10, int16, out0 */ + "smull v21.8h, v1.8b, v18.8b\n" /* w10, int16, out1 */ + "smull v22.8h, v2.8b, v18.8b\n" /* w10, int16, out2 */ + "smull v23.8h, v3.8b, v18.8b\n" /* w10, int16, out3 */ + "smlal v20.8h, v1.8b, v19.8b\n" /* w11, int16, out0 */ + "smlal v21.8h, v2.8b, v19.8b\n" /* w11, int16, out1 */ + "smlal v22.8h, v3.8b, v19.8b\n" /* w11, int16, out2 */ + "smlal v23.8h, v4.8b, v19.8b\n" /* w11, int16, out3 */ + + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 12-15 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v2.8b, v8.8b\n" /* w12, int16, out0 */ + "smull v21.8h, v3.8b, v8.8b\n" /* w12, int16, out1 */ + "smull v22.8h, v4.8b, v8.8b\n" /* w12, int16, out2 */ + "smull v23.8h, v5.8b, v8.8b\n" /* w12, int16, out3 */ + "smlal v20.8h, v3.8b, v9.8b\n" /* w13, int16, out0 */ + "smlal v21.8h, v4.8b, v9.8b\n" /* w13, int16, out1 */ + "smlal v22.8h, v5.8b, v9.8b\n" /* w13, int16, out2 */ + "smlal v23.8h, v6.8b, v9.8b\n" /* w13, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r3]], #32\n" /* load r3 0-3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "smull v20.8h, v4.8b, v10.8b\n" /* w14, int16, out0 */ + "smull v21.8h, v5.8b, v10.8b\n" /* w14, int16, out1 */ + "smull v22.8h, v6.8b, v10.8b\n" /* w14, int16, out2 */ + "smull v23.8h, v7.8b, v10.8b\n" /* w14, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r3]]\n" /* load r3 4-7 */ + /* in r3 */ + "smlal v20.8h, v0.8b, v11.8b\n" /* w15, int16, out0 */ + "smlal v21.8h, v1.8b, v11.8b\n" /* w15, int16, out1 */ + "smlal v22.8h, v2.8b, v11.8b\n" /* w15, int16, out2 */ + "smlal v23.8h, v3.8b, v11.8b\n" /* w15, int16, out3 */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 16-19 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v1.8b, v12.8b\n" /* w16, int16, out0 */ + "smull v21.8h, v2.8b, v12.8b\n" /* w16, int16, out1 */ + "smull v22.8h, v3.8b, v12.8b\n" /* w16, int16, out2 */ + "smull v23.8h, v4.8b, v12.8b\n" /* w16, int16, out3 */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 20-23 */ + "smlal v20.8h, v2.8b, v13.8b\n" /* w17, int16, out0 */ + "smlal v21.8h, v3.8b, v13.8b\n" /* w17, int16, out1 */ + "smlal v22.8h, v4.8b, v13.8b\n" /* w17, int16, out2 */ + "smlal v23.8h, v5.8b, v13.8b\n" /* w17, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v3.8b, v14.8b\n" /* w18, int16, out0 */ + "smull v21.8h, v4.8b, v14.8b\n" /* w18, int16, out1 */ + "smull v22.8h, v5.8b, v14.8b\n" /* w18, int16, out2 */ + "smull v23.8h, v6.8b, v14.8b\n" /* w18, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r4]], #32\n" /* load r4 0-3 */ + "smlal v20.8h, v4.8b, v15.8b\n" /* w19, int16, out0 */ + "smlal v21.8h, v5.8b, v15.8b\n" /* w19, int16, out1 */ + "smlal v22.8h, v6.8b, v15.8b\n" /* w19, int16, out2 */ + "smlal v23.8h, v7.8b, v15.8b\n" /* w19, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r4]]\n" /* load r4 4-7 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + /* in r4 */ + "smull v20.8h, v0.8b, v16.8b\n" /* w20, int16, out0 */ + "smull v21.8h, v1.8b, v16.8b\n" /* w20, int16, out1 */ + "smull v22.8h, v2.8b, v16.8b\n" /* w20, int16, out2 */ + "smull v23.8h, v3.8b, v16.8b\n" /* w20, int16, out3 */ + "smlal v20.8h, v1.8b, v17.8b\n" /* w21, int16, out0 */ + "smlal v21.8h, v2.8b, v17.8b\n" /* w21, int16, out1 */ + "smlal v22.8h, v3.8b, v17.8b\n" /* w21, int16, out2 */ + "smlal v23.8h, v4.8b, v17.8b\n" /* w21, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "ld1 {v12.8b}, [%[wc]], #8\n" /* load wc 24 */ + "smull v20.8h, v2.8b, v18.8b\n" /* w22, int16, out0 */ + "smull v21.8h, v3.8b, v18.8b\n" /* w22, int16, out1 */ + "smull v22.8h, v4.8b, v18.8b\n" /* w22, int16, out2 */ + "smull v23.8h, v5.8b, v18.8b\n" /* w22, int16, out3 */ + "smlal v20.8h, v3.8b, v19.8b\n" /* w23, int16, out0 */ + "smlal v21.8h, v4.8b, v19.8b\n" /* w23, int16, out1 */ + "smlal v22.8h, v5.8b, v19.8b\n" /* w23, int16, out2 */ + "smlal v23.8h, v6.8b, v19.8b\n" /* w23, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */ + "sub %[wc], %[wc], #200 \n" + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */ + "smull v20.8h, v4.8b, v12.8b\n" /* w24, int16, out0 */ + "smull v21.8h, v5.8b, v12.8b\n" /* w24, int16, out1 */ + "smull v22.8h, v6.8b, v12.8b\n" /* w24, int16, out2 */ + "smull v23.8h, v7.8b, v12.8b\n" /* w24, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "stp q24, q25, [%[ptr_out0]], #32\n" + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "stp q26, q27, [%[ptr_out0]], #32\n" + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "subs %w[cnt], %w[cnt], #1\n" + "stp q28, q29, [%[ptr_out0]], #32\n" + "stp q30, q31, [%[ptr_out0]], #32\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc] "+r"(wptr), + [ptr_out0] "+r"(ptr_out0) + : + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25", + "v26","v27","v28","v29","v30","v31" + ); +#else + auto wptr = weight_c; + asm volatile( + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "1:\n" + /* inr0 */ + "vmull.s8 q4, d0, d6\n" /* int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* int16, out3 */ + "vmlal.s8 q4, d1, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* int16, out3 */ + "vmovl.s16 q8, d8\n" /* mov to out0 low */ + "vmovl.s16 q9, d9\n" /* mov to out0 hig */ + "vmovl.s16 q10, d10\n" /* mov to out1 low */ + "vmovl.s16 q11, d11\n" /* mov to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */ + "vmovl.s16 q12, d12\n" /* mov to out2 low */ + "vmovl.s16 q13, d13\n" /* mov to out2 hig */ + "vmovl.s16 q14, d14\n" /* mov to out3 low */ + "vmovl.s16 q15, d15\n" /* mov to out3 hig */ + "vld1.32 {d0-d1}, [%[r0]]\n" /* load r0, 6-7 */ + + "vmull.s8 q4, d2, d6\n" /* w2, int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* w2, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w2, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w2, int16, out3 */ + "vmlal.s8 q4, d3, d7\n" /* w3, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w3, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w3, int16, out2 */ + "vmlal.s8 q7, d0, d7\n" /* w3, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */ + "sub %[r0], %[r0], #16\n" /* r0 = r0 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d4, d6\n" /* w4, int16, out0 */ + "vmull.s8 q5, d5, d6\n" /* w4, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w4, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w4, int16, out3 */ + "vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */ + /* inr1 */ + "vmlal.s8 q4, d0, d7\n" /* w5, int16, out0 */ + "vmlal.s8 q5, d1, d7\n" /* w5, int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* w5, int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* w5, int16, out3 */ + "vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4-5 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d1, d6\n" /* w6, int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* w6, int16, out1 */ + "vmull.s8 q6, d3, d6\n" /* w6, int16, out2 */ + "vmull.s8 q7, d4, d6\n" /* w6, int16, out3 */ + "vld1.32 {d0-d1}, [%[r1]]\n" /* load r1, 6-7 */ + "vmlal.s8 q4, d2, d7\n" /* w7, int16, out0 */ + "vmlal.s8 q5, d3, d7\n" /* w7, int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* w7, int16, out2 */ + "vmlal.s8 q7, d5, d7\n" /* w7, int16, out3 */ + "sub %[r1], %[r1], #16\n" /* r0 = r0 - 16 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w8-w9 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d3, d6\n" /* w8, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w8, int16, out1 */ + "vmull.s8 q6, d5, d6\n" /* w8, int16, out2 */ + "vmull.s8 q7, d0, d6\n" /* w8, int16, out3 */ + "vmlal.s8 q4, d4, d7\n" /* w9, int16, out0 */ + "vmlal.s8 q5, d5, d7\n" /* w9, int16, out1 */ + "vmlal.s8 q6, d0, d7\n" /* w9, int16, out2 */ + "vmlal.s8 q7, d1, d7\n" /* w9, int16, out3 */ + "vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w10-w11 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */ + + /* inr2 */ + "vmull.s8 q4, d0, d6\n" /* w10, int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* w10, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w10, int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* w10, int16, out3 */ + "vmlal.s8 q4, d1, d7\n" /* w11, int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* w11, int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* w11, int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* w11, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w12-w13 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d0-d1}, [%[r2]]\n" /* load r2, 6-7 */ + + "vmull.s8 q4, d2, d6\n" /* w12, int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* w12, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w12, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w12, int16, out3 */ + "vmlal.s8 q4, d3, d7\n" /* w13, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w13, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w13, int16, out2 */ + "vmlal.s8 q7, d0, d7\n" /* w13, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w14-w15 */ + "sub %[r2], %[r2], #16\n" /* r2 = r2 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d4, d6\n" /* w14, int16, out0 */ + "vmull.s8 q5, d5, d6\n" /* w14, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w14, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w14, int16, out3 */ + "vld1.32 {d0-d3}, [%[r3]]!\n" /* load r3, 0-3 */ + /* inr3 */ + "vmlal.s8 q4, d0, d7\n" /* w15, int16, out0 */ + "vmlal.s8 q5, d1, d7\n" /* w15, int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* w15, int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* w15, int16, out3 */ + "vld1.32 {d4-d5}, [%[r3]]!\n" /* load r3, 4-5 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w16-w17 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d1, d6\n" /* w16, int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* w16, int16, out1 */ + "vmull.s8 q6, d3, d6\n" /* w16, int16, out2 */ + "vmull.s8 q7, d4, d6\n" /* w16, int16, out3 */ + "vld1.32 {d0-d1}, [%[r3]]\n" /* load r3, 6-7 */ + "vmlal.s8 q4, d2, d7\n" /* w17, int16, out0 */ + "vmlal.s8 q5, d3, d7\n" /* w17, int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* w17, int16, out2 */ + "vmlal.s8 q7, d5, d7\n" /* w17, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w18-w19 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "sub %[r3], %[r3], #16\n" /* r3 = r3 - 16 */ + + "vmull.s8 q4, d3, d6\n" /* w18, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w18, int16, out1 */ + "vmull.s8 q6, d5, d6\n" /* w18, int16, out2 */ + "vmull.s8 q7, d0, d6\n" /* w18, int16, out3 */ + "vmlal.s8 q4, d4, d7\n" /* w19, int16, out0 */ + "vmlal.s8 q5, d5, d7\n" /* w19, int16, out1 */ + "vmlal.s8 q6, d0, d7\n" /* w19, int16, out2 */ + "vmlal.s8 q7, d1, d7\n" /* w19, int16, out3 */ + "vld1.32 {d0-d3}, [%[r4]]!\n" /* load r4, 0-3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w20-w21 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d4-d5}, [%[r4]]!\n" /* load r4, 4-5 */ + + /* inr4 */ + "vmull.s8 q4, d0, d6\n" /* w20, int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* w20, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w20, int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* w20, int16, out3 */ + "vmlal.s8 q4, d1, d7\n" /* w21, int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* w21, int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* w21, int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* w21, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w22-w23 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d0-d1}, [%[r4]]\n" /* load r4, 5-6 */ + + "vmull.s8 q4, d2, d6\n" /* w22, int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* w22, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w22, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w22, int16, out3 */ + "vmlal.s8 q4, d3, d7\n" /* w23, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w23, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w23, int16, out2 */ + "vmlal.s8 q7, d0, d7\n" /* w23, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w24 */ + "sub %[r4], %[r4], #16\n" /* r4 = r4 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "sub %[wptr], %[wptr], #200 \n" /* wptr = wptr - 200 */ + + "vmull.s8 q4, d4, d6\n" /* w22, int16, out0 */ + "vmull.s8 q5, d5, d6\n" /* w22, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w22, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w22, int16, out3 */ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 0-3 */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vst1.32 {d16-d19}, [%[ptr_out0]]!\n"/* store out0 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vst1.32 {d20-d23}, [%[ptr_out0]]!\n"/*store out1 */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "subs %[cnt], #1\n" /* cnt = cnt - 1 */ + "vst1.32 {d24-d27}, [%[ptr_out0]]!\n"/* store out2 */ + "vst1.32 {d28-d31}, [%[ptr_out0]]!\n"/* store out3 */ + "bne 1b\n" /* branch main loop */ + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [ptr_out0] "+r"(ptr_out0), + [wptr] "+r"(wptr) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + // clang-format on + int32_t* ptr_tmp = ptr_out0 - w_loop * 32; + block_inr0 = block_inr1; + block_inr1 = block_inr2; + block_inr2 = block_inr3; + block_inr3 = block_inr4; + block_inr4 = block_inr3 + in_len; + } + write_int32_nchwc8_to_nchw(pre_out, + reinterpret_cast(dout_batch), + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + bias_local, + flag_bias, + ptr_write, + scale + c); + } + } + } +} + +template void conv_depthwise_5x5s1_int8(int8_t* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + +template void conv_depthwise_5x5s1_int8(float* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_5x5s2.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc similarity index 99% rename from lite/backends/arm/math/conv_depthwise_5x5s2.cc rename to lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index dd715fd5345c20380de9a93e035243bcabbd1fb0..dced24db72f71630c0cb9d7ff4275f740a2b69a4 100644 --- a/lite/backends/arm/math/conv_depthwise_5x5s2.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/arm/math/conv_depthwise.h" #include +#include "lite/backends/arm/math/conv_depthwise.h" namespace paddle { namespace lite { @@ -80,21 +80,21 @@ void conv_depthwise_5x5s2p2_relu_s(const float* din, bool flag_relu, ARMContext* ctx); -void conv_depthwise_5x5s2(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { +void conv_depthwise_5x5s2_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { if (pad == 2) { if (win >= 9) { if (flag_relu) { diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 3deb6bcb5ff716405ce113a26072d6fefb1b2ebd..b2d16d18d2300ea51de8c8e9f25664ffdf4aebc7 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -15,7 +15,9 @@ #pragma once #include #include +#include "lite/backends/arm/math/gemm_s8.h" #include "lite/backends/arm/math/saturate.h" +#include "lite/backends/arm/math/sgemm.h" #include "lite/backends/arm/math/type_trans.h" #include "lite/core/target_wrapper.h" #include "lite/utils/cp_logging.h" @@ -26,6 +28,47 @@ namespace arm { namespace math { #define LITEMAX(a, b) ((a) > (b) ? (a) : (b)) +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +inline void trans_gemm_weights(const Tensor& tin, + Tensor& tout, // NOLINT + int group, + ARMContext* ctx); + +template <> +inline void trans_gemm_weights(const Tensor& tin, + Tensor& tout, // NOLINT + int group, + ARMContext* ctx) { + CHECK_EQ(tin.dims().size(), 4) << "conv weights dims size must = 4"; + int m = tin.dims()[0] / group; + int k = tin.dims().count(1, 4); + int hblock = lite::arm::math::get_hblock(ctx); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; + float* w_trans_ptr = nullptr; + tout.Resize({group_size_round_up * group}); + w_trans_ptr = tout.mutable_data(); + const auto* w_data = tin.data(); + for (int g = 0; g < group; ++g) { + const float* weights_group = w_data + g * m * k; + float* weights_trans_ptr = w_trans_ptr + g * group_size_round_up; + lite::arm::math::prepackA( + weights_trans_ptr, weights_group, 1.f, k, 0, m, 0, k, false, ctx); + } +} + +template <> +inline void trans_gemm_weights(const Tensor& tin, + Tensor& tout, // NOLINT + int group, + ARMContext* ctx) { + CHECK_EQ(tin.dims().size(), 4) << "conv weights dims size must = 4"; + int m = tin.dims()[0] / group; + int k = tin.dims().count(1, 4); + prepackA_int8(&tout, tin, m, k, group, false, ctx); +} inline void fill_packed_biasc4(float* dout, const float* bias, int size) { float32x4_t vb = vld1q_f32(bias); @@ -159,6 +202,391 @@ static bool prepack_input_nxw(const dtype* din, return true; } +inline void transpose_4x4(float32x4_t v0, + float32x4_t v1, + float32x4_t v2, + float32x4_t v3, + float* dout) { +#ifdef __aarch64__ + asm volatile( + "trn1 v0.4s, %[v0].4s, %[v1].4s\n" /* trans q0, q1, a0b0a2b2*/ + "trn2 v1.4s, %[v0].4s, %[v1].4s\n" /* trans q0, q1, a1b1a3b3*/ + "trn1 v2.4s, %[v2].4s, %[v3].4s\n" /* trans q2, q3, c0d0c2d2*/ + "trn2 v3.4s, %[v2].4s, %[v3].4s\n" /* trans q2, q3, c1d1c3d3*/ + "trn1 v4.2d, v0.2d, v2.2d\n" /* trans q0, q2, a0b0c0d0*/ + "trn2 v6.2d, v0.2d, v2.2d\n" /* trans q0, q2, a2b2c2d2*/ + "trn1 v5.2d, v1.2d, v3.2d\n" /* trans q1, q3, a1b1c1d1*/ + "trn2 v7.2d, v1.2d, v3.2d\n" /* trans q1, q3, a3b3c3d3*/ + "stp q4, q5, [%[dout]], #32\n" + "stp q6, q7, [%[dout]]\n" + : [dout] "+r"(dout) + : [v0] "w"(v0), [v1] "w"(v1), [v2] "w"(v2), [v3] "w"(v3) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + "vtrn.32 %q[v0], %q[v1]\n" /* trans q0, q1, a0b0a2b2, a1b1a3b3*/ + "vtrn.32 %q[v2], %q[v3]\n" /* trans q2, q3, c0d0c2d2, c1d1c3d3*/ + "vswp %f[v0], %e[v2]\n" /* trans q0, q2, a0b0c0d0, a2b2c2d2*/ + "vswp %f[v1], %e[v3]\n" /* trans q1, q3, a1b1c1d1, a3b3c3d3*/ + "vst1.32 {%e[v0], %f[v0]}, [%[dout]]!\n" + "vst1.32 {%e[v1], %f[v1]}, [%[dout]]!\n" + "vst1.32 {%e[v2], %f[v2]}, [%[dout]]!\n" + "vst1.32 {%e[v3], %f[v3]}, [%[dout]]\n" + : [dout] "+r"(dout) + : [v0] "w"(v0), [v1] "w"(v1), [v2] "w"(v2), [v3] "w"(v3) + :); +#endif +} + +inline void prepack_input_nxwc4_dw(const float* din, + float* dout, + int cs, + int hs, + int he, + int ws, + int we, + int channel, + int width, + int height, + float* zero_ptr) { + int n = he - hs; + if (n <= 0) { + LOG(FATAL) << "prepack_dw_input, valid height must > zero"; + } + float32x4_t vzero = vdupq_n_f32(0.f); + + int size_w = we - ws; + int w0 = ws < 0 ? 0 : ws; + int w1 = we > width ? width : we; + int valid_w = w1 - w0; + + int mask[4] = {0, 1, 2, 3}; + + int pad_l = ws < 0 ? -ws : 0; + int pad_r = we > width ? we - width : 0; + int cnt_l = pad_l / 4; + int left_remain = pad_l - cnt_l * 4; + + bool flag_ext_l = left_remain > 0; + int left_sl = 4 - left_remain; + uint32x4_t vmask_padl; + bool flag_mask_l = false; + if (flag_ext_l) { + if (valid_w < 3) { + flag_mask_l = true; + vmask_padl = vcltq_s32(vld1q_s32(mask), vdupq_n_s32(valid_w)); + } + valid_w -= left_sl; + valid_w = valid_w > 0 ? valid_w : 0; + } + int cnt_valid = valid_w / 4; + int valid_sl = valid_w - cnt_valid * 4; + bool flag_mask_valid = valid_sl > 0; + uint32x4_t vmask_valid; + if (flag_mask_valid) { + vmask_valid = vcltq_s32(vld1q_s32(mask), vdupq_n_s32(valid_sl)); + pad_r -= 4 - valid_sl; + pad_r = pad_r > 0 ? pad_r : 0; + } + int size_c = width * height; + for (int h = hs; h < he; ++h) { + auto ptr_c0 = din + cs * size_c + h * width; + auto ptr_c1 = ptr_c0 + size_c; + auto ptr_c2 = ptr_c1 + size_c; + auto ptr_c3 = ptr_c2 + size_c; + if (h < 0 || h >= height) { + memset(dout, 0, sizeof(float) * size_w * 4); + dout += size_w * 4; + continue; + } else if (cs + 4 > channel) { + switch (cs + 4 - channel) { + case 3: + ptr_c1 = zero_ptr; + case 2: + ptr_c2 = zero_ptr; + case 1: + ptr_c3 = zero_ptr; + default: + break; + } + } + /// left padding + if (cnt_l > 0) { + memset(dout, 0, sizeof(float) * 16 * cnt_l); + dout += 16 * cnt_l; + } + /// left mask + if (flag_ext_l) { + float32x4_t vc0 = vld1q_f32(ptr_c0); + float32x4_t vc1 = vld1q_f32(ptr_c1); + float32x4_t vc2 = vld1q_f32(ptr_c2); + float32x4_t vc3 = vld1q_f32(ptr_c3); + if (flag_mask_l) { + vc0 = vbslq_f32(vmask_padl, vc0, vzero); + vc1 = vbslq_f32(vmask_padl, vc1, vzero); + vc2 = vbslq_f32(vmask_padl, vc2, vzero); + vc3 = vbslq_f32(vmask_padl, vc3, vzero); + } + switch (left_sl) { + case 1: + vc0 = vextq_f32(vzero, vc0, 1); + vc1 = vextq_f32(vzero, vc1, 1); + vc2 = vextq_f32(vzero, vc2, 1); + vc3 = vextq_f32(vzero, vc3, 1); + break; + case 2: + vc0 = vextq_f32(vzero, vc0, 2); + vc1 = vextq_f32(vzero, vc1, 2); + vc2 = vextq_f32(vzero, vc2, 2); + vc3 = vextq_f32(vzero, vc3, 2); + break; + case 3: + vc0 = vextq_f32(vzero, vc0, 3); + vc1 = vextq_f32(vzero, vc1, 3); + vc2 = vextq_f32(vzero, vc2, 3); + vc3 = vextq_f32(vzero, vc3, 3); + break; + default: + break; + } + transpose_4x4(vc0, vc1, vc2, vc3, dout); + dout += 16; + ptr_c0 += left_sl; + ptr_c1 += left_sl; + ptr_c2 += left_sl; + ptr_c3 += left_sl; + } + /// valid + for (int i = 0; i < cnt_valid; ++i) { + float32x4_t vc0 = vld1q_f32(ptr_c0); + float32x4_t vc1 = vld1q_f32(ptr_c1); + float32x4_t vc2 = vld1q_f32(ptr_c2); + float32x4_t vc3 = vld1q_f32(ptr_c3); + transpose_4x4(vc0, vc1, vc2, vc3, dout); + dout += 16; + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; + } + if (flag_mask_valid) { + float32x4_t vc0 = vld1q_f32(ptr_c0); + float32x4_t vc1 = vld1q_f32(ptr_c1); + float32x4_t vc2 = vld1q_f32(ptr_c2); + float32x4_t vc3 = vld1q_f32(ptr_c3); + vc0 = vbslq_f32(vmask_valid, vc0, vzero); + vc1 = vbslq_f32(vmask_valid, vc1, vzero); + vc2 = vbslq_f32(vmask_valid, vc2, vzero); + vc3 = vbslq_f32(vmask_valid, vc3, vzero); + transpose_4x4(vc0, vc1, vc2, vc3, dout); + dout += 16; + } + /// right padding + if (pad_r > 0) { + memset(dout, 0, sizeof(float) * 4 * pad_r); + dout += 4 * pad_r; + } + } +} + +inline void prepack_input_nxwc8_int8_dw(const int8_t* din, + int8_t* dout, + int cs, + int hs, + int he, + int ws, + int we, + int channel, + int width, + int height) { + int n = he - hs; + if (n <= 0) { + LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero"; + } + int size_w = we - ws; + int w0 = ws < 0 ? 0 : ws; + int w1 = we > width ? width : we; + int valid_w = w1 - w0; + int pad_l = ws < 0 ? -ws : 0; + int pad_r = we > width ? we - width : 0; + int size_c = width * height; + + int valid_cnt = valid_w >> 3; + int remain = valid_w & 7; + + int8_t zero_ptr[size_w * 2]; // NOLINT + memset(zero_ptr, 0, size_w * 2); + + for (int h = hs; h < he; ++h) { + const int8_t* ptr_c0 = din + h * width + cs * size_c; + const int8_t* ptr_c1 = ptr_c0 + size_c; + const int8_t* ptr_c2 = ptr_c1 + size_c; + const int8_t* ptr_c3 = ptr_c2 + size_c; + const int8_t* ptr_c4 = ptr_c3 + size_c; + const int8_t* ptr_c5 = ptr_c4 + size_c; + const int8_t* ptr_c6 = ptr_c5 + size_c; + const int8_t* ptr_c7 = ptr_c6 + size_c; + if (h < 0 || h >= height) { + memset(dout, 0, 8 * size_w * sizeof(int8_t)); + dout += size_w * 8; + continue; + } else if (cs + 8 > channel) { + switch (cs + 8 - channel) { + case 7: + ptr_c1 = zero_ptr; + case 6: + ptr_c2 = zero_ptr; + case 5: + ptr_c3 = zero_ptr; + case 4: + ptr_c4 = zero_ptr; + case 3: + ptr_c5 = zero_ptr; + case 2: + ptr_c6 = zero_ptr; + case 1: + ptr_c7 = zero_ptr; + default: + break; + } + } + if (pad_l) { + memset(dout, 0, pad_l * 8 * sizeof(int8_t)); + dout += pad_l * 8; + } + if (valid_cnt) { + int cnt = valid_cnt; +#ifdef __aarch64__ + asm volatile( + /* main loop */ + "1:\n" + "ldr d0, [%[r0]], #8\n" + "ldr d1, [%[r1]], #8\n" + "ldr d2, [%[r2]], #8\n" + "ldr d3, [%[r3]], #8\n" + "ldr d4, [%[r4]], #8\n" + "ldr d5, [%[r5]], #8\n" + "ldr d6, [%[r6]], #8\n" + "ldr d7, [%[r7]], #8\n" + "trn1 v8.8b, v0.8b, v1.8b\n" + "trn2 v9.8b, v0.8b, v1.8b\n" + "trn1 v10.8b, v2.8b, v3.8b\n" + "trn2 v11.8b, v2.8b, v3.8b\n" + "trn1 v12.8b, v4.8b, v5.8b\n" + "trn2 v13.8b, v4.8b, v5.8b\n" + "trn1 v14.8b, v6.8b, v7.8b\n" + "trn2 v15.8b, v6.8b, v7.8b\n" + "trn1 v0.4h, v8.4h, v10.4h\n" + "trn2 v1.4h, v8.4h, v10.4h\n" + "trn1 v2.4h, v9.4h, v11.4h\n" + "trn2 v3.4h, v9.4h, v11.4h\n" + "trn1 v4.4h, v12.4h, v14.4h\n" + "trn2 v5.4h, v12.4h, v14.4h\n" + "trn1 v6.4h, v13.4h, v15.4h\n" + "trn2 v7.4h, v13.4h, v15.4h\n" + "trn1 v8.2s, v0.2s, v4.2s\n" + "trn1 v9.2s, v2.2s, v6.2s\n" + "trn1 v10.2s, v1.2s, v5.2s\n" + "trn1 v11.2s, v3.2s, v7.2s\n" + "stp d8, d9, [%[ptr_out]], #16\n" + "trn2 v12.2s, v0.2s, v4.2s\n" + "trn2 v13.2s, v2.2s, v6.2s\n" + "stp d10, d11, [%[ptr_out]], #16\n" + "trn2 v14.2s, v1.2s, v5.2s\n" + "trn2 v15.2s, v3.2s, v7.2s\n" + "subs %w[cnt], %w[cnt], #1\n" + "stp d12, d13, [%[ptr_out]], #16\n" + "stp d14, d15, [%[ptr_out]], #16\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(ptr_c0), + [r1] "+r"(ptr_c1), + [r2] "+r"(ptr_c2), + [r3] "+r"(ptr_c3), + [r4] "+r"(ptr_c4), + [r5] "+r"(ptr_c5), + [r6] "+r"(ptr_c6), + [r7] "+r"(ptr_c7), + [ptr_out] "+r"(dout) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile( + /* main loop */ + "1:\n" + "vld1.32 {d0}, [%[r0]]!\n" + "vld1.32 {d1}, [%[r1]]!\n" + "vld1.32 {d2}, [%[r2]]!\n" + "vld1.32 {d3}, [%[r3]]!\n" + "vld1.32 {d4}, [%[r4]]!\n" + "vld1.32 {d5}, [%[r5]]!\n" + "vld1.32 {d6}, [%[r6]]!\n" + "vld1.32 {d7}, [%[r7]]!\n" + "vtrn.8 d0, d1\n" + "vtrn.8 d2, d3\n" + "vtrn.8 d4, d5\n" + "vtrn.8 d6, d7\n" + "vtrn.16 d0, d2\n" + "vtrn.16 d1, d3\n" + "vtrn.16 d4, d6\n" + "vtrn.16 d5, d7\n" + "vtrn.32 d0, d4\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d3, d7\n" + "subs %[cnt], #1\n" + "vst1.32 {d0-d3}, [%[ptr_out]]!\n" + "vst1.32 {d4-d7}, [%[ptr_out]]!\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(ptr_c0), + [r1] "+r"(ptr_c1), + [r2] "+r"(ptr_c2), + [r3] "+r"(ptr_c3), + [r4] "+r"(ptr_c4), + [r5] "+r"(ptr_c5), + [r6] "+r"(ptr_c6), + [r7] "+r"(ptr_c7), + [ptr_out] "+r"(dout) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ + } + for (int i = 0; i < remain; ++i) { + dout[0] = *(ptr_c0++); + dout[1] = *(ptr_c1++); + dout[2] = *(ptr_c2++); + dout[3] = *(ptr_c3++); + dout[4] = *(ptr_c4++); + dout[5] = *(ptr_c5++); + dout[6] = *(ptr_c6++); + dout[7] = *(ptr_c7++); + dout += 8; + } + if (pad_r) { + memset(dout, 0, pad_r * 8 * sizeof(int8_t)); + dout += pad_r * 8; + } + } +} + /*wirte result in outputs * input din: [n, c, h, w], output dout: [n, c, h, w] */ @@ -1195,2881 +1623,1274 @@ inline bool write_to_output_c8_fp32(const float* din, return true; } -/*wirte result in outputs -* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] -*/ -inline bool write_to_output_c4_int32(const int* din, - int* dout, - int ch_n, - int hei_n, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - int* trash_ptr) { - if (ch_n != 4 || hei_n <= 0) { - LOG(ERROR) << "ch_n must be equal 4 and hei_n is more than zero"; - return false; +template +inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT + Dtype*& dout1, // NOLINT + Dtype*& dout2, // NOLINT + Dtype*& dout3, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale, + float32x4_t bias, + bool is_relu); + +#ifdef __aarch64__ +#define NCHWC4_TRANS_INT32 \ + "ldp q0, q1, [%[ptr_din]], #32\n" \ + "ldp q2, q3, [%[ptr_din]], #32\n" \ + "movi v20.4s, #0\n" \ + "1:\n" \ + "trn1 v8.4s, v0.4s, v1.4s\n" \ + "trn2 v9.4s, v0.4s, v1.4s\n" \ + "ldp q0, q1, [%[ptr_din]], #32\n" \ + "trn1 v10.4s, v2.4s, v3.4s\n" \ + "trn2 v11.4s, v2.4s, v3.4s\n" \ + "ldp q2, q3, [%[ptr_din]], #32\n" \ + "trn1 v16.2d, v8.2d, v10.2d\n" \ + "trn2 v17.2d, v8.2d, v10.2d\n" \ + "trn1 v18.2d, v9.2d, v11.2d\n" \ + "trn2 v19.2d, v9.2d, v11.2d\n" /* int32 --> fp32 */ \ + "scvtf v4.4s, v16.4s\n" \ + "scvtf v5.4s, v17.4s\n" \ + "scvtf v6.4s, v18.4s\n" \ + "scvtf v7.4s, v19.4s\n" /* add bias */ \ + "dup v16.4s, %[bias].s[0]\n" \ + "dup v17.4s, %[bias].s[2]\n" \ + "dup v18.4s, %[bias].s[1]\n" \ + "dup v19.4s, %[bias].s[3]\n" /* mul scale */ \ + "fmla v16.4s, v4.4s, %[scale].s[0]\n" \ + "fmla v17.4s, v5.4s, %[scale].s[2]\n" \ + "fmla v18.4s, v6.4s, %[scale].s[1]\n" \ + "fmla v19.4s, v7.4s, %[scale].s[3]\n" /* relu */ \ + "cbz %w[relu], 2f\n" \ + "fmax v16.4s, v16.4s, v20.4s \n" \ + "fmax v17.4s, v17.4s, v20.4s \n" \ + "fmax v18.4s, v18.4s, v20.4s \n" \ + "fmax v19.4s, v19.4s, v20.4s \n" \ + "2:\n" + +#else +#define NCHWC4_TRANS_INT32 \ + "vld1.32 {d4-d7}, [%[ptr_din]]!\n" \ + "vld1.32 {d8-d11}, [%[ptr_din]]!\n" \ + "vmov.u32 q15, #0\n" \ + "1:\n" /* transpose */ \ + "vtrn.32 q2, q3\n" \ + "vtrn.32 q4, q5\n" \ + "vswp.32 d5, d8\n" \ + "vswp.32 d7, d10\n" /* int32-> fp32 */ \ + "vcvt.f32.s32 q6, q2\n" \ + "vcvt.f32.s32 q7, q3\n" \ + "vcvt.f32.s32 q8, q4\n" \ + "vcvt.f32.s32 q9, q5\n" /* add bias */ \ + "vdup.32 q10, %e[bias][0]\n" \ + "vdup.32 q11, %e[bias][1]\n" \ + "vdup.32 q12, %f[bias][0]\n" \ + "vdup.32 q13, %f[bias][1]\n" /* mul scale */ \ + "vmla.f32 q10, q6, %e[scale][0]\n" \ + "vmla.f32 q11, q7, %e[scale][1]\n" \ + "vmla.f32 q12, q8, %f[scale][0]\n" \ + "vmla.f32 q13, q9, %f[scale][1]\n" /* relu */ \ + "cmp %[relu], #0\n" \ + "beq 2f\n" \ + "vmax.f32 q10, q10, q15\n" \ + "vmax.f32 q11, q11, q15\n" \ + "vmax.f32 q12, q12, q15\n" \ + "vmax.f32 q13, q13, q15\n" \ + "2:\n" + +#endif + +template <> +inline void int32_nchwc4_kernel(float*& dout0, // NOLINT + float*& dout1, // NOLINT + float*& dout2, // NOLINT + float*& dout3, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale, + float32x4_t bias, + bool is_relu) { +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_INT32 + "subs %w[cnt], %w[cnt], #1\n" + /* store result */ + "str q16, [%[doutc0r0]], #16\n" + "str q17, [%[doutc2r0]], #16\n" + "str q18, [%[doutc1r0]], #16\n" + "str q19, [%[doutc3r0]], #16\n" + "bne 1b\n" + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); +#else + asm volatile(NCHWC4_TRANS_INT32 + "subs %[cnt], %[cnt], #1\n" + /* store result */ + "vld1.32 {d4-d7}, [%[ptr_din]]!\n" + "vst1.32 {d20-d21}, [%[doutc0r0]]!\n" + "vst1.32 {d22-d23}, [%[doutc1r0]]!\n" + "vld1.32 {d8-d11}, [%[ptr_din]]!\n" + "vst1.32 {d24-d25}, [%[doutc2r0]]!\n" + "vst1.32 {d26-d27}, [%[doutc3r0]]!\n" + "bne 1b\n" + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu) + : "cc", + "memory", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif +} + +template <> +inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT + int8_t*& dout1, // NOLINT + int8_t*& dout2, // NOLINT + int8_t*& dout3, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale, + float32x4_t bias, + bool is_relu) { +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_INT32 + "subs %w[cnt], %w[cnt], #1\n" + /* fp32-int32 */ + "fcvtas v4.4s, v16.4s\n" + "fcvtas v5.4s, v18.4s\n" + "fcvtas v6.4s, v17.4s\n" + "fcvtas v7.4s, v19.4s\n" + /* int32-int16 */ + "sqxtn v8.4h, v4.4s\n" + "sqxtn v9.4h, v5.4s\n" + "sqxtn v10.4h, v6.4s\n" + "sqxtn v11.4h, v7.4s\n" + /* int16-int8 */ + "sqxtn v16.8b, v8.8h\n" + "sqxtn v17.8b, v9.8h\n" + "sqxtn v18.8b, v10.8h\n" + "sqxtn v19.8b, v11.8h\n" + /* store result */ + "str s16, [%[doutc0r0]], #4\n" + "str s17, [%[doutc1r0]], #4\n" + "str s18, [%[doutc2r0]], #4\n" + "str s19, [%[doutc3r0]], #4\n" + "bne 1b\n" + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); +#else + asm volatile(NCHWC4_TRANS_INT32 + /* set 0.5 offset */ + "vmov.f32 q2, #0.5\n" + "vmov.f32 q14, #-0.5\n" + "vand.i32 q3, q2, q2 @ set offset, 0.5\n" + "vand.i32 q4, q2, q2 @ set offset, 0.5\n" + "vand.i32 q5, q2, q2 @ set offset, 0.5\n" + "vcgt.f32 q6, q10, q15 @ get mask > 0, in0\n" + "vcgt.f32 q7, q11, q15 @ get mask > 0, in1\n" + "vcgt.f32 q8, q12, q15 @ get mask > 0, in2\n" + "vcgt.f32 q9, q13, q15 @ get mask > 0, in3\n" + /* set 0.5 offset */ + "vbif.f32 q2, q14, q6 @ get right offset\n" + "vbif.f32 q3, q14, q7 @ get right offset\n" + "vbif.f32 q4, q14, q8 @ get right offset\n" + "vbif.f32 q5, q14, q9 @ get right offset\n" + /* add offset */ + "vadd.f32 q10, q2, q10\n" + "vadd.f32 q11, q3, q11\n" + "vadd.f32 q12, q4, q12\n" + "vadd.f32 q13, q5, q13\n" + /* fp32 to int32 */ + "vcvt.s32.f32 q6, q10 @ cvt to int32\n" + "vcvt.s32.f32 q7, q11 @ cvt to int32\n" + "vcvt.s32.f32 q8, q12 @ cvt to int32\n" + "vcvt.s32.f32 q9, q13 @ cvt to int32\n" + /* int32 to int16 */ + "vqmovn.s32 d20, q6 @ cnt to int16\n" + "vqmovn.s32 d22, q7 @ cnt to int16\n" + "vqmovn.s32 d24, q8 @ cnt to int16\n" + "vqmovn.s32 d26, q9 @ cnt to int16\n" + /* int16 to int8 */ + "vqmovn.s16 d12, q10 @ cnt to int8\n" + "vqmovn.s16 d13, q11 @ cnt to int8\n" + "vqmovn.s16 d14, q12 @ cnt to int8\n" + "vqmovn.s16 d15, q13 @ cnt to int8\n" + "subs %[cnt], %[cnt], #1\n" + /* store */ + "vld1.32 {d4-d7}, [%[ptr_din]]!\n" + "vst1.32 {d12[0]}, [%[doutc0r0]]!\n" + "vst1.32 {d13[0]}, [%[doutc1r0]]!\n" + "vld1.32 {d8-d11}, [%[ptr_din]]!\n" + "vst1.32 {d14[0]}, [%[doutc2r0]]!\n" + "vst1.32 {d15[0]}, [%[doutc3r0]]!\n" + "bne 1b @ jump to main loop\n" + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu) + : "cc", + "memory", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif +} + +template <> +inline void int32_nchwc4_kernel(int32_t*& dout0, // NOLINT + int32_t*& dout1, // NOLINT + int32_t*& dout2, // NOLINT + int32_t*& dout3, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale, + float32x4_t bias, + bool is_relu) { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "cbz %w[relu], 2f\n" + "smax v16.4s, v16.4s, v20.4s \n" /* relu */ + "smax v17.4s, v17.4s, v20.4s \n" /* relu */ + "smax v18.4s, v18.4s, v20.4s \n" /* relu */ + "smax v19.4s, v19.4s, v20.4s \n" /* relu */ + "2:\n" + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q0, q1 @ trans q0, q1 \n" + "vtrn.32 q2, q3 @ trans q2, q3 \n" + "vswp.32 d1, d4 @ swap d1, d4 \n" + "vswp.32 d3, d6 @ swap d3, d6 \n" + "cmp %[relu], #0\n" + "bne 2f\n" + "vmax.s32 q0, q0, q15 @ relu\n" + "vmax.s32 q1, q1, q15 @ relu\n" + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + "2:\n" + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "bne 1b @ jump to main loop\n" + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [relu] "r"(is_relu) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15"); +#endif +} + +template +inline Dtype cvt_kernel(int din, float scale, float bias, bool flag_relu); + +template <> +inline float cvt_kernel(int din, float scale, float bias, bool flag_relu) { + if (flag_relu) { + return LITEMAX(din * scale + bias, 0); } - int size_c_out = width * height; + return din * scale + bias; +} + +template <> +inline int8_t cvt_kernel(int din, float scale, float bias, bool flag_relu) { + if (flag_relu) { + return saturate_cast(round(LITEMAX(din * scale + bias, 0))); + } + return saturate_cast(round(din * scale + bias)); +} - int* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - int* doutc1r0 = doutc0r0 + size_c_out; - int* doutc2r0 = doutc1r0 + size_c_out; - int* doutc3r0 = doutc2r0 + size_c_out; +template <> +inline int32_t cvt_kernel(int din, float scale, float bias, bool flag_relu) { + if (flag_relu) { + return LITEMAX(din, 0); + } + return din; +} - const int* ptr_din = din; +template +inline void write_int32_nchwc4_to_nchw(const int* din, + Dtype* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* bias, + bool flag_bias, + Dtype* trash_ptr, + const float* scale) { + int size_c_out = width * height; + + Dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + Dtype* doutc1r0 = doutc0r0 + size_c_out; + Dtype* doutc2r0 = doutc1r0 + size_c_out; + Dtype* doutc3r0 = doutc2r0 + size_c_out; int size_h = (he > height ? height : he) - hs; // size_h == hei_n int valid_w = we - ws; int cnt = valid_w / 4; + float32x4_t w_scale = vld1q_f32(scale); + float32x4_t w_bias = flag_bias ? vld1q_f32(bias) : vdupq_n_f32(0.f); + if (we > width) { cnt--; } - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - int* doutc1_ptr = doutc1r0 + size_w; - int* doutc2_ptr = doutc2r0 + size_w; - int* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + Dtype* doutc0_ptr = doutc0r0 + size_w; + Dtype* doutc1_ptr = doutc1r0 + size_w; + Dtype* doutc2_ptr = doutc2r0 + size_w; + Dtype* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; + } + int index = i * valid_w * 4; + const int* din_hei_ptr = din + index; + if (cnt > 0) { + int32_nchwc4_kernel(doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + din_hei_ptr, + cnt, + w_scale, + w_bias, + flag_relu); + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = din + index + offset; + int j = we - 4; + for (; j < width; ++j) { + *(doutc0_ptr++) = + cvt_kernel(din_hei_ptr[0], scale[0], bias[0], flag_relu); + *(doutc1_ptr++) = + cvt_kernel(din_hei_ptr[1], scale[1], bias[1], flag_relu); + *(doutc2_ptr++) = + cvt_kernel(din_hei_ptr[2], scale[2], bias[2], flag_relu); + *(doutc3_ptr++) = + cvt_kernel(din_hei_ptr[3], scale[3], bias[3], flag_relu); + din_hei_ptr += 4; + } + } + } +} + +template +inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT + Dtype*& dout1, // NOLINT + Dtype*& dout2, // NOLINT + Dtype*& dout3, // NOLINT + Dtype*& dout4, // NOLINT + Dtype*& dout5, // NOLINT + Dtype*& dout6, // NOLINT + Dtype*& dout7, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale0, + float32x4_t scale1, + float32x4_t bias0, + float32x4_t bias1, + bool is_relu); + +// clang-format off #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "smax v16.4s, v16.4s, v20.4s \n" /* relu */ - "smax v17.4s, v17.4s, v20.4s \n" /* relu */ - "smax v18.4s, v18.4s, v20.4s \n" /* relu */ - "smax v19.4s, v19.4s, v20.4s \n" /* relu */ - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ +#define INT32_NCHWC8_TO_NCHW_FP32 \ + "ldp q0, q1, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \ + "ldp q2, q3, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \ + "ldp q4, q5, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \ + "ldp q6, q7, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \ + "movi v31.4s, #0\n" /* main loop*/ \ + "1:\n" \ + "trn1 v8.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \ + "trn2 v9.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \ + "trn1 v10.4s, v1.4s, v3.4s\n" /* trans q2, q3*/ \ + "trn2 v11.4s, v1.4s, v3.4s\n" /* trans q2, q3*/ \ + "ldp q0, q1, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \ + "trn1 v12.4s, v4.4s, v6.4s\n" /* trans q0, q1*/ \ + "trn2 v13.4s, v4.4s, v6.4s\n" /* trans q0, q1*/ \ + "trn1 v14.4s, v5.4s, v7.4s\n" /* trans q2, q3*/ \ + "trn2 v15.4s, v5.4s, v7.4s\n" /* trans q2, q3*/ \ + "ldp q2, q3, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \ + "trn1 v16.2d, v8.2d, v12.2d\n" /* trans q8, q10 00 01 02 03*/ \ + "trn2 v17.2d, v8.2d, v12.2d\n" /* trans q8, q10 20 21 22 23*/ \ + "trn1 v18.2d, v9.2d, v13.2d\n" /* trans q9, q11 10 11 12 13*/ \ + "trn2 v19.2d, v9.2d, v13.2d\n" /* trans q9, q11 30 31 32 33*/ \ + "ldp q4, q5, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \ + "trn1 v8.2d, v10.2d, v14.2d\n" /* trans q8, q10 40 41 42 43*/ \ + "trn2 v9.2d, v10.2d, v14.2d\n" /* trans q8, q10 60 61 62 63*/ \ + "trn1 v12.2d, v11.2d, v15.2d\n" /* trans q9, q11 50 51 52 53*/ \ + "trn2 v13.2d, v11.2d, v15.2d\n" /* trans q9, q11 70 71 72 73*/ \ + "ldp q6, q7, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \ + /* int32->fp32 */ \ + "scvtf v10.4s, v16.4s\n" \ + "scvtf v11.4s, v17.4s\n" \ + "scvtf v14.4s, v18.4s\n" \ + "scvtf v15.4s, v19.4s\n" \ + /* add bias */ \ + "dup v16.4s, %[bias0].s[0]\n" \ + "dup v17.4s, %[bias0].s[2]\n" \ + "dup v18.4s, %[bias0].s[1]\n" \ + "dup v19.4s, %[bias0].s[3]\n" \ + /* mul scale */ \ + "fmla v16.4s, v10.4s, %[scale0].s[0]\n" \ + "fmla v17.4s, v11.4s, %[scale0].s[2]\n" \ + "fmla v18.4s, v14.4s, %[scale0].s[1]\n" \ + "fmla v19.4s, v15.4s, %[scale0].s[3]\n" \ + "scvtf v10.4s, v8.4s\n" \ + "scvtf v11.4s, v9.4s\n" \ + "scvtf v14.4s, v12.4s\n" \ + "scvtf v15.4s, v13.4s\n" \ + /* add bias */ \ + "dup v8.4s, %[bias1].s[0]\n" \ + "dup v9.4s, %[bias1].s[2]\n" \ + "dup v12.4s, %[bias1].s[1]\n" \ + "dup v13.4s, %[bias1].s[3]\n" \ + /* mul scale */ \ + "fmla v8.4s, v10.4s, %[scale1].s[0]\n" \ + "fmla v9.4s, v11.4s, %[scale1].s[2]\n" \ + "fmla v12.4s, v14.4s, %[scale1].s[1]\n" \ + "fmla v13.4s, v15.4s, %[scale1].s[3]\n" \ + /* relu */ \ + "cbz %w[relu], 2f\n" \ + "fmax v16.4s, v16.4s, v31.4s\n" /*relu*/ \ + "fmax v17.4s, v17.4s, v31.4s\n" /*relu*/ \ + "fmax v18.4s, v18.4s, v31.4s\n" /*relu*/ \ + "fmax v19.4s, v19.4s, v31.4s\n" /*relu*/ \ + "fmax v8.4s, v8.4s, v31.4s\n" /*relu*/ \ + "fmax v9.4s, v9.4s, v31.4s\n" /*relu*/ \ + "fmax v12.4s, v12.4s, v31.4s\n" /*relu*/ \ + "fmax v13.4s, v13.4s, v31.4s\n" /*relu*/ \ + "2:\n" - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ +#else +#define INT32_NCHWC8_TO_NCHW_FP32 \ + "1: @ main loop\n" \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" \ + /* int32-> fp32 */ \ + "vcvt.f32.s32 q8, q0\n" \ + "vcvt.f32.s32 q9, q1\n" \ + "vcvt.f32.s32 q10, q2\n" \ + "vcvt.f32.s32 q11, q3\n" \ + "vand.32 q0, %q[bias0], %q[bias0]\n" \ + "vand.32 q1, %q[bias1], %q[bias1]\n" \ + "vand.32 q2, %q[bias0], %q[bias0]\n" \ + "vand.32 q3, %q[bias1], %q[bias1]\n" \ + /* mul scale */ \ + "vmla.f32 q0, q8, %q[scale0]\n" \ + "vmla.f32 q1, q9, %q[scale1]\n" \ + "vmla.f32 q2, q10, %q[scale0]\n" \ + "vmla.f32 q3, q11, %q[scale1]\n" \ + /* int32-> fp32 */ \ + "vcvt.f32.s32 q8, q4\n" \ + "vcvt.f32.s32 q9, q5\n" \ + "vcvt.f32.s32 q10, q6\n" \ + "vcvt.f32.s32 q11, q7\n" \ + "vand.32 q4, %q[bias0], %q[bias0]\n" \ + "vand.32 q5, %q[bias1], %q[bias1]\n" \ + "vand.32 q6, %q[bias0], %q[bias0]\n" \ + "vand.32 q7, %q[bias1], %q[bias1]\n" \ + /* mul scale */ \ + "vmla.f32 q4, q8, %q[scale0]\n" \ + "vmla.f32 q5, q9, %q[scale1]\n" \ + "vmla.f32 q6, q10, %q[scale0]\n" \ + "vmla.f32 q7, q11, %q[scale1]\n" \ + /* transpose */ \ + "vtrn.32 q0, q2\n" \ + "vtrn.32 q1, q3\n" \ + "vtrn.32 q4, q6\n" \ + "vtrn.32 q5, q7\n" \ + "vswp d1, d8\n" /* q0: a0-a3, q4: c0-c3 */ \ + "vswp d5, d12\n" /* q2: b0-b3, q6: d0-d3 */ \ + "vswp d3, d10\n" /* q1: e0-e3, q5: g0-g3 */ \ + "vswp d7, d14\n" /* q3: f0-f3, q7: h0-h3 */ \ + /* relu */ \ + "vmov.i32 q8, #0\n" \ + "cmp %[relu], #0\n" \ + "beq 2f\n" \ + "vmax.f32 q0, q0, q8\n" /*relu*/ \ + "vmax.f32 q2, q2, q8\n" /*relu*/ \ + "vmax.f32 q4, q4, q8\n" /*relu*/ \ + "vmax.f32 q6, q6, q8\n" /*relu*/ \ + "vmax.f32 q1, q1, q8\n" /*relu*/ \ + "vmax.f32 q3, q3, q8\n" /*relu*/ \ + "vmax.f32 q5, q5, q8\n" /*relu*/ \ + "vmax.f32 q7, q7, q8\n" /*relu*/ \ + "2:\n" - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); +#endif +// clang-format on + +template <> +inline void int32_nchwc8_kernel(float*& dout0, // NOLINT + float*& dout1, // NOLINT + float*& dout2, // NOLINT + float*& dout3, // NOLINT + float*& dout4, // NOLINT + float*& dout5, // NOLINT + float*& dout6, // NOLINT + float*& dout7, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale0, + float32x4_t scale1, + float32x4_t bias0, + float32x4_t bias1, + bool is_relu) { +#ifdef __aarch64__ + asm volatile(INT32_NCHWC8_TO_NCHW_FP32 + "subs %w[cnt], %w[cnt], #1\n" /* loop count -1*/ + "str q16, [%[doutc0r0]], #16\n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16\n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16\n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16\n" /* store c3r0*/ + "str q8, [%[doutc4r0]], #16\n" /* store c4r0*/ + "str q9, [%[doutc6r0]], #16\n" /* store c6r0*/ + "str q12, [%[doutc5r0]], #16\n" /* store c5r0*/ + "str q13, [%[doutc7r0]], #16\n" /* store c7r0*/ + "bne 1b\n" /* jump to main loop*/ + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [doutc4r0] "+r"(dout4), + [doutc5r0] "+r"(dout5), + [doutc6r0] "+r"(dout6), + [doutc7r0] "+r"(dout7), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale0] "w"(scale0), + [scale1] "w"(scale1), + [bias0] "w"(bias0), + [bias1] "w"(bias1), + [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans q0, q1 \n" - "vtrn.32 q2, q3 @ trans q2, q3 \n" - "vswp.32 d1, d4 @ swap d1, d4 \n" - "vswp.32 d3, d6 @ swap d3, d6 \n" - - "vmax.s32 q0, q0, q15 @ relu\n" - "vmax.s32 q1, q1, q15 @ relu\n" - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" + asm volatile(INT32_NCHWC8_TO_NCHW_FP32 + "subs %[cnt], #1\n" /* loop count -1*/ + "vst1.32 {d0-d1}, [%[doutc0r0]]!\n" /* store c0r0*/ + "vst1.32 {d4-d5}, [%[doutc1r0]]!\n" /* store c0r0*/ + "vst1.32 {d8-d9}, [%[doutc2r0]]!\n" /* store c0r0*/ + "vst1.32 {d12-d13}, [%[doutc3r0]]!\n" /* store c0r0*/ + "vst1.32 {d2-d3}, [%[doutc4r0]]!\n" /* store c0r0*/ + "vst1.32 {d6-d7}, [%[doutc5r0]]!\n" /* store c0r0*/ + "vst1.32 {d10-d11}, [%[doutc6r0]]!\n" /* store c0r0*/ + "vst1.32 {d14-d15}, [%[doutc7r0]]!\n" /* store c0r0*/ + "bne 1b\n" /* jump to main loop*/ + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [doutc4r0] "+r"(dout4), + [doutc5r0] "+r"(dout5), + [doutc6r0] "+r"(dout6), + [doutc7r0] "+r"(dout7), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale0] "w"(scale0), + [scale1] "w"(scale1), + [bias0] "w"(bias0), + [bias1] "w"(bias1), + [relu] "r"(is_relu) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +#endif +} - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" +template <> +inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT + int8_t*& dout1, // NOLINT + int8_t*& dout2, // NOLINT + int8_t*& dout3, // NOLINT + int8_t*& dout4, // NOLINT + int8_t*& dout5, // NOLINT + int8_t*& dout6, // NOLINT + int8_t*& dout7, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale0, + float32x4_t scale1, + float32x4_t bias0, + float32x4_t bias1, + bool is_relu) { +#ifdef __aarch64__ + asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* fp32-int32 */ + "fcvtas v10.4s, v16.4s\n" + "fcvtas v11.4s, v17.4s\n" + "fcvtas v14.4s, v18.4s\n" + "fcvtas v15.4s, v19.4s\n" + "fcvtas v20.4s, v8.4s\n" + "fcvtas v21.4s, v9.4s\n" + "fcvtas v22.4s, v12.4s\n" + "fcvtas v23.4s, v13.4s\n" + /* int32-int16 */ + "sqxtn v16.4h, v10.4s\n" + "sqxtn v17.4h, v11.4s\n" + "sqxtn v18.4h, v14.4s\n" + "sqxtn v19.4h, v15.4s\n" + "sqxtn v8.4h, v20.4s\n" + "sqxtn v9.4h, v21.4s\n" + "sqxtn v12.4h, v22.4s\n" + "sqxtn v13.4h, v23.4s\n" + /* int16-int8 */ + "sqxtn v10.8b, v16.8h\n" + "sqxtn v11.8b, v17.8h\n" + "sqxtn v14.8b, v18.8h\n" + "sqxtn v15.8b, v19.8h\n" + "sqxtn v20.8b, v8.8h\n" + "sqxtn v21.8b, v9.8h\n" + "sqxtn v22.8b, v12.8h\n" + "sqxtn v23.8b, v13.8h\n" + "str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/ + "str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/ + "str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/ + "str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/ + "str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/ + "str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/ + "str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/ + "bne 1b \n" /* jump to main loop*/ + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [doutc4r0] "+r"(dout4), + [doutc5r0] "+r"(dout5), + [doutc6r0] "+r"(dout6), + [doutc7r0] "+r"(dout7), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale0] "w"(scale0), + [scale1] "w"(scale1), + [bias0] "w"(bias0), + [bias1] "w"(bias1), + [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v31"); +#else + asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* set +-0.5 offset */ + "vmov.f32 q10, #-0.5\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q0, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q0, q0, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q2, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q2, q2, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q4, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q4, q4, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q6, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q6, q6, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q1, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q1, q1, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q3, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q3, q3, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q5, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q5, q5, q9\n" + "vmov.f32 q9, #0.5\n" + "vcgt.f32 q11, q7, q8 @ get mask > 0, in0\n" + "vbif.f32 q9, q10, q11 @ get right offset\n" + "vadd.f32 q7, q7, q9\n" + /* fp32 to int32 */ + "vcvt.s32.f32 q8, q0 @ cvt to int32\n" + "vcvt.s32.f32 q9, q2 @ cvt to int32\n" + "vcvt.s32.f32 q10, q4 @ cvt to int32\n" + "vcvt.s32.f32 q11, q6 @ cvt to int32\n" + /* int32 to int16 */ + "vqmovn.s32 d0, q8 @ cnt to int16\n" + "vqmovn.s32 d4, q9 @ cnt to int16\n" + "vqmovn.s32 d8, q10 @ cnt to int16\n" + "vqmovn.s32 d12, q11 @ cnt to int16\n" + /* fp32 to int32 */ + "vcvt.s32.f32 q8, q1 @ cvt to int32\n" + "vcvt.s32.f32 q9, q3 @ cvt to int32\n" + "vcvt.s32.f32 q10, q5 @ cvt to int32\n" + "vcvt.s32.f32 q11, q7 @ cvt to int32\n" + /* int32 to int16 */ + "vqmovn.s32 d2, q8 @ cnt to int16\n" + "vqmovn.s32 d6, q9 @ cnt to int16\n" + "vqmovn.s32 d10, q10 @ cnt to int16\n" + "vqmovn.s32 d14, q11 @ cnt to int16\n" + /* int16 to int8 */ + "vqmovn.s16 d16, q0 @ cnt to int8\n" + "vqmovn.s16 d17, q2 @ cnt to int8\n" + "vqmovn.s16 d18, q4 @ cnt to int8\n" + "vqmovn.s16 d19, q6 @ cnt to int8\n" + "vst1.32 {d16[0]}, [%[doutc0r0]]!\n" + "vqmovn.s16 d20, q1 @ cnt to int8\n" + "vst1.32 {d17[0]}, [%[doutc1r0]]!\n" + "vqmovn.s16 d21, q3 @ cnt to int8\n" + "vst1.32 {d18[0]}, [%[doutc2r0]]!\n" + "vqmovn.s16 d22, q5 @ cnt to int8\n" + "vst1.32 {d19[0]}, [%[doutc3r0]]!\n" + "vqmovn.s16 d23, q7 @ cnt to int8\n" + "subs %[cnt], #1\n" + "vst1.32 {d20[0]}, [%[doutc4r0]]!\n" + "vst1.32 {d21[0]}, [%[doutc5r0]]!\n" + "vst1.32 {d22[0]}, [%[doutc6r0]]!\n" + "vst1.32 {d23[0]}, [%[doutc7r0]]!\n" + "bne 1b\n" /* jump to main loop*/ + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [doutc4r0] "+r"(dout4), + [doutc5r0] "+r"(dout5), + [doutc6r0] "+r"(dout6), + [doutc7r0] "+r"(dout7), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [scale0] "w"(scale0), + [scale1] "w"(scale1), + [bias0] "w"(bias0), + [bias1] "w"(bias1), + [relu] "r"(is_relu) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +#endif +} - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" +template <> +inline void int32_nchwc8_kernel(int32_t*& dout0, // NOLINT + int32_t*& dout1, // NOLINT + int32_t*& dout2, // NOLINT + int32_t*& dout3, // NOLINT + int32_t*& dout4, // NOLINT + int32_t*& dout5, // NOLINT + int32_t*& dout6, // NOLINT + int32_t*& dout7, // NOLINT + const int32_t*& din, // NOLINT + int cnt, + float32x4_t scale0, + float32x4_t scale1, + float32x4_t bias0, + float32x4_t bias1, + bool is_relu) { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "cbz %w[relu], 2f\n" + "smax v16.4s, v16.4s, v20.4s \n" /*relu*/ + "smax v17.4s, v17.4s, v20.4s \n" /*relu*/ + "smax v18.4s, v18.4s, v20.4s \n" /*relu*/ + "smax v19.4s, v19.4s, v20.4s \n" /*relu*/ + "smax v8.4s, v8.4s, v20.4s \n" /*relu*/ + "smax v9.4s, v9.4s, v20.4s \n" /*relu*/ + "smax v12.4s, v12.4s, v20.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v20.4s \n" /*relu*/ + "2:\n" + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + "bne 1b \n" /* jump to main loop*/ + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [doutc4r0] "+r"(dout4), + [doutc5r0] "+r"(dout5), + [doutc6r0] "+r"(dout6), + [doutc7r0] "+r"(dout7), + [ptr_din] "+r"(din), + [cnt] "+r"(cnt) + : [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "vmov.s32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + "cmp %[relu], #0\n" + "bne 2f\n" + "vmax.s32 q0, q0, q15 @ relu\n" + "vmax.s32 q1, q1, q15 @ relu\n" + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + "vmax.s32 q4, q4, q15 @ relu\n" + "vmax.s32 q5, q5, q15 @ relu\n" + "vmax.s32 q6, q6, q15 @ relu\n" + "vmax.s32 q7, q7, q15 @ relu\n" + "2:\n" + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "bne 1b @ jump to main loop\n" + : [doutc0r0] "+r"(dout0), + [doutc1r0] "+r"(dout1), + [doutc2r0] "+r"(dout2), + [doutc3r0] "+r"(dout3), + [doutc4r0] "+r"(dout4), + [doutc5r0] "+r"(dout5), + [doutc6r0] "+r"(dout6), + [doutc7r0] "+r"(dout7), + [ptr_din] "+r"(din) + : [cnt] "r"(cnt), [relu] "r"(is_relu) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15"); +#endif +} - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" +/*wirte result in outputs +* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] +*/ +template +inline void write_int32_nchwc8_to_nchw(const int* din, + Dtype* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* bias, + bool flag_bias, + Dtype* trash_ptr, + const float* scale) { + int size_c_out = width * height; - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4", "q15"); -#endif - } - if (we > width) { - int offset = 16 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0); - din_hei_ptr += 4; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - int* doutc1_ptr = doutc1r0 + size_w; - int* doutc2_ptr = doutc2r0 + size_w; - int* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans q0, q1\n" - "vtrn.32 q2, q3 @ trans q2, q3\n" - "vswp.32 d1, d4 @ swap d1, d4 \n" - "vswp.32 d3, d6 @ swap d3, d6 \n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4", "q15"); -#endif - } - if (we > width) { - int offset = 16 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = din_hei_ptr[0]; - *(doutc1_ptr++) = din_hei_ptr[1]; - *(doutc2_ptr++) = din_hei_ptr[2]; - *(doutc3_ptr++) = din_hei_ptr[3]; - din_hei_ptr += 4; - } - } - } - } - return true; -} - -/*wirte result in outputs --int8, fp32 -* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] -*/ -template -inline bool write_to_output_c4_int32_1(const int* din, - dtype* dout, - int ch_n, - int hei_n, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - dtype* trash_ptr, - const float* scale, - PrecisionType out_dtype) { - if (ch_n != 4 || hei_n <= 0) { - LOG(ERROR) << "ch_n must be equal 4 and hei_n is more than zero"; - return false; - } - int size_c_out = width * height; - - dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - dtype* doutc1r0 = doutc0r0 + size_c_out; - dtype* doutc2r0 = doutc1r0 + size_c_out; - dtype* doutc3r0 = doutc2r0 + size_c_out; - - const int* ptr_din = din; - - int size_h = (he > height ? height : he) - hs; // size_h == hei_n - - int valid_w = we - ws; - int cnt = valid_w / 4; - - float32x4_t w_scale = vld1q_f32(scale); - // float32x4_t vzero = vdupq_n_f32(0.f); - - if (we > width) { - cnt--; - } - if (out_dtype == PRECISION(kFloat)) { - // int32_to_fp32 - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "smax v16.4s, v16.4s, v20.4s \n" /* relu */ - "smax v17.4s, v17.4s, v20.4s \n" /* relu */ - "smax v18.4s, v18.4s, v20.4s \n" /* relu */ - "smax v19.4s, v19.4s, v20.4s \n" /* relu */ - // int32 --> fp32 - "scvtf v4.4s, v16.4s \n" - "scvtf v5.4s, v17.4s \n" - "scvtf v6.4s, v18.4s \n" - "scvtf v7.4s, v19.4s \n" - // mul - "fmul v16.4s, v4.4s, %[scale].s[0] \n" - "fmul v17.4s, v5.4s, %[scale].s[2] \n" - "fmul v18.4s, v6.4s, %[scale].s[1] \n" - "fmul v19.4s, v7.4s, %[scale].s[3] \n" - // res - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale] "w"(w_scale) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q2, q3 @ trans q0, q1 \n" - "vtrn.32 q4, q5 @ trans q2, q3 \n" - "vswp.32 d5, d8 @ swap d1, d4 \n" - "vswp.32 d7, d10 @ swap d3, d6 \n" - - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" - "vmax.s32 q4, q4, q15 @ relu\n" - "vmax.s32 q5, q5, q15 @ relu\n" - - // int32-> fp32 - "vcvt.f32.s32 q6, q2 \n" - "vcvt.f32.s32 q7, q3 \n" - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - - // mul - "vmul.f32 q2, q6, %e[scale][0] \n" - "vmul.f32 q3, q7, %e[scale][1] \n" - "vmul.f32 q4, q8, %f[scale][0] \n" - "vmul.f32 q5, q9, %f[scale][1] \n" - - "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d10-d11}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : [scale] "w"(w_scale) - : "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - } - if (we > width) { - int offset = 16 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int j = we - 4; - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0] * scale[0], 0); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1] * scale[1], 0); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2] * scale[2], 0); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3] * scale[3], 0); - din_hei_ptr += 4; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - // int32 --> fp32 - "scvtf v4.4s, v16.4s \n" - "scvtf v5.4s, v17.4s \n" - "scvtf v6.4s, v18.4s \n" - "scvtf v7.4s, v19.4s \n" - // mul - "fmul v16.4s, v4.4s, %[scale].s[0] \n" - "fmul v17.4s, v5.4s, %[scale].s[2] \n" - "fmul v18.4s, v6.4s, %[scale].s[1] \n" - "fmul v19.4s, v7.4s, %[scale].s[3] \n" - // res - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale] "w"(w_scale) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q2, q3 @ trans q0, q1 \n" - "vtrn.32 q4, q5 @ trans q2, q3 \n" - "vswp.32 d5, d8 @ swap d1, d4 \n" - "vswp.32 d7, d10 @ swap d3, d6 \n" - - // int32-> fp32 - "vcvt.f32.s32 q6, q2 \n" - "vcvt.f32.s32 q7, q3 \n" - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - - // mul - "vmul.f32 q2, q6, %e[scale][0] \n" - "vmul.f32 q3, q7, %e[scale][1] \n" - "vmul.f32 q4, q8, %f[scale][0] \n" - "vmul.f32 q5, q9, %f[scale][1] \n" - - "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d10-d11}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : [scale] "w"(w_scale) - : "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - } - if (we > width) { - int offset = 16 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int j = we - 4; - for (; j < width; ++j) { - *(doutc0_ptr++) = din_hei_ptr[0] * scale[0]; - *(doutc1_ptr++) = din_hei_ptr[1] * scale[1]; - *(doutc2_ptr++) = din_hei_ptr[2] * scale[2]; - *(doutc3_ptr++) = din_hei_ptr[3] * scale[3]; - din_hei_ptr += 4; - } - } - } - } - - } else if (out_dtype == PRECISION(kInt8)) { - // int32_to_int8 - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "smax v16.4s, v16.4s, v20.4s \n" /* relu */ - "smax v17.4s, v17.4s, v20.4s \n" /* relu */ - "smax v18.4s, v18.4s, v20.4s \n" /* relu */ - "smax v19.4s, v19.4s, v20.4s \n" /* relu */ - // int32 --> fp32 - "scvtf v4.4s, v16.4s \n" - "scvtf v5.4s, v17.4s \n" - "scvtf v6.4s, v18.4s \n" - "scvtf v7.4s, v19.4s \n" - - // mul - "fmul v16.4s, v4.4s, %[scale].s[0] \n" - "fmul v17.4s, v5.4s, %[scale].s[2] \n" - "fmul v18.4s, v6.4s, %[scale].s[1] \n" - "fmul v19.4s, v7.4s, %[scale].s[3] \n" - - // fp32-int32 - "fcvtas v4.4s, v16.4s \n" - "fcvtas v5.4s, v17.4s \n" - "fcvtas v6.4s, v18.4s \n" - "fcvtas v7.4s, v19.4s \n" - - // int32-int16 - "sqxtn v8.4h, v4.4s \n" - "sqxtn v9.4h, v5.4s \n" - "sqxtn v10.4h, v6.4s \n" - "sqxtn v11.4h, v7.4s \n" - - "sqxtn v16.8b, v8.8h \n" - "sqxtn v17.8b, v9.8h \n" - "sqxtn v18.8b, v10.8h \n" - "sqxtn v19.8b, v11.8h \n" - // res - "str s16, [%[doutc0r0]], #4 \n" - "str s17, [%[doutc2r0]], #4 \n" - "str s18, [%[doutc1r0]], #4 \n" - "str s19, [%[doutc3r0]], #4 \n" - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale] "w"(w_scale) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q2, q3 @ trans q0, q1 \n" - "vtrn.32 q4, q5 @ trans q2, q3 \n" - "vswp.32 d5, d8 @ swap d1, d4 \n" - "vswp.32 d7, d10 @ swap d3, d6 \n" - - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" - "vmax.s32 q4, q4, q15 @ relu\n" - "vmax.s32 q5, q5, q15 @ relu\n" - - // int32-> fp32 - "vcvt.f32.s32 q6, q2 \n" - "vcvt.f32.s32 q7, q3 \n" - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - - "vmov.f32 q2, #0.5 \n" - - // "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" - "vand.i32 q3, q2, q2 @ set offset, 0.5\n" - "vand.i32 q4, q2, q2 @ set offset, 0.5\n" - "vand.i32 q5, q2, q2 @ set offset, 0.5\n" - - "vcgt.f32 q10, q6, q15 @ get mask > 0, in0\n" - "vcgt.f32 q11, q7, q15 @ get mask > 0, in1\n" - "vcgt.f32 q12, q8, q15 @ get mask > 0, in2\n" - "vcgt.f32 q13, q9, q15 @ get mask > 0, in3\n" - - "vmov.f32 q15, #-0.5 \n" - - "vbif.f32 q2, q15, q10 @ get right offset\n" - "vbif.f32 q3, q15, q11 @ get right offset\n" - "vbif.f32 q4, q15, q12 @ get right offset\n" - "vbif.f32 q5, q15, q13 @ get right offset\n" - - "vmla.f32 q2, q6, %e[scale][0] @ mul scale\n" - "vmla.f32 q3, q7, %e[scale][1] @ mul scale\n" - "vmla.f32 q4, q8, %f[scale][0] @ mul scale\n" - "vmla.f32 q5, q9, %f[scale][1] @ mul scale\n" - - "vcvt.s32.f32 q6, q2 @ cvt to int32\n" - "vcvt.s32.f32 q7, q3 @ cvt to int32\n" - "vcvt.s32.f32 q8, q4 @ cvt to int32\n" - "vcvt.s32.f32 q9, q5 @ cvt to int32\n" - - "vqmovn.s32 d20, q6 @ cnt to int16\n" - "vqmovn.s32 d22, q7 @ cnt to int16\n" - "vqmovn.s32 d24, q8 @ cnt to int16\n" - "vqmovn.s32 d26, q9 @ cnt to int16\n" - - "vqmovn.s16 d8, q10 @ cnt to int8\n" - "vqmovn.s16 d9, q11 @ cnt to int8\n" - "vqmovn.s16 d10, q12 @ cnt to int8\n" - "vqmovn.s16 d11, q13 @ cnt to int8\n" - - "vst1.32 {d8[0]}, [%[doutc0r0]] @ write to output\n" - "vst1.32 {d9[0]}, [%[doutc1r0]] @ write to output\n" - "vst1.32 {d10[0]}, [%[doutc2r0]] @ write to output\n" - "vst1.32 {d11[0]}, [%[doutc3r0]] @ write to output\n" - - "add %[doutc0r0], #4 \n" - "add %[doutc1r0], #4 \n" - "add %[doutc2r0], #4 \n" - "add %[doutc3r0], #4 \n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vmov.u32 q15, #0 @ dump zero\n" - - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : [scale] "w"(w_scale) - : "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - } - if (we > width) { - int offset = 16 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int j = we - 4; - for (; j < width; ++j) { - *(doutc0_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[0], 0) * scale[0])); - *(doutc1_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[1], 0) * scale[1])); - *(doutc2_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[2], 0) * scale[2])); - *(doutc3_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[3], 0) * scale[3])); - din_hei_ptr += 4; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { // size_h - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - // int32 --> fp32 - "scvtf v4.4s, v16.4s \n" - "scvtf v5.4s, v17.4s \n" - "scvtf v6.4s, v18.4s \n" - "scvtf v7.4s, v19.4s \n" - - // mul - "fmul v16.4s, v4.4s, %[scale].s[0] \n" - "fmul v17.4s, v5.4s, %[scale].s[2] \n" - "fmul v18.4s, v6.4s, %[scale].s[1] \n" - "fmul v19.4s, v7.4s, %[scale].s[3] \n" - - // fp32-int32 - "fcvtas v4.4s, v16.4s \n" - "fcvtas v5.4s, v17.4s \n" - "fcvtas v6.4s, v18.4s \n" - "fcvtas v7.4s, v19.4s \n" - - // int32-int16 - "sqxtn v8.4h, v4.4s \n" - "sqxtn v9.4h, v5.4s \n" - "sqxtn v10.4h, v6.4s \n" - "sqxtn v11.4h, v7.4s \n" - - "sqxtn v16.8b, v8.8h \n" - "sqxtn v17.8b, v9.8h \n" - "sqxtn v18.8b, v10.8h \n" - "sqxtn v19.8b, v11.8h \n" - // res - "str s16, [%[doutc0r0]], #4 \n" - "str s17, [%[doutc2r0]], #4 \n" - "str s18, [%[doutc1r0]], #4 \n" - "str s19, [%[doutc3r0]], #4 \n" - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale] "w"(w_scale) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q2, q3 @ trans q0, q1 \n" - "vtrn.32 q4, q5 @ trans q2, q3 \n" - "vswp.32 d5, d8 @ swap d1, d4 \n" - "vswp.32 d7, d10 @ swap d3, d6 \n" - - // int32-> fp32 - "vcvt.f32.s32 q6, q2 \n" - "vcvt.f32.s32 q7, q3 \n" - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - - "vmov.f32 q2, #0.5 \n" - - // "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" - "vand.i32 q3, q2, q2 @ set offset, 0.5\n" - "vand.i32 q4, q2, q2 @ set offset, 0.5\n" - "vand.i32 q5, q2, q2 @ set offset, 0.5\n" - - "vcgt.f32 q10, q6, q15 @ get mask > 0, in0\n" - "vcgt.f32 q11, q7, q15 @ get mask > 0, in1\n" - "vcgt.f32 q12, q8, q15 @ get mask > 0, in2\n" - "vcgt.f32 q13, q9, q15 @ get mask > 0, in3\n" - - "vmov.f32 q15, #-0.5 \n" - - "vbif.f32 q2, q15, q10 @ get right offset\n" - "vbif.f32 q3, q15, q11 @ get right offset\n" - "vbif.f32 q4, q15, q12 @ get right offset\n" - "vbif.f32 q5, q15, q13 @ get right offset\n" - - "vmla.f32 q2, q6, %e[scale][0] @ mul scale\n" - "vmla.f32 q3, q7, %e[scale][1] @ mul scale\n" - "vmla.f32 q4, q8, %f[scale][0] @ mul scale\n" - "vmla.f32 q5, q9, %f[scale][1] @ mul scale\n" - - "vcvt.s32.f32 q6, q2 @ cvt to int32\n" - "vcvt.s32.f32 q7, q3 @ cvt to int32\n" - "vcvt.s32.f32 q8, q4 @ cvt to int32\n" - "vcvt.s32.f32 q9, q5 @ cvt to int32\n" - - "vqmovn.s32 d20, q6 @ cnt to int16\n" - "vqmovn.s32 d22, q7 @ cnt to int16\n" - "vqmovn.s32 d24, q8 @ cnt to int16\n" - "vqmovn.s32 d26, q9 @ cnt to int16\n" - - "vqmovn.s16 d8, q10 @ cnt to int8\n" - "vqmovn.s16 d9, q11 @ cnt to int8\n" - "vqmovn.s16 d10, q12 @ cnt to int8\n" - "vqmovn.s16 d11, q13 @ cnt to int8\n" - - "vst1.32 {d8[0]}, [%[doutc0r0]] @ write to output\n" - "vst1.32 {d9[0]}, [%[doutc1r0]] @ write to output\n" - "vst1.32 {d10[0]}, [%[doutc2r0]] @ write to output\n" - "vst1.32 {d11[0]}, [%[doutc3r0]] @ write to output\n" - - "add %[doutc0r0], #4 \n" - "add %[doutc1r0], #4 \n" - "add %[doutc2r0], #4 \n" - "add %[doutc3r0], #4 \n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : [scale] "w"(w_scale) - : "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - } - if (we > width) { - int offset = 16 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int j = we - 4; - for (; j < width; ++j) { - *(doutc0_ptr++) = - saturate_cast(roundf(din_hei_ptr[0] * scale[0])); - *(doutc1_ptr++) = - saturate_cast(roundf(din_hei_ptr[1] * scale[1])); - *(doutc2_ptr++) = - saturate_cast(roundf(din_hei_ptr[2] * scale[2])); - *(doutc3_ptr++) = - saturate_cast(roundf(din_hei_ptr[3] * scale[3])); - din_hei_ptr += 4; - } - } - } - } - } else { - LOG(ERROR) << "ERROR: unsupported input data type!!"; - return false; - } - return true; -} - -/*wirte result in outputs -* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] -*/ -inline bool write_to_output_c8_int32(const int* din, - int* dout, - int ch_n, - int hei_n, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - int* trash_ptr) { - if (ch_n != 8 || hei_n <= 0) { - LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; - return false; - } - int size_c_out = width * height; - - int* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - int* doutc1r0 = doutc0r0 + size_c_out; - int* doutc2r0 = doutc1r0 + size_c_out; - int* doutc3r0 = doutc2r0 + size_c_out; - int* doutc4r0 = doutc3r0 + size_c_out; - int* doutc5r0 = doutc4r0 + size_c_out; - int* doutc6r0 = doutc5r0 + size_c_out; - int* doutc7r0 = doutc6r0 + size_c_out; - - const int* ptr_din = din; - - int size_h = (he > height ? height : he) - hs; // size_h == hei_n - - int valid_w = we - ws; - int cnt = valid_w / 4; - - if (we > width) { - cnt--; - } - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - int* doutc1_ptr = doutc1r0 + size_w; - int* doutc2_ptr = doutc2r0 + size_w; - int* doutc3_ptr = doutc3r0 + size_w; - int* doutc4_ptr = doutc4r0 + size_w; - int* doutc5_ptr = doutc5r0 + size_w; - int* doutc6_ptr = doutc6r0 + size_w; - int* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "smax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "smax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "smax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "smax v19.4s, v19.4s, v20.4s \n" /*relu*/ - - "smax v8.4s, v8.4s, v20.4s \n" /*relu*/ - "smax v9.4s, v9.4s, v20.4s \n" /*relu*/ - "smax v12.4s, v12.4s, v20.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v20.4s \n" /*relu*/ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "vmov.s32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "vmax.s32 q0, q0, q15 @ relu\n" - "vmax.s32 q1, q1, q15 @ relu\n" - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" - - "vmax.s32 q4, q4, q15 @ relu\n" - "vmax.s32 q5, q5, q15 @ relu\n" - "vmax.s32 q6, q6, q15 @ relu\n" - "vmax.s32 q7, q7, q15 @ relu\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4", "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0); - *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0); - *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0); - *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0); - *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0); - din_hei_ptr += 8; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - int* doutc1_ptr = doutc1r0 + size_w; - int* doutc2_ptr = doutc2r0 + size_w; - int* doutc3_ptr = doutc3r0 + size_w; - int* doutc4_ptr = doutc4r0 + size_w; - int* doutc5_ptr = doutc5r0 + size_w; - int* doutc6_ptr = doutc6r0 + size_w; - int* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4", "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = din_hei_ptr[0]; - *(doutc1_ptr++) = din_hei_ptr[1]; - *(doutc2_ptr++) = din_hei_ptr[2]; - *(doutc3_ptr++) = din_hei_ptr[3]; - *(doutc4_ptr++) = din_hei_ptr[4]; - *(doutc5_ptr++) = din_hei_ptr[5]; - *(doutc6_ptr++) = din_hei_ptr[6]; - *(doutc7_ptr++) = din_hei_ptr[7]; - din_hei_ptr += 8; - } - } - } - } - return true; -} - -/*wirte result in outputs--int8, fp32 -* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] -*/ -template -static bool write_to_output_c8_int32_1(const int* din, - dtype* dout, - int ch_n, - int hei_n, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - dtype* trash_ptr, - const float* scale, - PrecisionType out_dtype) { - if (ch_n != 8 || hei_n <= 0) { - LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; - return false; - } - int size_c_out = width * height; - - dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - dtype* doutc1r0 = doutc0r0 + size_c_out; - dtype* doutc2r0 = doutc1r0 + size_c_out; - dtype* doutc3r0 = doutc2r0 + size_c_out; - dtype* doutc4r0 = doutc3r0 + size_c_out; - dtype* doutc5r0 = doutc4r0 + size_c_out; - dtype* doutc6r0 = doutc5r0 + size_c_out; - dtype* doutc7r0 = doutc6r0 + size_c_out; + Dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + Dtype* doutc1r0 = doutc0r0 + size_c_out; + Dtype* doutc2r0 = doutc1r0 + size_c_out; + Dtype* doutc3r0 = doutc2r0 + size_c_out; + Dtype* doutc4r0 = doutc3r0 + size_c_out; + Dtype* doutc5r0 = doutc4r0 + size_c_out; + Dtype* doutc6r0 = doutc5r0 + size_c_out; + Dtype* doutc7r0 = doutc6r0 + size_c_out; const int* ptr_din = din; int size_h = (he > height ? height : he) - hs; // size_h == hei_n - int valid_w = we - ws; + int w_stride = we - ws; + int valid_w = (we > width ? width : we) - ws; int cnt = valid_w / 4; float32x4_t w_scale0 = vld1q_f32(scale); float32x4_t w_scale1 = vld1q_f32(scale + 4); + float32x4_t w_bias0 = flag_bias ? vld1q_f32(bias) : vdupq_n_f32(0.f); + float32x4_t w_bias1 = flag_bias ? vld1q_f32(bias + 4) : vdupq_n_f32(0.f); - float32x4_t vzero = vdupq_n_f32(0.f); - - if (we > width) { - cnt--; - } - if (out_dtype == PRECISION(kFloat)) { - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - dtype* doutc4_ptr = doutc4r0 + size_w; - dtype* doutc5_ptr = doutc5r0 + size_w; - dtype* doutc6_ptr = doutc6r0 + size_w; - dtype* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "smax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "smax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "smax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "smax v19.4s, v19.4s, v20.4s \n" /*relu*/ - - "smax v8.4s, v8.4s, v20.4s \n" /*relu*/ - "smax v9.4s, v9.4s, v20.4s \n" /*relu*/ - "smax v12.4s, v12.4s, v20.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v20.4s \n" /*relu*/ - - // int32->fp32 - "scvtf v10.4s, v16.4s \n" - "scvtf v11.4s, v17.4s \n" - "scvtf v14.4s, v18.4s \n" - "scvtf v15.4s, v19.4s \n" - // mul - "fmul v16.4s, v10.4s, %[scale0].s[0] \n" - "fmul v17.4s, v11.4s, %[scale0].s[2] \n" - "fmul v18.4s, v14.4s, %[scale0].s[1] \n" - "fmul v19.4s, v15.4s, %[scale0].s[3] \n" - - "scvtf v10.4s, v8.4s \n" - "scvtf v11.4s, v9.4s \n" - "scvtf v14.4s, v12.4s \n" - "scvtf v15.4s, v13.4s \n" - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - // mul - "fmul v8.4s, v10.4s, %[scale1].s[0] \n" - "fmul v9.4s, v11.4s, %[scale1].s[2] \n" - "fmul v12.4s, v14.4s, %[scale1].s[1] \n" - "fmul v13.4s, v15.4s, %[scale1].s[3] \n" - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "vmov.s32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vmax.s32 q0, q0, q15 @ relu\n" - "vmax.s32 q1, q1, q15 @ relu\n" - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" - - "vmax.s32 q4, q4, q15 @ relu\n" - "vmax.s32 q5, q5, q15 @ relu\n" - "vmax.s32 q6, q6, q15 @ relu\n" - "vmax.s32 q7, q7, q15 @ relu\n" - - // int32-> fp32 - "vcvt.f32.s32 q8, q0 \n" - "vcvt.f32.s32 q9, q1 \n" - "vcvt.f32.s32 q10, q2 \n" - "vcvt.f32.s32 q11, q3 \n" - - // mul - "vmul.f32 q0, q8, %q[scale0] \n" - "vmul.f32 q1, q9, %q[scale1] \n" - "vmul.f32 q2, q10, %q[scale0] \n" - "vmul.f32 q3, q11, %q[scale1] \n" - - // int32-> fp32 - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - "vcvt.f32.s32 q10, q6 \n" - "vcvt.f32.s32 q11, q7 \n" - - // mul - "vmul.f32 q4, q8, %q[scale0] \n" - "vmul.f32 q5, q9, %q[scale1] \n" - "vmul.f32 q6, q10, %q[scale0] \n" - "vmul.f32 q7, q11, %q[scale1] \n" - - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0] * scale[0], 0); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1] * scale[1], 0); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2] * scale[2], 0); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3] * scale[3], 0); - *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4] * scale[4], 0); - *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5] * scale[5], 0); - *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6] * scale[6], 0); - *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7] * scale[7], 0); - din_hei_ptr += 8; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - dtype* doutc4_ptr = doutc4r0 + size_w; - dtype* doutc5_ptr = doutc5r0 + size_w; - dtype* doutc6_ptr = doutc6r0 + size_w; - dtype* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - // int32->fp32 - "scvtf v10.4s, v16.4s \n" - "scvtf v11.4s, v17.4s \n" - "scvtf v14.4s, v18.4s \n" - "scvtf v15.4s, v19.4s \n" - // mul - "fmul v16.4s, v10.4s, %[scale0].s[0] \n" - "fmul v17.4s, v11.4s, %[scale0].s[2] \n" - "fmul v18.4s, v14.4s, %[scale0].s[1] \n" - "fmul v19.4s, v15.4s, %[scale0].s[3] \n" - - "scvtf v10.4s, v8.4s \n" - "scvtf v11.4s, v9.4s \n" - "scvtf v14.4s, v12.4s \n" - "scvtf v15.4s, v13.4s \n" - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - // mul - "fmul v8.4s, v10.4s, %[scale1].s[0] \n" - "fmul v9.4s, v11.4s, %[scale1].s[2] \n" - "fmul v12.4s, v14.4s, %[scale1].s[1] \n" - "fmul v13.4s, v15.4s, %[scale1].s[3] \n" - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "vmov.s32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - // int32-> fp32 - "vcvt.f32.s32 q8, q0 \n" - "vcvt.f32.s32 q9, q1 \n" - "vcvt.f32.s32 q10, q2 \n" - "vcvt.f32.s32 q11, q3 \n" - - // mul - "vmul.f32 q0, q8, %q[scale0] \n" - "vmul.f32 q1, q9, %q[scale1] \n" - "vmul.f32 q2, q10, %q[scale0] \n" - "vmul.f32 q3, q11, %q[scale1] \n" - - // int32-> fp32 - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - "vcvt.f32.s32 q10, q6 \n" - "vcvt.f32.s32 q11, q7 \n" - - // mul - "vmul.f32 q4, q8, %q[scale0] \n" - "vmul.f32 q5, q9, %q[scale1] \n" - "vmul.f32 q6, q10, %q[scale0] \n" - "vmul.f32 q7, q11, %q[scale1] \n" - - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = din_hei_ptr[0] * scale[0]; - *(doutc1_ptr++) = din_hei_ptr[1] * scale[1]; - *(doutc2_ptr++) = din_hei_ptr[2] * scale[2]; - *(doutc3_ptr++) = din_hei_ptr[3] * scale[3]; - *(doutc4_ptr++) = din_hei_ptr[4] * scale[4]; - *(doutc5_ptr++) = din_hei_ptr[5] * scale[5]; - *(doutc6_ptr++) = din_hei_ptr[6] * scale[6]; - *(doutc7_ptr++) = din_hei_ptr[7] * scale[7]; - din_hei_ptr += 8; - } - } + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + Dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + Dtype* doutc1_ptr = doutc1r0 + size_w; + Dtype* doutc2_ptr = doutc2r0 + size_w; + Dtype* doutc3_ptr = doutc3r0 + size_w; + Dtype* doutc4_ptr = doutc4r0 + size_w; + Dtype* doutc5_ptr = doutc5r0 + size_w; + Dtype* doutc6_ptr = doutc6r0 + size_w; + Dtype* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; } } - } else if (out_dtype == PRECISION(kInt8)) { - // int32_to_int8 - float32x4_t vpoff = vdupq_n_f32(0.5f); - float32x4_t vnoff = vdupq_n_f32(-0.5f); - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - dtype* doutc4_ptr = doutc4r0 + size_w; - dtype* doutc5_ptr = doutc5r0 + size_w; - dtype* doutc6_ptr = doutc6r0 + size_w; - dtype* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - // "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "smax v16.4s, v16.4s, %[vzero].4s \n" /*relu*/ - "smax v17.4s, v17.4s, %[vzero].4s \n" /*relu*/ - "smax v18.4s, v18.4s, %[vzero].4s \n" /*relu*/ - "smax v19.4s, v19.4s, %[vzero].4s \n" /*relu*/ - - "smax v8.4s, v8.4s, %[vzero].4s \n" /*relu*/ - "smax v9.4s, v9.4s, %[vzero].4s \n" /*relu*/ - "smax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - "smax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - // int32 --> fp32 - "scvtf v10.4s, v16.4s \n" - "scvtf v11.4s, v17.4s \n" - "scvtf v14.4s, v18.4s \n" - "scvtf v15.4s, v19.4s \n" - - "scvtf v20.4s, v8.4s \n" - "scvtf v21.4s, v9.4s \n" - "scvtf v22.4s, v12.4s \n" - "scvtf v23.4s, v13.4s \n" - - // mul - "fmul v16.4s, v10.4s, %[scale0].s[0] \n" - "fmul v17.4s, v11.4s, %[scale0].s[2] \n" - "fmul v18.4s, v14.4s, %[scale0].s[1] \n" - "fmul v19.4s, v15.4s, %[scale0].s[3] \n" - - "fmul v8.4s, v20.4s, %[scale1].s[0] \n" - "fmul v9.4s, v21.4s, %[scale1].s[2] \n" - "fmul v12.4s, v22.4s, %[scale1].s[1] \n" - "fmul v13.4s, v23.4s, %[scale1].s[3] \n" - - // fp32-int32 - "fcvtas v10.4s, v16.4s \n" - "fcvtas v11.4s, v17.4s \n" - "fcvtas v14.4s, v18.4s \n" - "fcvtas v15.4s, v19.4s \n" - - "fcvtas v20.4s, v8.4s \n" - "fcvtas v21.4s, v9.4s \n" - "fcvtas v22.4s, v12.4s \n" - "fcvtas v23.4s, v13.4s \n" - - // int32-int16 - "sqxtn v16.4h, v10.4s \n" - "sqxtn v17.4h, v11.4s \n" - "sqxtn v18.4h, v14.4s \n" - "sqxtn v19.4h, v15.4s \n" - - "sqxtn v8.4h, v20.4s \n" - "sqxtn v9.4h, v21.4s \n" - "sqxtn v12.4h, v22.4s \n" - "sqxtn v13.4h, v23.4s \n" - - // int16-int8 - "sqxtn v10.8b, v16.8h \n" - "sqxtn v11.8b, v17.8h \n" - "sqxtn v14.8b, v18.8h \n" - "sqxtn v15.8b, v19.8h \n" - - "sqxtn v20.8b, v8.8h \n" - "sqxtn v21.8b, v9.8h \n" - "sqxtn v22.8b, v12.8h \n" - "sqxtn v23.8b, v13.8h \n" - - "str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/ - "str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/ - "str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/ - "str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/ - "str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/ - "str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/ - "str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - [scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23"); -#else - asm volatile( - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "1: @ main loop\n" - "vmax.s32 q4, q4, %q[vzero] @ relu\n" - "vmax.s32 q5, q5, %q[vzero] @ relu\n" - "vmax.s32 q6, q6, %q[vzero] @ relu\n" - "vmax.s32 q7, q7, %q[vzero] @ relu\n" - - // int32-> fp32 - "vmov.f32 q15, #0.5 \n" - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - "vcvt.f32.s32 q10, q6 \n" - "vcvt.f32.s32 q11, q7 \n" - - "vand.i32 q4, q15, q15 @ set offset, 0.5\n" - "vand.i32 q5, q15, q15 @ set offset, 0.5\n" - "vand.i32 q6, q15, q15 @ set offset, 0.5\n" - "vand.i32 q7, q15, q15 @ set offset, 0.5\n" - - "vmov.f32 q15, #-0.5 \n" - - "vcgt.f32 q12, q8, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q13, q9, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q14, q10, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q3, q11, %q[vzero] @ get mask > 0, in0\n" - - "vbif.f32 q4, q15, q12 @ get right offset\n" - "vbif.f32 q5, q15, q13 @ get right offset\n" - "vbif.f32 q6, q15, q14 @ get right offset\n" - "vbif.f32 q7, q15, q3 @ get right offset\n" - - "vld1.32 {d24-d27}, [%[ptr_din]]! @load data \n" - "vld1.32 {d28-d29}, [%[ptr_din]]! @load data \n" - "vld1.32 {d6-d7}, [%[ptr_din]]! @load data \n" - - "vmla.f32 q4, q8, %q[scale0] @ mul scale\n" - "vmla.f32 q5, q9, %q[scale1] @ mul scale\n" - "vmla.f32 q6, q10, %q[scale0] @ mul scale\n" - "vmla.f32 q7, q11, %q[scale1] @ mul scale\n" - - "vmax.s32 q12, q12, %q[vzero] @ relu\n" - "vmax.s32 q13, q13, %q[vzero] @ relu\n" - "vmax.s32 q14, q14, %q[vzero] @ relu\n" - "vmax.s32 q3, q3, %q[vzero] @ relu\n" - - "vcvt.s32.f32 q8, q4 @ cvt to int32\n" - "vcvt.s32.f32 q9, q5 @ cvt to int32\n" - "vcvt.s32.f32 q10, q6 @ cvt to int32\n" - "vcvt.s32.f32 q11, q7 @ cvt to int32\n" - - "vqmovn.s32 d8, q8 @ cnt to int16\n" - "vqmovn.s32 d10, q9 @ cnt to int16\n" - "vqmovn.s32 d12, q10 @ cnt to int16\n" - "vqmovn.s32 d14, q11 @ cnt to int16\n" - - "vqmovn.s16 d16, q4 @ cnt to int8\n" - "vqmovn.s16 d17, q5 @ cnt to int8\n" - "vqmovn.s16 d18, q6 @ cnt to int8\n" - "vqmovn.s16 d19, q7 @ cnt to int8\n" - - "vmov.f32 q15, #0.5 \n" - - "vcvt.f32.s32 q4, q12 \n" - "vcvt.f32.s32 q5, q13 \n" - "vcvt.f32.s32 q6, q14 \n" - "vcvt.f32.s32 q7, q3 \n" - - "vand.i32 q12, q15, q15 @ set offset, 0.5\n" - "vand.i32 q13, q15, q15 @ set offset, 0.5\n" - "vand.i32 q14, q15, q15 @ set offset, 0.5\n" - "vand.i32 q3, q15, q15 @ set offset, 0.5\n" - - "vmov.f32 q15, #-0.5 \n" - - "vcgt.f32 q10, q4, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q11, q5, %q[vzero] @ get mask > 0, in0\n" - - "vbif.f32 q12, q15, q10 @ get right offset\n" - "vbif.f32 q13, q15, q11 @ get right offset\n" - - "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in0\n" - - "vbif.f32 q14, q15, q10 @ get right offset\n" - "vbif.f32 q3, q15, q11 @ get right offset\n" - - "vmla.f32 q12, q4, %q[scale0] @ mul scale\n" - "vmla.f32 q13, q5, %q[scale1] @ mul scale\n" - "vmla.f32 q14, q6, %q[scale0] @ mul scale\n" - "vmla.f32 q3, q7, %q[scale1] @ mul scale\n" - - "vcvt.s32.f32 q4, q12 @ cvt to int32\n" - "vcvt.s32.f32 q5, q13 @ cvt to int32\n" - "vcvt.s32.f32 q6, q14 @ cvt to int32\n" - "vcvt.s32.f32 q7, q3 @ cvt to int32\n" - - "vqmovn.s32 d24, q4 @ cnt to int16\n" - "vqmovn.s32 d26, q5 @ cnt to int16\n" - "vqmovn.s32 d28, q6 @ cnt to int16\n" - "vqmovn.s32 d6, q7 @ cnt to int16\n" - - "vqmovn.s16 d20, q12 @ cnt to int8\n" - "vqmovn.s16 d21, q13 @ cnt to int8\n" - "vqmovn.s16 d22, q14 @ cnt to int8\n" - "vqmovn.s16 d23, q3 @ cnt to int8\n" - - "vtrn.8 d16, d18 @ trans q0, q2 \n" - "vtrn.8 d20, d22 @ trans q4, q6 \n" - "vtrn.16 d16, d20 @ trans q0, q2 \n" - "vtrn.16 d18, d22 @ trans q4, q6 \n" - - "vtrn.8 d17, d19 @ trans q0, q2 \n" - "vtrn.8 d21, d23 @ trans q4, q6 \n" - "vtrn.16 d17, d21 @ trans q0, q2 \n" - "vtrn.16 d19, d23 @ trans q4, q6 \n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d16[0]}, [%[doutc0r0]] @ store result, add " - "pointer\n" - "vst1.32 {d18[0]}, [%[doutc1r0]] @ store result, add " - "pointer\n" - "vst1.32 {d20[0]}, [%[doutc2r0]] @ store result, add " - "pointer\n" - "vst1.32 {d22[0]}, [%[doutc3r0]] @ store result, add " - "pointer\n" - - "vst1.32 {d17[0]}, [%[doutc4r0]] @ store result, add " - "pointer\n" - "vst1.32 {d19[0]}, [%[doutc5r0]] @ store result, add " - "pointer\n" - "vst1.32 {d21[0]}, [%[doutc6r0]] @ store result, add " - "pointer\n" - "vst1.32 {d23[0]}, [%[doutc7r0]] @ store result, add " - "pointer\n" - - "add %[doutc0r0], #4 @ add \n" - "add %[doutc1r0], #4 @ add \n" - "add %[doutc2r0], #4 @ add \n" - "add %[doutc3r0], #4 @ add \n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "add %[doutc4r0], #4 @ add \n" - "add %[doutc5r0], #4 @ add \n" - "add %[doutc6r0], #4 @ add \n" - "add %[doutc7r0], #4 @ add \n" - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - [scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero) - : "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[0] * scale[0], 0))); - *(doutc1_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[1] * scale[1], 0))); - *(doutc2_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[2] * scale[2], 0))); - *(doutc3_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[3] * scale[3], 0))); - *(doutc4_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[4] * scale[4], 0))); - *(doutc5_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[5] * scale[5], 0))); - *(doutc6_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[6] * scale[6], 0))); - *(doutc7_ptr++) = saturate_cast( - roundf(LITEMAX(din_hei_ptr[7] * scale[7], 0))); - din_hei_ptr += 8; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - dtype* doutc1_ptr = doutc1r0 + size_w; - dtype* doutc2_ptr = doutc2r0 + size_w; - dtype* doutc3_ptr = doutc3r0 + size_w; - dtype* doutc4_ptr = doutc4r0 + size_w; - dtype* doutc5_ptr = doutc5r0 + size_w; - dtype* doutc6_ptr = doutc6r0 + size_w; - dtype* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const int* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - // "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - // int32 --> fp32 - "scvtf v10.4s, v16.4s \n" - "scvtf v11.4s, v17.4s \n" - "scvtf v14.4s, v18.4s \n" - "scvtf v15.4s, v19.4s \n" - - "scvtf v20.4s, v8.4s \n" - "scvtf v21.4s, v9.4s \n" - "scvtf v22.4s, v12.4s \n" - "scvtf v23.4s, v13.4s \n" - - // mul - "fmul v16.4s, v10.4s, %[scale0].s[0] \n" - "fmul v17.4s, v11.4s, %[scale0].s[2] \n" - "fmul v18.4s, v14.4s, %[scale0].s[1] \n" - "fmul v19.4s, v15.4s, %[scale0].s[3] \n" - - "fmul v8.4s, v20.4s, %[scale1].s[0] \n" - "fmul v9.4s, v21.4s, %[scale1].s[2] \n" - "fmul v12.4s, v22.4s, %[scale1].s[1] \n" - "fmul v13.4s, v23.4s, %[scale1].s[3] \n" - - // fp32-int32 - "fcvtas v10.4s, v16.4s \n" - "fcvtas v11.4s, v17.4s \n" - "fcvtas v14.4s, v18.4s \n" - "fcvtas v15.4s, v19.4s \n" - - "fcvtas v20.4s, v8.4s \n" - "fcvtas v21.4s, v9.4s \n" - "fcvtas v22.4s, v12.4s \n" - "fcvtas v23.4s, v13.4s \n" - - // int32-int16 - "sqxtn v16.4h, v10.4s \n" - "sqxtn v17.4h, v11.4s \n" - "sqxtn v18.4h, v14.4s \n" - "sqxtn v19.4h, v15.4s \n" - - "sqxtn v8.4h, v20.4s \n" - "sqxtn v9.4h, v21.4s \n" - "sqxtn v12.4h, v22.4s \n" - "sqxtn v13.4h, v23.4s \n" - - // int16-int8 - "sqxtn v10.8b, v16.8h \n" - "sqxtn v11.8b, v17.8h \n" - "sqxtn v14.8b, v18.8h \n" - "sqxtn v15.8b, v19.8h \n" - - "sqxtn v20.8b, v8.8h \n" - "sqxtn v21.8b, v9.8h \n" - "sqxtn v22.8b, v12.8h \n" - "sqxtn v23.8b, v13.8h \n" - - "str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/ - "str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/ - "str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/ - "str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/ - "str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/ - "str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/ - "str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23"); -#else - asm volatile( - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "1: @ main loop\n" - // int32-> fp32 - "vmov.f32 q15, #0.5 \n" - "vcvt.f32.s32 q8, q4 \n" - "vcvt.f32.s32 q9, q5 \n" - "vcvt.f32.s32 q10, q6 \n" - "vcvt.f32.s32 q11, q7 \n" - - "vand.i32 q4, q15, q15 @ set offset, 0.5\n" - "vand.i32 q5, q4, q4 @ set offset, 0.5\n" - "vand.i32 q6, q4, q4 @ set offset, 0.5\n" - "vand.i32 q7, q4, q4 @ set offset, 0.5\n" - - "vmov.f32 q15, #-0.5 \n" - - "vcgt.f32 q12, q8, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q13, q9, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q14, q10, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q3, q11, %q[vzero] @ get mask > 0, in0\n" - - "vbif.f32 q4, q15, q12 @ get right offset\n" - "vbif.f32 q5, q15, q13 @ get right offset\n" - "vbif.f32 q6, q15, q14 @ get right offset\n" - "vbif.f32 q7, q15, q3 @ get right offset\n" - - "vld1.32 {d24-d27}, [%[ptr_din]]! @load data \n" - "vld1.32 {d28-d29}, [%[ptr_din]]! @load data \n" - "vld1.32 {d6-d7}, [%[ptr_din]]! @load data \n" - - "vmla.f32 q4, q8, %q[scale0] @ mul scale\n" - "vmla.f32 q5, q9, %q[scale1] @ mul scale\n" - "vmla.f32 q6, q10, %q[scale0] @ mul scale\n" - "vmla.f32 q7, q11, %q[scale1] @ mul scale\n" - - "vcvt.s32.f32 q8, q4 @ cvt to int32\n" - "vcvt.s32.f32 q9, q5 @ cvt to int32\n" - "vcvt.s32.f32 q10, q6 @ cvt to int32\n" - "vcvt.s32.f32 q11, q7 @ cvt to int32\n" - - "vqmovn.s32 d8, q8 @ cnt to int16\n" - "vqmovn.s32 d10, q9 @ cnt to int16\n" - "vqmovn.s32 d12, q10 @ cnt to int16\n" - "vqmovn.s32 d14, q11 @ cnt to int16\n" - - "vqmovn.s16 d16, q4 @ cnt to int8\n" - "vqmovn.s16 d17, q5 @ cnt to int8\n" - "vqmovn.s16 d18, q6 @ cnt to int8\n" - "vqmovn.s16 d19, q7 @ cnt to int8\n" - - "vmov.f32 q15, #0.5 \n" - - "vcvt.f32.s32 q4, q12 \n" - "vcvt.f32.s32 q5, q13 \n" - "vcvt.f32.s32 q6, q14 \n" - "vcvt.f32.s32 q7, q3 \n" - - "vand.i32 q12, q15, q15 @ set offset, 0.5\n" - "vand.i32 q13, q12, q12 @ set offset, 0.5\n" - "vand.i32 q14, q12, q12 @ set offset, 0.5\n" - "vand.i32 q3, q12, q12 @ set offset, 0.5\n" - - "vmov.f32 q15, #-0.5 \n" - - "vcgt.f32 q10, q4, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q11, q5, %q[vzero] @ get mask > 0, in0\n" - - "vbif.f32 q12, q15, q10 @ get right offset\n" - "vbif.f32 q13, q15, q11 @ get right offset\n" - - "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in0\n" - "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in0\n" - - "vbif.f32 q14, q15, q10 @ get right offset\n" - "vbif.f32 q3, q15, q11 @ get right offset\n" - - "vmla.f32 q12, q4, %q[scale0] @ mul scale\n" - "vmla.f32 q13, q5, %q[scale1] @ mul scale\n" - "vmla.f32 q14, q6, %q[scale0] @ mul scale\n" - "vmla.f32 q3, q7, %q[scale1] @ mul scale\n" - - "vcvt.s32.f32 q4, q12 @ cvt to int32\n" - "vcvt.s32.f32 q5, q13 @ cvt to int32\n" - "vcvt.s32.f32 q6, q14 @ cvt to int32\n" - "vcvt.s32.f32 q7, q3 @ cvt to int32\n" - - "vqmovn.s32 d24, q4 @ cnt to int16\n" - "vqmovn.s32 d26, q5 @ cnt to int16\n" - "vqmovn.s32 d28, q6 @ cnt to int16\n" - "vqmovn.s32 d6, q7 @ cnt to int16\n" - - "vqmovn.s16 d20, q12 @ cnt to int8\n" - "vqmovn.s16 d21, q13 @ cnt to int8\n" - "vqmovn.s16 d22, q14 @ cnt to int8\n" - "vqmovn.s16 d23, q3 @ cnt to int8\n" - - "vtrn.8 d16, d18 @ trans q0, q2 \n" - "vtrn.8 d20, d22 @ trans q4, q6 \n" - "vtrn.16 d16, d20 @ trans q0, q2 \n" - "vtrn.16 d18, d22 @ trans q4, q6 \n" - - "vtrn.8 d17, d19 @ trans q0, q2 \n" - "vtrn.8 d21, d23 @ trans q4, q6 \n" - "vtrn.16 d17, d21 @ trans q0, q2 \n" - "vtrn.16 d19, d23 @ trans q4, q6 \n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d16[0]}, [%[doutc0r0]] @ store result, add " - "pointer\n" - "vst1.32 {d18[0]}, [%[doutc1r0]] @ store result, add " - "pointer\n" - "vst1.32 {d20[0]}, [%[doutc2r0]] @ store result, add " - "pointer\n" - "vst1.32 {d22[0]}, [%[doutc3r0]] @ store result, add " - "pointer\n" - - "vst1.32 {d17[0]}, [%[doutc4r0]] @ store result, add " - "pointer\n" - "vst1.32 {d19[0]}, [%[doutc5r0]] @ store result, add " - "pointer\n" - "vst1.32 {d21[0]}, [%[doutc6r0]] @ store result, add " - "pointer\n" - "vst1.32 {d23[0]}, [%[doutc7r0]] @ store result, add " - "pointer\n" - - "add %[doutc0r0], #4 @ add \n" - "add %[doutc1r0], #4 @ add \n" - "add %[doutc2r0], #4 @ add \n" - "add %[doutc3r0], #4 @ add \n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "add %[doutc4r0], #4 @ add \n" - "add %[doutc5r0], #4 @ add \n" - "add %[doutc6r0], #4 @ add \n" - "add %[doutc7r0], #4 @ add \n" - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - [scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero) - : "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = - saturate_cast(roundf(din_hei_ptr[0] * scale[0])); - *(doutc1_ptr++) = - saturate_cast(roundf(din_hei_ptr[1] * scale[1])); - *(doutc2_ptr++) = - saturate_cast(roundf(din_hei_ptr[2] * scale[2])); - *(doutc3_ptr++) = - saturate_cast(roundf(din_hei_ptr[3] * scale[3])); - *(doutc4_ptr++) = - saturate_cast(roundf(din_hei_ptr[4] * scale[4])); - *(doutc5_ptr++) = - saturate_cast(roundf(din_hei_ptr[5] * scale[5])); - *(doutc6_ptr++) = - saturate_cast(roundf(din_hei_ptr[6] * scale[6])); - *(doutc7_ptr++) = - saturate_cast(roundf(din_hei_ptr[7] * scale[7])); - din_hei_ptr += 8; - } + ptr_din = din + i * w_stride * 8; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int32_nchwc8_kernel(doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + doutc4_ptr, + doutc5_ptr, + doutc6_ptr, + doutc7_ptr, + din_hei_ptr, + cnt, + w_scale0, + w_scale1, + w_bias0, + w_bias1, + flag_relu); + } + if (we > width) { + int offset = 32 * cnt; + din_hei_ptr = ptr_din + offset; + for (int j = ws + cnt * 4; j < width; ++j) { + if (flag_bias) { + *(doutc0_ptr++) = + cvt_kernel(din_hei_ptr[0], scale[0], bias[0], flag_relu); + *(doutc1_ptr++) = + cvt_kernel(din_hei_ptr[1], scale[1], bias[1], flag_relu); + *(doutc2_ptr++) = + cvt_kernel(din_hei_ptr[2], scale[2], bias[2], flag_relu); + *(doutc3_ptr++) = + cvt_kernel(din_hei_ptr[3], scale[3], bias[3], flag_relu); + *(doutc4_ptr++) = + cvt_kernel(din_hei_ptr[4], scale[4], bias[4], flag_relu); + *(doutc5_ptr++) = + cvt_kernel(din_hei_ptr[5], scale[5], bias[5], flag_relu); + *(doutc6_ptr++) = + cvt_kernel(din_hei_ptr[6], scale[6], bias[6], flag_relu); + *(doutc7_ptr++) = + cvt_kernel(din_hei_ptr[7], scale[7], bias[7], flag_relu); + } else { + *(doutc0_ptr++) = + cvt_kernel(din_hei_ptr[0], scale[0], 0.f, flag_relu); + *(doutc1_ptr++) = + cvt_kernel(din_hei_ptr[1], scale[1], 0.f, flag_relu); + *(doutc2_ptr++) = + cvt_kernel(din_hei_ptr[2], scale[2], 0.f, flag_relu); + *(doutc3_ptr++) = + cvt_kernel(din_hei_ptr[3], scale[3], 0.f, flag_relu); + *(doutc4_ptr++) = + cvt_kernel(din_hei_ptr[4], scale[4], 0.f, flag_relu); + *(doutc5_ptr++) = + cvt_kernel(din_hei_ptr[5], scale[5], 0.f, flag_relu); + *(doutc6_ptr++) = + cvt_kernel(din_hei_ptr[6], scale[6], 0.f, flag_relu); + *(doutc7_ptr++) = + cvt_kernel(din_hei_ptr[7], scale[7], 0.f, flag_relu); } + din_hei_ptr += 8; } } - } else { - LOG(ERROR) << "ERROR: unsupported input data type!!"; - return false; } - return true; } /* diff --git a/lite/backends/arm/math/conv_depthwise.cc b/lite/backends/arm/math/conv_depthwise.cc deleted file mode 100644 index 79b8cec57161668a5e7e8e66ab18d328a1c1ca23..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise.cc +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_depthwise.h" -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -template <> -bool DepthwiseConv::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ic = x_dims[1]; - int ow = o_dims[3]; - int oc = o_dims[1]; - int kw = w_dims[3]; - int sw = param.strides[1]; - // select dw conv kernel - if (kw == 3) { - VLOG(5) << "invoke 3x3 dw conv"; - impl_ = conv_depthwise_3x3; - } else if (kw == 5) { - VLOG(5) << "invoke 5x5 dw conv"; - this->ctx_->ExtendWorkspace((iw + ow) * sizeof(float)); - impl_ = conv_depthwise_5x5; - } else { - LOG(ERROR) << "this type dw conv not impl"; - return false; - } - return true; -} - -template <> -bool DepthwiseConv::init(const operators::ConvParam& param, - Context* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template <> -bool DepthwiseConv::run(const operators::ConvParam& param) { - // start timer - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int oc = o_dims[1]; - - impl_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_); - - // timer end - return true; -} - -template -bool DepthwiseConvInt8::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int ic = x_dims[1]; - int ih = x_dims[2]; - int iw = x_dims[3]; // nchw - int oc = o_dims[1]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int kw = w_dims[3]; - int sw = param.strides[1]; - w_scale_ = param.weight_scale; - - //! select dw conv kernel - if (kw == 3) { - tmp_int32_out_.Resize(o_dims); - VLOG(5) << "invoke 3x3 depthwise int8 conv"; - impl_ = conv_depthwise_3x3_int8; - } else if (kw == 5) { - // update w_data scale - if (Ptype_out == PRECISION(kFloat) || Ptype_out == PRECISION(kInt8)) { - CHECK_EQ(w_scale_.size(), oc) << "w_data scale size must be oc"; - float input_scale = param.input_scale; - float output_scale = param.output_scale; - for (auto& ws : w_scale_) { - ws *= input_scale; - if (Ptype_out == PRECISION(kInt8)) { - ws /= output_scale; - } - } - } - - const int wout_round = ((ow + 7) / 8) * 8; - const int win_round = wout_round * sw + 5 - 1; - const int hout_round = ((oh + 2) / 3) * 3; - const int hin_round = hout_round * sw + 5 - 1; - const int tmp_size_out = wout_round * hout_round; - const int tmp_size_in = win_round * hin_round; - const int tmp_size_io_bytes = tmp_size_in + tmp_size_out * sizeof(int); - const int tmp_row_io_bytes = win_round + wout_round * sizeof(int); - const int tmp_size_io_float = - (tmp_size_io_bytes + sizeof(float) - 1) / sizeof(float); - const int tmp_row_io_float = - (tmp_row_io_bytes + sizeof(float) - 1) / sizeof(float); - ctx_->ExtendWorkspace( - (ctx_->threads() * tmp_size_io_float + tmp_row_io_float) * - sizeof(float)); - impl_ = conv_depthwise_5x5_int8; - VLOG(5) << "invoke conv_depthwise_5x5 int8 conv"; - } else { - LOG(ERROR) << "this type depthwise int8 conv not impl"; - return false; - } - return true; -} - -template -bool DepthwiseConvInt8::init(const operators::ConvParam& param, - Context* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template -bool DepthwiseConvInt8::run(const operators::ConvParam& param) { - const int8_t* i_data = param.x->data(); - int32_t* o_data = nullptr; - const int8_t* w_data = param.filter->data(); - const int32_t* b_data = param.bias ? param.bias->data() : nullptr; - - // LOG(INFO) << "input size: " << param.x->memory_size() << " " - // << param.input_scale << " " << w_scale_.size(); - - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - int bs = x_dims[0]; - int ic = x_dims[1]; - int ih = x_dims[2]; - int iw = x_dims[3]; // nchw - int oc = o_dims[1]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int kw = w_dims[3]; - int sw = param.strides[1]; - - if (kw == 3 && Ptype_out != PRECISION(kInt32)) { - o_data = tmp_int32_out_.mutable_data(); - } else if (kw == 5 || (kw == 3 && Ptype_out == PRECISION(kInt32))) { - o_data = param.output->mutable_data(); - } else { - LOG(ERROR) << "this type dw int8 conv not impl"; - return false; - } - - impl_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_, - Ptype_out, - w_scale_.data()); - - auto i_scale = param.input_scale; - auto o_scale = param.output_scale; - if (kw == 3) { - if (Ptype_out == PRECISION(kInt8)) { - trans_tensor_dtype( - &tmp_int32_out_, param.output, i_scale, o_scale, w_scale_); - } else if (Ptype_out == PRECISION(kFloat)) { - trans_tensor_dtype( - &tmp_int32_out_, param.output, i_scale, 1.f, w_scale_); - } else if (Ptype_out != PRECISION(kInt32)) { - LOG(ERROR) << "unsupported precision type!!"; - return false; - } - } - - return true; -} - -template class DepthwiseConvInt8; -template class DepthwiseConvInt8; -template class DepthwiseConvInt8; - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index cdddda79d1277190f7fc1371e271f3cc0d137530..1a23982cd575afb6b249390de7081165c03414b9 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -16,20 +16,16 @@ #include #include -#include "lite/backends/arm/math/conv_impl.h" #include "lite/core/context.h" #include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { namespace arm { namespace math { -template -class DepthwiseConv - : public ImplBase { - public: - typedef void (*conv_dw_impl)(const float* i_data, +void conv_3x3s1_depthwise_fp32(const float* i_data, float* o_data, int bs, int oc, @@ -37,62 +33,175 @@ class DepthwiseConv int ow, int ic, int ih, - int kw, - const float* w_data, - const float* b_data, + int win, + const float* weights, + const float* bias, const operators::ConvParam& param, - Context* ctx); - DepthwiseConv() = default; - ~DepthwiseConv() {} + ARMContext* ctx); - virtual bool init(const operators::ConvParam& param, - Context* ctx); +void conv_3x3s2_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); - virtual bool create(const operators::ConvParam& param, - Context* ctx); +void conv_depthwise_3x3s1_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); - virtual bool run(const operators::ConvParam& param); +void conv_depthwise_3x3s2_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); - private: - conv_dw_impl impl_{nullptr}; -}; +void conv_depthwise_3x3p0_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); -template -class DepthwiseConvInt8 - : public ImplBase { - public: - typedef void (*conv_dw_int8_impl)(const int8_t* i_data, - int32_t* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int kw, - const int8_t* w_data, - const int32_t* b_data, - const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale); +void conv_depthwise_3x3p1_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); - DepthwiseConvInt8() = default; - ~DepthwiseConvInt8() {} +template +void conv_depthwise_3x3s1_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); - virtual bool init(const operators::ConvParam& param, - Context* ctx); +template +void conv_depthwise_3x3s2_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); - virtual bool create(const operators::ConvParam& param, - Context* ctx); +void conv_depthwise_5x5s1_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); - virtual bool run(const operators::ConvParam& param); +void conv_depthwise_5x5s2_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); - private: - conv_dw_int8_impl impl_{nullptr}; - std::vector w_scale_; - Tensor tmp_int32_out_; -}; +template +void conv_depthwise_5x5s1_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/conv_depthwise_3x3_int8.cc b/lite/backends/arm/math/conv_depthwise_3x3_int8.cc deleted file mode 100644 index d1eedd9557db397ea97ae33d006c2919ba39c91b..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_3x3_int8.cc +++ /dev/null @@ -1,5832 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p1_bias_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 8 -void conv_depthwise_3x3s1p1_bias_s_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 8 -void conv_depthwise_3x3s2p1_bias_s_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s1p1_bias_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p1_bias_s_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p1_bias_s_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3_int8(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - ARMContext* ctx, - PrecisionType out_type, - const float* scale) { - int w_in = win; - int h_in = hin; - int ch_in = chin; - - int w_out = wout; - int h_out = hout; - int ch_out = chout; - int stride_h = param.strides[0]; - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active){ - // if (param.activation_param.active == Active_relu || - // fabs(param.activation_param.negative_slope) > 1e-6f){ - // flag_relu = true; - // } - // } - //! only support stride = 1 or 2 - if (stride_h == 1) { - if (flag_relu) { - if (w_in > 8) { - conv_depthwise_3x3s1p1_bias_relu_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_relu_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 8) { - conv_depthwise_3x3s1p1_bias_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } else { //! stride = 2 - if (flag_relu) { - if (w_in > 16) { - conv_depthwise_3x3s2p1_bias_relu_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s_relu_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 16) { - conv_depthwise_3x3s2p1_bias_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s_int8(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ - -// 4line w_in > 8 -void conv_depthwise_3x3s1p1_bias_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - const unsigned char right_pad_idx[16] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 7) >> 3; - int tile_h = (h_out + 1) >> 1; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); - - int size_pad_bottom = h_out % 2; - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), - // vld1_u8(right_pad_idx + 8)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; - -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "sub %[din_ptr0], %[din_ptr0], #1 \n" - "sub %[din_ptr1], %[din_ptr1], #1 \n" - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ - - "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "sub %[din_ptr2], %[din_ptr2], #1 \n" - "sub %[din_ptr3], %[din_ptr3], #1 \n" - - "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, - 7); 00123456 */ - "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 - */ - "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 - */ - - // r2 - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 - */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 to - q0*/ - - "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * - w11 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * - w11 */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ - "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ - "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ - - "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ - "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - // r1 - "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ - "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ - "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - - "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "subs %[cnt], %[cnt], #1 \n" - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v21.8b, v14.8b \n" - "bif v1.8b, v21.8b, v15.8b \n" - "bif v2.8b, v21.8b, v14.8b \n" - "bif v3.8b, v21.8b, v15.8b \n" - - "ext v4.8b, v0.8b, v1.8b, #1 \n" - "ext v5.8b, v0.8b, v1.8b, #2 \n" - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v16.8b, v2.8b, v3.8b, #1 \n" - "ext v17.8b, v2.8b, v3.8b, #2 \n" - - "bif v6.8b, v21.8b, v14.8b \n" - "bif v7.8b, v21.8b, v15.8b \n" - - "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "bif v8.8b, v21.8b, v14.8b \n" - "bif v9.8b, v21.8b, v15.8b \n" - - "ext v20.8b, v6.8b, v7.8b, #1 \n" - "ext v22.8b, v6.8b, v7.8b, #2 \n" - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - // r1 - "ext v4.8b, v8.8b, v9.8b, #1 \n" - "ext v5.8b, v8.8b, v9.8b, #2 \n" - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v14.4s}, [%[rmask]], #16 \n" - "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" - "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v1.4s}, [%[ptr_out0]] \n" - "ld1 {v3.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "sub %[ptr_out0], %[ptr_out0], #16 \n" - "sub %[ptr_out1], %[ptr_out1], #16 \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "bif v10.16b, v0.16b, v14.16b \n" - "bif v11.16b, v1.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "bif v12.16b, v2.16b, v14.16b \n" - "bif v13.16b, v3.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [ptr_out0] "+r"(doutr0), - [ptr_out1] "+r"(doutr1), - [vmask] "+r"(val_mask), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "add %[din_ptr0], #7 @add \n" - "add %[din_ptr1], #7 @add \n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #7 @add \n" - "add %[din_ptr3], #7 @add \n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "add %[din_ptr0], #8 @add \n" - "add %[din_ptr1], #8 @add \n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #8 @add \n" - "add %[din_ptr3], #8 @add \n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "subs %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - - "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 - - "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d12, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with " - "right pad\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 - "vmull.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 - "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 - - "vbif.8 d14, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with " - "right pad\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 - "sub %[dout_ptr1], #16 @ sub \n" - "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 - "vbif q8, q14, q1 @ bit select, deal with right " - "pad\n" - "vbif q9, q6, q2 @ bit select, deal with right " - "pad\n" - "sub %[dout_ptr2], #16 @ sub \n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vbif q10, q7, q1 @ bit select, deal with right pad\n" - "vbif q11, q12, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += 2 * w_out; - } - } - } -} - -// w_in <= 8 -void conv_depthwise_3x3s1p1_bias_s_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_h = (h_out + 1) >> 1; - - unsigned int size_pad_right = (unsigned int)(w_in); - - uint8x8_t vmask_rp = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), - // vld1_u8(right_pad_idx + 8)); - unsigned char vmask[8]; - vst1_u8(vmask, vmask_rp); - - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - - int out_buf1[8]; - int out_buf2[8]; - int trash_buf[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = trash_buf; - } -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v4.8b}, [%[vmask]] \n" - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "bif v0.8b, v21.8b, v4.8b \n" - "bif v1.8b, v21.8b, v4.8b \n" - "bif v2.8b, v21.8b, v4.8b \n" - "bif v3.8b, v21.8b, v4.8b \n" - - "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v10.4s}, [%[vbias]] \n" - "ld1 {v11.4s}, [%[vbias]] \n" - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v12.4s}, [%[vbias]] \n" - "ld1 {v13.4s}, [%[vbias]] \n" - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v14.4s}, [%[rmask]], #16 \n" - // "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" - // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, - // 1); 12345678 - - // "ld1 {v0.4s}, [%[ptr_out0]] \n" - // "ld1 {v1.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - // "sub %[ptr_out0], %[ptr_out0], #16 \n" - // "sub %[ptr_out1], %[ptr_out1], #16 \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - // "bif v10.16b, v16.16b, v14.16b \n" - // "bif v11.16b, v0.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // "bif v12.16b, v17.16b, v14.16b \n" - // "bif v13.16b, v1.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out */ - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [vbias] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [vmask] "r"(vmask), - [ptr_out0] "r"(out_buf1), - [ptr_out1] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" - "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 - - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - // "sub %[dout_ptr1], #16 @ sub \n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - - // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 - // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 - // 2 3 4 5 6 7 8 9\n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - - // "vbif q8, q14, q1 @ bit select, deal with right - // pad\n" "vbif q9, q6, q2 @ bit select, deal - // with right pad\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - // "sub %[dout_ptr2], #16 @ sub \n" - - "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // "vbif q10, q3, q1 @ bit select, deal with right - // pad\n" "vbif q11, q7, q2 @ bit select, deal - // with right pad\n" - - "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" - // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [dout_ptr1] "r"(out_buf1), - [dout_ptr2] "r"(out_buf2) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - dout_ptr += 2 * w_out; - } - } - } -} - -// 4line w_in > 16 -void conv_depthwise_3x3s2p1_bias_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 15) >> 4; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); - if (size_pad_right == 17) { - size_pad_right = 0; - cnt_col++; - } - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - // printf("cnt_col: %d, rst_remain: %d, size_pad_right: %d\n", cnt_col, - // rst_remain, size_pad_right); - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } -#ifdef __aarch64__ - int cnt = cnt_col; - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v10.4s, #0x0\n" - // left - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - "add %[din_ptr0], %[din_ptr0], #15 \n" - "add %[din_ptr1], %[din_ptr1], #15 \n" - "add %[din_ptr2], %[din_ptr2], #15 \n" - - // r1 - "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smull v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smull v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ - "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ - "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ - - "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - - // r0 - "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ - - "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smull v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smull v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ - - "subs %[cnt], %[cnt], #1 \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v10.8b, v14.8b \n" - "bif v1.8b, v10.8b, v15.8b \n" - "bif v2.8b, v10.8b, v14.8b \n" - "bif v3.8b, v10.8b, v15.8b \n" - "bif v4.8b, v10.8b, v14.8b \n" - "bif v5.8b, v10.8b, v15.8b \n" - - "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468..*/ - "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - - // r0 - "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smull v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smull v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ - - "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ - "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "bif v12.16b, v0.16b, v9.16b \n" - "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [ptr_out0] "+r"(doutr0), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); -#else - unsigned int* rst_mask = rmask; - int cnt = cnt_col; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - "add %[din_ptr0], #15 @add \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "add %[din_ptr1], #15 @add \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - "add %[din_ptr2], #15 @add \n" - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - - "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 - // 4 6 - "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 - // 4 6 - "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 - // 4 6 - - "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 - // 4 6 8 - "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 - "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - - "subs %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "cmp %[size_pad_right], #1 \n" - "blt 3f \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 - // 1 3 5 - "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vbif q11, q6, q1 @ bit select, deal with right pad\n" - "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "3: \n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [dout_ptr1] "+r"(doutr0), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += w_out; - } - } - } -} -// w_in <= 16 -void conv_depthwise_3x3s2p1_bias_s_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - unsigned int size_pad_right = (unsigned int)(w_in); - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - int out_buf1[8]; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr2 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } -#ifdef __aarch64__ - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v16.4s, #0x0\n" - // left - "ld1 {v10.8b}, [%[vmask]], #8 \n" - "ld1 {v11.8b}, [%[vmask]] \n" - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "bif v0.8b, v16.8b, v10.8b \n" - "bif v1.8b, v16.8b, v11.8b \n" - "bif v2.8b, v16.8b, v10.8b \n" - "bif v3.8b, v16.8b, v11.8b \n" - "bif v4.8b, v16.8b, v10.8b \n" - "bif v5.8b, v16.8b, v11.8b \n" - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, - // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* - // dup v10, bias */ - - // r1 - "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smull v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smull v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // "bif v12.16b, v0.16b, v10.16b \n" - // "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out - */ - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask), - [ptr_out0] "r"(out_buf1) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - unsigned int* rst_mask = rmask; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // "pld [%[dout_ptr1]] @ preload data\n" - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - - // "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // "vbif q11, q6, q1 @ bit select, deal with right pad\n" - // "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [size_pad_right] "r"(size_pad_right), - [dout_ptr1] "r"(out_buf1) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - } - dout_ptr += w_out; - } - } - } -} - -// relu -void conv_depthwise_3x3s1p1_bias_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 7) >> 3; - int tile_h = (h_out + 1) >> 1; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); - - int size_pad_bottom = h_out % 2; - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), - // vld1_u8(right_pad_idx + 8)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "sub %[din_ptr0], %[din_ptr0], #1 \n" - "sub %[din_ptr1], %[din_ptr1], #1 \n" - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ - - "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "sub %[din_ptr2], %[din_ptr2], #1 \n" - "sub %[din_ptr3], %[din_ptr3], #1 \n" - - "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, - 7); 00123456 */ - "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - "smull v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 - */ - "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 - */ - - // r2 - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - "smull v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 - */ - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 to - q0*/ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * - w11 */ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ - "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * - w11 */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ - "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ - "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ - "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ - - "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ - "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - // r1 - "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ - "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 - to q0*/ - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ - "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - - "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ - "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "subs %[cnt], %[cnt], #1 \n" - - "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ - "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v21.8b, v14.8b \n" - "bif v1.8b, v21.8b, v15.8b \n" - "bif v2.8b, v21.8b, v14.8b \n" - "bif v3.8b, v21.8b, v15.8b \n" - - "ext v4.8b, v0.8b, v1.8b, #1 \n" - "ext v5.8b, v0.8b, v1.8b, #2 \n" - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v16.8b, v2.8b, v3.8b, #1 \n" - "ext v17.8b, v2.8b, v3.8b, #2 \n" - - "bif v6.8b, v21.8b, v14.8b \n" - "bif v7.8b, v21.8b, v15.8b \n" - - "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "bif v8.8b, v21.8b, v14.8b \n" - "bif v9.8b, v21.8b, v15.8b \n" - - "ext v20.8b, v6.8b, v7.8b, #1 \n" - "ext v22.8b, v6.8b, v7.8b, #2 \n" - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - // r1 - "ext v4.8b, v8.8b, v9.8b, #1 \n" - "ext v5.8b, v8.8b, v9.8b, #2 \n" - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v14.4s}, [%[rmask]], #16 \n" - "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - "smull v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" - "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v1.4s}, [%[ptr_out0]] \n" - "ld1 {v3.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "sub %[ptr_out0], %[ptr_out0], #16 \n" - "sub %[ptr_out1], %[ptr_out1], #16 \n" - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ - "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ - - "bif v10.16b, v0.16b, v14.16b \n" - "bif v11.16b, v1.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ - "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ - - "bif v12.16b, v2.16b, v14.16b \n" - "bif v13.16b, v3.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [ptr_out0] "+r"(doutr0), - [ptr_out1] "+r"(doutr1), - [vmask] "+r"(val_mask), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "add %[din_ptr0], #7 @add \n" - "add %[din_ptr1], #7 @add \n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #7 @add \n" - "add %[din_ptr3], #7 @add \n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vmov.u32 q0, #0 @ mov \n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "add %[din_ptr0], #8 @add \n" - "add %[din_ptr1], #8 @add \n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #8 @add \n" - "add %[din_ptr3], #8 @add \n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "subs %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - - "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 - - "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d12, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with " - "right pad\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 - "vmull.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 - "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 - - "vbif.8 d14, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with " - "right pad\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 - "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 - "vbif q8, q14, q1 @ bit select, deal with right " - "pad\n" - "vbif q9, q6, q2 @ bit select, deal with right " - "pad\n" - "sub %[dout_ptr1], #16 @ sub \n" - "sub %[dout_ptr2], #16 @ sub \n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - "vbif q10, q7, q1 @ bit select, deal with right pad\n" - "vbif q11, q12, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += 2 * w_out; - } - } - } -} -// w_in <= 8 -void conv_depthwise_3x3s1p1_bias_s_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_h = (h_out + 3) >> 2; - - unsigned int size_pad_right = (unsigned int)(w_in); - - int size_pad_bottom = h_out % 4; - - uint8x8_t vmask_rp = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - unsigned char vmask[8]; - vst1_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - int out_buf1[8]; - int out_buf2[8]; - int trash_buf[8]; - - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = trash_buf; - } -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v4.8b}, [%[vmask]] \n" - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "bif v0.8b, v21.8b, v4.8b \n" - "bif v1.8b, v21.8b, v4.8b \n" - "bif v2.8b, v21.8b, v4.8b \n" - "bif v3.8b, v21.8b, v4.8b \n" - - "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v10.4s}, [%[vbias]] \n" - "ld1 {v11.4s}, [%[vbias]] \n" - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v12.4s}, [%[vbias]] \n" - "ld1 {v13.4s}, [%[vbias]] \n" - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v14.4s}, [%[rmask]], #16 \n" - // "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" - // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, - // 1); 12345678 - - // "ld1 {v0.4s}, [%[ptr_out0]] \n" - // "ld1 {v1.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - // "sub %[ptr_out0], %[ptr_out0], #16 \n" - // "sub %[ptr_out1], %[ptr_out1], #16 \n" - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r3 - "smull v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu */ - "smax v11.4s, v11.4s, v21.4s \n" /* relu */ - - // "bif v10.16b, v16.16b, v14.16b \n" - // "bif v11.16b, v0.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v21.4s \n" /* relu */ - "smax v13.4s, v13.4s, v21.4s \n" /* relu */ - - // "bif v12.16b, v17.16b, v14.16b \n" - // "bif v13.16b, v1.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out */ - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [vbias] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [vmask] "r"(vmask), - [ptr_out0] "r"(out_buf1), - [ptr_out1] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" - "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 - - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - - "vmlal.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - // "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - - "vmov.u32 q0, #0 @ zero\n" - - // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 - // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 - // 2 3 4 5 6 7 8 9\n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - // "sub %[dout_ptr2], #16 @ sub \n" - // "vbif q8, q14, q1 @ bit select, deal with right - // pad\n" "vbif q9, q6, q2 @ bit select, deal - // with right pad\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - // "vbif q10, q3, q1 @ bit select, deal with right - // pad\n" "vbif q11, q7, q2 @ bit select, deal - // with right pad\n" - - "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" - // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [dout_ptr1] "r"(out_buf1), - [dout_ptr2] "r"(out_buf2) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - dout_ptr += 2 * w_out; - } - } - } -} - -// 1 line w_in > 16 -void conv_depthwise_3x3s2p1_bias_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 15) >> 4; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); - if (size_pad_right == 17) { - size_pad_right = 0; - cnt_col++; - } - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; - -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } - int cnt = cnt_col; -#ifdef __aarch64__ - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v10.4s, #0x0\n" - // left - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - "add %[din_ptr0], %[din_ptr0], #15 \n" - "add %[din_ptr1], %[din_ptr1], #15 \n" - "add %[din_ptr2], %[din_ptr2], #15 \n" - - // r1 - "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smull v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smull v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ - "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ - "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ - - "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - - // r0 - "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ - - "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smull v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smull v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ - - "subs %[cnt], %[cnt], #1 \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v10.8b, v14.8b \n" - "bif v1.8b, v10.8b, v15.8b \n" - "bif v2.8b, v10.8b, v14.8b \n" - "bif v3.8b, v10.8b, v15.8b \n" - "bif v4.8b, v10.8b, v14.8b \n" - "bif v5.8b, v10.8b, v15.8b \n" - - "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468..*/ - "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - - // r0 - "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smull v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smull v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ - - "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ - "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ - - "bif v12.16b, v0.16b, v9.16b \n" - "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [ptr_out0] "+r"(doutr0), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); -#else - unsigned int* rst_mask = rmask; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - "add %[din_ptr0], #15 @add \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "add %[din_ptr1], #15 @add \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "add %[din_ptr2], #15 @add \n" - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmov.u32 q8, #0 @ max \n" // max - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - - "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 - // 4 6 - "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 - // 4 6 - "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 - // 4 6 - - "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 - // 4 6 8 - "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 - "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vmov.u32 q8, #0 @ mov \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - - "subs %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "cmp %[size_pad_right], #1 \n" - "blt 3f \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 - // 1 3 5 - "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "sub %[dout_ptr1], #16 @ sub \n" - "vmov.u32 q8, #0 @mov \n" - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - "vbif q11, q6, q1 @ bit select, deal with right pad\n" - "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "3: \n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [dout_ptr1] "+r"(doutr0), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += w_out; - } - } - } -} -// w_in <= 16 -void conv_depthwise_3x3s2p1_bias_s_relu_int8(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - unsigned int size_pad_right = (unsigned int)(w_in); - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; - -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - - int out_buf1[8]; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr2 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } -#ifdef __aarch64__ - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v16.4s, #0x0\n" - // left - "ld1 {v10.8b}, [%[vmask]], #8 \n" - "ld1 {v11.8b}, [%[vmask]] \n" - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "bif v0.8b, v16.8b, v10.8b \n" - "bif v1.8b, v16.8b, v11.8b \n" - "bif v2.8b, v16.8b, v10.8b \n" - "bif v3.8b, v16.8b, v11.8b \n" - "bif v4.8b, v16.8b, v10.8b \n" - "bif v5.8b, v16.8b, v11.8b \n" - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, - // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* - // dup v10, bias */ - - // r1 - "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // r2 - "smull v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smull v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smull v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v16.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v16.4s \n" /*relu*/ - - // "bif v12.16b, v0.16b, v10.16b \n" - // "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out - */ - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask), - [ptr_out0] "r"(out_buf1) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); - -#else - unsigned int* rst_mask = rmask; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // "pld [%[dout_ptr1]] @ preload data\n" - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // r2 - "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - - // "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vmov.u32 q8, #0 @ mov \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - // "vbif q11, q6, q1 @ bit select, deal with right pad\n" - // "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [size_pad_right] "r"(size_pad_right), - [dout_ptr1] "r"(out_buf1) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - } - dout_ptr += w_out; - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_3x3p0.cc b/lite/backends/arm/math/conv_depthwise_3x3p0.cc index ec7f3cfb843af57767a4a97213bfc75a326fabaa..0c050ffe6fb0f064f5c26ea0da6acee17f4403ae 100644 --- a/lite/backends/arm/math/conv_depthwise_3x3p0.cc +++ b/lite/backends/arm/math/conv_depthwise_3x3p0.cc @@ -128,21 +128,21 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, const int w_out, ARMContext* ctx); -void conv_depthwise_3x3p0(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { +void conv_depthwise_3x3p0_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { if (stride == 1) { if (flag_relu) { if (w_in > 5) { diff --git a/lite/backends/arm/math/conv_depthwise_3x3p1.cc b/lite/backends/arm/math/conv_depthwise_3x3p1.cc index b5de99d7f5c1f96cdb934ee5038882b8770d6c7f..6f28d48d6d2bdd60e0c33f9b4b753835337fc8a4 100644 --- a/lite/backends/arm/math/conv_depthwise_3x3p1.cc +++ b/lite/backends/arm/math/conv_depthwise_3x3p1.cc @@ -128,21 +128,21 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, const int w_out, ARMContext* ctx); -void conv_depthwise_3x3p1(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { +void conv_depthwise_3x3p1_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { if (stride == 1) { if (flag_relu) { if (w_in > 4) { diff --git a/lite/backends/arm/math/conv_depthwise_3x3s1.cc b/lite/backends/arm/math/conv_depthwise_3x3s1.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d0ebb58ad1b7e325bae3649b13914641021038f --- /dev/null +++ b/lite/backends/arm/math/conv_depthwise_3x3s1.cc @@ -0,0 +1,2539 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/conv_depthwise.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_3x3s1p0_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx); + +void conv_depthwise_3x3s1p0_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx); + +void conv_depthwise_3x3s1p1_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx); + +void conv_depthwise_3x3s1p1_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx); + +void conv_depthwise_3x3s1_fp32(const float *din, + float *dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float *weights, + const float *bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext *ctx) { + if (pad == 0) { + if (w_in > 5) { + conv_depthwise_3x3s1p0_bias(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p0_bias_s(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + if (pad == 1) { + if (w_in > 4) { + conv_depthwise_3x3s1p1_bias(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } +} + +#ifdef __aarch64__ +#define INIT_S1 \ + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr4]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr5]] \n" \ + "movi v21.4s, #0x0\n" /* out0 = 0 */ \ + \ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + +#define LEFT_COMPUTE_S1 \ + "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ /* r0 */ \ + "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * w0[1]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ \ + "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ \ + \ + "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * w0[0]*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ \ + "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ \ + \ + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * w0[2]*/ \ + \ + "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ /* r1 */ \ + "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * w1[1]*/ \ + "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ \ + "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ \ + \ + "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16=1234 */ \ + "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ \ + \ + /* r2 */ \ + "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ + \ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ + +#define LEFT_RESULT_S1 \ + /* r4 */ \ + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ /* r5 */ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "blt 3f \n" + +#define MID_COMPUTE_S1 \ + "1: \n" /* r0 */ \ + "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + +#define MID_RESULT_S1 \ + /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define RIGHT_COMPUTE_S1 \ + "3: \n" \ + "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ + "ld1 {v22.4s}, [%[doutr0]] \n" \ + "ld1 {v23.4s}, [%[doutr1]] \n" \ + "ld1 {v24.4s}, [%[doutr2]] \n" \ + "ld1 {v25.4s}, [%[doutr3]] \n" \ + \ + "bif v0.16b, %[vzero].16b, v18.16b \n" \ + "bif v1.16b, %[vzero].16b, v19.16b \n" \ + "bif v2.16b, %[vzero].16b, v18.16b \n" \ + "bif v3.16b, %[vzero].16b, v19.16b \n" \ + \ + "bif v4.16b, %[vzero].16b, v18.16b \n" \ + "bif v5.16b, %[vzero].16b, v19.16b \n" \ + "bif v6.16b, %[vzero].16b, v18.16b \n" \ + "bif v7.16b, %[vzero].16b, v19.16b \n" \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ + "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v8.16b, %[vzero].16b, v18.16b \n" \ + "bif v9.16b, %[vzero].16b, v19.16b \n" \ + "bif v10.16b, %[vzero].16b, v18.16b \n" \ + "bif v11.16b, %[vzero].16b, v19.16b \n" \ + \ + "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v18.4s}, [%[rmask]] \n" \ + \ + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + +#define RIGHT_RESULT_S1 \ + /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v12.16b, v22.16b, v18.16b \n" \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define LEFT_RESULT_S1_RELU \ + /* r4 */ \ + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + +#define MID_RESULT_S1_RELU \ + /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define RIGHT_RESULT_S1_RELU \ + /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v22.16b, v18.16b \n" \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define COMPUTE_S_S1 \ + "prfm pldl1keep, [%[din0]]\n" \ + "prfm pldl1keep, [%[din1]]\n" \ + "prfm pldl1keep, [%[din2]]\n" \ + "prfm pldl1keep, [%[din3]]\n" \ + \ + "ld1 {v0.4s}, [%[din0]], #16\n" \ + "ld1 {v1.4s}, [%[din1]], #16\n" \ + "ld1 {v2.4s}, [%[din2]], #16\n" \ + "ld1 {v3.4s}, [%[din3]], #16\n" \ + \ + "bif v0.16b, %[zero].16b, %[mask].16b\n" \ + "bif v1.16b, %[zero].16b, %[mask].16b\n" \ + "bif v2.16b, %[zero].16b, %[mask].16b\n" \ + "bif v3.16b, %[zero].16b, %[mask].16b\n" \ + \ + "ext v4.16b, %[zero].16b, v0.16b, #12\n" \ + "ext v5.16b, %[zero].16b, v1.16b, #12\n" \ + "ext v6.16b, %[zero].16b, v2.16b, #12\n" \ + "ext v7.16b, %[zero].16b, v3.16b, #12\n" \ + \ + "ext v8.16b, v0.16b, %[zero].16b, #4\n" \ + "ext v9.16b, v1.16b, %[zero].16b, #4\n" \ + "ext v10.16b, v2.16b, %[zero].16b, #4\n" \ + "ext v11.16b, v3.16b, %[zero].16b, #4\n" \ + \ + "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ + "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ + \ + "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ + "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ + \ + "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ + "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ + \ + "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ + "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ + \ + "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ + "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ + \ + "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ + "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ + \ + "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ + \ + "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ + "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ + \ + "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ + "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ + \ + "fadd v12.4s, v12.4s, v14.4s\n" \ + "fadd v12.4s, v12.4s, v16.4s\n" \ + \ + "fadd v13.4s, v13.4s, v15.4s\n" \ + "fadd v13.4s, v13.4s, v17.4s\n" \ + \ + "fadd v12.4s, v12.4s, %[bias].4s\n" \ + "fadd v13.4s, v13.4s, %[bias].4s\n" + +#define RESULT_S_S1 \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define RESULT_S_S1_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[zero].4s\n" \ + "fmax v13.4s, v13.4s, %[zero].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define COMPUTE_S_S1_P0 \ + "prfm pldl1keep, [%[din0]]\n" \ + "prfm pldl1keep, [%[din1]]\n" \ + "prfm pldl1keep, [%[din2]]\n" \ + "prfm pldl1keep, [%[din3]]\n" \ + \ + "ld1 {v0.4s, v1.4s}, [%[din0]]\n" \ + "ld1 {v2.4s, v3.4s}, [%[din1]]\n" \ + "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ + "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ + \ + "bif v0.16b, %[zero].16b, %[mask1].16b\n" \ + "bif v1.16b, %[zero].16b, %[mask2].16b\n" \ + \ + "bif v2.16b, %[zero].16b, %[mask1].16b\n" \ + "bif v3.16b, %[zero].16b, %[mask2].16b\n" \ + \ + "bif v4.16b, %[zero].16b, %[mask1].16b\n" \ + "bif v5.16b, %[zero].16b, %[mask2].16b\n" \ + \ + "bif v6.16b, %[zero].16b, %[mask1].16b\n" \ + "bif v7.16b, %[zero].16b, %[mask2].16b\n" \ + \ + "ext v8.16b, v0.16b, v1.16b, #4\n" \ + "ext v9.16b, v0.16b, v1.16b, #8\n" \ + \ + "and v12.16b, %[vbias].16b, %[vbias].16b \n" \ + "and v13.16b, %[vbias].16b, %[vbias].16b \n" /* r0 */ \ + "fmul v10.4s, v0.4s, %[wr0].s[0]\n" \ + "fmul v11.4s, v8.4s, %[wr0].s[1]\n" \ + "fmla v12.4s, v9.4s, %[wr0].s[2]\n" \ + \ + "ext v8.16b, v2.16b, v3.16b, #4\n" \ + "ext v9.16b, v2.16b, v3.16b, #8\n" /* r1 */ \ + "fmul v14.4s, v2.4s, %[wr0].s[0]\n" \ + "fmla v10.4s, v2.4s, %[wr1].s[0]\n" \ + \ + "fmul v15.4s, v8.4s, %[wr0].s[1]\n" \ + "fmla v11.4s, v8.4s, %[wr1].s[1]\n" \ + \ + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ + "fmla v12.4s, v9.4s, %[wr1].s[2]\n" \ + \ + "ext v8.16b, v4.16b, v5.16b, #4\n" \ + "ext v9.16b, v4.16b, v5.16b, #8\n" /* r2 */ \ + "fmla v14.4s, v4.4s, %[wr1].s[0]\n" \ + "fmla v10.4s, v4.4s, %[wr2].s[0]\n" \ + \ + "fmla v15.4s, v8.4s, %[wr1].s[1]\n" \ + "fmla v11.4s, v8.4s, %[wr2].s[1]\n" \ + \ + "fmla v13.4s, v9.4s, %[wr1].s[2]\n" \ + "fmla v12.4s, v9.4s, %[wr2].s[2]\n" \ + \ + "ext v8.16b, v6.16b, v7.16b, #4\n" \ + "ext v9.16b, v6.16b, v7.16b, #8\n" \ + \ + "fmla v14.4s, v6.4s, %[wr2].s[0]\n" \ + \ + "fmla v15.4s, v8.4s, %[wr2].s[1]\n" \ + \ + "fadd v12.4s, v12.4s, v10.4s\n" \ + \ + "fmla v13.4s, v9.4s, %[wr2].s[2]\n" \ + \ + "fadd v12.4s, v12.4s, v11.4s\n" \ + "fadd v13.4s, v13.4s, v14.4s\n" \ + "fadd v13.4s, v13.4s, v15.4s\n" // \ + // "prfm pldl1keep, [%[out1]]\n" \ + // "prfm pldl1keep, [%[out2]]\n" \ + // \ + // "st1 {v12.4s}, [%[out1]]\n" \ + // "st1 {v13.4s}, [%[out2]]\n" \ + + +#else +#define INIT_S1 \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + "pld [%[din3_ptr]] @ preload data\n" \ + \ + "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" \ + "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" \ + "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" \ + "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" \ + \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vdup.32 q5, %[bias_val] @ and \n" + +#define LEFT_COMPUTE_S1 \ + "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" \ + "vext.32 q7, q8, q9, #1 @ 1234\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" \ + "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" \ + "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" \ + "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + "pld [%[din3_ptr]] @ preload data\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" \ + "vext.32 q7, q10, q11, #1 @ 1234\n" \ + \ + /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" \ + "vext.32 q7, q12, q13, #1 @ 1234\n" \ + \ + /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" \ + "vext.32 q7, q14, q15, #1 @ 1234\n" + +#define LEFT_RESULT_S1 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define MID_COMPUTE_S1 \ + "1: @ right pad entry\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + "pld [%[din3_ptr]] @ preload data\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ + \ + "vext.32 q6, q10, q11, #1 @ 1234\n" \ + "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q12, q13, #1 @ 1234\n" \ + "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q14, q15, #1 @ 1234\n" \ + "vext.32 q7, q14, q15, #2 @ 2345\n" + +#define MID_RESULT_S1 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define RIGHT_COMPUTE_S1 \ + "3: @ right pad entry\n" \ + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ + \ + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ + "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" \ + \ + "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q10, q11, #1 @ 1234\n" \ + "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" \ + "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" \ + "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q12, q13, #1 @ 1234\n" \ + "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q14, q15, #1 @ 1234\n" \ + "vext.32 q7, q14, q15, #2 @ 2345\n" + +#define RIGHT_RESULT_S1 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define LEFT_RESULT_S1_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define MID_RESULT_S1_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define RIGHT_RESULT_S1_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define COMPUTE_S_S1 \ + "pld [%[din0]]\n" \ + "pld [%[din1]]\n" \ + "pld [%[din2]]\n" \ + "pld [%[din3]]\n" \ + \ + "vld1.32 {d12-d13}, [%[din0]]!\n" \ + "vld1.32 {d14-d15}, [%[din1]]!\n" \ + "vld1.32 {d16-d17}, [%[din2]]!\n" \ + "vld1.32 {d18-d19}, [%[din3]]!\n" \ + \ + "vbif q6, %q[vzero], %q[mask]\n" \ + "vbif q7, %q[vzero], %q[mask]\n" \ + "vbif q8, %q[vzero], %q[mask]\n" \ + "vbif q9, %q[vzero], %q[mask]\n" \ + \ + "vmul.f32 q14, q6, %e[wr0][1]\n" \ + "vmul.f32 q15, q7, %e[wr0][1]\n" \ + \ + "vmla.f32 q14, q7, %e[wr1][1]\n" \ + "vmla.f32 q15, q8, %e[wr1][1]\n" \ + \ + "vmla.f32 q14, q8, %e[wr2][1]\n" \ + "vmla.f32 q15, q9, %e[wr2][1]\n" \ + \ + "vext.32 q10, %q[vzero], q6, #3\n" \ + "vext.32 q11, %q[vzero], q7, #3\n" \ + "vext.32 q12, %q[vzero], q8, #3\n" \ + "vext.32 q13, %q[vzero], q9, #3\n" \ + \ + "vmla.f32 q14, q10, %e[wr0][0]\n" \ + "vmla.f32 q15, q11, %e[wr0][0]\n" \ + \ + "vmla.f32 q14, q11, %e[wr1][0]\n" \ + "vmla.f32 q15, q12, %e[wr1][0]\n" \ + \ + "vmla.f32 q14, q12, %e[wr2][0]\n" \ + "vmla.f32 q15, q13, %e[wr2][0]\n" \ + \ + "vext.32 q10, q6, %q[vzero], #1\n" \ + "vext.32 q11, q7, %q[vzero], #1\n" \ + "vext.32 q12, q8, %q[vzero], #1\n" \ + "vext.32 q13, q9, %q[vzero], #1\n" \ + \ + "vmla.f32 q14, q10, %f[wr0][0]\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" \ + \ + "vmla.f32 q14, q11, %f[wr1][0]\n" \ + "vmla.f32 q15, q12, %f[wr1][0]\n" \ + \ + "vmla.f32 q14, q12, %f[wr2][0]\n" \ + "vmla.f32 q15, q13, %f[wr2][0]\n" \ + \ + "vadd.f32 q14, q14, %q[bias]\n" \ + "vadd.f32 q15, q15, %q[bias]\n" + +#define RESULT_S_S1 \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define RESULT_S_S1_RELU \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define COMPUTE_S_S1_P0 \ + "pld [%[din0]]\n" \ + "pld [%[din1]]\n" \ + "pld [%[din2]]\n" \ + "pld [%[din3]]\n" \ + "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" \ + "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" \ + "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" \ + "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" \ + \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ + \ + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ + \ + "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ + \ + "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ + \ + "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ + "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q10, q11, #1 @ 1234\n" \ + "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ + "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q12, q13, #1 @ 1234\n" \ + "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q14, q15, #1 @ 1234\n" \ + "vext.32 q7, q14, q15, #2 @ 2345\n" /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + "vadd.f32 q4, q4, q10 @ q4 += q10 \n" \ + \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vadd.f32 q14, q4, q11 @ q4 += q10 \n" \ + \ + "vadd.f32 q5, q5, q8 @ q4 += q10 \n" \ + "vadd.f32 q15, q5, q9 @ q4 += q10 \n" + +#endif +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ +void conv_depthwise_3x3s1p1_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float *zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float *write_ptr = zero_ptr + w_in; + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 3) >> 2; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + float *dout_ptr = dout_batch + c * size_out_channel; + + const float *din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float *wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + + float *doutr0 = dout_ptr; + float *doutr1 = doutr0 + w_out; + float *doutr2 = doutr1 + w_out; + float *doutr3 = doutr2 + w_out; + + const float *dr0 = din_ch_ptr; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + const float *dr4 = dr3 + w_in; + const float *dr5 = dr4 + w_in; + + const float *din_ptr0 = dr0; + const float *din_ptr1 = dr1; + const float *din_ptr2 = dr2; + const float *din_ptr3 = dr3; + const float *din_ptr4 = dr4; + const float *din_ptr5 = dr5; + float *ptr_zero = const_cast(zero); +#ifdef __aarch64__ + for (int i = 0; i < h_in; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + din_ptr4 = dr3; + din_ptr5 = dr4; + dr0 = dr3; + dr1 = dr4; + dr2 = dr5; + } else { + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + } + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 > h_in) { + switch (i + 5 - h_in) { + case 5: + din_ptr1 = zero_ptr; + case 4: + din_ptr2 = zero_ptr; + case 3: + din_ptr3 = zero_ptr; + case 2: + din_ptr4 = zero_ptr; + case 1: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = cnt_col; + if (flag_relu) { + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } + dout_ptr = dout_ptr + 4 * w_out; + } +#else + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = dout_ptr + w_out; + // unsigned int* rst_mask = rmask; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; + unsigned int *rmask_ptr = rmask; + unsigned int *vmask_ptr = vmask; + if (flag_relu) { + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + dout_ptr += 2 * w_out; + } //! end of processing mid rows +#endif + } + } +} + +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p1_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[4] = {3, 2, 1, 0}; + const float zero[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float *dout_channel = dout_batch + i * size_out_channel; + const float *din_channel = din_batch + i * size_in_channel; + const float *weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } + + int hs = -1; + int he = 3; + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + int h_cnt = (h_out + 1) >> 1; + float *doutr0 = dout_channel; + float *doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_cnt; ++j) { + const float *dr0 = din_channel + hs * w_in; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + + if (hs == -1) { + dr0 = zero; + } + + switch (he - h_in) { + case 2: + dr2 = zero; + doutr1 = trash_buf; + case 1: + dr3 = zero; + default: + break; + } +#ifdef __aarch64__ + if (flag_relu) { + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [zero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + } else { + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [zero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + } +#else + if (flag_relu) { + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + doutr0 = doutr1; + doutr1 += w_out; + hs += 2; + he += 2; + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} + +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ +void conv_depthwise_3x3s1p0_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float *zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float *write_ptr = zero_ptr + w_in; + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = w_out >> 2; + int remain = w_out % 4; + + unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); + const int remian_idx[4] = {0, 1, 2, 3}; + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + float *dout_ptr = dout_batch + c * size_out_channel; + + const float *din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float *wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + + float *doutr0 = dout_ptr; + float *doutr1 = doutr0 + w_out; + float *doutr2 = doutr1 + w_out; + float *doutr3 = doutr2 + w_out; + + const float *dr0 = din_ch_ptr; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + const float *dr4 = dr3 + w_in; + const float *dr5 = dr4 + w_in; + + const float *din_ptr0 = dr0; + const float *din_ptr1 = dr1; + const float *din_ptr2 = dr2; + const float *din_ptr3 = dr3; + const float *din_ptr4 = dr4; + const float *din_ptr5 = dr5; + + float *ptr_zero = const_cast(zero); +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 >= h_in) { + switch (i + 5 - h_in) { + case 4: + din_ptr1 = zero_ptr; + case 3: + din_ptr2 = zero_ptr; + case 2: + din_ptr3 = zero_ptr; + case 1: + din_ptr4 = zero_ptr; + case 0: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = tile_w; + if (flag_relu) { + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } else { + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1 "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } + dout_ptr = dout_ptr + 4 * w_out; + } +#else + for (int i = 0; i < h_out; i += 2) { + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = dout_ptr + w_out; + + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + //! process bottom pad + if (i + 3 >= h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + case 0: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = tile_w; + unsigned int *rmask_ptr = rmask; + unsigned int *vmask_ptr = vmask; + if (flag_relu) { + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1 "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + dout_ptr += 2 * w_out; + } //! end of processing mid rows +#endif + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p0_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext *ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp1 = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); + uint32x4_t vmask_rp2 = + vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float *dout_channel = dout_batch + i * size_out_channel; + const float *din_channel = din_batch + i * size_in_channel; + const float *weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#endif // __aarch64__ + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + float *doutr0 = dout_channel; + float *doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_out; j += 2) { + const float *dr0 = din_channel + j * w_in; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + + doutr0 = dout_channel + j * w_out; + doutr1 = doutr0 + w_out; + + if (j + 3 >= h_in) { + switch (j + 3 - h_in) { + case 3: + dr1 = zero_ptr; + case 2: + dr2 = zero_ptr; + case 1: + dr3 = zero_ptr; + doutr1 = trash_buf; + case 0: + dr3 = zero_ptr; + doutr1 = trash_buf; + default: + break; + } + } +#ifdef __aarch64__ + if (flag_relu) { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [zero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } else { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [zero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } +#else + unsigned int *vmask_ptr = vmask; + float bias_val = flag_bias ? bias[i] : 0.f; + if (flag_relu) { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_3x3s2.cc b/lite/backends/arm/math/conv_depthwise_3x3s2.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec039af98cb7e4fb037475dd4e5ee29204252165 --- /dev/null +++ b/lite/backends/arm/math/conv_depthwise_3x3s2.cc @@ -0,0 +1,1862 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/conv_depthwise.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void conv_depthwise_3x3s2p0_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + if (pad == 0) { + if (w_in > 7) { + conv_depthwise_3x3s2p0_bias(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + if (pad == 1) { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } +} +#ifdef __aarch64__ +#define INIT_S2 \ + "prfm pldl1keep, [%[inptr0]] \n" \ + "prfm pldl1keep, [%[inptr1]] \n" \ + "prfm pldl1keep, [%[inptr2]] \n" \ + "prfm pldl1keep, [%[inptr3]] \n" \ + "prfm pldl1keep, [%[inptr4]] \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" + +#define LEFT_COMPUTE_S2 \ + "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[1] \n" /* {0,2,4,6} * w01 */ \ + "fmul v12.4s, v1.4s, %[w0].s[2] \n" /* {1,3,5,7} * w02 */ \ + "fmla v16.4s, v10.4s, %[w0].s[0] \n" /* {0,1,3,5} * w00*/ \ + \ + "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" /* v10 = {0,1,3,5} */ \ + \ + "sub %[inptr0], %[inptr0], #4 \n" \ + "sub %[inptr1], %[inptr1], #4 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[1] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" \ + \ + "sub %[inptr2], %[inptr2], #4 \n" \ + "sub %[inptr3], %[inptr3], #4 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[1] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[1] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[2] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[2] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[0] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" \ + \ + "sub %[inptr4], %[inptr4], #4 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[1] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" + +#define LEFT_RESULT_S2 \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define MID_COMPUTE_S2 \ + "2: \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, v18.16b, #4 \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, v19.16b, #4 \n" \ + \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, v20.16b, #4 \n" \ + \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, v21.16b, #4 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" + +#define MID_RESULT_S2 \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_COMPUTE_S2 \ + "1: \n" \ + "cmp %w[remain], #1 \n" \ + "blt 4f \n" \ + "3: \n" \ + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" \ + \ + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" \ + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" \ + "ld1 {v0.4s}, [%[outptr0]] \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" \ + "ld1 {v1.4s}, [%[outptr1]] \n" + +#define RIGHT_RESULT_S2 \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define LEFT_RESULT_S2_RELU \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ + \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define MID_RESULT_S2_RELU \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define COMPUTE_S_S2 \ + "movi v9.4s, #0 \n" \ + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ + \ + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \ + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \ + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \ + \ + "bif v10.16b, v9.16b, v6.16b \n" \ + "bif v11.16b, v9.16b, v7.16b \n" \ + "bif v12.16b, v9.16b, v6.16b \n" \ + "bif v13.16b, v9.16b, v7.16b \n" \ + "bif v14.16b, v9.16b, v6.16b \n" \ + "bif v15.16b, v9.16b, v7.16b \n" \ + \ + "ext v6.16b, v9.16b, v11.16b, #12 \n" \ + "ext v7.16b, v9.16b, v13.16b, #12 \n" \ + "ext v8.16b, v9.16b, v15.16b, #12 \n" \ + \ + "fmul v4.4s, v10.4s, %[wr0].s[1] \n" \ + "fmul v5.4s, v11.4s, %[wr0].s[2] \n" \ + "fmul v6.4s, v6.4s, %[wr0].s[0] \n" \ + \ + "fmla v4.4s, v12.4s, %[wr1].s[1] \n" \ + "fmla v5.4s, v13.4s, %[wr1].s[2] \n" \ + "fmla v6.4s, v7.4s, %[wr1].s[0] \n" \ + \ + "fmla v4.4s, v14.4s, %[wr2].s[1] \n" \ + "fmla v5.4s, v15.4s, %[wr2].s[2] \n" \ + "fmla v6.4s, v8.4s, %[wr2].s[0] \n" \ + \ + "fadd v4.4s, v4.4s, v5.4s \n" \ + "fadd v4.4s, v4.4s, v6.4s \n" + +#define RESULT_S_S2 \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" + +#define RESULT_S_S2_RELU \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" + +#define COMPUTE_S_S2_P0 \ + "movi v9.4s, #0 \n" \ + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ + \ + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \ + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \ + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \ + "and v4.16b, %[bias].16b, %[bias].16b \n" \ + \ + "bif v10.16b, v9.16b, v6.16b \n" \ + "bif v11.16b, v9.16b, v7.16b \n" \ + "bif v12.16b, v9.16b, v6.16b \n" \ + "bif v13.16b, v9.16b, v7.16b \n" \ + "bif v14.16b, v9.16b, v6.16b \n" \ + "bif v15.16b, v9.16b, v7.16b \n" \ + \ + "ext v6.16b, v10.16b, v9.16b, #4 \n" \ + "ext v7.16b, v12.16b, v9.16b, #4 \n" \ + "ext v8.16b, v14.16b, v9.16b, #4 \n" \ + \ + "fmla v4.4s, v10.4s, %[wr0].s[0] \n" \ + "fmul v5.4s, v11.4s, %[wr0].s[1] \n" \ + "fmul v16.4s, v6.4s, %[wr0].s[2] \n" \ + \ + "fmla v4.4s, v12.4s, %[wr1].s[0] \n" \ + "fmla v5.4s, v13.4s, %[wr1].s[1] \n" \ + "fmla v16.4s, v7.4s, %[wr1].s[2] \n" \ + \ + "fmla v4.4s, v14.4s, %[wr2].s[0] \n" \ + "fmla v5.4s, v15.4s, %[wr2].s[1] \n" \ + "fmla v16.4s, v8.4s, %[wr2].s[2] \n" \ + \ + "fadd v4.4s, v4.4s, v5.4s \n" \ + "fadd v4.4s, v4.4s, v16.4s \n" + +#define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n" + +#define RESULT_S_S2_P0_RELU \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + "st1 {v4.4s}, [%[out]] \n" + +#else +#define INIT_S2 \ + "vmov.u32 q9, #0 \n" \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + \ + "vdup.32 q3, %[bias] @ and \n" + +#define LEFT_COMPUTE_S2 \ + "vext.32 q6, q9, q11, #3 @ shift right 1 data\n" \ + "vext.32 q7, q9, q13, #3 @ shift right 1 data\n" \ + "vext.32 q8, q9, q15, #3 @ shift right 1 data\n" \ + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, out0\n" \ + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, out0\n" \ + \ + "sub %[din0_ptr], #4 @ inpitr0 - 1\n" \ + "sub %[din1_ptr], #4 @ inpitr1 - 1\n" \ + "sub %[din2_ptr], #4 @ inpitr2 - 1\n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \ + \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, out1\n" \ + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \ + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \ + \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define LEFT_RESULT_S2 \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "cmp %[cnt], #1 \n" \ + "blt 1f \n" + +#define MID_COMPUTE_S2 \ + "2: \n" \ + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" \ + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" \ + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" \ + \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define MID_RESULT_S2 \ + "subs %[cnt], #1 \n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "bne 2b \n" + +#define RIGHT_COMPUTE_S2 \ + "1: \n" \ + "cmp %[remain], #1 \n" \ + "blt 3f \n" \ + \ + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" \ + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" \ + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define RIGHT_RESULT_S2 \ + "vbif.f32 q3, q10, q11 @ write mask\n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "3: \n" + +#define LEFT_RESULT_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "cmp %[cnt], #1 \n" \ + "blt 1f \n" + +#define MID_RESULT_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "subs %[cnt], #1 \n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vbif.f32 q3, q10, q11 @ write mask\n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "3: \n" + +#define COMPUTE_S_S2 \ + "vmov.u32 q9, #0 \n" \ + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q9, q11, #3 @ shift left 1 \n" \ + "vext.32 q7, q9, q13, #3 @ shift left 1 \n" \ + "vext.32 q8, q9, q15, #3 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, out0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" + +#define RESULT_S_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu\n" \ + \ + "vst1.32 {d6-d7}, [%[out]] \n" + +#define COMPUTE_S_S2_P0 \ + "vmov.u32 q9, #0 \n" \ + "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" \ + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" \ + "vext.32 q8, q14, q9, #1 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" + +#define RESULT_S_S2_P0_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vst1.32 {d6-d7}, [%[out]] \n" + +#endif + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + * w_in > 7 + */ +void conv_depthwise_3x3s2p1_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int cnt_col = (w_out >> 2) - 2; + int size_right_remain = w_in - (7 + cnt_col * 8); + if (size_right_remain >= 9) { + cnt_col++; + size_right_remain -= 8; + } + int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + + int size_right_pad = w_out * 2 - w_in; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_in; i += 4) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 > h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i / 2 + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + if (flag_relu) { + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 + MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } else { + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_in; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + unsigned int* mask_ptr = dmask; + if (flag_relu) { + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 + MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ +void conv_depthwise_3x3s2p1_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + if (flag_relu) { + asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } else { + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } +#else + if (flag_relu) { + asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + */ +// w_in > 7 +void conv_depthwise_3x3s2p0_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 5 > h_in) { + switch (i * 2 + 5 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + if (flag_relu) { + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } else { + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i * 2 + 3 > h_in) { + switch (i * 2 + 3 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + if (flag_relu) { + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU + RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ +void conv_depthwise_3x3s2p0_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; j++) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + if (j * 2 + 2 >= h_in) { + switch (j + 2 - h_in) { + case 1: + din1_ptr = zero_ptr; + case 0: + din2_ptr = zero_ptr; + default: + break; + } + } + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + if (flag_relu) { + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + } else { + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + } +#else + if (flag_relu) { + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_5x5s1_int8.cc b/lite/backends/arm/math/conv_depthwise_5x5s1_int8.cc deleted file mode 100644 index 0d0034dd85e200cf7902bf997c688d4d69668278..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_5x5s1_int8.cc +++ /dev/null @@ -1,618 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" -#ifdef ARM_WITH_OMP -#include -#endif - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_5x5s1_int8(int32_t* dout, - const int8_t* din, - const int8_t* weights, - const int* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int chin, - const int hin, - const int win, - const int hout, - const int wout, - ARMContext* ctx, - PrecisionType out_type, - const float* scale); - -void conv_depthwise_5x5_int8(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - ARMContext* ctx, - PrecisionType out_type, - const float* scale) { - int stride_h = param.strides[0]; - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active){ - // if (param.activation_param.active == Active_relu || - // fabs(param.activation_param.negative_slope) > 1e-6f){ - // flag_relu = true; - // } - // } - if (stride_h == 1) { -#ifdef __aarch64__ - conv_depthwise_5x5s1_int8(dout, - din, - weights, - bias, - flag_bias, - flag_relu, - num, - chin, - hin, - win, - hout, - wout, - ctx, - out_type, - scale); -#else - - LOG(FATAL) << "5x5 dw conv armv7 has not impl"; -#endif - } -} - -/** - * \brief depthwise convolution, kernel size 5x5, stride 1, pad 1, with bias, - * width > 4 - */ -// 2 line -#ifdef __aarch64__ - -template -inline void prefetch(const Dtype* din) { -#ifdef __aarch64__ - asm volatile("PRFM PLDL1KEEP, [%[din]] \n" : : [din] "r"(din) : "memory"); -#else - asm volatile("pld [%[din]] \n" : : [din] "r"(din) : "memory"); -#endif -} - -void conv_depthwise_5x5s1_int8( - int32_t* dout, - const int8_t* din, - const int8_t* weights, - const int32_t* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int chin, - const int hin, - const int win, - const int hout, - const int wout, - ARMContext* ctx, - PrecisionType od_type, - float const* scales) { /// scale_size = channel-out - - // printf("5*5 multiply\n"); - int size_in_channel = win * hin; - int size_out_channel = wout * hout; - int w_stride = 5 * 5; - - static int const stride_w = 1; - int const stride_h = stride_w; - int const chout = chin; - int const pad_w = 2; - int const pad_h = pad_w; - - int const wout_round = ((wout + 7) / 8) * 8; - int const win_round = wout_round * stride_w + 5 - 1; - int const hout_round = ((hout + 2) / 3) * 3; - int const hin_round = hout_round * stride_h + 5 - 1; - int const tile_h = hout_round / 3; - int const tile_w = wout_round / 8; - - int const pre_in_size = hin_round * win_round; - int const pre_out_size = hout_round * wout_round; - int const pre_io_size = pre_in_size + pre_out_size * sizeof(int); - - int const hs = -pad_h; - int const he = hs + hin_round; - int const ws = -pad_w; - int const we = ws + win_round; - - // signed char* tmp_work_space = new signed char [1024*5]; - signed char* tmp_work_space = ctx->workspace_data(); - signed char* ptr_zero = tmp_work_space; - int* ptr_write = reinterpret_cast(ptr_zero + win_round); - signed char* pre_data = - reinterpret_cast(ptr_write + wout_round); - - memset(ptr_zero, 0, win_round * sizeof(signed char)); - - for (int n = 0; n < num; ++n) { - signed char const* din_batch = din + n * chin * size_in_channel; - int* dout_batch = dout + n * chout * size_out_channel; - - // #pragma omp parallel for - for (int c = 0; c < chout; c++) { -#ifdef ARM_WITH_OMP - int const thno = omp_get_thread_num(); -#else - int const thno = 0; -#endif - signed char const* din_channel = din_batch + c * size_in_channel; - signed char* pre_din = pre_data + thno * pre_io_size; - int* pre_out = reinterpret_cast(pre_din + pre_in_size); - int* dout_ptr = pre_out; - - prepack_input_nxw(din_channel, - pre_din, - c, - c + 1, - hs, - he, - ws, - we, - 1, - win, - hin, - ptr_zero); - - signed char const* wei_ptr = weights + c * w_stride; - int bias_val = flag_bias ? bias[c] : 0.f; - - int8x8_t wr00 = vdup_n_s8(wei_ptr[0 * 5 + 0]); - int8x8_t wr01 = vdup_n_s8(wei_ptr[0 * 5 + 1]); - int8x8_t wr02 = vdup_n_s8(wei_ptr[0 * 5 + 2]); - int8x8_t wr03 = vdup_n_s8(wei_ptr[0 * 5 + 3]); - int8x8_t wr04 = vdup_n_s8(wei_ptr[0 * 5 + 4]); - - int8x8_t wr10 = vdup_n_s8(wei_ptr[1 * 5 + 0]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[1 * 5 + 1]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[1 * 5 + 2]); - int8x8_t wr13 = vdup_n_s8(wei_ptr[1 * 5 + 3]); - int8x8_t wr14 = vdup_n_s8(wei_ptr[1 * 5 + 4]); - - int8x8_t wr20 = vdup_n_s8(wei_ptr[2 * 5 + 0]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[2 * 5 + 1]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[2 * 5 + 2]); - int8x8_t wr23 = vdup_n_s8(wei_ptr[2 * 5 + 3]); - int8x8_t wr24 = vdup_n_s8(wei_ptr[2 * 5 + 4]); - - int8x8_t wr30 = vdup_n_s8(wei_ptr[3 * 5 + 0]); - int8x8_t wr31 = vdup_n_s8(wei_ptr[3 * 5 + 1]); - int8x8_t wr32 = vdup_n_s8(wei_ptr[3 * 5 + 2]); - int8x8_t wr33 = vdup_n_s8(wei_ptr[3 * 5 + 3]); - int8x8_t wr34 = vdup_n_s8(wei_ptr[3 * 5 + 4]); - - int8x8_t wr40 = vdup_n_s8(wei_ptr[4 * 5 + 0]); - int8x8_t wr41 = vdup_n_s8(wei_ptr[4 * 5 + 1]); - int8x8_t wr42 = vdup_n_s8(wei_ptr[4 * 5 + 2]); - int8x8_t wr43 = vdup_n_s8(wei_ptr[4 * 5 + 3]); - int8x8_t wr44 = vdup_n_s8(wei_ptr[4 * 5 + 4]); - - int* doutr0 = nullptr; - int* doutr1 = nullptr; - int* doutr2 = nullptr; - - signed char const* dr0 = pre_din; - signed char const* dr1 = dr0 + win_round; - signed char const* dr2 = dr1 + win_round; - signed char const* dr3 = dr2 + win_round; - signed char const* dr4 = dr3 + win_round; - signed char const* dr5 = dr4 + win_round; - signed char const* dr6 = dr5 + win_round; - - signed char const* din_ptr0 = nullptr; - signed char const* din_ptr1 = nullptr; - signed char const* din_ptr2 = nullptr; - signed char const* din_ptr3 = nullptr; - signed char const* din_ptr4 = nullptr; - signed char const* din_ptr5 = nullptr; - signed char const* din_ptr6 = nullptr; - - for (int h = 0; h < tile_h; h++) { - // printf("c:%d h:%d\n", c, h); - doutr0 = dout_ptr; - doutr1 = doutr0 + wout_round; - doutr2 = doutr1 + wout_round; - - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - din_ptr6 = dr6; - - prefetch(doutr0); - prefetch(doutr1); - prefetch(doutr2); - prefetch(din_ptr0); - prefetch(din_ptr1); - prefetch(din_ptr2); - prefetch(din_ptr3); - prefetch(din_ptr4); - prefetch(din_ptr5); - prefetch(din_ptr6); - - for (int j = 0; j < tile_w; ++j) { - // printf("j:%d\n", j); - int32x4_t voutr00 = vdupq_n_s32(bias_val); - int32x4_t voutr01 = vdupq_n_s32(bias_val); - int32x4_t voutr10 = vdupq_n_s32(bias_val); - int32x4_t voutr11 = vdupq_n_s32(bias_val); - int32x4_t voutr20 = vdupq_n_s32(bias_val); - int32x4_t voutr21 = vdupq_n_s32(bias_val); - - // din data - int8x8_t vinr00 = vld1_s8(din_ptr0 + 0); - int8x8_t vinr01 = vld1_s8(din_ptr0 + 8); - int8x8_t vinr10 = vld1_s8(din_ptr1 + 0); - int8x8_t vinr11 = vld1_s8(din_ptr1 + 8); - int8x8_t vinr20 = vld1_s8(din_ptr2 + 0); - int8x8_t vinr21 = vld1_s8(din_ptr2 + 8); - int8x8_t vinr30 = vld1_s8(din_ptr3 + 0); - int8x8_t vinr31 = vld1_s8(din_ptr3 + 8); - int8x8_t vinr40 = vld1_s8(din_ptr4 + 0); - int8x8_t vinr41 = vld1_s8(din_ptr4 + 8); - int8x8_t vinr50 = vld1_s8(din_ptr5 + 0); - int8x8_t vinr51 = vld1_s8(din_ptr5 + 8); - int8x8_t vinr60 = vld1_s8(din_ptr6 + 0); - int8x8_t vinr61 = vld1_s8(din_ptr6 + 8); - - /// the first row - // r0 - int8x8_t vtmp1 = vext_s8(vinr00, vinr01, 1); // 12345678 - int8x8_t vtmp2 = vext_s8(vinr00, vinr01, 2); // 2345678 - int8x8_t vtmp3 = vext_s8(vinr00, vinr01, 3); // 345678 - int8x8_t vtmp4 = vext_s8(vinr00, vinr01, 4); // 45678 - - int16x8_t tvoutr0 = vmull_s8(vinr00, wr00); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr01); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr02); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr03); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr04); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - // r1 - vtmp1 = vext_s8(vinr10, vinr11, 1); // 12345678 - vtmp2 = vext_s8(vinr10, vinr11, 2); // 2345678 - vtmp3 = vext_s8(vinr10, vinr11, 3); // 345678 - vtmp4 = vext_s8(vinr10, vinr11, 4); // 45678 - - tvoutr0 = vmull_s8(vinr10, wr10); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr11); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr12); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr13); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr14); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - int16x8_t tvoutr1 = vmull_s8(vinr10, wr00); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr01); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr02); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr03); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr04); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - // r2 - vtmp1 = vext_s8(vinr20, vinr21, 1); // 12345678 - vtmp2 = vext_s8(vinr20, vinr21, 2); // 2345678 - vtmp3 = vext_s8(vinr20, vinr21, 3); // 345678 - vtmp4 = vext_s8(vinr20, vinr21, 4); // 45678 - - tvoutr0 = vmull_s8(vinr20, wr20); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr21); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr22); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr23); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr24); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - tvoutr1 = vmull_s8(vinr20, wr10); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr11); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr12); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr13); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr14); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - int16x8_t tvoutr2 = vmull_s8(vinr20, wr00); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr01); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr02); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr03); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr04); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r3 - vtmp1 = vext_s8(vinr30, vinr31, 1); // 12345678 - vtmp2 = vext_s8(vinr30, vinr31, 2); // 2345678 - vtmp3 = vext_s8(vinr30, vinr31, 3); // 345678 - vtmp4 = vext_s8(vinr30, vinr31, 4); // 45678 - - tvoutr0 = vmull_s8(vinr30, wr30); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr31); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr32); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr33); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr34); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - tvoutr1 = vmull_s8(vinr30, wr20); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr21); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr22); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr23); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr24); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - tvoutr2 = vmull_s8(vinr30, wr10); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr11); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr12); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr13); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr14); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r4 - vtmp1 = vext_s8(vinr40, vinr41, 1); // 12345678 - vtmp2 = vext_s8(vinr40, vinr41, 2); // 2345678 - vtmp3 = vext_s8(vinr40, vinr41, 3); // 345678 - vtmp4 = vext_s8(vinr40, vinr41, 4); // 45678 - - tvoutr0 = vmull_s8(vinr40, wr40); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr41); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr42); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr43); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr44); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - tvoutr1 = vmull_s8(vinr40, wr30); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr31); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr32); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr33); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr34); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - tvoutr2 = vmull_s8(vinr40, wr20); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr21); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr22); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr23); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr24); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r5 - vtmp1 = vext_s8(vinr50, vinr51, 1); // 12345678 - vtmp2 = vext_s8(vinr50, vinr51, 2); // 2345678 - vtmp3 = vext_s8(vinr50, vinr51, 3); // 345678 - vtmp4 = vext_s8(vinr50, vinr51, 4); // 45678 - - tvoutr1 = vmull_s8(vinr50, wr40); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr41); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr42); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr43); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr44); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - tvoutr2 = vmull_s8(vinr50, wr30); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr31); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr32); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr33); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr34); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r6 - vtmp1 = vext_s8(vinr60, vinr61, 1); // 12345678 - vtmp2 = vext_s8(vinr60, vinr61, 2); // 2345678 - vtmp3 = vext_s8(vinr60, vinr61, 3); // 345678 - vtmp4 = vext_s8(vinr60, vinr61, 4); // 45678 - - tvoutr2 = vmull_s8(vinr60, wr40); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr41); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr42); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr43); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr44); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - /// data shift 8 bytes - din_ptr0 += 8; - din_ptr1 += 8; - din_ptr2 += 8; - din_ptr3 += 8; - din_ptr4 += 8; - din_ptr5 += 8; - din_ptr6 += 8; - - /// store - vst1q_s32(doutr0, voutr00); - vst1q_s32(doutr1, voutr10); - vst1q_s32(doutr2, voutr20); - doutr0 += 4; - doutr1 += 4; - doutr2 += 4; - vst1q_s32(doutr0, voutr01); - vst1q_s32(doutr1, voutr11); - vst1q_s32(doutr2, voutr21); - doutr0 += 4; - doutr1 += 4; - doutr2 += 4; - } /// end of tile_w - - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - dr3 = dr6; - dr4 = dr3 + win_round; - dr5 = dr4 + win_round; - dr6 = dr5 + win_round; - - dout_ptr = dout_ptr + 3 * wout_round; - } /// end of tile_h - - if (scales == 0) { - write_to_output_numc(pre_out, - dout_batch, - 1, - hout_round, - c, - c + 1, - 0, - hout, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - ptr_write); - } else if (od_type == PRECISION(kFloat)) { - write2_to_output_numc(pre_out, - reinterpret_cast(dout_batch), - 1, - hout_round, - c, - c + 1, - 0, - hout, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - scales); - } else if (od_type == PRECISION(kInt8)) { - write2_to_output_numc(pre_out, - reinterpret_cast(dout_batch), - 1, - hout_round, - c, - c + 1, - 0, - hout, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - scales); - } - // else if (od_type == AK_INT32) { - // write2_to_output_numc(pre_out, (int*)dout_batch, 1, hout_round, c, - // c+1, - // 0, hout, 0, wout_round, chout, hout, wout, flag_relu, - // (int*)ptr_write, scales); - // } - } /// end of chout - } /// end of batch num -} - -#endif // __aarch64__ - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_direct.cc b/lite/backends/arm/math/conv_direct.cc deleted file mode 100644 index 51526aa2b3ce3e08afe41c9b1a150963fffebde1..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_direct.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_direct.h" -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -template <> -bool DirectConv::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ic = x_dims[1]; - int ow = o_dims[3]; - int oc = o_dims[1]; - int kw = w_dims[3]; - int sw = param.strides[1]; - // select dw conv kernel - const auto* w_data = param.filter->data(); - if (kw == 3 && sw == 1) { - VLOG(5) << "invoke 3x3s1 direct conv"; - impl_ = conv_3x3s1_direct_fp32; - - constexpr int cblock = 4; - int cround = (oc + cblock - 1) / cblock * cblock; - weights_trans_.Resize({cround, ic, kw, kw}); - float* transed_w_data = weights_trans_.mutable_data(); - - conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); - is_weights_transed_ = true; - } else if (kw == 3 && sw == 2) { - VLOG(5) << "invoke 3x3s2 direct conv"; - impl_ = conv_3x3s2_direct_fp32; - - constexpr int cblock = 4; - int cround = (oc + cblock - 1) / cblock * cblock; - weights_trans_.Resize({cround, ic, kw, kw}); - float* transed_w_data = weights_trans_.mutable_data(); - conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); - is_weights_transed_ = true; - } else { - LOG(ERROR) << "this type direct conv not impl"; - return false; - } - return true; -} - -template <> -bool DirectConv::init(const operators::ConvParam& param, - Context* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template <> -bool DirectConv::run(const operators::ConvParam& param) { - // start timer - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - - if (is_weights_transed_ == true) { - w_data = weights_trans_.data(); - } - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int oc = o_dims[1]; - - impl_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_); - - // timer end - return true; -} - -template -bool DirectConvInt8::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ic = x_dims[1]; - int ow = o_dims[3]; - int oc = o_dims[1]; - int kw = w_dims[3]; - int sw = param.strides[1]; - // select dw conv kernel - w_scale_ = param.weight_scale; - //! update weights scale - const auto* w_data = param.filter->data(); - if (Ptype_out == PRECISION(kInt8) || Ptype_out == PRECISION(kFloat)) { - CHECK_EQ(this->w_scale_.size(), oc) << "weights scale size must be chout"; - float input_scale = param.input_scale; - for (auto& w_s : w_scale_) { - w_s *= input_scale; - if (Ptype_out == PRECISION(kInt8)) { - w_s /= param.output_scale; - } - } - } - if (kw == 3 && sw == 1) { - VLOG(5) << "invoke 3x3s1 direct conv"; - impl_int8_ = conv_3x3s1_direct_int8; - - constexpr int cblock = 4; - int inpad = 4; - int cround = (oc + cblock - 1) / cblock * cblock; - weights_trans_.Resize({cround, ic, kw, kw}); - int8_t* transed_w_data = weights_trans_.mutable_data(); - conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); - - int wout_round = ((ow + 3) / 4) * 4; - int win_round = wout_round * sw + inpad; - int row_out = 2; - int row_in = 4; - int tmp_size_out = wout_round * row_out * cblock; - int in_len = win_round * ic; - int tmp_size_in = row_in * in_len; - ctx_->ExtendWorkspace(ctx_->threads() * tmp_size_out + - (tmp_size_in + 3) / 4 * 4 + wout_round + win_round); - is_weights_transed_ = true; - - } else if (kw == 3 && sw == 2) { - VLOG(5) << "invoke 3x3s2 direct conv"; - impl_int8_ = conv_3x3s2_direct_int8; - - // constexpr int cblock = 4; - int cblock = conv_3x3s2_direct_int8_c_num(); - int cround = (oc + cblock - 1) / cblock * cblock; - weights_trans_.Resize({cround, ic, kw, kw}); - int8_t* transed_w_data = weights_trans_.mutable_data(); - conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); - is_weights_transed_ = true; - - } else { - LOG(ERROR) << "this type direct conv not impl"; - return false; - } - return true; -} - -template -bool DirectConvInt8::init(const operators::ConvParam& param, - Context* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template -bool DirectConvInt8::run(const operators::ConvParam& param) { - // start timer - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - if (is_weights_transed_ == true) { - w_data = weights_trans_.data(); - } - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int oc = o_dims[1]; - - impl_int8_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_, - Ptype_out, - w_scale_.data()); - - // Modified from int32 for debug convenience - if (Ptype_out == PRECISION(kInt8)) param.output->mutable_data(); - return true; -} - -template class DirectConvInt8; -template class DirectConvInt8; -template class DirectConvInt8; - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_direct.h b/lite/backends/arm/math/conv_direct.h deleted file mode 100644 index e6132dca5e6f5899160c90cf03c3ece7c105c0e9..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_direct.h +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/core/target_wrapper.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -template -class DirectConv : public ImplBase { - public: - typedef void (*conv_direct_impl)(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - Context* ctx); - - DirectConv() = default; - ~DirectConv() {} - - virtual bool init(const operators::ConvParam& param, - Context* ctx); - - virtual bool create(const operators::ConvParam& param, - Context* ctx); - - virtual bool run(const operators::ConvParam& param); - - protected: - bool is_weights_transed_{false}; - Tensor weights_trans_; - Tensor _tmp_out; - - private: - conv_direct_impl impl_{nullptr}; -}; - -template -class DirectConvInt8 - : public ImplBase { - public: - typedef void (*conv_direct_int8_impl)(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale); - - DirectConvInt8() = default; - ~DirectConvInt8() {} - - virtual bool init(const operators::ConvParam& param, - Context* ctx); - - virtual bool create(const operators::ConvParam& param, - Context* ctx); - - virtual bool run(const operators::ConvParam& param); - - private: - bool is_weights_transed_{false}; - Tensor weights_trans_; - Tensor _tmp_out; - conv_direct_int8_impl impl_int8_{nullptr}; - std::vector w_scale_; -}; - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_direct_3x3s1.cc b/lite/backends/arm/math/conv_direct_3x3s1.cc deleted file mode 100644 index 6991481ee10b0d2f54c8cec12958674dc9b51369..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_direct_3x3s1.cc +++ /dev/null @@ -1,1067 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" -#ifdef ARM_WITH_OMP -#include -#endif - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_3x3s1_direct_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx) { - const int threads = ctx->threads(); - int l2_size = ctx->llc_size() / sizeof(float); - - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; - const int hout_c_block = 4; - const int hout_r_kernel = 2; - const int wout_block = 4; - const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block; - const int win_round = wout_round + 2; - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active) { - // if (param.activation_param.active == Active_relu && - // fabs(param.activation_param.negative_slope) < 1e-6f) { - // flag_relu = true; - // } - // } - int hout_r_block = (l2_size - 2 * win_round * ic) / - (win_round * ic + hout_c_block * wout_round * threads); - hout_r_block = hout_r_block > oh ? oh : hout_r_block; - hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; - hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; - - const int hin_r_block = hout_r_block + 2; - - float* tmp_work_space = ctx->workspace_data(); - float ptr_zero[win_round]; // NOLINT - memset(ptr_zero, 0, sizeof(float) * win_round); - float ptr_write[wout_round]; // NOLINT - - int in_len = win_round * ic; - int pre_in_size = hin_r_block * in_len; - int pre_out_size = hout_c_block * hout_r_block * wout_round; - - float* pre_din = tmp_work_space; - - int size_in_channel = win * ih; - int size_out_channel = ow * oh; - int w_stride = ic * 9; // kernel_w * kernel_h; - int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * - - int ws = -pad_w; - int we = ws + win_round; - int w_loop = wout_round / 4; - - int c_remain = oc - (oc / hout_c_block) * hout_c_block; - int c_round_down = (oc / hout_c_block) * hout_c_block; - - int out_row_stride = hout_c_block * wout_round; - for (int n = 0; n < bs; ++n) { - const float* din_batch = i_data + n * ic * size_in_channel; - float* dout_batch = o_data + n * oc * size_out_channel; - for (int h = 0; h < oh; h += hout_r_block) { - int h_kernel = hout_r_block; - if (h + hout_r_block > oh) { - h_kernel = oh - h; - } - int hs = h - pad_h; - int he = hs + h_kernel + 2; - prepack_input_nxw( - din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < oc - (hout_c_block - 1); c += hout_c_block) { -#ifdef ARM_WITH_OMP - float* pre_out = - pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; -#else - float* pre_out = pre_din + pre_in_size; -#endif - const float* block_inr0 = pre_din; - const float* block_inr1 = block_inr0 + in_len; - const float* block_inr2 = block_inr1 + in_len; - const float* block_inr3 = block_inr2 + in_len; - - const float* weight_c = weights + c * w_stride; - const float* bias_ptr = ptr_zero; - if (flag_bias) { - bias_ptr = bias + c; - } - fill_packed_biasc4( - pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); - - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { - const float* wc0 = weight_c; - - const float* inr0 = block_inr0; - const float* inr1 = block_inr1; - const float* inr2 = block_inr2; - const float* inr3 = block_inr3; - - float* pre_out0 = pre_out + hk * out_row_stride; - float* pre_out1 = pre_out0 + out_row_stride; -#ifdef __aarch64__ - for (int i = 0; i < ic; ++i) { - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - float32x4_t w0 = vld1q_f32(wc0); // w0, v23 - float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 - float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 - float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 - float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 - float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 - float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 - float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 - float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - - int cnt = w_loop; - asm volatile( - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, - outr01*/ - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ - "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ - "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ - "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ - "2: \n" /* main loop*/ - /* r0, r1, mul w0, get out r0, r1 */ - "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ - "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ - "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ - "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ - "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ - "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ - "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ - "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ - - /* r0, r1, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ - "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ - "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ - "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ - "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ - "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ - "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ - "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ - - "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ - - /* r0, r1, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ - "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ - "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ - "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ - "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ - "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ - "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ - - /* r1, r2, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ - "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ - "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ - "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ - "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ - "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ - "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ - - "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ - - /* r1, r2, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ - "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ - "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ - "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ - "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ - "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ - "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ - "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ - - "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ - - /* r1, r2, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ - "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ - "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ - "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ - "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ - "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ - "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ - - /* r2, r3, mul w6, get out r0, r1 */ - "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ - "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ - "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ - "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ - "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ - "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ - "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ - - "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ - - /* r2, r3, mul w7, get out r0, r1 */ - "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ - "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ - "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ - "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ - "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ - "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ - "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - - /* r2, r3, mul w8, get out r0, r1 */ - "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ - "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ - "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ - - "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ - "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ - "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ - "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ - "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ - "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ - "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ - "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ - "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); - - wc0 += 9 * hout_c_block; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - } -#else // not __aarch64__ - for (int i = 0; i < ic; ++i) { - const float* wc0 = weight_c + i * w_stride_chin; - - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - - int cnt = w_loop; - asm volatile( - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ " - "load outr0, w0, w1, c0~c3\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - /* load r0, r1 */ - "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " - "4 float\n" - "vld1.32 {d2}, [%[r0]] @ load r0, " - "2 float\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - /* main loop */ - "0: @ main " - "loop\n" - /* mul r0 with w0, w1, w2, get out r0 */ - "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " - "outr1, w0, w1, c0~c3\n" - "vmla.f32 q8, q5, d0[0] @ w0 * " - "inr00\n" - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " - "outr1, w2, w3, c0~c3\n" - "vmla.f32 q9, q5, d0[1] @ w0 * " - "inr01\n" - "vmla.f32 q10, q5, d1[0] @ w0 * " - "inr02\n" - "vmla.f32 q11, q5, d1[1] @ w0 * " - "inr03\n" - "vld1.32 {d3-d4}, [%[r1]]! @ load r1, " - "4 float\n" - "vmla.f32 q8, q6, d0[1] @ w1 * " - "inr01\n" - "vmla.f32 q9, q6, d1[0] @ w1 * " - "inr02\n" - "vmla.f32 q10, q6, d1[1] @ w1 * " - "inr03\n" - "vmla.f32 q11, q6, d2[0] @ w1 * " - "inr04\n" - "vld1.32 {d5}, [%[r1]] @ load r0, " - "2 float\n" - "vmla.f32 q8, q7, d1[0] @ w2 * " - "inr02\n" - "vmla.f32 q9, q7, d1[1] @ w2 * " - "inr03\n" - "vmla.f32 q10, q7, d2[0] @ w2 * " - "inr04\n" - "vmla.f32 q11, q7, d2[1] @ w2 * " - "inr05\n" - - "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " - "- 32, to start address\n" - - /* mul r1 with w0, w1, w2, get out r1 */ - "vmla.f32 q12, q5, d3[0] @ w0 * " - "inr10\n" - "vmla.f32 q13, q5, d3[1] @ w0 * " - "inr11\n" - "vmla.f32 q14, q5, d4[0] @ w0 * " - "inr12\n" - "vmla.f32 q15, q5, d4[1] @ w0 * " - "inr13\n" - "vmla.f32 q12, q6, d3[1] @ w1 * " - "inr11\n" - "vmla.f32 q13, q6, d4[0] @ w1 * " - "inr12\n" - "vmla.f32 q14, q6, d4[1] @ w1 * " - "inr13\n" - "vmla.f32 q15, q6, d5[0] @ w1 * " - "inr14\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " - "w4, to q5, q6\n" - "vmla.f32 q12, q7, d4[0] @ w2 * " - "inr12\n" - "vmla.f32 q13, q7, d4[1] @ w2 * " - "inr13\n" - "vmla.f32 q14, q7, d5[0] @ w2 * " - "inr14\n" - "vmla.f32 q15, q7, d5[1] @ w2 * " - "inr15\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " - "to q7\n" - - /* mul r1 with w3, w4, w5, get out r0 */ - "vmla.f32 q8, q5, d3[0] @ w3 * " - "inr10\n" - "vmla.f32 q9, q5, d3[1] @ w3 * " - "inr11\n" - "vmla.f32 q10, q5, d4[0] @ w3 * " - "inr12\n" - "vmla.f32 q11, q5, d4[1] @ w3 * " - "inr13\n" - "vld1.32 {d0-d1}, [%[r2]]! @ load r2, " - "4 float\n" - "vmla.f32 q8, q6, d3[1] @ w4 * " - "inr11\n" - "vmla.f32 q9, q6, d4[0] @ w4 * " - "inr12\n" - "vmla.f32 q10, q6, d4[1] @ w4 * " - "inr13\n" - "vmla.f32 q11, q6, d5[0] @ w4 * " - "inr14\n" - "vld1.32 {d2}, [%[r2]] @ load r2, " - "2 float\n" - "vmla.f32 q8, q7, d4[0] @ w5 * " - "inr12\n" - "vmla.f32 q9, q7, d4[1] @ w5 * " - "inr13\n" - "vmla.f32 q10, q7, d5[0] @ w5 * " - "inr14\n" - "vmla.f32 q11, q7, d5[1] @ w5 * " - "inr15\n" - - /* mul r2 with w3, w4, w5, get out r1 */ - "vmla.f32 q12, q5, d0[0] @ w3 * " - "inr20\n" - "vmla.f32 q13, q5, d0[1] @ w3 * " - "inr21\n" - "vmla.f32 q14, q5, d1[0] @ w3 * " - "inr22\n" - "vmla.f32 q15, q5, d1[1] @ w3 * " - "inr23\n" - "vmla.f32 q12, q6, d0[1] @ w4 * " - "inr21\n" - "vmla.f32 q13, q6, d1[0] @ w4 * " - "inr22\n" - "vmla.f32 q14, q6, d1[1] @ w4 * " - "inr23\n" - "vmla.f32 q15, q6, d2[0] @ w4 * " - "inr24\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " - "w7, to q5, q6\n" - "vmla.f32 q12, q7, d1[0] @ w5 * " - "inr22\n" - "vmla.f32 q13, q7, d1[1] @ w5 * " - "inr23\n" - "vmla.f32 q14, q7, d2[0] @ w5 * " - "inr24\n" - "vmla.f32 q15, q7, d2[1] @ w5 * " - "inr25\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " - "to q7\n" - - "sub %[wc0], %[wc0], #144 @ wc0 - " - "144 to start address\n" - - /* mul r2 with w6, w7, w8, get out r0 */ - "vmla.f32 q8, q5, d0[0] @ w6 * " - "inr20\n" - "vmla.f32 q9, q5, d0[1] @ w6 * " - "inr21\n" - "vld1.32 {d3-d4}, [%[r3]]! @ load r3, " - "4 float\n" - "vmla.f32 q10, q5, d1[0] @ w6 * " - "inr22\n" - "vmla.f32 q11, q5, d1[1] @ w6 * " - "inr23\n" - "vmla.f32 q8, q6, d0[1] @ w7 * " - "inr21\n" - "vmla.f32 q9, q6, d1[0] @ w7 * " - "inr22\n" - "vld1.32 {d5}, [%[r3]] @ load r3, " - "2 float\n" - "vmla.f32 q10, q6, d1[1] @ w7 * " - "inr23\n" - "vmla.f32 q11, q6, d2[0] @ w7 * " - "inr24\n" - "vmla.f32 q8, q7, d1[0] @ w8 * " - "inr22\n" - "vmla.f32 q9, q7, d1[1] @ w8 * " - "inr23\n" - "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " - "4 float\n" - "vmla.f32 q10, q7, d2[0] @ w8 * " - "inr24\n" - "vmla.f32 q11, q7, d2[1] @ w8 * " - "inr25\n" - "vld1.32 {d2}, [%[r0]] @ load r0, " - "2 float\n" - - /* mul r3 with w6, w7, w8, get out r1 */ - "vmla.f32 q12, q5, d3[0] @ w6 * " - "inr20\n" - "vmla.f32 q13, q5, d3[1] @ w6 * " - "inr21\n" - "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " - "r00, r01, c0~c3\n" - "vmla.f32 q14, q5, d4[0] @ w6 * " - "inr22\n" - "vmla.f32 q15, q5, d4[1] @ w6 * " - "inr23\n" - "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " - "r02, r03, c0~c3\n" - "vmla.f32 q12, q6, d3[1] @ w7 * " - "inr21\n" - "vmla.f32 q13, q6, d4[0] @ w7 * " - "inr22\n" - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " - "outr0, w0, w1, c0~c3\n" - "vmla.f32 q14, q6, d4[1] @ w7 * " - "inr23\n" - "vmla.f32 q15, q6, d5[0] @ w7 * " - "inr24\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vmla.f32 q12, q7, d4[0] @ w8 * " - "inr22\n" - "vmla.f32 q13, q7, d4[1] @ w8 * " - "inr23\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - "vmla.f32 q14, q7, d5[0] @ w8 * " - "inr24\n" - "vmla.f32 q15, q7, d5[1] @ w8 * " - "inr25\n" - - "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " - "r10, r11, c0~c3\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " - "r12, r13, c0~c3\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - "subs %[cnt], #1 @ loop " - "count--\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1), - [wc0] "+r"(wc0) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - } -#endif // __aarch64__ - block_inr0 = block_inr2; - block_inr1 = block_inr3; - block_inr2 = block_inr1 + in_len; - block_inr3 = block_inr2 + in_len; - } - write_to_output_c4_fp32(pre_out, - dout_batch, - c, - c + hout_c_block, - h, - h + h_kernel, - 0, - wout_round, - oc, - oh, - ow, - flag_relu, - ptr_write); - } - const float* weight_remain_ptr = weights + c_round_down * w_stride; -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < c_remain; ++c) { -#ifdef ARM_WITH_OMP - float* pre_out = - pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; -#else - float* pre_out = pre_din + pre_in_size; -#endif - - int c_idx = c_round_down + c; - - int h_kernel = hout_r_block; - if (h + hout_r_block > oh) { - h_kernel = oh - h; - } - - const float* block_inr0 = pre_din; - const float* block_inr1 = block_inr0 + in_len; - const float* block_inr2 = block_inr1 + in_len; - const float* block_inr3 = block_inr2 + in_len; - - const float* bias_ptr = ptr_zero; - if (flag_bias) { - bias_ptr = bias + c_idx; - } - fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); - - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { - const float* wc0 = weight_remain_ptr; - - const float* inr0 = block_inr0; - const float* inr1 = block_inr1; - const float* inr2 = block_inr2; - const float* inr3 = block_inr3; - - float* pre_out0 = pre_out + hk * wout_round; - float* pre_out1 = pre_out0 + wout_round; -#ifdef __aarch64__ - for (int i = 0; i < ic; ++i) { - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 - float32x4_t w1 = vdupq_n_f32(wc0[4 + c]); // w1, v24 - float32x4_t w2 = vdupq_n_f32(wc0[8 + c]); // w2, v25 - float32x4_t w3 = vdupq_n_f32(wc0[12 + c]); // w3, v26 - float32x4_t w4 = vdupq_n_f32(wc0[16 + c]); // w4, v27 - float32x4_t w5 = vdupq_n_f32(wc0[20 + c]); // w5, v28 - float32x4_t w6 = vdupq_n_f32(wc0[24 + c]); // w6, v29 - float32x4_t w7 = vdupq_n_f32(wc0[28 + c]); // w7, v30 - float32x4_t w8 = vdupq_n_f32(wc0[32 + c]); // w8, v31 - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - - int cnt = w_loop; - asm volatile( - "ldr q21, [%[ptr_out0]] \n" /* load outr0, - w0~w3*/ - "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ - "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ - "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ - "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ - "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ - "2: \n" /* main loop*/ - - "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ - "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ - - "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ - "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ - "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ - "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ - - "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ - - "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ - "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ - - "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ - "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ - - "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ - "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ - - "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ - "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ - "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ - "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ - - "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ - "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ - - "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ - "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ - - "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ - - "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ - "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ - - "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ - - "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ - "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ - - "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ - - "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ - "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ - - "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ - "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - - "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ - "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ - - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v21", - "v22"); - - wc0 += 9 * hout_c_block; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - } -#else // not __aarch64__ - for (int i = 0; i < ic; ++i) { - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - //! get valid weights of current output channel - float w_tmp[10] = {wc0[c], - wc0[c + 4], - wc0[c + 8], - wc0[c + 12], - wc0[c + 16], - wc0[c + 20], - wc0[c + 24], - wc0[c + 28], - wc0[c + 32], - 0.f}; - float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 - float32x4_t w1 = vld1q_f32(w_tmp + 3); // w3, w4, w5, q1 - float32x4_t w2 = vld1q_f32(w_tmp + 6); // w6, w7, w8, q2 - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - int cnt = w_loop / 2; - if (cnt > 0) { - asm volatile( - "vld1.32 {d24-d27}, [%[ptr_out0]] @ " - "load or00, or01\n" - "vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 " - "float\n" - "vld1.32 {d10}, [%[r0]] @ load r0, 2 " - "float\n" - /* main loop */ - "0: @ main loop\n" - /* r0 * w0, w1, w2, get out r0*/ - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " - "or11\n" - "vext.32 q8, q3, q4, #1 @ r0, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r0, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, " - "0, 1, 2, 3\n" - "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, " - "4, 5, 6, 7\n" - "vext.32 q10, q3, q4, #2 @ r0, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r0, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " - "1, 2, 3, 4\n" - "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, " - "5, 6, 7, 8\n" - "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 " - "float\n" - "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, " - "2, 3, 4, 5\n" - "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r1]] @ load r1, 2 " - "float\n" - - /* r1 * w3, w4, w5, get out r0*/ - /* r1 * w0, w1, w2, get out r1*/ - "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, " - "0, 1, 2, 3\n" - "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, " - "4, 5, 6, 7\n" - "vext.32 q8, q3, q4, #1 @ r1, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r1, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, " - "4, 5, 6, 7\n" - "vext.32 q10, q3, q4, #2 @ r1, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r1, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " - "1, 2, 3, 4\n" - "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, " - "5, 6, 7, 8\n" - "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, " - "1, 2, 3, 4\n" - "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, " - "5, 6, 7, 8\n" - "vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 " - "float\n" - "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, " - "2, 3, 4, 5\n" - "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, " - "6, 7, 8, 9\n" - "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, " - "2, 3, 4, 5\n" - "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r2]] @ load r2, 2 " - "float\n" - - /* r2 * w6, w7, w8, get out r0*/ - /* r2 * w3, w4, w5, get out r1*/ - "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " - "0, 1, 2, 3\n" - "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, " - "4, 5, 6, 7\n" - "vext.32 q8, q3, q4, #1 @ r2, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r2, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, " - "4, 5, 6, 7\n" - "vext.32 q10, q3, q4, #2 @ r2, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r2, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, " - "1, 2, 3, 4\n" - "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, " - "5, 6, 7, 8\n" - "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, " - "1, 2, 3, 4\n" - "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, " - "5, 6, 7, 8\n" - "vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 " - "float\n" - "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, " - "2, 3, 4, 5\n" - "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, " - "6, 7, 8, 9\n" - "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, " - "2, 3, 4, 5\n" - "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r3]] @ load r3, 2 " - "float\n" - - /* r3 * w6, w7, w8, get out r1*/ - "vext.32 q8, q3, q4, #1 @ r3, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r3, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, " - "4, 5, 6, 7\n" - "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, " - "or01\n" - "vext.32 q10, q3, q4, #2 @ r3, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r3, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, " - "4, 5, 6, 7\n" - "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " - "or01\n" - "vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 " - "float\n" - "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, " - "2, 3, 4, 5\n" - "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r0]] @ load r0, 2 " - "float\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, " - "or11\n" - - "subs %[cnt], #1 @loop count " - "-1\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - r0 -= 8; - } - //! deal with remain ow - if (w_loop & 1) { - ptr_out0[0] += - r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + - r1[0] * w_tmp[3] + r1[1] * w_tmp[4] + r1[2] * w_tmp[5] + - r2[0] * w_tmp[6] + r2[1] * w_tmp[7] + r2[2] * w_tmp[8]; - - ptr_out0[1] += - r0[1] * w_tmp[0] + r0[2] * w_tmp[1] + r0[3] * w_tmp[2] + - r1[1] * w_tmp[3] + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + - r2[1] * w_tmp[6] + r2[2] * w_tmp[7] + r2[3] * w_tmp[8]; - - ptr_out0[2] += - r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + - r1[2] * w_tmp[3] + r1[3] * w_tmp[4] + r1[4] * w_tmp[5] + - r2[2] * w_tmp[6] + r2[3] * w_tmp[7] + r2[4] * w_tmp[8]; - - ptr_out0[3] += - r0[3] * w_tmp[0] + r0[4] * w_tmp[1] + r0[5] * w_tmp[2] + - r1[3] * w_tmp[3] + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + - r2[3] * w_tmp[6] + r2[4] * w_tmp[7] + r2[5] * w_tmp[8]; - - ptr_out1[0] += - r1[0] * w_tmp[0] + r1[1] * w_tmp[1] + r1[2] * w_tmp[2] + - r2[0] * w_tmp[3] + r2[1] * w_tmp[4] + r2[2] * w_tmp[5] + - r3[0] * w_tmp[6] + r3[1] * w_tmp[7] + r3[2] * w_tmp[8]; - - ptr_out1[1] += - r1[1] * w_tmp[0] + r1[2] * w_tmp[1] + r1[3] * w_tmp[2] + - r2[1] * w_tmp[3] + r2[2] * w_tmp[4] + r2[3] * w_tmp[5] + - r3[1] * w_tmp[6] + r3[2] * w_tmp[7] + r3[3] * w_tmp[8]; - - ptr_out1[2] += - r1[2] * w_tmp[0] + r1[3] * w_tmp[1] + r1[4] * w_tmp[2] + - r2[2] * w_tmp[3] + r2[3] * w_tmp[4] + r2[4] * w_tmp[5] + - r3[2] * w_tmp[6] + r3[3] * w_tmp[7] + r3[4] * w_tmp[8]; - - ptr_out1[3] += - r1[3] * w_tmp[0] + r1[4] * w_tmp[1] + r1[5] * w_tmp[2] + - r2[3] * w_tmp[3] + r2[4] * w_tmp[4] + r2[5] * w_tmp[5] + - r3[3] * w_tmp[6] + r3[4] * w_tmp[7] + r3[5] * w_tmp[8]; - } - - wc0 += 36; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - } -#endif // __aarch64__ - block_inr0 = block_inr2; - block_inr1 = block_inr3; - block_inr2 = block_inr1 + in_len; - block_inr3 = block_inr2 + in_len; - } - write_to_output_c1_fp32(pre_out, - dout_batch, - c_idx, - c_idx + 1, - h, - h + h_kernel, - 0, - wout_round, - oc, - oh, - ow, - flag_relu, - ptr_write); - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_direct_3x3s2.cc b/lite/backends/arm/math/conv_direct_3x3s2.cc deleted file mode 100644 index 4bc9c5d25bd455c9a3dc802790ec99d30804238c..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_direct_3x3s2.cc +++ /dev/null @@ -1,1209 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#ifdef ARM_WITH_OMP -#include -#endif - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_3x3s2_direct_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx) { - //! 3x3s2 convolution, implemented by direct algorithm - //! prepack input to tmp buffer - //! write output to tmp buffer - const int threads = ctx->threads(); - int l2_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; - const int hout_c_block = 4; - const int hout_r_kernel = 2; - const int wout_block = 4; - const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block; - const int win_round = wout_round * 2 /*stride_w*/ + 1; - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active) { - // if (param.activation_param.active == Active_relu && - // fabs(param.activation_param.negative_slope) < 1e-6f) { - // flag_relu = true; - // } - // } - //! get h block - //! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block - //! * threads = l2_size - //! win_round = 2 * wout_round + 1 - //! hin_r_block = 2 * hout_r_block + 1 - int hout_r_block = - (l2_size - 2 * wout_round * ic - ic) / - ((4 * wout_round + 2) * ic + wout_round * hout_c_block * threads); - hout_r_block = hout_r_block > oh ? oh : hout_r_block; - hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; - hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; - - const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1; - - float* tmp_work_space = ctx->workspace_data(); - float ptr_zero[win_round]; // NOLINT - memset(ptr_zero, 0, sizeof(float) * win_round); - float ptr_write[wout_round]; // NOLINT - - int in_len = win_round * ic; - int pre_in_size = hin_r_block * in_len; - int pre_out_size = hout_c_block * hout_r_block * wout_round; - - //! l2_cache start - float* pre_din = tmp_work_space; - - int size_in_channel = win * ih; - int size_out_channel = ow * oh; - int w_stride = ic * 9; /*kernel_w * kernel_h*/ - int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * - - int ws = -pad_w; - int we = ws + win_round; - int w_loop = wout_round / 4; - - int c_remain = oc - (oc / hout_c_block) * hout_c_block; - int c_round_down = (oc / hout_c_block) * hout_c_block; - - int out_row_stride = hout_c_block * wout_round; - - for (int n = 0; n < bs; ++n) { - const float* din_batch = i_data + n * ic * size_in_channel; - float* dout_batch = o_data + n * oc * size_out_channel; - for (int h = 0; h < oh; h += hout_r_block) { - int h_kernel = hout_r_block; - if (h + hout_r_block > oh) { - h_kernel = oh - h; - } - - int hs = h * 2 /*stride_h*/ - pad_h; - int he = hs + h_kernel * 2 /*stride_h*/ + 1; - - prepack_input_nxw( - din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); - - const float* cblock_inr0 = pre_din; - const float* cblock_inr1 = cblock_inr0 + in_len; - const float* cblock_inr2 = cblock_inr1 + in_len; - const float* cblock_inr3 = cblock_inr2 + in_len; - const float* cblock_inr4 = cblock_inr3 + in_len; - -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < c_round_down; c += hout_c_block) { -#ifdef ARM_WITH_OMP - float* pre_out = - pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; -#else - float* pre_out = pre_din + pre_in_size; -#endif - const float* block_inr0 = cblock_inr0; - const float* block_inr1 = cblock_inr1; - const float* block_inr2 = cblock_inr2; - const float* block_inr3 = cblock_inr3; - const float* block_inr4 = cblock_inr4; - - const float* weight_c = weights + c * w_stride; - const float* bias_ptr = ptr_zero; - if (flag_bias) { - bias_ptr = bias + c; - } - fill_packed_biasc4( - pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); - - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { - const float* wc0 = weight_c; - - const float* inr0 = block_inr0; - const float* inr1 = block_inr1; - const float* inr2 = block_inr2; - const float* inr3 = block_inr3; - const float* inr4 = block_inr4; - - float* pre_out0 = pre_out + hk * out_row_stride; - float* pre_out1 = pre_out0 + out_row_stride; -#ifdef __aarch64__ - for (int i = 0; i < ic; ++i) { - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - float32x4_t w0 = vld1q_f32(wc0); // w0, v23 - float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 - float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 - float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 - float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 - float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 - float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 - float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 - float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - const float* r4 = inr4; - - int cnt = w_loop; - asm volatile( - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, - outr01*/ - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - - "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "2: \n" /* main loop*/ - /* r0, r2, mul w0, get out r0, r1 */ - "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ - "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ - "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ - "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ - "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ - "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ - "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ - "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ - "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ - "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ - - "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ - - /* r2 mul w6, get out r0*/ - "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ - "fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ - "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ - - "ldr d11, [%[r1]] \n" /* load input r1, 9th - element*/ - - /* r0, r2, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ - "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ - "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ - "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ - "fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ - "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ - "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ - "fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ - - "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ - - /* r2 mul w7, get out r0 */ - "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ - "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ - "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ - - "ldr d13, [%[r3]] \n" /* load input r3, 9th - element*/ - - /* r0, r2, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ - "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ - "fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 * - r0[8]*/ - "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ - "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ - "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ - "fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 * - r2[8]*/ - - "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ - - /* r2, mul w8, get out r0 */ - "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ - "fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ - "fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 * - r2[8]*/ - - "ldr d14, [%[r4]] \n" /* load input r4, 9th - element*/ - - /* r1, r3, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ - "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ - "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ - "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ - "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ - "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ - "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ - - "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ - - /* r1, r3, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ - "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ - "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ - "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ - "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ - "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ - "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ - "fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ - - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - - /* r1, r3, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ - "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ - "fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 * - r1[8]*/ - - "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ - "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ - - "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ - "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ - "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ - "fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 * - r3[8]*/ - - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ - - /* r4, mul w6, get out r1 */ - "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/ - "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/ - "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/ - "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/ - - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ - - /* r4, mul w7, get out r1 */ - "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/ - "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/ - "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/ - "fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/ - - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - - /* r4, mul w8, get out r1 */ - "fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/ - "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/ - "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/ - "fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 * - r4[8]*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - - "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ - "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ - - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); - - wc0 += 9 * hout_c_block; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - inr4 += win_round; - } -#else // not __aarch64__ - for (int i = 0; i < ic; ++i) { - const float* wc0 = weight_c + i * w_stride_chin; - - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - const float* r4 = inr4; - - int cnt = w_loop; - asm volatile( - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ " - "load outr0, w0, w1, c0~c3\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - /* load r0, r2 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " - "8 float\n" - "vld1.32 {d8}, [%[r0]] @ load r0, " - "9th float\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - /* main loop */ - "0: @ main " - "loop\n" - /* mul r0, with w0, w1, w2 */ - "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " - "outr1, w0, w1, c0~c3\n" - "vmla.f32 q8, q5, d0[0] @ w0 * " - "inr00\n" - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " - "outr1, w2, w3, c0~c3\n" - "vmla.f32 q9, q5, d1[0] @ w0 * " - "inr02\n" - "vmla.f32 q10, q5, d2[0] @ w0 * " - "inr04\n" - "vmla.f32 q11, q5, d3[0] @ w0 * " - "inr06\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, " - "8 float\n" - "vmla.f32 q8, q6, d0[1] @ w1 * " - "inr01\n" - "vmla.f32 q9, q6, d1[1] @ w1 * " - "inr03\n" - "vmla.f32 q10, q6, d2[1] @ w1 * " - "inr05\n" - "vmla.f32 q11, q6, d3[1] @ w1 * " - "inr07\n" - "vld1.32 {d9}, [%[r2]] @ load r2, " - "9th float\n" - "vmla.f32 q8, q7, d1[0] @ w2 * " - "inr02\n" - "vmla.f32 q9, q7, d2[0] @ w2 * " - "inr04\n" - "vmla.f32 q10, q7, d3[0] @ w2 * " - "inr06\n" - "vmla.f32 q11, q7, d8[0] @ w2 * " - "inr08\n" - - "sub %[r2], %[r2], #32 @ r2 - 32, " - "load r2 twice\n" - - /* mul r2, with w0, w1, w2 */ - "vld1.32 {d0-d3}, [%[r1]]! @ load r1, " - "8 float\n" - "vmla.f32 q12, q5, d4[0] @ w0 * " - "inr20\n" - "vmla.f32 q13, q5, d5[0] @ w0 * " - "inr22\n" - "vmla.f32 q14, q5, d6[0] @ w0 * " - "inr24\n" - "vmla.f32 q15, q5, d7[0] @ w0 * " - "inr26\n" - "vld1.32 {d8}, [%[r1]] @ load r1, " - "9th float\n" - "vmla.f32 q12, q6, d4[1] @ w1 * " - "inr21\n" - "vmla.f32 q13, q6, d5[1] @ w1 * " - "inr23\n" - "vmla.f32 q14, q6, d6[1] @ w1 * " - "inr25\n" - "vmla.f32 q15, q6, d7[1] @ w1 * " - "inr27\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " - "w4, to q5, q6\n" - "vmla.f32 q12, q7, d5[0] @ w2 * " - "inr22\n" - "vmla.f32 q13, q7, d6[0] @ w2 * " - "inr24\n" - "vmla.f32 q14, q7, d7[0] @ w2 * " - "inr26\n" - "vmla.f32 q15, q7, d9[0] @ w2 * " - "inr28\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " - "to q7\n" - - /* mul r1, with w3, w4, w5 */ - "vmla.f32 q8, q5, d0[0] @ w3 * " - "inr10\n" - "vmla.f32 q9, q5, d1[0] @ w3 * " - "inr12\n" - "vmla.f32 q10, q5, d2[0] @ w3 * " - "inr14\n" - "vmla.f32 q11, q5, d3[0] @ w3 * " - "inr16\n" - "vld1.32 {d4-d7}, [%[r3]]! @ load r3, " - "8 float\n" - "vmla.f32 q8, q6, d0[1] @ w4 * " - "inr11\n" - "vmla.f32 q9, q6, d1[1] @ w4 * " - "inr13\n" - "vmla.f32 q10, q6, d2[1] @ w4 * " - "inr15\n" - "vmla.f32 q11, q6, d3[1] @ w4 * " - "inr17\n" - "vld1.32 {d9}, [%[r3]] @ load r3, " - "9th float\n" - "vmla.f32 q8, q7, d1[0] @ w5 * " - "inr12\n" - "vmla.f32 q9, q7, d2[0] @ w5 * " - "inr14\n" - "vmla.f32 q10, q7, d3[0] @ w5 * " - "inr16\n" - "vmla.f32 q11, q7, d8[0] @ w5 * " - "inr18\n" - - "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " - "- 32, to start address\n" - - /* mul r3, with w3, w4, w5 */ - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, " - "8 float\n" - "vmla.f32 q12, q5, d4[0] @ w3 * " - "inr30\n" - "vmla.f32 q13, q5, d5[0] @ w3 * " - "inr32\n" - "vmla.f32 q14, q5, d6[0] @ w3 * " - "inr34\n" - "vmla.f32 q15, q5, d7[0] @ w3 * " - "inr36\n" - "vld1.32 {d8}, [%[r2]] @ load r2, " - "9th float\n" - "vmla.f32 q12, q6, d4[1] @ w4 * " - "inr31\n" - "vmla.f32 q13, q6, d5[1] @ w4 * " - "inr33\n" - "vmla.f32 q14, q6, d6[1] @ w4 * " - "inr35\n" - "vmla.f32 q15, q6, d7[1] @ w4 * " - "inr37\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " - "w7, to q5, q6\n" - "vmla.f32 q12, q7, d5[0] @ w5 * " - "inr32\n" - "vmla.f32 q13, q7, d6[0] @ w5 * " - "inr34\n" - "vmla.f32 q14, q7, d7[0] @ w5 * " - "inr36\n" - "vmla.f32 q15, q7, d9[0] @ w5 * " - "inr38\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " - "to q7\n" - - /* mul r2, with w6, w7, w8 */ - "vmla.f32 q8, q5, d0[0] @ w6 * " - "inr20\n" - "vmla.f32 q9, q5, d1[0] @ w6 * " - "inr22\n" - "vmla.f32 q10, q5, d2[0] @ w6 * " - "inr24\n" - "vmla.f32 q11, q5, d3[0] @ w6 * " - "inr26\n" - "vld1.32 {d4-d7}, [%[r4]]! @ load r4, " - "8 float\n" - "vmla.f32 q8, q6, d0[1] @ w7 * " - "inr21\n" - "vmla.f32 q9, q6, d1[1] @ w7 * " - "inr23\n" - "vmla.f32 q10, q6, d2[1] @ w7 * " - "inr25\n" - "vmla.f32 q11, q6, d3[1] @ w7 * " - "inr27\n" - "vld1.32 {d9}, [%[r4]] @ load r4, " - "9th float\n" - "vmla.f32 q8, q7, d1[0] @ w8 * " - "inr22\n" - "vmla.f32 q9, q7, d2[0] @ w8 * " - "inr24\n" - "vmla.f32 q10, q7, d3[0] @ w8 * " - "inr26\n" - "vmla.f32 q11, q7, d8[0] @ w8 * " - "inr28\n" - - "sub %[wc0], %[wc0], #144 @ wc0 - " - "144 to start address\n" - - /* mul r4, with w6, w7, w8 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " - "8 float\n" - "vmla.f32 q12, q5, d4[0] @ w3 * " - "inr40\n" - "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " - "r00, r01, c0~c3\n" - "vmla.f32 q13, q5, d5[0] @ w3 * " - "inr42\n" - "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " - "r02, r03, c0~c3\n" - "vmla.f32 q14, q5, d6[0] @ w3 * " - "inr44\n" - "vmla.f32 q15, q5, d7[0] @ w3 * " - "inr46\n" - "vld1.32 {d8}, [%[r0]] @ load " - "r0, 9th float\n" - "vmla.f32 q12, q6, d4[1] @ w4 * " - "inr41\n" - "vmla.f32 q13, q6, d5[1] @ w4 * " - "inr43\n" - "vmla.f32 q14, q6, d6[1] @ w4 * " - "inr45\n" - "vmla.f32 q15, q6, d7[1] @ w4 * " - "inr47\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vmla.f32 q12, q7, d5[0] @ w5 * " - "inr42\n" - "vmla.f32 q13, q7, d6[0] @ w5 * " - "inr44\n" - "vmla.f32 q14, q7, d7[0] @ w5 * " - "inr46\n" - "vmla.f32 q15, q7, d9[0] @ w5 * " - "inr48\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " - "r10, r11, c0~c3\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " - "r12, r13, c0~c3\n" - - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " - "outr0, w0, w1, c0~c3\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - "subs %[cnt], #1 @ loop " - "count--\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1), - [wc0] "+r"(wc0) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - inr4 += win_round; - } -#endif // __aarch64__ - block_inr0 = block_inr4; - block_inr1 = block_inr0 + in_len; - block_inr2 = block_inr1 + in_len; - block_inr3 = block_inr2 + in_len; - block_inr4 = block_inr3 + in_len; - } - - write_to_output_c4_fp32(pre_out, - dout_batch, - c, - c + hout_c_block, - h, - h + h_kernel, - 0, - wout_round, - oc, - oh, - ow, - flag_relu, - ptr_write); - } - -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < c_remain; ++c) { -#ifdef ARM_WITH_OMP - float* pre_out = - pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; -#else - float* pre_out = pre_din + pre_in_size; -#endif - - const float* block_inr0 = cblock_inr0; - const float* block_inr1 = cblock_inr1; - const float* block_inr2 = cblock_inr2; - const float* block_inr3 = cblock_inr3; - const float* block_inr4 = cblock_inr4; - - //! get weights ptr of remained - const float* weight_c = weights + c_round_down * w_stride; - - //! fill bias to one channel - const float* bias_ptr = ptr_zero; - if (flag_bias) { - bias_ptr = bias + c_round_down + c; - } - fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); - - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { - const float* wc0 = weight_c; - - const float* inr0 = block_inr0; - const float* inr1 = block_inr1; - const float* inr2 = block_inr2; - const float* inr3 = block_inr3; - const float* inr4 = block_inr4; - - float* pre_out0 = pre_out + hk * wout_round; - float* pre_out1 = pre_out0 + wout_round; -#ifdef __aarch64__ - for (int i = 0; i < ic; ++i) { - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - //! get valid weights of current output channel - float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 - float32x4_t w1 = vdupq_n_f32(wc0[c + 4]); // w1, v24 - float32x4_t w2 = vdupq_n_f32(wc0[c + 8]); // w2, v25 - float32x4_t w3 = vdupq_n_f32(wc0[c + 12]); // w3, v26 - float32x4_t w4 = vdupq_n_f32(wc0[c + 16]); // w4, v27 - float32x4_t w5 = vdupq_n_f32(wc0[c + 20]); // w5, v28 - float32x4_t w6 = vdupq_n_f32(wc0[c + 24]); // w6, v29 - float32x4_t w7 = vdupq_n_f32(wc0[c + 28]); // w7, v30 - float32x4_t w8 = vdupq_n_f32(wc0[c + 32]); // w8, v31 - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - const float* r4 = inr4; - - int cnt = w_loop; - asm volatile( - "ldr q21, [%[ptr_out0]] \n" /* load outr00, - outr01, - outr02, - outr03*/ - - "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "2: \n" /* main loop*/ - /* r0, r2, mul w0, get out r0, r1 */ - "ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11, - outr12, outr13*/ - - "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2, - 4, 6]*/ - "fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2, - 4, 6]*/ - - "ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/ - - /* r2 mul w6, get out r0*/ - "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2, - 4, 6]*/ - "ldr d11, [%[r1]] \n" /* load input r1, 9th - element*/ - - /* shift left 1 */ - "ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/ - "ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/ - - /* r0, r2, mul w1, get out r0, r1 */ - "fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3, - 5, 7]*/ - "fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3, - 5, 7]*/ - - "ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/ - - /* r2 mul w7, get out r0 */ - "fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1, - 3, 5, 7]*/ - - "ldr d13, [%[r3]] \n" /* load input r3, 9th - element*/ - - /* r0, r2, mul w2, get out r0, r1 */ - "fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4, - 6, 8]*/ - "fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4, - 6, 8]*/ - - "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/ - - /* r2, mul w8, get out r0 */ - "fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2, - 4, 6, 8]*/ - - "ldr d14, [%[r4]] \n" /* load input r4, 9th - element*/ - - /* r1, r3, mul w3, get out r0, r1 */ - "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2, - 4, 6]*/ - "fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2, - 4, 6]*/ - - /* shift left 1 */ - "ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/ - "ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/ - - "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ - - /* r1, r3, mul w4, get out r0, r1 */ - "fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3, - 5, 7]*/ - "fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3, - 5, 7]*/ - - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - - /* r1, r3, mul w5, get out r0, r1 */ - "fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/ - "fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/ - - "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/ - - /* r4, mul w6, get out r1 */ - "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2, - 4, 6]*/ - - "ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/ - "ldr q21, [%[ptr_out0]] \n" /* load outr0*/ - - /* r4, mul w7, get out r1 */ - "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3, - 5, 7]*/ - - /* r4, mul w8, get out r1 */ - "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4, - 6, 8]*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - "str q22, [%[ptr_out1]], #16 \n" /* save outr1*/ - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v21", - "v22"); - - wc0 += 36; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - inr4 += win_round; - } -#else // not __aarch64__ - for (int i = 0; i < ic; ++i) { - float* ptr_out0 = pre_out0; - float* ptr_out1 = pre_out1; - - //! get valid weights of current output channel - float w_tmp[12] = {wc0[c], - wc0[c + 4], - wc0[c + 8], - 0.f, - wc0[c + 12], - wc0[c + 16], - wc0[c + 20], - 0.f, - wc0[c + 24], - wc0[c + 28], - wc0[c + 32], - 0.f}; - float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 - float32x4_t w1 = vld1q_f32(w_tmp + 4); // w3, w4, w5, q1 - float32x4_t w2 = vld1q_f32(w_tmp + 8); // w6, w7, w8, q2 - - const float* r0 = inr0; - const float* r1 = inr1; - const float* r2 = inr2; - const float* r3 = inr3; - const float* r4 = inr4; - - int cnt = w_loop / 2; - if (cnt > 0) { - asm volatile( - /* main loop */ - "0: @ " - "main loop\n" - "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " - "or01\n" - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " - "or11\n" - "vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 " - "float, interleave\n" - "vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 " - "float, interleave\n" - "vld1.32 {d22}, [%[r2]] @ load 16th " - "float\n" - - /* r2 * w2, r2 * w0, get or0, or1 */ - "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, " - "1, 3, 5, 7\n" - "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, " - "9, 11, 13, 15\n" - "vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 " - "float, interleave\n" - "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, " - "1, 3, 5, 7\n" - "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, " - "9, 11, 13, 15\n" - - "vext.32 q4, q3, q5, #1 @ r2, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q6, q5, q11, #1 @ r2, shift " - "left 1, get 10, 12, 14, 16\n" - - "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " - "0, 2, 4, 6\n" - "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, " - "8, 10, 12, 14\n" - "vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 " - "float, interleave\n" - "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, " - "0, 2, 4, 6\n" - "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, " - "8, 10, 12, 14\n" - - "vld1.32 {d22}, [%[r0]] @ load 16th " - "float\n" - - "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, " - "2, 4, 6, 8\n" - "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, " - "2, 4, 6, 8\n" - "vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 " - "float, interleave\n" - "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, " - "10, 12, 14, 16\n" - "vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, " - "10, 12, 14, 16\n" - "vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 " - "float, interleave\n" - - /* r0 * w0, get or0, r3 * w1, get or1*/ - "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " - "1, 3, 5, 7\n" - "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, " - "9, 11, 13, 15\n" - "vext.32 q8, q7, q9, #1 @ r0, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q10, q9, q11, #1 @ r0, shift " - "left 1, get 10, 12, 14, 16\n" - "vld1.32 {d22}, [%[r3]] @ load 16th " - "float\n" - "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, " - "1, 3, 5, 7\n" - "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, " - "9, 11, 13, 15\n" - - "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, " - "0, 2, 4, 6\n" - "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, " - "8, 10, 12, 14\n" - "vext.32 q4, q3, q5, #1 @ r3, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q6, q5, q11, #1 @ r3, shift " - "left 1, get 10, 12, 14, 16\n" - "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, " - "0, 2, 4, 6\n" - "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, " - "8, 10, 12, 14\n" - - "vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, " - "2, 4, 6, 8\n" - "vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 " - "float, interleave\n" - "vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, " - "10, 12, 14, 16\n" - "vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 " - "float, interleave\n" - "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, " - "2, 4, 6, 8\n" - "vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 " - "float, interleave\n" - "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, " - "10, 12, 14, 16\n" - "vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 " - "float, interleave\n" - - "vld1.32 {d22}, [%[r1]] @ load 16th " - "float\n" - - /* r1 * w1, get or0, r4 * w2, get or1 */ - "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " - "1, 3, 5, 7\n" - "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, " - "9, 11, 13, 15\n" - "vext.32 q8, q7, q9, #1 @ r1, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q10, q9, q11, #1 @ r1, shift " - "left 1, get 10, 12, 14, 16\n" - "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, " - "1, 3, 5, 7\n" - "vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, " - "9, 11, 13, 15\n" - "vld1.32 {d22}, [%[r4]] @ load 16th " - "float\n" - - "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, " - "0, 2, 4, 6\n" - "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, " - "8, 10, 12, 14\n" - "vext.32 q4, q3, q5, #1 @ r1, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q6, q5, q11, #1 @ r1, shift " - "left 1, get 10, 12, 14, 16\n" - "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, " - "0, 2, 4, 6\n" - "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, " - "8, 10, 12, 14\n" - - "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, " - "2, 4, 6, 8\n" - "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, " - "10, 12, 14, 16\n" - "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, " - "2, 4, 6, 8\n" - "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, " - "10, 12, 14, 16\n" - - "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n" - - "subs %[cnt], #1 @loop count " - "-1\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - //! deal with remain ow - if (w_loop & 1) { - ptr_out0[0] += - r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + - r1[0] * w_tmp[4] + r1[1] * w_tmp[5] + r1[2] * w_tmp[6] + - r2[0] * w_tmp[8] + r2[1] * w_tmp[9] + r2[2] * w_tmp[10]; - - ptr_out0[1] += - r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + - r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + r1[4] * w_tmp[6] + - r2[2] * w_tmp[8] + r2[3] * w_tmp[9] + r2[4] * w_tmp[10]; - - ptr_out0[2] += - r0[4] * w_tmp[0] + r0[5] * w_tmp[1] + r0[6] * w_tmp[2] + - r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + r1[6] * w_tmp[6] + - r2[4] * w_tmp[8] + r2[5] * w_tmp[9] + r2[6] * w_tmp[10]; - - ptr_out0[3] += - r0[6] * w_tmp[0] + r0[7] * w_tmp[1] + r0[8] * w_tmp[2] + - r1[6] * w_tmp[4] + r1[7] * w_tmp[5] + r1[8] * w_tmp[6] + - r2[6] * w_tmp[8] + r2[7] * w_tmp[9] + r2[8] * w_tmp[10]; - - ptr_out1[0] += - r2[0] * w_tmp[0] + r2[1] * w_tmp[1] + r2[2] * w_tmp[2] + - r3[0] * w_tmp[4] + r3[1] * w_tmp[5] + r3[2] * w_tmp[6] + - r4[0] * w_tmp[8] + r4[1] * w_tmp[9] + r4[2] * w_tmp[10]; - - ptr_out1[1] += - r2[2] * w_tmp[0] + r2[3] * w_tmp[1] + r2[4] * w_tmp[2] + - r3[2] * w_tmp[4] + r3[3] * w_tmp[5] + r3[4] * w_tmp[6] + - r4[2] * w_tmp[8] + r4[3] * w_tmp[9] + r4[4] * w_tmp[10]; - - ptr_out1[2] += - r2[4] * w_tmp[0] + r2[5] * w_tmp[1] + r2[6] * w_tmp[2] + - r3[4] * w_tmp[4] + r3[5] * w_tmp[5] + r3[6] * w_tmp[6] + - r4[4] * w_tmp[8] + r4[5] * w_tmp[9] + r4[6] * w_tmp[10]; - - ptr_out1[3] += - r2[6] * w_tmp[0] + r2[7] * w_tmp[1] + r2[8] * w_tmp[2] + - r3[6] * w_tmp[4] + r3[7] * w_tmp[5] + r3[8] * w_tmp[6] + - r4[6] * w_tmp[8] + r4[7] * w_tmp[9] + r4[8] * w_tmp[10]; - } - - wc0 += 36; - inr0 += win_round; - inr1 += win_round; - inr2 += win_round; - inr3 += win_round; - inr4 += win_round; - } -#endif // __aarch64__ - block_inr0 = block_inr4; - block_inr1 = block_inr0 + in_len; - block_inr2 = block_inr1 + in_len; - block_inr3 = block_inr2 + in_len; - block_inr4 = block_inr3 + in_len; - } - write_to_output_c1_fp32(pre_out, - dout_batch, - c + c_round_down, - c + c_round_down + 1, - h, - h + h_kernel, - 0, - wout_round, - oc, - oh, - ow, - flag_relu, - ptr_write); - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_gemmlike.cc b/lite/backends/arm/math/conv_gemmlike.cc deleted file mode 100644 index 1dd102db1e96efff89a7df678fcb35cf01890191..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_gemmlike.cc +++ /dev/null @@ -1,285 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_gemmlike.h" -#include -#include "lite/backends/arm/math/gemm_prepacked_int8.h" -#include "lite/backends/arm/math/packed_sgemm.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -/********************* Gemmlike Conv Precision Is Float ***********************/ -template <> -bool GemmLikeConv::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int ow = o_dims[3]; - int oh = o_dims[2]; - int oc = o_dims[1]; - int kw = w_dims[3]; - int kh = w_dims[2]; - int sw = param.strides[1]; - int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; - - int m = oc / param.groups; - int k = ic * kh * kw / param.groups; - int n = oh * ow; - bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); - bool ks_equal = (sw == sh) && (kw == kh); - //! select conv gemmlike kernel - if (kw == 1 && sw == 1 && pw == 0 && kps_equal) { - //! 1x1s1p0 gemmlike conv - impl_ = conv1x1s1_gemm; - } else { - //! otherwise case - if (kw == 3 && sw == 1 && n > 1 && ks_equal) { - idx_data_.Resize({1, 1, 1, n * kh * kw}); - int* idx_out = idx_data_.mutable_data(); - for (int i = 0; i < oh; ++i) { - for (int j = 0; j < ow; ++j) { - compute_offset(idx_out, i, j, kh, kw, ih, iw, ph, pw, dh, dw); - idx_out += kh * kw; - } - } - } - //! im2col gemmlike conv - impl_ = conv_im2col_gemm; - this->ctx_->ExtendWorkspace(k * n * sizeof(float)); - } - - if (n > 1) { - int hblock = get_hblock(this->ctx_->arch()); - int m_roundup = hblock * ((m + hblock - 1) / hblock); - int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; - float* w_trans_ptr = nullptr; - weights_trans_.Resize({1, 1, 1, group_size_round_up * param.groups}); - w_trans_ptr = weights_trans_.mutable_data(); - const auto* w_data = param.filter->data(); - for (int g = 0; g < param.groups; ++g) { - const float* weights_group = w_data + g * m * k; - float* weights_trans_ptr = w_trans_ptr + g * group_size_round_up; - prepackA(weights_trans_ptr, - weights_group, - 1.f, - k, - 0, - m, - 0, - k, - false, - this->ctx_); - } - is_weights_transed_ = true; - } - return true; -} - -template <> -bool GemmLikeConv::init(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template <> -bool GemmLikeConv::run(const operators::ConvParam& param) { - // start timer - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - const int* idx_data = idx_data_.mutable_data(); - - if (is_weights_transed_) { - w_data = weights_trans_.data(); - } - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int oc = o_dims[1]; - - impl_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_, - idx_data); - - // timer end - return true; -} - -/********************* Gemmlike Conv Precision Is Int8 ************************/ -template -bool GemmLikeConvInt8::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int ow = o_dims[3]; - int oh = o_dims[2]; - int oc = o_dims[1]; - int kw = w_dims[3]; - int kh = w_dims[2]; - int sw = param.strides[1]; - int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; - - int m = oc / param.groups; - int k = ic * kh * kw / param.groups; - int n = oh * ow; - w_scale_ = param.weight_scale; - //! update weights scale - if (Ptype_out == PRECISION(kInt8) || Ptype_out == PRECISION(kFloat)) { - CHECK_EQ(this->w_scale_.size(), oc) << "weights scale size must be chout"; - float input_scale = param.input_scale; - for (auto& w_s : w_scale_) { - w_s *= input_scale; - if (Ptype_out == PRECISION(kInt8)) { - w_s /= param.output_scale; - } - } - } - - bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); - bool ks_equal = (sw == sh) && (kw == kh); - //! select conv gemmlike kernel - if (kw == 1 && sw == 1 && pw == 0 && kps_equal) { - //! 1x1s1p0 gemmlike conv - impl_int8_ = conv1x1s1_gemm_int8; - } else { - //! otherwise case - if (kw == 3 && sw == 1 && n > 1 && ks_equal) { - idx_data_.Resize({1, 1, 1, n * kh * kw}); - int* idx_out = idx_data_.mutable_data(); - for (int i = 0; i < oh; ++i) { - for (int j = 0; j < ow; ++j) { - compute_offset(idx_out, i, j, kh, kw, ih, iw, ph, pw, dh, dw); - idx_out += kh * kw; - } - } - } - //! im2col gemmlike conv - impl_int8_ = conv_im2col_gemm_int8; - this->ctx_->ExtendWorkspace(k * n); - } - - if (n > 1) { - prepackA_int8(&this->weights_trans_, - *param.filter, - m, - k, - param.groups, - false, - this->ctx_); - this->is_weights_transed_ = true; - } - return true; -} - -template -bool GemmLikeConvInt8::init(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template -bool GemmLikeConvInt8::run(const operators::ConvParam& param) { - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - const int32_t* idx_data = idx_data_.mutable_data(); - - if (this->is_weights_transed_ == true) { - w_data = this->weights_trans_.template data(); - } - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int oc = o_dims[1]; - - impl_int8_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_, - Ptype_out, - this->w_scale_.data(), - idx_data); - - return true; -} - -template class GemmLikeConvInt8; -template class GemmLikeConvInt8; -template class GemmLikeConvInt8; - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_gemmlike.h b/lite/backends/arm/math/conv_gemmlike.h deleted file mode 100644 index 5986b5c2c818e9c7f81fe78d5bf69cb0496fcd20..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_gemmlike.h +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/core/target_wrapper.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -template -class GemmLikeConv - : public ImplBase { - public: - typedef void (*conv_im2col_gemm_impl)(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx, - const int* idx_ptr); - - GemmLikeConv() = default; - ~GemmLikeConv() {} - - virtual bool init(const operators::ConvParam& param, ARMContext* ctx) { - LOG(FATAL) << "GemmLikeConv::init() not implemented."; - } - - virtual bool create(const operators::ConvParam& param, ARMContext* ctx) { - LOG(FATAL) << "GemmLikeConv::create() not implemented."; - } - - virtual bool run(const operators::ConvParam& param) { - LOG(FATAL) << "GemmLikeConv::run() not implemented."; - } - - protected: - bool is_weights_transed_{false}; - Tensor idx_data_; - Tensor weights_trans_; - - private: - conv_im2col_gemm_impl impl_{nullptr}; -}; - -template -class GemmLikeConvInt8 : public GemmLikeConv { - public: - typedef void (*conv_im2col_gemm_int8_impl)(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - ARMContext* ctx, - PrecisionType out_type, - const float* scale, - const int* idx_ptr); - - GemmLikeConvInt8() = default; - ~GemmLikeConvInt8() {} - - virtual bool init(const operators::ConvParam& param, ARMContext* ctx); - - virtual bool create(const operators::ConvParam& param, ARMContext* ctx); - - virtual bool run(const operators::ConvParam& param); - - private: - conv_im2col_gemm_int8_impl impl_int8_{nullptr}; - std::vector w_scale_; -}; - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index dbea9d643ee12a8583f1c19aa5a20f35b2da70c6..010563bf936c2f8454162c8aad48cd8815c5f7af 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -12,14 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// #include "saber/funcs/impl/arm/neon/impl/conv_arm_depthwise.h" -// #include "saber/funcs/impl/arm/neon/impl/conv_arm_impl.h" -// #include "saber/funcs/impl/arm/neon/impl/gemm_prepacked_int8.h" -// #include "saber/funcs/impl/arm/neon/impl/gemv_arm_int8.h" -// #include "saber/funcs/impl/arm/neon/impl/sgemv_arm.h" - #include "lite/backends/arm/math/conv_impl.h" #include +#include "lite/backends/arm/math/conv_depthwise.h" #include "lite/backends/arm/math/gemm_prepacked_int8.h" #include "lite/backends/arm/math/gemv_arm_int8.h" #include "lite/backends/arm/math/packed_sgemm.h" @@ -107,17 +102,17 @@ inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { */ template void im2col(const Dtype* data_im, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, Dtype* data_col) { const int output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; @@ -150,121 +145,6 @@ void im2col(const Dtype* data_im, } } } -void compute_offset(int* idx_out, - int h, - int w, - int kernel_h, - int kernel_w, - int height, - int width, - int pad_h, - int pad_w, - int dilation_h, - int dilation_w) { - int idx_h[kernel_h]; // NOLINT - int idx_w[kernel_w]; // NOLINT - for (int i = 0; i < kernel_h; ++i) { - idx_h[i] = h - pad_h + i * dilation_h; - } - for (int i = 0; i < kernel_w; ++i) { - idx_w[i] = w - pad_w + i * dilation_w; - } - for (int k_h = 0; k_h < kernel_h; ++k_h) { - for (int k_w = 0; k_w < kernel_w; ++k_w) { - idx_out[k_h * kernel_w + k_w] = - (idx_h[k_h] >= 0 && idx_w[k_w] >= 0 && idx_h[k_h] < height && - idx_w[k_w] < width) - ? idx_h[k_h] * width + idx_w[k_w] - : -1; - } - } -} -template -void im2col3x3(const Dtype* data_im, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - Dtype* data_col, - const int* idx) { - const int output_h = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int output_w = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - int kernel_stride = kernel_h * kernel_w; - int in_channel_stride = height * width; - const int* idx_out = idx; - Dtype* data_col_ptr = data_col; - - bool flag_continue = false; - if (dilation_h == 1 && dilation_w == 1) { - flag_continue = true; - } - - for (int o = 0; o < output_h * output_w; o += 1) { - const Dtype* data_im_ptr = data_im; - - // int* idx_out_d = idx_out; - - int idx_out_d0 = idx_out[0]; - int idx_out_d1 = idx_out[1]; - int idx_out_d2 = idx_out[2]; - int idx_out_d3 = idx_out[3]; - int idx_out_d4 = idx_out[4]; - int idx_out_d5 = idx_out[5]; - int idx_out_d6 = idx_out[6]; - int idx_out_d7 = idx_out[7]; - int idx_out_d8 = idx_out[8]; - - for (int i = 0; i < channels; i += 1) { - if (idx_out_d0 >= 0 && idx_out_d2 >= 0 && idx_out_d6 >= 0 && - idx_out_d8 >= 0) { - if (flag_continue) { - memcpy( - data_col_ptr, data_im_ptr + idx_out_d0, kernel_w * sizeof(Dtype)); - memcpy(data_col_ptr + kernel_w, - data_im_ptr + idx_out_d3, - kernel_w * sizeof(Dtype)); - memcpy(data_col_ptr + kernel_w + kernel_w, - data_im_ptr + idx_out_d6, - kernel_w * sizeof(Dtype)); - } else { - data_col_ptr[0] = data_im_ptr[idx_out_d0]; - data_col_ptr[1] = data_im_ptr[idx_out_d1]; - data_col_ptr[2] = data_im_ptr[idx_out_d2]; - data_col_ptr[3] = data_im_ptr[idx_out_d3]; - data_col_ptr[4] = data_im_ptr[idx_out_d4]; - data_col_ptr[5] = data_im_ptr[idx_out_d5]; - data_col_ptr[6] = data_im_ptr[idx_out_d6]; - data_col_ptr[7] = data_im_ptr[idx_out_d7]; - data_col_ptr[8] = data_im_ptr[idx_out_d8]; - } - } else { - data_col_ptr[0] = (idx_out_d0 < 0) ? 0 : data_im_ptr[idx_out_d0]; - data_col_ptr[1] = (idx_out_d1 < 0) ? 0 : data_im_ptr[idx_out_d1]; - data_col_ptr[2] = (idx_out_d2 < 0) ? 0 : data_im_ptr[idx_out_d2]; - data_col_ptr[3] = (idx_out_d3 < 0) ? 0 : data_im_ptr[idx_out_d3]; - data_col_ptr[4] = (idx_out_d4 < 0) ? 0 : data_im_ptr[idx_out_d4]; - data_col_ptr[5] = (idx_out_d5 < 0) ? 0 : data_im_ptr[idx_out_d5]; - data_col_ptr[6] = (idx_out_d6 < 0) ? 0 : data_im_ptr[idx_out_d6]; - data_col_ptr[7] = (idx_out_d7 < 0) ? 0 : data_im_ptr[idx_out_d7]; - data_col_ptr[8] = (idx_out_d8 < 0) ? 0 : data_im_ptr[idx_out_d8]; - } - data_im_ptr += height * width; - data_col_ptr += kernel_stride; - } - // data_col_ptr += channels * kernel_stride; - // idx_out += kernel_stride * 2; - idx_out += kernel_stride; - } -} /** * \brief convolution function for kernel size 1x1, stride size 1, gemm @@ -282,8 +162,7 @@ void conv1x1s1_gemm(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, - ARMContext* ctx, - const int* idx_ptr) { + ARMContext* ctx) { int channel_size_out = ow * oh; int channel_size_in = win * ih; @@ -294,21 +173,14 @@ void conv1x1s1_gemm(const float* i_data, bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active) { - // if (param.activation_param.active == Active_relu && - // fabs(param.activation_param.negative_slope) < 1e-6f) { - // flag_relu = true; - // } - // } - int hblock = get_hblock(ctx->arch()); + + int hblock = get_hblock(ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int weights_size_per_group = m * k; if (n > 1) { weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; } - // int weights_size_per_group = m_roundup * k;//oc * ic / (group * - // group); //! use gemv when the output channel size = 1 for (int b = 0; b < num; ++b) { // dC @@ -351,8 +223,9 @@ void conv1x1s1_gemm(const float* i_data, } } +template void conv1x1s1_gemm_int8(const int8_t* i_data, - int32_t* o_data, + Dtype* o_data, int num, int oc, int oh, @@ -361,12 +234,10 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, int ih, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, ARMContext* ctx, - PrecisionType out_type, - const float* scale, - const int32_t* idx_ptr) { + const float* scale) { int group = param.groups; int channel_size_out = ow * oh; int channel_size_in = win * ih; @@ -386,94 +257,71 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, for (int b = 0; b < num; ++b) { // dC for (int g = 0; g < group; ++g) { - signed char* dout_group = - reinterpret_cast(o_data) + - (b * oc + g * m) * channel_size_out * PrecisionTypeLength(out_type); + Dtype* dout_group = o_data + (b * oc + g * m) * channel_size_out; const int8_t* din_group = i_data + (b * ic + g * k) * channel_size_in; const int8_t* weights_group = weights + g * weights_size_per_group; - const int* bias_group = bias + g * m; + const float* bias_group = bias + g * m; const float* scale_group = scale + g * m; if (n == 1) { - if (out_type == PRECISION(kFloat)) { - gemv_int8(weights_group, - din_group, - reinterpret_cast(dout_group), - false, - m, - k, - scale_group, - flag_bias, - bias_group, - flag_relu); - } else if (out_type == PRECISION(kInt8)) { // int8 - gemv_int8(weights_group, - din_group, - dout_group, - false, - m, - k, - scale_group, - flag_bias, - bias_group, - flag_relu); - } else { - gemv_int8(weights_group, - din_group, - reinterpret_cast(dout_group), - false, - m, - k, - scale_group, - flag_bias, - bias_group, - flag_relu); - } + gemv_int8(weights_group, + din_group, + dout_group, + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu, + ctx); } else { - if (out_type == PRECISION(kFloat)) { - gemm_prepack_int8(weights_group, - din_group, - bias_group, - reinterpret_cast(dout_group), - m, - n, - k, - flag_bias, - flag_relu, - false, - scale_group, - ctx); - } else if (out_type == PRECISION(kInt8)) { // int8 - gemm_prepack_int8(weights_group, - din_group, - bias_group, - dout_group, - m, - n, - k, - flag_bias, - flag_relu, - false, - scale_group, - ctx); - } else { - gemm_prepack_int8(weights_group, - din_group, - bias_group, - reinterpret_cast(dout_group), - m, - n, - k, - flag_bias, - flag_relu, - false, - scale_group, - ctx); - } + gemm_prepack_int8(weights_group, + din_group, + bias_group, + dout_group, + m, + n, + k, + flag_bias, + flag_relu, + false, + scale_group, + ctx); } } } } +template void conv1x1s1_gemm_int8(const int8_t* i_data, + int8_t* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +template void conv1x1s1_gemm_int8(const int8_t* i_data, + float* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + /** * \brief convolution function for kernel size 3x3, stride size 2, gemm * implementation @@ -490,8 +338,7 @@ void conv_im2col_gemm(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, - ARMContext* ctx, - const int* idx_ptr) { + ARMContext* ctx) { const int group = param.groups; auto filter_dims = param.filter->dims(); const int kernel_h = filter_dims[2]; @@ -504,22 +351,13 @@ void conv_im2col_gemm(const float* i_data, int channel_size_in = win * ih; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active) { - // if (param.activation_param.active == Active_relu && - // fabs(param.activation_param.negative_slope) < 1e-6f) { - // flag_relu = true; - // } - // } - int hblock = get_hblock(ctx->arch()); + int hblock = get_hblock(ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int weights_size_per_group = m * k; if (n > 1) { weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; } - bool flag_im2col2 = (kernel_h == 3 && kernel_w == 3 && - param.strides[0] == 1 && param.strides[1] == 1 && n > 1); - float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); @@ -534,36 +372,20 @@ void conv_im2col_gemm(const float* i_data, const float* bias_group = bias + g * m; float* dB = tmp_work_space; - if (flag_im2col2) { - im2col3x3(din_group, - chin_per_group, - ih, - win, - kernel_h, - kernel_w, - param.paddings[0], - param.paddings[1], - param.strides[0], - param.strides[1], - param.dilations[0], - param.dilations[1], - dB, - idx_ptr); - } else { - im2col(din_group, - chin_per_group, - ih, - win, - kernel_h, - kernel_w, - param.paddings[0], - param.paddings[1], - param.strides[0], - param.strides[1], - param.dilations[0], - param.dilations[1], - dB); - } + im2col(din_group, + chin_per_group, + ih, + win, + kernel_h, + kernel_w, + param.paddings[0], + param.paddings[1], + param.strides[0], + param.strides[1], + param.dilations[0], + param.dilations[1], + dB); + if (n == 1) { sgemv(weights_group, dB, @@ -576,10 +398,7 @@ void conv_im2col_gemm(const float* i_data, flag_relu); } else { int ldb = n; - if (flag_im2col2) { - ldb = k; - } - sgemm_prepack(flag_im2col2, + sgemm_prepack(false, m, n, k, @@ -598,8 +417,9 @@ void conv_im2col_gemm(const float* i_data, } } +template void conv_im2col_gemm_int8(const int8_t* i_data, - int32_t* o_data, + Dtype* o_data, int num, int oc, int oh, @@ -608,12 +428,10 @@ void conv_im2col_gemm_int8(const int8_t* i_data, int ih, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, ARMContext* ctx, - PrecisionType out_type, - const float* scale, - const int32_t* idx_ptr) { + const float* scale) { int group = param.groups; auto filter_dims = param.filter->dims(); int kernel_h = filter_dims[2]; @@ -641,9 +459,6 @@ void conv_im2col_gemm_int8(const int8_t* i_data, weights_size_per_group = ((m_roundup * k_roundup + 15) / 16) * 16; } - bool flag_im2col2 = (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && - stride_w == 1 && n > 1); - int8_t* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(int8_t); @@ -651,249 +466,442 @@ void conv_im2col_gemm_int8(const int8_t* i_data, for (int b = 0; b < num; ++b) { // dC for (int g = 0; g < group; ++g) { - signed char* dout_group = - reinterpret_cast(o_data) + - (b * oc + g * m) * channel_size_out * PrecisionTypeLength(out_type); + Dtype* dout_group = o_data + (b * oc + g * m) * channel_size_out; const int8_t* din_group = static_cast(i_data) + (b * ic + g * chin_per_group) * channel_size_in; const int8_t* weights_group = static_cast(weights) + g * weights_size_per_group; - const int* bias_group = static_cast(bias) + g * m; + const float* bias_group = bias + g * m; int8_t* dB = tmp_work_space; const float* scale_group = scale + g * m; - if (flag_im2col2) { - im2col3x3(din_group, - chin_per_group, - ih, - win, - kernel_h, - kernel_w, - pad_h, - pad_w, - stride_h, - stride_w, - dila_h, - dila_w, - dB, - idx_ptr); - - } else { - im2col(din_group, - chin_per_group, - ih, - win, - kernel_h, - kernel_w, - pad_h, - pad_w, - stride_h, - stride_w, - dila_h, - dila_w, - dB); - } + im2col(din_group, + chin_per_group, + ih, + win, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dila_h, + dila_w, + dB); if (n == 1) { - if (out_type == PRECISION(kFloat)) { - gemv_int8(weights_group, - dB, - reinterpret_cast(dout_group), - false, - m, - k, - scale_group, - flag_bias, - bias_group, - flag_relu); - } else if (out_type == PRECISION(kInt8)) { // int8 - gemv_int8(weights_group, - dB, - dout_group, - false, - m, - k, - scale_group, - flag_bias, - bias_group, - flag_relu); - } else { - gemv_int8(weights_group, - dB, - reinterpret_cast(dout_group), - false, - m, - k, - scale_group, - flag_bias, - bias_group, - flag_relu); - } + gemv_int8(weights_group, + dB, + dout_group, + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu, + ctx); } else { - if (out_type == PRECISION(kFloat)) { - gemm_prepack_int8(weights_group, - dB, - bias_group, - reinterpret_cast(dout_group), - m, - n, - k, - flag_bias, - flag_relu, - flag_im2col2, - scale_group, - ctx); - } else if (out_type == PRECISION(kInt8)) { // int8 - gemm_prepack_int8(weights_group, - dB, - bias_group, - dout_group, - m, - n, - k, - flag_bias, - flag_relu, - flag_im2col2, - scale_group, - ctx); - } else { - gemm_prepack_int8(weights_group, - dB, - bias_group, - reinterpret_cast(dout_group), - m, - n, - k, - flag_bias, - flag_relu, - flag_im2col2, - scale_group, - ctx); - } + gemm_prepack_int8(weights_group, + dB, + bias_group, + dout_group, + m, + n, + k, + flag_bias, + flag_relu, + false, + scale_group, + ctx); } } } } -void conv_depthwise_3x3(const float* i_data, - float* o_data, - int num, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx) { - int pad = param.paddings[1]; +template void conv_im2col_gemm_int8(const int8_t* i_data, + int8_t* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +template void conv_im2col_gemm_int8(const int8_t* i_data, + float* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const int8_t* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +void conv_depthwise_3x3_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + if (pad_w != pad_h) { + LOG(FATAL) << "fp32 depthwise conv3x3 pad_w: " << pad_w + << ", pad_h: " << pad_h << " must be equal"; + return; + } int stride = param.strides[1]; + int pad = pad_w; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active) { - // if (param.activation_param.active == Active_relu && - // fabs(param.activation_param.negative_slope) < 1e-6f) { - // flag_relu = true; - // } - // } + if (stride == 1 && pad < 2) { // support pad = [0, 1] + conv_depthwise_3x3s1_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else if (stride == 2 && pad < 2) { // support pad = [0, 1] + conv_depthwise_3x3s2_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else { + LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride + << " or pad(<2): " << pad << " unsupported"; + } +#if 0 if (pad == 1) { - conv_depthwise_3x3p1(i_data, - o_data, - num, - oc, - oh, - ow, - ic, - ih, - win, - weights, - bias, - stride, - flag_bias, - flag_relu, - ctx); - } else if (pad == 0 && ih > 2) { - conv_depthwise_3x3p0(i_data, - o_data, - num, - oc, - oh, - ow, - ic, - ih, - win, - weights, - bias, - stride, - flag_bias, - flag_relu, - ctx); + conv_depthwise_3x3p1_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + stride, + flag_bias, + flag_relu, + ctx); + } else if (pad == 0 && h_in > 2) { + conv_depthwise_3x3p0_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + stride, + flag_bias, + flag_relu, + ctx); } else { LOG(FATAL) << "unsupport this type 3x3 dw conv"; } +#endif } -void conv_depthwise_5x5(const float* i_data, - float* o_data, - int num, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx) { +void conv_depthwise_5x5_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { int pad = param.paddings[1]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active && - // fabs(param.activation_param.negative_slope) < 1e-6f) { - // if (param.activation_param.active == Active_relu) { - // flag_relu = true; - // } - // } + ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); if (pad == 2 && stride == 2) { - conv_depthwise_5x5s2(i_data, - o_data, - num, - oc, - oh, - ow, - ic, - ih, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); + conv_depthwise_5x5s2_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + flag_relu, + ctx); } else if (stride == 1) { - conv_depthwise_5x5s1(i_data, - o_data, - num, - oc, - oh, - ow, - ic, - ih, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); + conv_depthwise_5x5s1_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + flag_relu, + ctx); } else { LOG(FATAL) << "unsupport this type 5x5 dw conv"; } } +void conv_depthwise_3x3_int8_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + if (stride == 1) { + conv_depthwise_3x3s1_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else if (stride == 2) { + conv_depthwise_3x3s2_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else { + LOG(FATAL) << "unsupport this type 3x3 dw conv int8"; + } +} + +void conv_depthwise_3x3_int8_int8(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + if (stride == 1) { + conv_depthwise_3x3s1_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else if (stride == 2) { + conv_depthwise_3x3s2_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else { + LOG(FATAL) << "unsupport this type 3x3 dw conv int8"; + } +} + +void conv_depthwise_5x5_int8_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + if (stride == 1) { + conv_depthwise_5x5s1_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else { + LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; + } +} + +void conv_depthwise_5x5_int8_int8(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + if (stride == 1) { + conv_depthwise_5x5s1_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else { + LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index 38d799bb4c994cc08e2b26a71bcd0e5a7668c9ea..c5baa31e1414c4a7a0c926728e5c150c0fc3e21c 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -23,26 +23,9 @@ namespace lite { namespace arm { namespace math { -// TODO(TJ): move to somewhere else common -template -class ImplBase { - public: - ImplBase() {} - virtual ~ImplBase() {} - - virtual bool create(const Param& param, Context* ctx) { return false; } - - virtual bool init(const Param& param, Context* ctx) { return false; } - - virtual bool run(const Param& param) { return false; } - // void set_op_name(const char* name){_op_name = name;} - // const char* get_op_name() { return _op_name.c_str();} - - protected: - Param* param_; - Context* ctx_; -}; - +/// conv 3x3s1 +size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx); void conv_3x3s1_direct_fp32(const float* din, float* dout, int num, @@ -55,26 +38,11 @@ void conv_3x3s1_direct_fp32(const float* din, const float* weights, const float* bias, const operators::ConvParam& param, - Context* ctx); + ARMContext* ctx); +template void conv_3x3s1_direct_int8(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale); - -void conv_3x3s1_direct_int7(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -83,12 +51,14 @@ void conv_3x3s1_direct_int7(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, + ARMContext* ctx, const float* scale); +/// conv3x3s2 +size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx); void conv_3x3s2_direct_fp32(const float* din, float* dout, int num, @@ -101,12 +71,13 @@ void conv_3x3s2_direct_fp32(const float* din, const float* weights, const float* bias, const operators::ConvParam& param, - Context* ctx); + ARMContext* ctx); int conv_3x3s2_direct_int8_c_num(); +template void conv_3x3s2_direct_int8(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -115,14 +86,13 @@ void conv_3x3s2_direct_int8(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, + ARMContext* ctx, const float* scale); -void conv_1x5s1_direct(const void* din, - void* dout, +void conv_1x5s1_direct(const float* din, + float* dout, int num, int chout, int hout, @@ -130,8 +100,8 @@ void conv_1x5s1_direct(const void* din, int chin, int hin, int win, - const void* weights, - const void* bias, + const float* weights, + const float* bias, int group, int kernel_w, int kernel_h, @@ -143,12 +113,10 @@ void conv_1x5s1_direct(const void* din, int pad_h, bool flag_bias, bool flag_relu, - Context& ctx, - void* work_space, - const void* idx_ptr); + ARMContext& ctx); // NOLINT -void conv_5x1s1_direct(const void* din, - void* dout, +void conv_5x1s1_direct(const float* din, + float* dout, int num, int chout, int hout, @@ -156,8 +124,8 @@ void conv_5x1s1_direct(const void* din, int chin, int hin, int win, - const void* weights, - const void* bias, + const float* weights, + const float* bias, int group, int kernel_w, int kernel_h, @@ -169,9 +137,7 @@ void conv_5x1s1_direct(const void* din, int pad_h, bool flag_bias, bool flag_relu, - Context& ctx, - void* work_space, - const void* idx_ptr); + ARMContext& ctx); // NOLINT void conv1x1s1_gemm(const float* din, float* dout, @@ -185,11 +151,11 @@ void conv1x1s1_gemm(const float* din, const float* weights, const float* bias, const operators::ConvParam& param, - Context* ctx, - const int* idx_ptr); + ARMContext* ctx); +template void conv1x1s1_gemm_int8(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -198,12 +164,10 @@ void conv1x1s1_gemm_int8(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale, - const int32_t* idx_ptr); + ARMContext* ctx, + const float* scale); void conv_im2col_gemm(const float* din, float* dout, @@ -217,11 +181,11 @@ void conv_im2col_gemm(const float* din, const float* weights, const float* bias, const operators::ConvParam& param, - Context* ctx, - const int* idx_ptr); + ARMContext* ctx); +template void conv_im2col_gemm_int8(const int8_t* din, - int32_t* dout, + Dtype* dout, int num, int chout, int hout, @@ -230,157 +194,103 @@ void conv_im2col_gemm_int8(const int8_t* din, int hin, int win, const int8_t* weights, - const int32_t* bias, + const float* bias, const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale, - const int32_t* idx_ptr); - -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias - */ - -void conv_depthwise_3x3p0(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_3x3p1(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); + ARMContext* ctx, + const float* scale); -void conv_depthwise_3x3(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - Context* ctx); - -void conv_depthwise_3x3_int8(const int8_t* din, - int32_t* dout, +/// depthwise conv +void conv_depthwise_3x3_fp32(const void* din, + void* dout, int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, + ARMContext* ctx, const float* scale); -void conv_depthwise_3x3_int7(const int8_t* din, - int32_t* dout, +void conv_depthwise_3x3_int8_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +void conv_depthwise_3x3_int8_int8(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +void conv_depthwise_5x5_fp32(const void* din, + void* dout, int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - int8_t* weights, - const int32_t* bias, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, - const float* scale); - -void conv_depthwise_5x5(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - Context* ctx); - -void conv_depthwise_5x5_int8(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - Context* ctx, - PrecisionType out_type, + ARMContext* ctx, const float* scale); +void conv_depthwise_5x5_int8_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +void conv_depthwise_5x5_int8_int8(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + +/// winograd conv, only support 3x3s1 void conv_winograd3x3(const float* din, float* dout, int num, @@ -393,23 +303,11 @@ void conv_winograd3x3(const float* din, const float* weights, const float* bias, const operators::ConvParam& param, - Context* ctx); + ARMContext* ctx); void winograd_transform_weights( void* dout, const void* din, int ch_out, int ch_in, void* work_space); -void compute_offset(int* idx_out, - int h, - int w, - int kernel_h, - int kernel_w, - int height, - int width, - int pad_h, - int pad_w, - int dilation_h, - int dilation_w); - void fill_bias(float* tensor, const float* bias, int channel, int channel_size); void fill_bias_int8(int* tensor, diff --git a/lite/backends/arm/math/conv_winograd.cc b/lite/backends/arm/math/conv_winograd.cc deleted file mode 100644 index 43ad9e2cd8a60edd02ba20ac02c30512ce8b5eb9..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_winograd.cc +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_winograd.h" -#include -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/backends/arm/math/packed_sgemm.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -template <> -bool WinogradConv::create(const operators::ConvParam& param, - ARMContext* ctx) { - this->ctx_ = ctx; - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ic = x_dims[1]; - int ow = o_dims[3]; - int oh = o_dims[2]; - int oc = o_dims[1]; - int kw = w_dims[3]; - int sw = param.strides[1]; - if (kw == 3) { - is_weights_transed_ = true; - int tile_w = (ow + 5) / 6; - int tile_h = (oh + 5) / 6; - int size_tile = tile_h * tile_w; - int size_trans_channel = 8 * 8 * size_tile; - int max_ch = ic > oc ? ic : oc; - - const int m_wino = oc; - const int n_wino = size_tile; - int hblock = get_hblock(this->ctx_->arch()); - int m_round = hblock * ((m_wino + hblock - 1) / hblock); - weights_trans_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); - this->ctx_->ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * - sizeof(float)); - auto weights_wino = - static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); - void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); - if (weights_wino && trans_tmp_ptr) { - winograd_transform_weights( - weights_wino, param.filter->data(), oc, ic, trans_tmp_ptr); - auto weights_trans = weights_trans_.mutable_data(); - for (int i = 0; i < 64; ++i) { - float* packed_weights = weights_trans + i * m_round * ic; - const float* weights_wino_ptr = weights_wino + i * oc * ic; - prepackA(packed_weights, - weights_wino_ptr, - 1.f, - ic, - 0, - m_wino, - 0, - ic, - false, - this->ctx_); - } - impl_ = conv_winograd3x3; - free(trans_tmp_ptr); - free(weights_wino); - return true; - } - free(trans_tmp_ptr); - free(weights_wino); - } else { - LOG(ERROR) << "this type winograd conv not impl"; - } - return false; -} - -template <> -bool WinogradConv::init(const operators::ConvParam& param, - Context* ctx) { - this->ctx_ = ctx; - return create(param, ctx); -} - -template <> -bool WinogradConv::run(const operators::ConvParam& param) { - // start timer - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - - if (is_weights_transed_) { - w_data = weights_trans_.data(); - } - - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - - int iw = x_dims[3]; // nchw - int ih = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int oh = o_dims[2]; - int ow = o_dims[3]; - int oc = o_dims[1]; - - impl_(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - this->ctx_); - - // timer end - return true; -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_winograd.h b/lite/backends/arm/math/conv_winograd.h deleted file mode 100644 index 1ae5edb0aacd7e6a1f12192ab2ad598f2755b590..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_winograd.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/core/target_wrapper.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -template -class WinogradConv - : public ImplBase { - public: - typedef void (*conv_winograd_impl)(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - Context* ctx); - - WinogradConv() = default; - ~WinogradConv() {} - - virtual bool init(const operators::ConvParam& param, - Context* ctx); - - virtual bool create(const operators::ConvParam& param, - Context* ctx); - - virtual bool run(const operators::ConvParam& param); - - private: - conv_winograd_impl impl_{nullptr}; - bool is_weights_transed_{false}; - Tensor weights_trans_; -}; - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_winograd_3x3.cc b/lite/backends/arm/math/conv_winograd_3x3.cc index 87f51381e6e5d4705befba3deadfa76e053d448c..87b08f63102104b325e95c093fe0fc0aaef243e0 100644 --- a/lite/backends/arm/math/conv_winograd_3x3.cc +++ b/lite/backends/arm/math/conv_winograd_3x3.cc @@ -102,7 +102,7 @@ void conv_winograd3x3(const float* din, //! dot mul //! transpose input, convert from ch_in * tile_h * tile_w * 64 to //! 64 * ch_in * tile_h * tile_w - int hblock = get_hblock(ctx->arch()); + int hblock = get_hblock(ctx); int m_round = hblock * ((chout + hblock - 1) / hblock); int stride_a = m_round * chin; int stride_b = chin * size_tile; diff --git a/lite/backends/arm/math/dotprod/__gemm_sdot_meta__.h b/lite/backends/arm/math/dotprod/__gemm_sdot_meta__.h new file mode 100644 index 0000000000000000000000000000000000000000..e556c5eec3382f1cdb421d749ef374de0e3e2781 --- /dev/null +++ b/lite/backends/arm/math/dotprod/__gemm_sdot_meta__.h @@ -0,0 +1,369 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// clang-format off +#define GEMM_SDOT_INT8_KERNEL \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \ + "eor v8.16b, v8.16b, v8.16b\n" /* out0 = 0 */ \ + "eor v9.16b, v9.16b, v9.16b\n" /* out1 = 0 */ \ + "eor v10.16b, v10.16b, v10.16b\n" /* out2 = 0 */ \ + "eor v11.16b, v11.16b, v11.16b\n" /* out3 = 0 */ \ + "eor v12.16b, v12.16b, v12.16b\n" /* out4 = 0 */ \ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \ + "eor v13.16b, v13.16b, v13.16b\n" /* out5 = 0 */ \ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \ + "eor v14.16b, v14.16b, v14.16b\n" /* out6 = 0 */ \ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \ + "eor v15.16b, v15.16b, v15.16b\n" /* out7 = 0 */ \ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \ + "eor v16.16b, v16.16b, v16.16b\n" /* out8 = 0 */ \ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \ + "eor v17.16b, v17.16b, v17.16b\n" /* out9 = 0 */ \ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \ + "eor v18.16b, v18.16b, v18.16b\n" /* out10 = 0 */ \ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \ + "eor v19.16b, v19.16b, v19.16b\n" /* out11 = 0 */ \ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ \ + "eor v20.16b, v20.16b, v20.16b\n" /* out12 = 0 */ \ + "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ \ + "eor v21.16b, v21.16b, v21.16b\n" /* out13 = 0 */ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \ + "eor v22.16b, v22.16b, v22.16b\n" /* out14 = 0 */ \ + "eor v23.16b, v23.16b, v23.16b\n" /* out15 = 0 */ \ + "eor v24.16b, v24.16b, v24.16b\n" /* out16 = 0 */ \ + "eor v25.16b, v25.16b, v25.16b\n" /* out17 = 0 */ \ + "eor v26.16b, v26.16b, v26.16b\n" /* out18 = 0 */ \ + "eor v27.16b, v27.16b, v27.16b\n" /* out19 = 0 */ \ + "eor v28.16b, v28.16b, v28.16b\n" /* out20 = 0 */ \ + "eor v29.16b, v29.16b, v29.16b\n" /* out21 = 0 */ \ + "eor v30.16b, v30.16b, v30.16b\n" /* out22 = 0 */ \ + "eor v31.16b, v31.16b, v31.16b\n" /* out23 = 0 */ \ + "cbz %w[k], 2f\n" /* check loop count > 0 */ \ + /* main loop, unrool 0*/ \ + "1:\n" /* main loop */ \ + "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4 */ \ + "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4 */ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ \ + "sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q4 */ \ + "sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q4 */ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ \ + "sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q4 */ \ + "sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q4 */ \ + "sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q4 */ \ + "sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q4 */ \ + "sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q5 */ \ + "sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q5 */ \ + "sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q5*/ \ + "sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q5*/ \ + "sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q5*/ \ + "sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q5*/ \ + "sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q5*/ \ + "sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ \ + "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \ + "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ \ + /* unrool 1 */ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7 */ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n"/* out1 = b0 * a10[1], b0 = q7 */ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7 */ \ + "prfm pldl1keep, [%[a_ptr], #256]\n" \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7 */ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7 */ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7 */ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7 */ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7 */ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4 */ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4 */ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ \ + /* unrool 2*/ \ + "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6 */ \ + "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6 */ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \ + "sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ \ + "sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ \ + "sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ \ + "sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ \ + "sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ \ + "sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ \ + "sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ \ + "sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" \ + "sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q7*/ \ + "sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q7*/ \ + "sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q7*/ \ + "sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q7*/ \ + "sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q7*/ \ + "sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q7*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ \ + "sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q4*/ \ + "sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q4*/ \ + "sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q4*/ \ + "sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q4*/ \ + "sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q4*/ \ + "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \ + "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \ + "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \ + /* unrool 3*/ \ + "sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v5.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v5.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v5.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v5.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v5.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \ + "sdot v9.4s, v6.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v6.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q6*/ \ + "prfm pldl1keep, [%[a_ptr], #256]\n" \ + "sdot v15.4s, v6.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v6.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v6.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v6.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v6.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" \ + "sdot v30.4s, v6.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v7.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v7.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v7.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v7.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v7.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v7.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ \ + "sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "bne 1b\n" /* Target to use when K is 1 or 2 */ \ + "2:\n" /* process tail*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "beq 3f\n" /*jump to tail = 1*/ \ + /* final unrool 0, unrool 0, tail > 1*/ \ + "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4*/ \ + "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ \ + "sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q4*/ \ + "sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q4*/ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ \ + "sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q4*/ \ + "sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q4*/ \ + "sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q4*/ \ + "sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q4*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q5*/ \ + "sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q5*/ \ + "sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q5*/ \ + "sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q5*/ \ + "sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q5*/ \ + "sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q5*/ \ + "sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q5*/ \ + "sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ \ + "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \ + "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ + "beq 4f\n" /*jump to tail = 2*/ \ + /* unrool 1, tail > 2*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7*/ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q7*/ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7*/ \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7*/ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7*/ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7*/ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4*/ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4*/ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ + "beq 5f\n" /*jump to tail = 3*/ \ + /* unrool 2, tail = 4*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ \ + "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6*/ \ + "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6*/ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \ + "sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ \ + "sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ \ + "sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ \ + "sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ \ + "sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ \ + "sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ \ + "sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ \ + "sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ \ + "sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q7*/ \ + "sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q7*/ \ + "sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q7*/ \ + "sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q7*/ \ + "sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q7*/ \ + "sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q7*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ \ + "sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q4*/ \ + "sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q4*/ \ + "sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q4*/ \ + "sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q4*/ \ + "sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q4*/ \ + "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \ + "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \ + "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \ + /* unrool 3, tail = 4*/ \ + "sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v5.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v5.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v5.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v5.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v5.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v6.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v6.16b, v2.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v6.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v6.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v6.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v6.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v6.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v6.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v7.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v7.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v7.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v7.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v7.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v7.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "b 11f\n" /* tails==1 final tail*/ \ + "3: \n" /* tail=1*/ \ + "ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ \ + "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "b 11f\n" /* tails==2 final tail*/ \ + "4:\n" /* tail = 2*/ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "b 11f\n" /* tails==3 final tail*/ \ + "5:\n" /* tail = 3*/ \ + "ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ \ + "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "11: \n" /* end */ diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 28ae1ee4ca97b96ca55ab60eb4c62d69f00ad302..a4c61f9a9d181924c28cdd009f8412278d44f5bb 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -266,6 +266,251 @@ void elementwise_add_relu_broadcast(const float* dinx, } } +template <> +void elementwise_sub(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vsubq_f32(dinx0, diny0); + dinx1 = vsubq_f32(dinx1, diny1); + dinx2 = vsubq_f32(dinx2, diny2); + dinx3 = vsubq_f32(dinx3, diny3); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr - *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_sub_relu(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vsubq_f32(dinx0, diny0); + dinx1 = vsubq_f32(dinx1, diny1); + dinx2 = vsubq_f32(dinx2, diny2); + dinx3 = vsubq_f32(dinx3, diny3); + + // relu + dinx0 = vmaxq_f32(dinx0, vzero); + dinx1 = vmaxq_f32(dinx1, vzero); + dinx2 = vmaxq_f32(dinx2, vzero); + dinx3 = vmaxq_f32(dinx3, vzero); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + float tmp = *dinx_ptr - *diny_ptr; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_sub_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + din2 = vsubq_f32(din2, rb); + din3 = vsubq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vsubq_f32(din0, rb); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + *dout_ptr = *din_ptr - diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_sub_relu_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + din2 = vsubq_f32(din2, rb); + din3 = vsubq_f32(din3, rb); + + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + din2 = vmaxq_f32(din2, vzero); + din3 = vmaxq_f32(din3, vzero); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vsubq_f32(din0, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + float tmp = *din_ptr - diny_data; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + din_ptr++; + } + } + } + } +} + template <> void elementwise_mul(const float* dinx, const float* diny, diff --git a/lite/backends/arm/math/elementwise.h b/lite/backends/arm/math/elementwise.h index 866277ae9c9751a1f936f019d1347012aef252cb..f8273a5bb39505b03e911b5699cc10c5be755619 100644 --- a/lite/backends/arm/math/elementwise.h +++ b/lite/backends/arm/math/elementwise.h @@ -33,6 +33,20 @@ template void elementwise_add_relu_broadcast( const T* dinx, const T* diny, T* dout, int batch, int channels, int num); +template +void elementwise_sub(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_sub_relu(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_sub_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_sub_relu_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + template void elementwise_mul(const T* dinx, const T* diny, T* dout, int num); diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 9438a997b6752f3d496d65911a192c69dd2f13c0..d8ef6ff47d0392ac15caf2d94b7c53ff63659da2 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -27,14 +27,15 @@ #include "lite/backends/arm/math/box_coder.h" #include "lite/backends/arm/math/col_im_transform.h" #include "lite/backends/arm/math/concat.h" -#include "lite/backends/arm/math/conv_depthwise.h" -#include "lite/backends/arm/math/conv_direct.h" -#include "lite/backends/arm/math/conv_gemmlike.h" -#include "lite/backends/arm/math/conv_winograd.h" +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/decode_bboxes.h" #include "lite/backends/arm/math/dropout.h" #include "lite/backends/arm/math/elementwise.h" #include "lite/backends/arm/math/fill_bias_relu.h" +#include "lite/backends/arm/math/gemm_prepacked_int8.h" +#include "lite/backends/arm/math/gemm_s8.h" +#include "lite/backends/arm/math/gemv_arm_int8.h" #include "lite/backends/arm/math/im2sequence.h" #include "lite/backends/arm/math/increment.h" #include "lite/backends/arm/math/interpolate.h" @@ -61,6 +62,7 @@ #include "lite/backends/arm/math/stack.h" #include "lite/backends/arm/math/topk.h" #include "lite/backends/arm/math/yolo_box.h" + namespace paddle { namespace lite { namespace arm { @@ -261,7 +263,7 @@ inline float32x4_t exp_ps(float32x4_t x) { // almost no extra price so both sin_ps and cos_ps make use of // sincos_ps.. // -inline void sincos_ps(float32x4_t x, float32x4_t *ysin, float32x4_t *ycos) { +inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos) { // any x float32x4_t xmm1, xmm2, xmm3, y; @@ -350,23 +352,23 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { } template -void fill_bias_fc(T *tensor, const T *bias, int num, int channel); +void fill_bias_fc(T* tensor, const T* bias, int num, int channel); template -inline float32x4_t vactive_f32(const float32x4_t &x) { +inline float32x4_t vactive_f32(const float32x4_t& x) { return x; } template <> inline float32x4_t vactive_f32( - const float32x4_t &x) { + const float32x4_t& x) { float32x4_t __zero = vdupq_n_f32(0.f); return vmaxq_f32(x, __zero); } template <> inline float32x4_t vactive_f32( - const float32x4_t &x) { + const float32x4_t& x) { float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __six = vdupq_n_f32(6.f); return vminq_f32(vmaxq_f32(x, __zero), __six); @@ -374,7 +376,7 @@ inline float32x4_t vactive_f32( template <> inline float32x4_t vactive_f32( - const float32x4_t &x) { + const float32x4_t& x) { float32x4_t __one = vdupq_n_f32(1.f); float32x4_t __x = vnegq_f32(x); __x = exp_ps(__x); @@ -385,7 +387,7 @@ inline float32x4_t vactive_f32( template <> inline float32x4_t vactive_f32( - const float32x4_t &x) { + const float32x4_t& x) { float32x4_t __one = vdupq_n_f32(1.f); float32x4_t __x = vmulq_n_f32(x, -2.f); __x = exp_ps(__x); @@ -397,27 +399,27 @@ inline float32x4_t vactive_f32( } template -inline float active_f32(const float &x) { +inline float active_f32(const float& x) { return x; } template <> -inline float active_f32(const float &x) { +inline float active_f32(const float& x) { return std::max(x, 0.f); } template <> -inline float active_f32(const float &x) { +inline float active_f32(const float& x) { return std::min(std::max(x, 0.f), 6.f); } template <> -inline float active_f32(const float &x) { +inline float active_f32(const float& x) { return 1.f / (1.f + exp(-x)); } template <> -inline float active_f32(const float &x) { +inline float active_f32(const float& x) { return 2.f / (1.f + exp(-2.f * x)) - 1.f; } diff --git a/lite/backends/arm/math/gemm_prepacked_int8.cc b/lite/backends/arm/math/gemm_prepacked_int8.cc index 9efae1115772a17b73565ffd886dd0dcbda5df34..d7e04bfc60b1214bd1e77738efa420d3e25e1456 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.cc +++ b/lite/backends/arm/math/gemm_prepacked_int8.cc @@ -14,7 +14,7 @@ #include "lite/backends/arm/math/gemm_prepacked_int8.h" #include -#include "lite/backends/arm/math/dot_toolchain_support.h" +#include "lite/backends/arm/math/dotprod/gemm_sdot.h" namespace paddle { namespace lite { @@ -189,7 +189,7 @@ void prepackA_int8(TensorLite* tout, template inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT - const int32_t* bias, + const float* bias, Dtype*& c_ptr0, // NOLINT Dtype*& c_ptr1, // NOLINT Dtype*& c_ptr2, // NOLINT @@ -198,496 +198,440 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, bool is_relu, int k, int rem); +// clang-format off #ifdef __aarch64__ -#define GEMM_INT8_KERNEL \ - "ld1 {v0.16b}, [%[a_ptr]],#16\n" /* load a to q0, q1 */ \ - "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ - "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ - "ldr q8, [%[bias]]\n" /* load bias */ \ - "ext v9.16b, v8.16b, v8.16b, #4\n" /* shift left 1s */ \ - "ext v10.16b, v8.16b, v8.16b, #8\n" /* shift left 2s */ \ - "ext v11.16b, v8.16b, v8.16b, #12\n" /* shift left 3s */ \ - "and v16.16b, v8.16b, v8.16b\n" /* set bias0 to out00 */ \ - "and v17.16b, v9.16b, v9.16b\n" /* set bias0 to out01 */ \ - "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \ - "and v18.16b, v10.16b, v10.16b\n" /* set bias0 to out02 */ \ - "and v19.16b, v11.16b, v11.16b\n" /* set bias0 to out03 */ \ - "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \ - "and v20.16b, v8.16b, v8.16b\n" /* set bias0 to out10 */ \ - "and v21.16b, v9.16b, v9.16b\n" /* set bias0 to out11 */ \ - "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \ - "and v22.16b, v10.16b, v10.16b\n" /* set bias0 to out12 */ \ - "and v23.16b, v11.16b, v11.16b\n" /* set bias0 to out13 */ \ - "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \ - "and v24.16b, v8.16b, v8.16b\n" /* set bias0 to out20 */ \ - "and v25.16b, v9.16b, v9.16b\n" /* set bias0 to out21 */ \ - "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \ - "and v26.16b, v10.16b, v10.16b\n" /* set bias0 to out22 */ \ - "and v27.16b, v11.16b, v11.16b\n" /* set bias0 to out23 */ \ - "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \ - "and v28.16b, v8.16b, v8.16b\n" /* set bias0 to out30 */ \ - "and v29.16b, v9.16b, v9.16b\n" /* set bias0 to out31 */ \ - "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \ - "and v30.16b, v10.16b, v10.16b\n" /* set bias0 to out32 */ \ - "and v31.16b, v11.16b, v11.16b\n" /* set bias0 to out33 */ \ - "ext v1.16b, v0.16b, v0.16b, #2\n" /* shift left 2bytes */ \ - "ins v1.h[3], v0.h[0]\n" /* insert element */ \ - "ins v1.h[7], v0.h[4]\n" /* insert element */ \ - "rev64 v2.4s, v0.4s\n" /* get low: 22,33,00,11; hi: 66,77,44,55 */ \ - "rev64 v3.4s, v1.4s\n" /* get low: 33,00,11,22; hi: 77,44,55,66 */ \ - "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload a*/ \ - "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \ - "cbz %w[k], 3f\n" /* if k = 0, jump to remains */ /* 1st b0, b1 */ \ - "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ - "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ - "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c10 */ \ - "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c11 */ \ - "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c20 */ \ - "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c21 */ \ - "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c30 */ \ - "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c31 */ \ - "subs %w[k], %w[k], #1\n" /* loop count -1 */ /* 2nd b0, b1 */ \ - "smlal2 v8.8h, v0.16b, v4.16b\n" /* a0 * b0 = c00 */ \ - "smlal2 v12.8h, v0.16b, v5.16b\n" /* a0 * b1 = c01 */ \ - "smlal2 v9.8h, v1.16b, v4.16b\n" /* a1 * b0 = c10 */ \ - "smlal2 v13.8h, v1.16b, v5.16b\n" /* a1 * b1 = c11 */ \ - "smlal2 v10.8h, v2.16b, v4.16b\n" /* a2 * b0 = c20 */ \ - "smlal2 v14.8h, v2.16b, v5.16b\n" /* a2 * b1 = c21 */ \ - "smlal2 v11.8h, v3.16b, v4.16b\n" /* a3 * b0 = c30 */ \ - "smlal2 v15.8h, v3.16b, v5.16b\n" /* a3 * b1 = c31 */ \ - "beq 8f\n" /* skip main loop */ /* main loop*/ \ - "0:\n" /* main loop */ \ - "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ - "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */ \ - "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b2 = c02 */ \ - "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */ \ - "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b3 = c03 */ \ - "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */ \ - "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b2 = c12 */ \ - "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */ \ - "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b3 = c13 */ \ - "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */ \ - "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b2 = c22 */ \ - "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */ \ - "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b3 = c23 */ \ - "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */ \ - "smlal2 v8.8h, v0.16b, v6.16b\n" /* a0 * b2 = c02 */ \ - "smlal2 v12.8h, v0.16b, v7.16b\n" /* a0 * b3 = c03 */ \ - "ld1 {v0.16b}, [%[a_ptr]],#16\n" /* load a to q0, q1 */ \ - "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b2 = c32 */ \ - "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */ \ - "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b3 = c33 */ /* 2nd b2, b3 */ \ - "smlal2 v9.8h, v1.16b, v6.16b\n" /* a1 * b2 = c12 */ \ - "smlal2 v13.8h, v1.16b, v7.16b\n" /* a1 * b3 = c13 */ \ - "smlal2 v10.8h, v2.16b, v6.16b\n" /* a2 * b2 = c22 */ \ - "ext v1.16b, v0.16b, v0.16b, #2\n" /* shift left 2bytes*/ \ - "ins v1.h[3], v0.h[0]\n" /* insert element */ \ - "ins v1.h[7], v0.h[4]\n" /* insert element */ \ - "smlal2 v14.8h, v2.16b, v7.16b\n" /* a2 * b3 = c23 */ \ - "smlal2 v11.8h, v3.16b, v6.16b\n" /* a3 * b2 = c32 */ \ - "smlal2 v15.8h, v3.16b, v7.16b\n" /* a3 * b3 = c33 */ /* pre-process a */ \ - "rev64 v2.4s, v0.4s\n" /* get low: 22,33,00,11; hi: 66,77,44,55 */ \ - "rev64 v3.4s, v1.4s\n" /* get low: 33,00,11,22; hi: 77,44,55,66 */ \ - "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ - "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ - "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ - "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ - "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ - "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ - "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c00 */ \ - "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ - "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c01 */ \ - "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ - "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c00 */ \ - "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ - "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c01 */ \ - "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ - "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c00 */ \ - "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ - "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c01 */ \ - "subs %w[k], %w[k], #1\n" /* loop count -1 */ /* 2nd b0, b1 */ \ - "smlal2 v8.8h, v0.16b, v4.16b\n" /* a0 * b0 = c00 */ \ - "smlal2 v12.8h, v0.16b, v5.16b\n" /* a0 * b1 = c01 */ \ - "smlal2 v9.8h, v1.16b, v4.16b\n" /* a1 * b0 = c10 */ \ - "smlal2 v13.8h, v1.16b, v5.16b\n" /* a1 * b1 = c11 */ \ - "smlal2 v10.8h, v2.16b, v4.16b\n" /* a2 * b0 = c20 */ \ - "smlal2 v14.8h, v2.16b, v5.16b\n" /* a2 * b1 = c21 */ \ - "smlal2 v11.8h, v3.16b, v4.16b\n" /* a3 * b0 = c30 */ \ - "smlal2 v15.8h, v3.16b, v5.16b\n" /* a3 * b1 = c31 */ \ - "bgt 0b\n" /* jump to main loop */ \ - "8:\n" /* finish main loop */ /* 1st b2, b3 */ \ - "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */ \ - "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b0 = c02 */ \ - "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */ \ - "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b1 = c03 */ \ - "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */ \ - "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b0 = c12 */ \ - "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */ \ - "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b1 = c13 */ \ - "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */ \ - "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b0 = c22 */ \ - "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */ \ - "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b1 = c23 */ \ - "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */ \ - "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b0 = c32 */ \ - "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */ \ - "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b1 = c33 */ /* 2nd b2, b3 */ \ - "smlal2 v8.8h, v0.16b, v6.16b\n" /* a0 * b0 = c02 */ \ - "smlal2 v12.8h, v0.16b, v7.16b\n" /* a0 * b1 = c03 */ \ - "smlal2 v9.8h, v1.16b, v6.16b\n" /* a1 * b0 = c12 */ \ - "smlal2 v13.8h, v1.16b, v7.16b\n" /* a1 * b1 = c23 */ \ - "smlal2 v10.8h, v2.16b, v6.16b\n" /* a2 * b0 = c13 */ \ - "smlal2 v14.8h, v2.16b, v7.16b\n" /* a2 * b1 = c32 */ \ - "smlal2 v11.8h, v3.16b, v6.16b\n" /* a3 * b0 = c22 */ \ - "smlal2 v15.8h, v3.16b, v7.16b\n" /* a3 * b1 = c33 */ \ - "cbz %w[rem], 5f\n" /* skip remain */ \ - "ld1 {v0.8b}, [%[a_ptr]]\n" /* load a to q0, final */ \ - "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ - "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ - "5:\n" /* no remain */ \ - "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ - "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ - "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ - "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ - "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ - "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ - "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ - "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ - "3: \n" /* process remains */ \ - "cbz %w[rem], 7f\n" /* skip remain */ /* process remain k */ \ - "4: \n" /* remain = 1, 2 */ \ - "ext v1.8b, v0.8b, v0.8b, #2\n" /* shift left 2bytes */ \ - "ext v2.8b, v0.8b, v0.8b, #4\n" /* shift left 4bytes */ \ - "ext v3.8b, v0.8b, v0.8b, #6\n" /* shift left 6bytes */ /* 1st b0, b1 */ \ - "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ - "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ - "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c10 */ \ - "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c11 */ \ - "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c20 */ \ - "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c21 */ \ - "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c30 */ \ - "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c31 */ /* 1st b2, b3 */ \ - "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */ \ - "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b0 = c02 */ \ - "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */ \ - "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b1 = c03 */ \ - "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */ \ - "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b0 = c12 */ \ - "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */ \ - "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b1 = c13 */ \ - "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */ \ - "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b0 = c22 */ \ - "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */ \ - "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b1 = c23 */ \ - "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */ \ - "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b0 = c32 */ \ - "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */ \ - "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b1 = c33 */ \ - "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ - "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ - "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ - "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ - "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ - "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ - "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ - "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ - "7: \n" /* do relu */ /* do relu */ \ - "cbz %w[is_relu], 9f\n" /* not relu, jump to unpack */ \ - "movi v0.4s, #0\n" /* for relu */ \ - "smax v16.4s, v16.4s, v0.4s\n" /* relu */ \ - "smax v17.4s, v17.4s, v0.4s\n" /* relu */ \ - "smax v18.4s, v18.4s, v0.4s\n" /* relu */ \ - "smax v19.4s, v19.4s, v0.4s\n" /* relu */ \ - "smax v20.4s, v20.4s, v0.4s\n" /* relu */ \ - "smax v21.4s, v21.4s, v0.4s\n" /* relu */ \ - "smax v22.4s, v22.4s, v0.4s\n" /* relu */ \ - "smax v23.4s, v23.4s, v0.4s\n" /* relu */ \ - "smax v24.4s, v24.4s, v0.4s\n" /* relu */ \ - "smax v25.4s, v25.4s, v0.4s\n" /* relu */ \ - "smax v26.4s, v26.4s, v0.4s\n" /* relu */ \ - "smax v27.4s, v27.4s, v0.4s\n" /* relu */ \ - "smax v28.4s, v28.4s, v0.4s\n" /* relu */ \ - "smax v29.4s, v29.4s, v0.4s\n" /* relu */ \ - "smax v30.4s, v30.4s, v0.4s\n" /* relu */ \ - "smax v31.4s, v31.4s, v0.4s\n" /* relu */ /* unpack the result */ \ - "9:\n" /* unpack */ /* trans 1 */ \ - "trn1 v0.4s, v16.4s, v17.4s\n" /* get a0,b0, a2,b2 */ \ - "trn2 v1.4s, v16.4s, v17.4s\n" /* get a1,b1, a3,b3 */ \ - "trn1 v2.4s, v18.4s, v19.4s\n" /* get c0,d0, c2,c2 */ \ - "trn2 v3.4s, v18.4s, v19.4s\n" /* get c1,d1, c3,d3 */ \ - "trn1 v4.4s, v20.4s, v21.4s\n" \ - "trn2 v5.4s, v20.4s, v21.4s\n" \ - "trn1 v6.4s, v22.4s, v23.4s\n" \ - "trn2 v7.4s, v22.4s, v23.4s\n" \ - "trn1 v8.4s, v24.4s, v25.4s\n" \ - "trn2 v9.4s, v24.4s, v25.4s\n" \ - "trn1 v10.4s, v26.4s, v27.4s\n" \ - "trn2 v11.4s, v26.4s, v27.4s\n" \ - "trn1 v12.4s, v28.4s, v29.4s\n" \ - "trn2 v13.4s, v28.4s, v29.4s\n" \ - "trn1 v14.4s, v30.4s, v31.4s\n" \ - "trn2 v15.4s, v30.4s, v31.4s\n" /* trans 2 */ \ - "trn1 v16.2d, v0.2d, v2.2d\n" /* get a0,b0, c0,d0 */ \ - "trn2 v18.2d, v0.2d, v2.2d\n" /* get a2,b2, c2,d2 */ \ - "trn1 v17.2d, v1.2d, v3.2d\n" /* get a1,b1, c1,d1 */ \ - "trn2 v19.2d, v1.2d, v3.2d\n" /* get a3,b3, c3,d3 */ \ - "trn1 v20.2d, v4.2d, v6.2d\n" \ - "trn2 v22.2d, v4.2d, v6.2d\n" \ - "trn1 v21.2d, v5.2d, v7.2d\n" \ - "trn2 v23.2d, v5.2d, v7.2d\n" \ - "trn1 v24.2d, v8.2d, v10.2d\n" \ - "trn2 v26.2d, v8.2d, v10.2d\n" \ - "trn1 v25.2d, v9.2d, v11.2d\n" \ - "trn2 v27.2d, v9.2d, v11.2d\n" \ - "trn1 v28.2d, v12.2d, v14.2d\n" \ - "trn2 v30.2d, v12.2d, v14.2d\n" \ - "trn1 v29.2d, v13.2d, v15.2d\n" \ - "trn2 v31.2d, v13.2d, v15.2d\n" /* shift */ \ - "ext v17.16b, v17.16b, v17.16b, #12\n" /* circular shift left 1 */ \ - "ext v18.16b, v18.16b, v18.16b, #8\n" /* circular shift left 2 */ \ - "ext v19.16b, v19.16b, v19.16b, #4\n" /* circular shift left 3 */ \ - "ext v21.16b, v21.16b, v21.16b, #12\n" /* circular shift left 1 */ \ - "ext v22.16b, v22.16b, v22.16b, #8\n" /* circular shift left 2 */ \ - "ext v23.16b, v23.16b, v23.16b, #4\n" /* circular shift left 3 */ \ - "ext v25.16b, v25.16b, v25.16b, #12\n" /* circular shift left 1 */ \ - "ext v26.16b, v26.16b, v26.16b, #8\n" /* circular shift left 2 */ \ - "ext v27.16b, v27.16b, v27.16b, #4\n" /* circular shift left 3 */ \ - "ext v29.16b, v29.16b, v29.16b, #12\n" /* circular shift left 1 */ \ - "ext v30.16b, v30.16b, v30.16b, #8\n" /* circular shift left 2 */ \ - "ext v31.16b, v31.16b, v31.16b, #4\n" /* circular shift left 3 */ \ - "trn1 v0.4s, v16.4s, v17.4s\n" /* get a0,b0, a2,b2 */ \ - "trn2 v1.4s, v16.4s, v17.4s\n" /* get a1,b1, a3,b3 */ \ - "trn1 v2.4s, v18.4s, v19.4s\n" /* get c0,d0, c2,c2 */ \ - "trn2 v3.4s, v18.4s, v19.4s\n" /* get c1,d1, c3,d3 */ \ - "trn1 v4.4s, v20.4s, v21.4s\n" \ - "trn2 v5.4s, v20.4s, v21.4s\n" \ - "trn1 v6.4s, v22.4s, v23.4s\n" \ - "trn2 v7.4s, v22.4s, v23.4s\n" \ - "trn1 v8.4s, v24.4s, v25.4s\n" \ - "trn2 v9.4s, v24.4s, v25.4s\n" \ - "trn1 v10.4s, v26.4s, v27.4s\n" \ - "trn2 v11.4s, v26.4s, v27.4s\n" \ - "trn1 v12.4s, v28.4s, v29.4s\n" \ - "trn2 v13.4s, v28.4s, v29.4s\n" \ - "trn1 v14.4s, v30.4s, v31.4s\n" \ - "trn2 v15.4s, v30.4s, v31.4s\n" /* trans 2 */ \ - "trn1 v16.2d, v0.2d, v2.2d\n" /* get a0,b0, c0,d0 */ \ - "trn2 v24.2d, v0.2d, v2.2d\n" /* get a2,b2, c2,d2 */ \ - "trn1 v20.2d, v1.2d, v3.2d\n" /* get a1,b1, c1,d1 */ \ - "trn2 v28.2d, v1.2d, v3.2d\n" /* get a3,b3, c3,d3 */ \ - "trn1 v17.2d, v4.2d, v6.2d\n" \ - "trn2 v25.2d, v4.2d, v6.2d\n" \ - "trn1 v21.2d, v5.2d, v7.2d\n" \ - "trn2 v29.2d, v5.2d, v7.2d\n" \ - "trn1 v18.2d, v8.2d, v10.2d\n" \ - "trn2 v26.2d, v8.2d, v10.2d\n" \ - "trn1 v22.2d, v9.2d, v11.2d\n" \ - "trn2 v30.2d, v9.2d, v11.2d\n" \ - "trn1 v19.2d, v12.2d, v14.2d\n" \ - "trn2 v27.2d, v12.2d, v14.2d\n" \ - "trn1 v23.2d, v13.2d, v15.2d\n" \ +#define GEMM_INT8_KERNEL \ + "ld1 {v0.16b}, [%[a_ptr]],#16\n" /* load a to q0, q1 */ \ + "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ + "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ + "eor v16.16b, v8.16b, v8.16b\n" /* set 0 to out00 */ \ + "eor v17.16b, v9.16b, v9.16b\n" /* set 0 to out01 */ \ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \ + "eor v18.16b, v10.16b, v10.16b\n" /* set 0 to out02 */ \ + "eor v19.16b, v11.16b, v11.16b\n" /* set 0 to out03 */ \ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \ + "eor v20.16b, v8.16b, v8.16b\n" /* set 0 to out10 */ \ + "eor v21.16b, v9.16b, v9.16b\n" /* set 0 to out11 */ \ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \ + "eor v22.16b, v10.16b, v10.16b\n" /* set 0 to out12 */ \ + "eor v23.16b, v11.16b, v11.16b\n" /* set 0 to out13 */ \ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \ + "eor v24.16b, v8.16b, v8.16b\n" /* set 0 to out20 */ \ + "eor v25.16b, v9.16b, v9.16b\n" /* set 0 to out21 */ \ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \ + "eor v26.16b, v10.16b, v10.16b\n" /* set 0 to out22 */ \ + "eor v27.16b, v11.16b, v11.16b\n" /* set 0 to out23 */ \ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \ + "eor v28.16b, v8.16b, v8.16b\n" /* set 0 to out30 */ \ + "eor v29.16b, v9.16b, v9.16b\n" /* set 0 to out31 */ \ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \ + "eor v30.16b, v10.16b, v10.16b\n" /* set 0 to out32 */ \ + "eor v31.16b, v11.16b, v11.16b\n" /* set 0 to out33 */ \ + "ext v1.16b, v0.16b, v0.16b, #2\n" /* shift left 2bytes */ \ + "ins v1.h[3], v0.h[0]\n" /* insert element */ \ + "ins v1.h[7], v0.h[4]\n" /* insert element */ \ + "rev64 v2.4s, v0.4s\n" /* get low: 22,33,00,11; hi: 66,77,44,55 */ \ + "rev64 v3.4s, v1.4s\n" /* get low: 33,00,11,22; hi: 77,44,55,66 */ \ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload a*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \ + "cbz %w[k], 3f\n" /* if k = 0, jump to remains */ \ + /* 1st b0, b1 */ \ + "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ + "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ + "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c10 */ \ + "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c11 */ \ + "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c20 */ \ + "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c21 */ \ + "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c30 */ \ + "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c31 */ \ + "subs %w[k], %w[k], #1\n" /* loop count -1 */ \ + /* 2nd b0, b1 */ \ + "smlal2 v8.8h, v0.16b, v4.16b\n" /* a0 * b0 = c00 */ \ + "smlal2 v12.8h, v0.16b, v5.16b\n" /* a0 * b1 = c01 */ \ + "smlal2 v9.8h, v1.16b, v4.16b\n" /* a1 * b0 = c10 */ \ + "smlal2 v13.8h, v1.16b, v5.16b\n" /* a1 * b1 = c11 */ \ + "smlal2 v10.8h, v2.16b, v4.16b\n" /* a2 * b0 = c20 */ \ + "smlal2 v14.8h, v2.16b, v5.16b\n" /* a2 * b1 = c21 */ \ + "smlal2 v11.8h, v3.16b, v4.16b\n" /* a3 * b0 = c30 */ \ + "smlal2 v15.8h, v3.16b, v5.16b\n" /* a3 * b1 = c31 */ \ + "beq 8f\n" /* skip main loop */ \ + /* main loop*/ \ + "0:\n" \ + "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ + /* 1st b2, b3 */ \ + "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */\ + "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b2 = c02 */ \ + "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */\ + "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b3 = c03 */ \ + "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */\ + "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b2 = c12 */ \ + "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */\ + "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b3 = c13 */ \ + "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */\ + "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b2 = c22 */ \ + "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */\ + "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b3 = c23 */ \ + "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */\ + "smlal2 v8.8h, v0.16b, v6.16b\n" /* a0 * b2 = c02 */ \ + "smlal2 v12.8h, v0.16b, v7.16b\n" /* a0 * b3 = c03 */ \ + "ld1 {v0.16b}, [%[a_ptr]],#16\n" /* load a to q0, q1 */ \ + "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b2 = c32 */ \ + "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */\ + "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b3 = c33 */ \ + /* 2nd b2, b3 */ \ + "smlal2 v9.8h, v1.16b, v6.16b\n" /* a1 * b2 = c12 */ \ + "smlal2 v13.8h, v1.16b, v7.16b\n" /* a1 * b3 = c13 */ \ + "smlal2 v10.8h, v2.16b, v6.16b\n" /* a2 * b2 = c22 */ \ + "ext v1.16b, v0.16b, v0.16b, #2\n" /* shift left 2bytes */ \ + "ins v1.h[3], v0.h[0]\n" /* insert element */ \ + "ins v1.h[7], v0.h[4]\n" /* insert element */ \ + "smlal2 v14.8h, v2.16b, v7.16b\n" /* a2 * b3 = c23 */ \ + "smlal2 v11.8h, v3.16b, v6.16b\n" /* a3 * b2 = c32 */ \ + "smlal2 v15.8h, v3.16b, v7.16b\n" /* a3 * b3 = c33 */ \ + /* pre-process a*/ \ + "rev64 v2.4s, v0.4s\n" /* get low: 22,33,00,11; hi: 66,77,44,55 */ \ + "rev64 v3.4s, v1.4s\n" /* get low: 33,00,11,22; hi: 77,44,55,66 */ \ + "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ + /* 1st b0, b1 */ \ + "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */\ + "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ + "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */\ + "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ + "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */\ + "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c00 */ \ + "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */\ + "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c01 */ \ + "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */\ + "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c00 */ \ + "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */\ + "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c01 */ \ + "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */\ + "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c00 */ \ + "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */\ + "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c01 */ \ + "subs %w[k], %w[k], #1\n" /* loop count -1 */ \ + /* 2nd b0, b1 */ \ + "smlal2 v8.8h, v0.16b, v4.16b\n" /* a0 * b0 = c00 */ \ + "smlal2 v12.8h, v0.16b, v5.16b\n" /* a0 * b1 = c01 */ \ + "smlal2 v9.8h, v1.16b, v4.16b\n" /* a1 * b0 = c10 */ \ + "smlal2 v13.8h, v1.16b, v5.16b\n" /* a1 * b1 = c11 */ \ + "smlal2 v10.8h, v2.16b, v4.16b\n" /* a2 * b0 = c20 */ \ + "smlal2 v14.8h, v2.16b, v5.16b\n" /* a2 * b1 = c21 */ \ + "smlal2 v11.8h, v3.16b, v4.16b\n" /* a3 * b0 = c30 */ \ + "smlal2 v15.8h, v3.16b, v5.16b\n" /* a3 * b1 = c31 */ \ + "bgt 0b\n" /* jump to main loop */ \ + "8:\n" /* finish main loop */ \ + /* 1st b2, b3 */ \ + "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */\ + "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b0 = c02 */ \ + "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */\ + "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b1 = c03 */ \ + "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */\ + "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b0 = c12 */ \ + "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */\ + "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b1 = c13 */ \ + "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */\ + "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b0 = c22 */ \ + "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */\ + "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b1 = c23 */ \ + "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */\ + "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b0 = c32 */ \ + "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */\ + "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b1 = c33 */ /* 2nd b2, b3 */ \ + "smlal2 v8.8h, v0.16b, v6.16b\n" /* a0 * b0 = c02 */ \ + "smlal2 v12.8h, v0.16b, v7.16b\n" /* a0 * b1 = c03 */ \ + "smlal2 v9.8h, v1.16b, v6.16b\n" /* a1 * b0 = c12 */ \ + "smlal2 v13.8h, v1.16b, v7.16b\n" /* a1 * b1 = c23 */ \ + "smlal2 v10.8h, v2.16b, v6.16b\n" /* a2 * b0 = c13 */ \ + "smlal2 v14.8h, v2.16b, v7.16b\n" /* a2 * b1 = c32 */ \ + "smlal2 v11.8h, v3.16b, v6.16b\n" /* a3 * b0 = c22 */ \ + "smlal2 v15.8h, v3.16b, v7.16b\n" /* a3 * b1 = c33 */ \ + "cbz %w[rem], 5f\n" /* skip remain */ \ + "ld1 {v0.8b}, [%[a_ptr]]\n" /* load a to q0, final */ \ + "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ + "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ + "5:\n" /* no remain */ \ + "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ + "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ + "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ + "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ + "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ + "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ + "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ + "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ + "3: \n" /* process remains */ \ + "cbz %w[rem], 7f\n" /* skip remain */ \ + /* process remain k */ \ + "4: \n" /* remain = 1, 2 */ \ + "ext v1.8b, v0.8b, v0.8b, #2\n" /* shift left 2bytes */ \ + "ext v2.8b, v0.8b, v0.8b, #4\n" /* shift left 4bytes */ \ + "ext v3.8b, v0.8b, v0.8b, #6\n" /* shift left 6bytes */ \ + /* 1st b0, b1 */ \ + "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ + "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ + "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c10 */ \ + "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c11 */ \ + "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c20 */ \ + "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c21 */ \ + "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c30 */ \ + "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c31 */ /* 1st b2, b3 */ \ + "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */\ + "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b0 = c02 */ \ + "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */\ + "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b1 = c03 */ \ + "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */\ + "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b0 = c12 */ \ + "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */\ + "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b1 = c13 */ \ + "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */\ + "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b0 = c22 */ \ + "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */\ + "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b1 = c23 */ \ + "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */\ + "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b0 = c32 */ \ + "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */\ + "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b1 = c33 */ \ + "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */\ + "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */\ + "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */\ + "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */\ + "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */\ + "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */\ + "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */\ + "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */\ + "7: \n" \ + /* trans 1 */ \ + "trn1 v0.4s, v16.4s, v17.4s\n" \ + "trn2 v1.4s, v16.4s, v17.4s\n" \ + "trn1 v2.4s, v18.4s, v19.4s\n" \ + "trn2 v3.4s, v18.4s, v19.4s\n" \ + "trn1 v4.4s, v20.4s, v21.4s\n" \ + "trn2 v5.4s, v20.4s, v21.4s\n" \ + "trn1 v6.4s, v22.4s, v23.4s\n" \ + "trn2 v7.4s, v22.4s, v23.4s\n" \ + "trn1 v8.4s, v24.4s, v25.4s\n" \ + "trn2 v9.4s, v24.4s, v25.4s\n" \ + "trn1 v10.4s, v26.4s, v27.4s\n" \ + "trn2 v11.4s, v26.4s, v27.4s\n" \ + "trn1 v12.4s, v28.4s, v29.4s\n" \ + "trn2 v13.4s, v28.4s, v29.4s\n" \ + "trn1 v14.4s, v30.4s, v31.4s\n" \ + "trn2 v15.4s, v30.4s, v31.4s\n" \ + /* trans 2 */ \ + "trn1 v16.2d, v0.2d, v2.2d\n" \ + "trn2 v18.2d, v0.2d, v2.2d\n" \ + "trn1 v17.2d, v1.2d, v3.2d\n" \ + "trn2 v19.2d, v1.2d, v3.2d\n" \ + "trn1 v20.2d, v4.2d, v6.2d\n" \ + "trn2 v22.2d, v4.2d, v6.2d\n" \ + "trn1 v21.2d, v5.2d, v7.2d\n" \ + "trn2 v23.2d, v5.2d, v7.2d\n" \ + "trn1 v24.2d, v8.2d, v10.2d\n" \ + "trn2 v26.2d, v8.2d, v10.2d\n" \ + "trn1 v25.2d, v9.2d, v11.2d\n" \ + "trn2 v27.2d, v9.2d, v11.2d\n" \ + "trn1 v28.2d, v12.2d, v14.2d\n" \ + "trn2 v30.2d, v12.2d, v14.2d\n" \ + "trn1 v29.2d, v13.2d, v15.2d\n" \ + "trn2 v31.2d, v13.2d, v15.2d\n" \ + /* shift */ \ + "ext v17.16b, v17.16b, v17.16b, #12\n" /* circular shift left 1 */ \ + "ext v18.16b, v18.16b, v18.16b, #8\n" /* circular shift left 2 */ \ + "ext v19.16b, v19.16b, v19.16b, #4\n" /* circular shift left 3 */ \ + "ext v21.16b, v21.16b, v21.16b, #12\n" /* circular shift left 1 */ \ + "ext v22.16b, v22.16b, v22.16b, #8\n" /* circular shift left 2 */ \ + "ext v23.16b, v23.16b, v23.16b, #4\n" /* circular shift left 3 */ \ + "ext v25.16b, v25.16b, v25.16b, #12\n" /* circular shift left 1 */ \ + "ext v26.16b, v26.16b, v26.16b, #8\n" /* circular shift left 2 */ \ + "ext v27.16b, v27.16b, v27.16b, #4\n" /* circular shift left 3 */ \ + "ext v29.16b, v29.16b, v29.16b, #12\n" /* circular shift left 1 */ \ + "ext v30.16b, v30.16b, v30.16b, #8\n" /* circular shift left 2 */ \ + "ext v31.16b, v31.16b, v31.16b, #4\n" /* circular shift left 3 */ \ + /* trans */ \ + "trn1 v0.4s, v16.4s, v17.4s\n" /* get a0,b0, a2,b2 */ \ + "trn2 v1.4s, v16.4s, v17.4s\n" /* get a1,b1, a3,b3 */ \ + "trn1 v2.4s, v18.4s, v19.4s\n" /* get c0,d0, c2,c2 */ \ + "trn2 v3.4s, v18.4s, v19.4s\n" /* get c1,d1, c3,d3 */ \ + "trn1 v4.4s, v20.4s, v21.4s\n" \ + "trn2 v5.4s, v20.4s, v21.4s\n" \ + "trn1 v6.4s, v22.4s, v23.4s\n" \ + "trn2 v7.4s, v22.4s, v23.4s\n" \ + "trn1 v8.4s, v24.4s, v25.4s\n" \ + "trn2 v9.4s, v24.4s, v25.4s\n" \ + "trn1 v10.4s, v26.4s, v27.4s\n" \ + "trn2 v11.4s, v26.4s, v27.4s\n" \ + "trn1 v12.4s, v28.4s, v29.4s\n" \ + "trn2 v13.4s, v28.4s, v29.4s\n" \ + "trn1 v14.4s, v30.4s, v31.4s\n" \ + "trn2 v15.4s, v30.4s, v31.4s\n" /* trans 2 */ \ + "trn1 v16.2d, v0.2d, v2.2d\n" /* get a0,b0, c0,d0 */ \ + "trn2 v24.2d, v0.2d, v2.2d\n" /* get a2,b2, c2,d2 */ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* get a1,b1, c1,d1 */ \ + "trn2 v28.2d, v1.2d, v3.2d\n" /* get a3,b3, c3,d3 */ \ + "trn1 v17.2d, v4.2d, v6.2d\n" \ + "trn2 v25.2d, v4.2d, v6.2d\n" \ + "trn1 v21.2d, v5.2d, v7.2d\n" \ + "trn2 v29.2d, v5.2d, v7.2d\n" \ + "trn1 v18.2d, v8.2d, v10.2d\n" \ + "trn2 v26.2d, v8.2d, v10.2d\n" \ + "trn1 v22.2d, v9.2d, v11.2d\n" \ + "trn2 v30.2d, v9.2d, v11.2d\n" \ + "trn1 v19.2d, v12.2d, v14.2d\n" \ + "trn2 v27.2d, v12.2d, v14.2d\n" \ + "trn1 v23.2d, v13.2d, v15.2d\n" \ "trn2 v31.2d, v13.2d, v15.2d\n" -// clang-format off -#define GEMM_INT8_INT32_OUT \ - /* store */ \ - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[c_ptr0]], #64\n" \ - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[c_ptr1]], #64\n" \ - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[c_ptr2]], #64\n" \ - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[c_ptr3]], #64\n" -// clang-format on +#define GEMM_INT8_RELU \ + /* do relu */ \ + "cbz %w[is_relu], 9f\n" /* skip relu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \ + "fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \ + "fmax v18.4s, v18.4s, v0.4s\n" /* relu */ \ + "fmax v19.4s, v19.4s, v0.4s\n" /* relu */ \ + "fmax v20.4s, v20.4s, v0.4s\n" /* relu */ \ + "fmax v21.4s, v21.4s, v0.4s\n" /* relu */ \ + "fmax v22.4s, v22.4s, v0.4s\n" /* relu */ \ + "fmax v23.4s, v23.4s, v0.4s\n" /* relu */ \ + "fmax v24.4s, v24.4s, v0.4s\n" /* relu */ \ + "fmax v25.4s, v25.4s, v0.4s\n" /* relu */ \ + "fmax v26.4s, v26.4s, v0.4s\n" /* relu */ \ + "fmax v27.4s, v27.4s, v0.4s\n" /* relu */ \ + "fmax v28.4s, v28.4s, v0.4s\n" /* relu */ \ + "fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \ + "fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \ + "fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \ + "9:\n" + +#define GEMM_TRANS_INT32_TO_FP32 \ + "ldr q14, [%[bias]]\n" /* load scale */ \ + "ldr q15, [%[scale]]\n" /* load scale */ \ + "scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \ + "scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \ + "scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \ + "scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \ + "scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \ + "scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \ + "scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \ + "scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \ + /* add bias */ \ + "dup v16.4s, v14.s[0]\n" \ + "dup v17.4s, v14.s[0]\n" \ + "dup v18.4s, v14.s[0]\n" \ + "dup v19.4s, v14.s[0]\n" \ + "dup v20.4s, v14.s[1]\n" \ + "dup v21.4s, v14.s[1]\n" \ + "dup v22.4s, v14.s[1]\n" \ + "dup v23.4s, v14.s[1]\n" \ + "fmla v16.4s, v0.4s, v15.s[0]\n" /* 00, mul scale */ \ + "fmla v17.4s, v1.4s, v15.s[0]\n" /* 01, mul scale */ \ + "fmla v18.4s, v2.4s, v15.s[0]\n" /* 02, mul scale */ \ + "fmla v19.4s, v3.4s, v15.s[0]\n" /* 03, mul scale */ \ + "fmla v20.4s, v4.4s, v15.s[1]\n" /* 10, mul scale */ \ + "fmla v21.4s, v5.4s, v15.s[1]\n" /* 11, mul scale */ \ + "fmla v22.4s, v6.4s, v15.s[1]\n" /* 12, mul scale */ \ + "fmla v23.4s, v7.4s, v15.s[1]\n" /* 13, mul scale */ \ + "scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \ + "scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \ + "scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \ + "scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \ + "scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \ + "scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \ + "scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \ + "scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \ + "dup v24.4s, v14.s[2]\n" \ + "dup v25.4s, v14.s[2]\n" \ + "dup v26.4s, v14.s[2]\n" \ + "dup v27.4s, v14.s[2]\n" \ + "dup v28.4s, v14.s[3]\n" \ + "dup v29.4s, v14.s[3]\n" \ + "dup v30.4s, v14.s[3]\n" \ + "dup v31.4s, v14.s[3]\n" \ + "fmla v24.4s, v0.4s, v15.s[2]\n" /* 20, mul scale */ \ + "fmla v25.4s, v1.4s, v15.s[2]\n" /* 21, mul scale */ \ + "fmla v26.4s, v2.4s, v15.s[2]\n" /* 22, mul scale */ \ + "fmla v27.4s, v3.4s, v15.s[2]\n" /* 23, mul scale */ \ + "fmla v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale */ \ + "fmla v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale */ \ + "fmla v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale */ \ + "fmla v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale */ + +#define GEMM_INT8_FP32_OUT \ + GEMM_TRANS_INT32_TO_FP32 \ + GEMM_INT8_RELU \ + /* store result */ \ + "stp q16, q17, [%[c_ptr0]], #32\n" \ + "stp q18, q19, [%[c_ptr0]], #32\n" \ + "stp q20, q21, [%[c_ptr1]], #32\n" \ + "stp q22, q23, [%[c_ptr1]], #32\n" \ + "stp q24, q25, [%[c_ptr2]], #32\n" \ + "stp q26, q27, [%[c_ptr2]], #32\n" \ + "stp q28, q29, [%[c_ptr3]], #32\n" \ + "stp q30, q31, [%[c_ptr3]], #32\n" + +#define GEMM_INT8_INT8_OUT \ + GEMM_TRANS_INT32_TO_FP32 \ + GEMM_INT8_RELU \ + "fcvtas v0.4s, v16.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v17.4s\n" /* 01, cvt to int */ \ + "fcvtas v2.4s, v18.4s\n" /* 02, cvt to int */ \ + "fcvtas v3.4s, v19.4s\n" /* 03, cvt to int */ \ + "fcvtas v4.4s, v20.4s\n" /* 10, cvt to int */ \ + "fcvtas v5.4s, v21.4s\n" /* 11, cvt to int */ \ + "fcvtas v6.4s, v22.4s\n" /* 12, cvt to int */ \ + "fcvtas v7.4s, v23.4s\n" /* 13, cvt to int */ \ + "sqxtn v16.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v8.4s, v24.4s\n" /* 20, cvt to int */ \ + "sqxtn2 v16.8h, v1.4s\n" /* 01, cvt int32 to int16 */ \ + "fcvtas v9.4s, v25.4s\n" /* 21, cvt to int */ \ + "sqxtn v17.4h, v2.4s\n" /* 02, cvt int32 to int16 */ \ + "fcvtas v10.4s, v26.4s\n" /* 22, cvt to int */ \ + "sqxtn2 v17.8h, v3.4s\n" /* 03, cvt int32 to int16 */ \ + "fcvtas v11.4s, v27.4s\n" /* 23, cvt to int */ \ + "sqxtn v18.4h, v4.4s\n" /* 10, cvt int32 to int16 */ \ + "fcvtas v12.4s, v28.4s\n" /* 30, cvt to int */ \ + "sqxtn2 v18.8h, v5.4s\n" /* 11, cvt int32 to int16 */ \ + "fcvtas v13.4s, v29.4s\n" /* 31, cvt to int */ \ + "sqxtn v19.4h, v6.4s\n" /* 12, cvt int32 to int16 */ \ + "fcvtas v14.4s, v30.4s\n" /* 32, cvt to int */ \ + "sqxtn2 v19.8h, v7.4s\n" /* 13, cvt int32 to int16 */ \ + "fcvtas v15.4s, v31.4s\n" /* 33, cvt to int */ \ + "sqxtn v0.8b, v16.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn2 v0.16b, v17.8h\n" /* 02, 03, cvt int16 to int8 */ \ + "sqxtn v1.8b, v18.8h\n" /* 10, 11, cvt int16 to int8 */ \ + "sqxtn2 v1.16b, v19.8h\n" /* 12, 13, cvt int16 to int8 */ \ + "sqxtn v20.4h, v8.4s\n" /* 20, cvt int32 to int16 */ \ + "sqxtn2 v20.8h, v9.4s\n" /* 21, cvt int32 to int16 */ \ + "sqxtn v21.4h, v10.4s\n" /* 22, cvt int32 to int16 */ \ + "sqxtn2 v21.8h, v11.4s\n" /* 23, cvt int32 to int16 */ \ + "sqxtn v22.4h, v12.4s\n" /* 30, cvt int32 to int16 */ \ + "sqxtn2 v22.8h, v13.4s\n" /* 31, cvt int32 to int16 */ \ + "sqxtn v23.4h, v14.4s\n" /* 32, cvt int32 to int16 */ \ + "sqxtn2 v23.8h, v15.4s\n" /* 33, cvt int32 to int16 */ \ + "sqxtn v2.8b, v20.8h\n" /* 20, 21, cvt int16 to int8 */ \ + "sqxtn2 v2.16b, v21.8h\n" /* 22, 23, cvt int16 to int8 */ \ + "sqxtn v3.8b, v22.8h\n" /* 30, 31, cvt int16 to int8 */ \ + "sqxtn2 v3.16b, v23.8h\n" /* 32, 33, cvt int16 to int8 */ \ + "str q0, [%[c_ptr0]], #16\n" /* write r0 */ \ + "str q1, [%[c_ptr1]], #16\n" /* write r1 */ \ + "str q2, [%[c_ptr2]], #16\n" /* write r2 */ \ + "str q3, [%[c_ptr3]], #16\n" /* write r3 */ -#define GEMM_INT8_FP32_OUT \ - /* store */ \ - "ldr q15, [%[scale]]\n" /* load scale */ \ - "scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \ - "scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \ - "scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \ - "scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \ - "scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \ - "scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \ - "scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \ - "scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \ - "fmul v16.4s, v0.4s, v15.s[0]\n" /* 00, mul scale to get final result */ \ - "fmul v17.4s, v1.4s, v15.s[0]\n" /* 01, mul scale to get final result */ \ - "fmul v18.4s, v2.4s, v15.s[0]\n" /* 02, mul scale to get final result */ \ - "fmul v19.4s, v3.4s, v15.s[0]\n" /* 03, mul scale to get final result */ \ - "fmul v20.4s, v4.4s, v15.s[1]\n" /* 10, mul scale to get final result */ \ - "fmul v21.4s, v5.4s, v15.s[1]\n" /* 11, mul scale to get final result */ \ - "fmul v22.4s, v6.4s, v15.s[1]\n" /* 12, mul scale to get final result */ \ - "fmul v23.4s, v7.4s, v15.s[1]\n" /* 13, mul scale to get final result */ \ - "scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \ - "scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \ - "stp q16, q17, [%[c_ptr0]], #32\n" /* write r0, 0,1 */ \ - "scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \ - "scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \ - "stp q18, q19, [%[c_ptr0]], #32\n" /* write r0, 2,3 */ \ - "scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \ - "scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \ - "stp q20, q21, [%[c_ptr1]], #32\n" /* write r1, 0,1 */ \ - "scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \ - "scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \ - "stp q22, q23, [%[c_ptr1]], #32\n" /* write r1, 2,3 */ \ - "fmul v24.4s, v0.4s, v15.s[2]\n" /* 20, mul scale to get final result */ \ - "fmul v25.4s, v1.4s, v15.s[2]\n" /* 21, mul scale to get final result */ \ - "fmul v26.4s, v2.4s, v15.s[2]\n" /* 22, mul scale to get final result */ \ - "fmul v27.4s, v3.4s, v15.s[2]\n" /* 23, mul scale to get final result */ \ - "fmul v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale to get final result */ \ - "fmul v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale to get final result */ \ - "stp q24, q25, [%[c_ptr2]], #32\n" /* write r2, 2,3 */ \ - "fmul v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale to get final result */ \ - "stp q26, q27, [%[c_ptr2]], #32\n" /* write r2, 2,3 */ \ - "fmul v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale to get final result */ \ - "stp q28, q29, [%[c_ptr3]], #32\n" /* write r3, 2,3 */ \ - "stp q30, q31, [%[c_ptr3]], #32\n" /* write r3, 2,3 */ - -#define GEMM_INT8_INT8_OUT \ - /* store */ \ - "ldr q15, [%[scale]]\n" /* load scale */ \ - "scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \ - "scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \ - "scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \ - "scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \ - "scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \ - "scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \ - "scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \ - "scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \ - "fmul v16.4s, v0.4s, v15.s[0]\n" /* 00, mul scale to get final result */ \ - "fmul v17.4s, v1.4s, v15.s[0]\n" /* 01, mul scale to get final result */ \ - "fmul v18.4s, v2.4s, v15.s[0]\n" /* 02, mul scale to get final result */ \ - "fmul v19.4s, v3.4s, v15.s[0]\n" /* 03, mul scale to get final result */ \ - "fmul v20.4s, v4.4s, v15.s[1]\n" /* 20, mul scale to get final result */ \ - "fmul v21.4s, v5.4s, v15.s[1]\n" /* 21, mul scale to get final result */ \ - "fmul v22.4s, v6.4s, v15.s[1]\n" /* 22, mul scale to get final result */ \ - "fmul v23.4s, v7.4s, v15.s[1]\n" /* 23, mul scale to get final result */ \ - "scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \ - "scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \ - "scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \ - "scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \ - "scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \ - "scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \ - "scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \ - "scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \ - "fmul v24.4s, v0.4s, v15.s[2]\n" /* 20, mul scale to get final result */ \ - "fmul v25.4s, v1.4s, v15.s[2]\n" /* 21, mul scale to get final result */ \ - "fmul v26.4s, v2.4s, v15.s[2]\n" /* 22, mul scale to get final result */ \ - "fmul v27.4s, v3.4s, v15.s[2]\n" /* 23, mul scale to get final result */ \ - "fmul v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale to get final result */ \ - "fmul v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale to get final result */ \ - "fmul v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale to get final result */ \ - "fmul v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale to get final result */ \ - "fcvtas v0.4s, v16.4s\n" /* 00, cvt to int */ \ - "fcvtas v1.4s, v17.4s\n" /* 01, cvt to int */ \ - "fcvtas v2.4s, v18.4s\n" /* 02, cvt to int */ \ - "fcvtas v3.4s, v19.4s\n" /* 03, cvt to int */ \ - "fcvtas v4.4s, v20.4s\n" /* 10, cvt to int */ \ - "fcvtas v5.4s, v21.4s\n" /* 11, cvt to int */ \ - "fcvtas v6.4s, v22.4s\n" /* 12, cvt to int */ \ - "fcvtas v7.4s, v23.4s\n" /* 13, cvt to int */ \ - "sqxtn v16.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v8.4s, v24.4s\n" /* 20, cvt to int */ \ - "sqxtn2 v16.8h, v1.4s\n" /* 01, cvt int32 to int16 */ \ - "fcvtas v9.4s, v25.4s\n" /* 21, cvt to int */ \ - "sqxtn v17.4h, v2.4s\n" /* 02, cvt int32 to int16 */ \ - "fcvtas v10.4s, v26.4s\n" /* 22, cvt to int */ \ - "sqxtn2 v17.8h, v3.4s\n" /* 03, cvt int32 to int16 */ \ - "fcvtas v11.4s, v27.4s\n" /* 23, cvt to int */ \ - "sqxtn v18.4h, v4.4s\n" /* 10, cvt int32 to int16 */ \ - "fcvtas v12.4s, v28.4s\n" /* 30, cvt to int */ \ - "sqxtn2 v18.8h, v5.4s\n" /* 11, cvt int32 to int16 */ \ - "fcvtas v13.4s, v29.4s\n" /* 31, cvt to int */ \ - "sqxtn v19.4h, v6.4s\n" /* 12, cvt int32 to int16 */ \ - "fcvtas v14.4s, v30.4s\n" /* 32, cvt to int */ \ - "sqxtn2 v19.8h, v7.4s\n" /* 13, cvt int32 to int16 */ \ - "fcvtas v15.4s, v31.4s\n" /* 33, cvt to int */ \ - "sqxtn v0.8b, v16.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn2 v0.16b, v17.8h\n" /* 02, 03, cvt int16 to int8 */ \ - "sqxtn v1.8b, v18.8h\n" /* 10, 11, cvt int16 to int8 */ \ - "sqxtn2 v1.16b, v19.8h\n" /* 12, 13, cvt int16 to int8 */ \ - "sqxtn v20.4h, v8.4s\n" /* 20, cvt int32 to int16 */ \ - "sqxtn2 v20.8h, v9.4s\n" /* 21, cvt int32 to int16 */ \ - "sqxtn v21.4h, v10.4s\n" /* 22, cvt int32 to int16 */ \ - "sqxtn2 v21.8h, v11.4s\n" /* 23, cvt int32 to int16 */ \ - "sqxtn v22.4h, v12.4s\n" /* 30, cvt int32 to int16 */ \ - "sqxtn2 v22.8h, v13.4s\n" /* 31, cvt int32 to int16 */ \ - "sqxtn v23.4h, v14.4s\n" /* 32, cvt int32 to int16 */ \ - "sqxtn2 v23.8h, v15.4s\n" /* 33, cvt int32 to int16 */ \ - "sqxtn v2.8b, v20.8h\n" /* 20, 21, cvt int16 to int8 */ \ - "sqxtn2 v2.16b, v21.8h\n" /* 22, 23, cvt int16 to int8 */ \ - "sqxtn v3.8b, v22.8h\n" /* 30, 31, cvt int16 to int8 */ \ - "sqxtn2 v3.16b, v23.8h\n" /* 32, 33, cvt int16 to int8 */ \ - "str q0, [%[c_ptr0]], #16\n" /* write r0 */ \ - "str q1, [%[c_ptr1]], #16\n" /* write r1 */ \ - "str q2, [%[c_ptr2]], #16\n" /* write r2 */ \ - "str q3, [%[c_ptr3]], #16\n" /* write r3 */ +// clang-format on template <> inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT - const int32_t* bias, - int32_t*& c_ptr0, // NOLINT - int32_t*& c_ptr1, // NOLINT - int32_t*& c_ptr2, // NOLINT - int32_t*& c_ptr3, // NOLINT - const float* scale, // NOLINT - bool is_relu, // NOLINT - int k, - int rem) { - asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT - : [a_ptr] "+r"(a_ptr), - [b_ptr] "+r"(b_ptr), - [c_ptr0] "+r"(c_ptr0), - [c_ptr1] "+r"(c_ptr1), - [c_ptr2] "+r"(c_ptr2), - [c_ptr3] "+r"(c_ptr3), - [k] "+r"(k) - : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31", - "cc"); -} -template <> -inline void gemm_int8_kernel(const int8_t* a_ptr, - const int8_t*& b_ptr, // NOLINT - const int32_t* bias, - float*& c_ptr0, // NOLINT - float*& c_ptr1, // NOLINT - float*& c_ptr2, // NOLINT - float*& c_ptr3, // NOLINT - const float* scale, + const float* bias, + float32_t*& c_ptr0, // NOLINT + float32_t*& c_ptr1, // NOLINT + float32_t*& c_ptr2, // NOLINT + float32_t*& c_ptr3, // NOLINT + const float32_t* scale, bool is_relu, int k, int rem) { + // clang-format off asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), @@ -700,53 +644,27 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, [bias] "r"(bias), [rem] "r"(rem), [scale] "r"(scale) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31", - "cc"); + : "v0","v1","v2","v3","v4","v5","v6","v7","v8", + "v9","v10","v11","v12","v13","v14", + "v15","v16","v17","v18","v19","v20", + "v21","v22","v23","v24","v25","v26", + "v27","v28","v29","v30","v31","cc"); + // clang-format on } template <> inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT - const int32_t* bias, + const float* bias, int8_t*& c_ptr0, // NOLINT int8_t*& c_ptr1, // NOLINT int8_t*& c_ptr2, // NOLINT int8_t*& c_ptr3, // NOLINT - const float* scale, + const float32_t* scale, bool is_relu, int k, int rem) { + // clang-format off asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), @@ -759,98 +677,73 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, [bias] "r"(bias), [rem] "r"(rem), [scale] "r"(scale) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31", - "cc"); + : "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12", + "v13","v14","v15","v16","v17", + "v18","v19","v20","v21","v22", + "v23","v24","v25","v26","v27", + "v28","v29","v30","v31","cc"); + // clang-format on } #ifdef WITH_ARM_DOTPROD template -inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, - const int8_t*& b_ptr, // NOLINT - const int32_t* bias, - Dtype*& c_ptr0, // NOLINT - Dtype*& c_ptr1, // NOLINT - Dtype*& c_ptr2, // NOLINT - Dtype*& c_ptr3, // NOLINT - Dtype*& c_ptr4, // NOLINT - Dtype*& c_ptr5, // NOLINT - Dtype*& c_ptr6, // NOLINT - Dtype*& c_ptr7, // NOLINT - const float32_t* scale, - bool is_relu, - int k, - int rem); - +inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const float* bias, + Dtype*& c_ptr0, // NOLINT + Dtype*& c_ptr1, // NOLINT + Dtype*& c_ptr2, // NOLINT + Dtype*& c_ptr3, // NOLINT + Dtype*& c_ptr4, // NOLINT + Dtype*& c_ptr5, // NOLINT + Dtype*& c_ptr6, // NOLINT + Dtype*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int rem); +#if 0 +// clang-format off #define GEMM_SDOT_INT8_KERNEL \ - "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ \ "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ \ "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \ - "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ \ - "dup v9.4s, v2.s[0]\n" /* out1 = 0*/ \ - "dup v10.4s, v2.s[0]\n" /* out2 = 0*/ \ - "dup v11.4s, v2.s[1]\n" /* out3 = 0*/ \ - "dup v12.4s, v2.s[1]\n" /* out4 = 0*/ \ + "eor v8.16b, v8.16b, v8.16b\n" /* out0 = 0 */ \ + "eor v9.16b, v9.16b, v9.16b\n" /* out1 = 0 */ \ + "eor v10.16b, v10.16b, v10.16b\n" /* out2 = 0 */ \ + "eor v11.16b, v11.16b, v11.16b\n" /* out3 = 0 */ \ + "eor v12.16b, v12.16b, v12.16b\n" /* out4 = 0 */ \ "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \ - "dup v13.4s, v2.s[1]\n" /* out5 = 0*/ \ + "eor v13.16b, v13.16b, v13.16b\n" /* out5 = 0 */ \ "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \ - "dup v14.4s, v2.s[2]\n" /* out6 = 0*/ \ + "eor v14.16b, v14.16b, v14.16b\n" /* out6 = 0 */ \ "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \ - "dup v15.4s, v2.s[2]\n" /* out7 = 0*/ \ + "eor v15.16b, v15.16b, v15.16b\n" /* out7 = 0 */ \ "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \ - "dup v16.4s, v2.s[2]\n" /* out8 = 0*/ \ + "eor v16.16b, v16.16b, v16.16b\n" /* out8 = 0 */ \ "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \ - "dup v17.4s, v2.s[3]\n" /* out9 = 0*/ \ + "eor v17.16b, v17.16b, v17.16b\n" /* out9 = 0 */ \ "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \ - "dup v18.4s, v2.s[3]\n" /* out10 = 0*/ \ + "eor v18.16b, v18.16b, v18.16b\n" /* out10 = 0 */ \ "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \ - "dup v19.4s, v2.s[3]\n" /* out11 = 0*/ \ + "eor v19.16b, v19.16b, v19.16b\n" /* out11 = 0 */ \ "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ \ - "dup v20.4s, v3.s[0]\n" /* out12 = 0*/ \ + "eor v20.16b, v20.16b, v20.16b\n" /* out12 = 0 */ \ "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ \ - "dup v21.4s, v3.s[0]\n" /* out13 = 0*/ \ + "eor v21.16b, v21.16b, v21.16b\n" /* out13 = 0 */ \ "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \ - "dup v22.4s, v3.s[0]\n" /* out14 = 0*/ \ - "dup v23.4s, v3.s[1]\n" /* out15 = 0*/ \ - "dup v24.4s, v3.s[1]\n" /* out16 = 0*/ \ - "dup v25.4s, v3.s[1]\n" /* out17 = 0*/ \ - "dup v26.4s, v3.s[2]\n" /* out18 = 0*/ \ - "dup v27.4s, v3.s[2]\n" /* out19 = 0*/ \ - "dup v28.4s, v3.s[2]\n" /* out20 = 0*/ \ - "dup v29.4s, v3.s[3]\n" /* out21 = 0*/ \ - "dup v30.4s, v3.s[3]\n" /* out22 = 0*/ \ - "dup v31.4s, v3.s[3]\n" /* out23 = 0*/ \ - "cbz %w[k], 2f\n" /* check loop count > 0 */ \ + "eor v22.16b, v22.16b, v22.16b\n" /* out14 = 0 */ \ + "eor v23.16b, v23.16b, v23.16b\n" /* out15 = 0 */ \ + "eor v24.16b, v24.16b, v24.16b\n" /* out16 = 0 */ \ + "eor v25.16b, v25.16b, v25.16b\n" /* out17 = 0 */ \ + "eor v26.16b, v26.16b, v26.16b\n" /* out18 = 0 */ \ + "eor v27.16b, v27.16b, v27.16b\n" /* out19 = 0 */ \ + "eor v28.16b, v28.16b, v28.16b\n" /* out20 = 0 */ \ + "eor v29.16b, v29.16b, v29.16b\n" /* out21 = 0 */ \ + "eor v30.16b, v30.16b, v30.16b\n" /* out22 = 0 */ \ + "eor v31.16b, v31.16b, v31.16b\n" /* out23 = 0 */ \ + "cbz %w[k], 2f\n" /* check loop count > 0 */ \ + /* main loop, unrool 0*/ \ "1:\n" /* main loop */ \ "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4 */ \ "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4 */ \ @@ -874,40 +767,42 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \ "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \ "prfm pldl1keep, [%[b_ptr], #384]\n" \ - "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \ - "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \ - "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \ - "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ - "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ - "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ - "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ \ - "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7 */ \ - "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q7 */ \ - "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7 */ \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ \ + /* unrool 1 */ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7 */ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n"/* out1 = b0 * a10[1], b0 = q7 */ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7 */ \ "prfm pldl1keep, [%[a_ptr], #256]\n" \ - "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7 */ \ - "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7 */ \ - "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7 */ \ - "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7 */ \ - "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7 */ \ - "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ \ - "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4 */ \ - "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4 */ \ - "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \ - "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \ - "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \ - "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \ - "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \ - "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \ - "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \ - "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \ - "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \ - "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \ - "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \ - "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ - "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ - "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ - "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7 */ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7 */ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7 */ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7 */ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7 */ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4 */ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4 */ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ \ + /* unrool 2*/ \ "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6 */ \ "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6 */ \ "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \ @@ -935,7 +830,8 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \ "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \ "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \ - "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 */ /* unrool 3*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \ + /* unrool 3*/ \ "sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ "sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ "sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ @@ -964,10 +860,11 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "subs %w[k], %w[k], #1\n" /* loop count - 1*/ \ "sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ "sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ - "bne 1b\n" \ - "2:\n" /* process tail*/ \ - "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ - "beq 3f\n" \ + "bne 1b\n" /* Target to use when K is 1 or 2 */ \ + "2:\n" /* process tail*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "beq 3f\n" /*jump to tail = 1*/ \ + /* final unrool 0, unrool 0, tail > 1*/ \ "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4*/ \ "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4*/ \ "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ \ @@ -996,7 +893,8 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ - "beq 4f\n" /*jump to tail = 2*/ /* unrool 1, tail > 2*/ \ + "beq 4f\n" /*jump to tail = 2*/ \ + /* unrool 1, tail > 2*/ \ "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \ "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7*/ \ "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q7*/ \ @@ -1024,7 +922,8 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ - "beq 5f\n" /*jump to tail = 3*/ /* unrool 2, tail = 4*/ \ + "beq 5f\n" /*jump to tail = 3*/ \ + /* unrool 2, tail = 4*/ \ "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ \ "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6*/ \ "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6*/ \ @@ -1052,6 +951,7 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \ "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \ "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \ + /* unrool 3, tail = 4*/ \ "sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ "sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ "sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ @@ -1156,36 +1056,117 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ - "11: \n" /* check if relu */ \ - "cbz %w[relu], 12f\n" /* skip relu */ \ - "movi v2.4s, #0\n" /* for relu*/ \ - "smax v8.4s, v8.4s, v2.4s\n" /* relu*/ \ - "smax v9.4s, v9.4s, v2.4s\n" /* relu*/ \ - "smax v10.4s, v10.4s, v2.4s\n" /* relu*/ \ - "smax v11.4s, v11.4s, v2.4s\n" /* relu*/ \ - "smax v12.4s, v12.4s, v2.4s\n" /* relu*/ \ - "smax v13.4s, v13.4s, v2.4s\n" /* relu*/ \ - "smax v14.4s, v14.4s, v2.4s\n" /* relu*/ \ - "smax v15.4s, v15.4s, v2.4s\n" /* relu*/ \ - "smax v16.4s,v16.4s,v2.4s\n" /* relu*/ \ - "smax v17.4s,v17.4s,v2.4s\n" /* relu*/ \ - "smax v18.4s, v18.4s, v2.4s\n" /* relu*/ \ - "smax v19.4s, v19.4s, v2.4s\n" /* relu*/ \ - "smax v20.4s, v20.4s, v2.4s\n" /* relu*/ \ - "smax v21.4s, v21.4s, v2.4s\n" /* relu*/ \ - "smax v22.4s, v22.4s, v2.4s\n" /* relu*/ \ - "smax v23.4s, v23.4s, v2.4s\n" /* relu*/ \ - "smax v24.4s, v24.4s, v2.4s\n" /* relu*/ \ - "smax v25.4s, v25.4s, v2.4s\n" /* relu*/ \ - "smax v26.4s, v26.4s, v2.4s\n" /* relu*/ \ - "smax v27.4s, v27.4s, v2.4s\n" /* relu*/ \ - "smax v28.4s, v28.4s, v2.4s\n" /* relu*/ \ - "smax v29.4s, v29.4s, v2.4s\n" /* relu*/ \ - "smax v30.4s, v30.4s, v2.4s\n" /* relu*/ \ - "smax v31.4s, v31.4s, v2.4s\n" /* relu*/ \ + "11: \n" /* end */ +#endif + +#define GEMM_SDOT_RELU \ + "cbz %w[relu], 12f\n" /* skip relu */ \ + "movi v2.4s, #0\n" /* for relu*/ \ + "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \ + "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \ + "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ \ + "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ \ + "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ \ + "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ \ + "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ \ + "fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ \ + "fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ \ + "fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ \ + "fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ \ + "fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ \ + "fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ \ + "fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ \ + "fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ \ + "fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ \ + "fmax v24.4s, v24.4s, v2.4s\n" /* relu*/ \ + "fmax v25.4s, v25.4s, v2.4s\n" /* relu*/ \ + "fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ \ + "fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ \ + "fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ \ + "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \ + "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \ + "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \ "12: \n" -#define GEMM_SDOT_INT32_OUT \ +#define GEMM_SDOT_CVT_INT32_TO_FP32 \ + "ldp q0, q1, [%[scale]]\n" /* load scale */ \ + "ldp q2, q3, [%[bias_ptr]]\n" /* load bias */ \ + "scvtf v4.4s , v8.4s\n" /* 00, convert to fp32 */ \ + "scvtf v5.4s , v9.4s\n" /* 01, convert to fp32 */ \ + "scvtf v6.4s , v10.4s\n" /* 02, convert to fp32 */ \ + "dup v8.4s, v2.s[0]\n" /* fill with bias*/ \ + "dup v9.4s, v2.s[0]\n" /* fill with bias*/ \ + "dup v10.4s, v2.s[0]\n" /* fill with bias*/ \ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* 00, mul scale to get final result */ \ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* 01, mul scale to get final result */ \ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* 02, mul scale to get final result */ \ + "scvtf v4.4s , v11.4s\n" /* 10, convert to fp32 */ \ + "scvtf v5.4s , v12.4s\n" /* 11, convert to fp32 */ \ + "scvtf v6.4s , v13.4s\n" /* 12, convert to fp32 */ \ + "dup v11.4s, v2.s[1]\n" /* fill with bias*/ \ + "dup v12.4s, v2.s[1]\n" /* fill with bias*/ \ + "dup v13.4s, v2.s[1]\n" /* fill with bias*/ \ + "fmla v11.4s, v4.4s, v0.s[1]\n" /* 10, mul scale to get final result */ \ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* 11, mul scale to get final result */ \ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* 12, mul scale to get final result */ \ + "scvtf v4.4s , v14.4s\n" /* 20, convert to fp32 */ \ + "scvtf v5.4s , v15.4s\n" /* 21, convert to fp32 */ \ + "scvtf v6.4s , v16.4s\n" /* 22, convert to fp32 */ \ + "dup v14.4s, v2.s[2]\n" /* fill with bias*/ \ + "dup v15.4s, v2.s[2]\n" /* fill with bias*/ \ + "dup v16.4s, v2.s[2]\n" /* fill with bias*/ \ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* 20, mul scale to get final result */ \ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* 21, mul scale to get final result */ \ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* 22, mul scale to get final result */ \ + "scvtf v4.4s , v17.4s\n" /* 30, convert to fp32 */ \ + "scvtf v5.4s , v18.4s\n" /* 31, convert to fp32 */ \ + "scvtf v6.4s , v19.4s\n" /* 32, convert to fp32 */ \ + "dup v17.4s, v2.s[3]\n" /* fill with bias*/ \ + "dup v18.4s, v2.s[3]\n" /* fill with bias*/ \ + "dup v19.4s, v2.s[3]\n" /* fill with bias*/ \ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* 30, mul scale to get final result */ \ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* 31, mul scale to get final result */ \ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* 32, mul scale to get final result */ \ + "scvtf v4.4s , v20.4s\n" /* 40, convert to fp32 */ \ + "scvtf v5.4s , v21.4s\n" /* 41, convert to fp32 */ \ + "scvtf v6.4s , v22.4s\n" /* 42, convert to fp32 */ \ + "dup v20.4s, v3.s[0]\n" /* fill with bias*/ \ + "dup v21.4s, v3.s[0]\n" /* fill with bias*/ \ + "dup v22.4s, v3.s[0]\n" /* fill with bias*/ \ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* 40, mul scale to get final result */ \ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* 41, mul scale to get final result */ \ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* 42, mul scale to get final result */ \ + "scvtf v4.4s , v23.4s\n" /* 50, convert to fp32 */ \ + "scvtf v5.4s , v24.4s\n" /* 51, convert to fp32 */ \ + "scvtf v6.4s , v25.4s\n" /* 52, convert to fp32 */ \ + "dup v23.4s, v3.s[1]\n" /* fill with bias*/ \ + "dup v24.4s, v3.s[1]\n" /* fill with bias*/ \ + "dup v25.4s, v3.s[1]\n" /* fill with bias*/ \ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* 50, mul scale to get final result */ \ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* 51, mul scale to get final result */ \ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* 52, mul scale to get final result */ \ + "scvtf v4.4s , v26.4s\n" /* 60, convert to fp32 */ \ + "scvtf v5.4s , v27.4s\n" /* 61, convert to fp32 */ \ + "scvtf v6.4s , v28.4s\n" /* 62, convert to fp32 */ \ + "dup v26.4s, v3.s[2]\n" /* fill with bias*/ \ + "dup v27.4s, v3.s[2]\n" /* fill with bias*/ \ + "dup v28.4s, v3.s[2]\n" /* fill with bias*/ \ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* 60, mul scale to get final result */ \ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* 61, mul scale to get final result */ \ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* 62, mul scale to get final result */ \ + "scvtf v4.4s, v29.4s\n" /* 70, convert to fp32 */ \ + "scvtf v5.4s, v30.4s\n" /* 71, convert to fp32 */ \ + "scvtf v6.4s, v31.4s\n" /* 72, convert to fp32 */ \ + "dup v29.4s, v3.s[3]\n" /* fill with bias*/ \ + "dup v30.4s, v3.s[3]\n" /* fill with bias*/ \ + "dup v31.4s, v3.s[3]\n" /* fill with bias*/ \ + "fmla v29.4s, v4.4s,v1.s[3]\n" /* 70, mul scale to get final result */ \ + "fmla v30.4s, v5.4s,v1.s[3]\n" /* 71, mul scale to get final result */ \ + "fmla v31.4s, v6.4s,v1.s[3]\n" /* 72, mul scale to get final result */ + +#define GEMM_SDOT_FP32_OUT \ + GEMM_SDOT_CVT_INT32_TO_FP32 \ + GEMM_SDOT_RELU \ "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \ "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \ "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \ @@ -1195,213 +1176,111 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ \ "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ -#define GEMM_SDOT_FP32_OUT \ - "ldp q0, q1, [%[scale]]\n" /* load scale */ \ - "scvtf v2.4s , v8.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v9.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v10.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v11.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v12.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v13.4s\n" /* 00, convert to fp32 */ \ - "fmul v8.4s, v2.4s, v0.s[0]\n" /* 00, mul scale to get final */ \ - "fmul v9.4s, v3.4s, v0.s[0]\n" /* 00, mul scale to get final */ \ - "fmul v10.4s, v4.4s, v0.s[0]\n" /* 00, mul scale to get final */ \ - "fmul v11.4s, v5.4s, v0.s[1]\n" /* 00, mul scale to get final */ \ - "fmul v12.4s, v6.4s, v0.s[1]\n" /* 00, mul scale to get final */ \ - "fmul v13.4s, v7.4s, v0.s[1]\n" /* 00, mul scale to get final */ \ - "scvtf v2.4s , v14.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v15.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v16.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v17.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v18.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v19.4s\n" /* 00, convert to fp32 */ \ - "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \ - "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \ - "fmul v14.4s, v2.4s, v0.s[2]\n" /* 00, mul scale to get final */ \ - "fmul v15.4s, v3.4s, v0.s[2]\n" /* 00, mul scale to get final */ \ - "fmul v16.4s, v4.4s, v0.s[2]\n" /* 00, mul scale to get final */ \ - "fmul v17.4s, v5.4s, v0.s[3]\n" /* 00, mul scale to get final */ \ - "fmul v18.4s, v6.4s, v0.s[3]\n" /* 00, mul scale to get final */ \ - "fmul v19.4s, v7.4s, v0.s[3]\n" /* 00, mul scale to get final */ \ - "scvtf v2.4s , v20.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v21.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v22.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v23.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v24.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v25.4s\n" /* 00, convert to fp32 */ \ - "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \ - "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ \ - "fmul v20.4s, v2.4s, v1.s[0]\n" /* 00, mul scale to get final */ \ - "fmul v21.4s, v3.4s, v1.s[0]\n" /* 00, mul scale to get final */ \ - "fmul v22.4s, v4.4s, v1.s[0]\n" /* 00, mul scale to get final */ \ - "fmul v23.4s, v5.4s, v1.s[1]\n" /* 00, mul scale to get final */ \ - "fmul v24.4s, v6.4s, v1.s[1]\n" /* 00, mul scale to get final */ \ - "fmul v25.4s, v7.4s, v1.s[1]\n" /* 00, mul scale to get final */ \ - "scvtf v2.4s , v26.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v27.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v28.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v29.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v30.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v31.4s\n" /* 00, convert to fp32 */ \ - "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ \ - "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ \ - "fmul v26.4s, v2.4s, v1.s[2]\n" /* 00, mul scale to get final */ \ - "fmul v27.4s, v3.4s, v1.s[2]\n" /* 00, mul scale to get final */ \ - "fmul v28.4s, v4.4s, v1.s[2]\n" /* 00, mul scale to get final */ \ - "fmul v29.4s, v5.4s, v1.s[3]\n" /* 00, mul scale to get final */ \ - "fmul v30.4s, v6.4s, v1.s[3]\n" /* 00, mul scale to get final */ \ - "fmul v31.4s, v7.4s, v1.s[3]\n" /* 00, mul scale to get final */ \ - "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ \ - "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ +#define GEMM_SDOT_INT8_OUT \ + GEMM_SDOT_CVT_INT32_TO_FP32 \ + GEMM_SDOT_RELU \ + "fcvtas v0.4s, v8.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v9.4s\n" /* 01, cvt to int */ \ + "fcvtas v2.4s, v10.4s\n" /* 02, cvt to int */ \ + "fcvtas v3.4s, v11.4s\n" /* 10, cvt to int */ \ + "fcvtas v4.4s, v12.4s\n" /* 11, cvt to int */ \ + "fcvtas v5.4s, v13.4s\n" /* 12, cvt to int */ \ + "fcvtas v6.4s, v14.4s\n" /* 20, cvt to int */ \ + "fcvtas v7.4s, v15.4s\n" /* 21, cvt to int */ \ + "sqxtn v10.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v10.8h, v1.4s\n" /* 01, cvt int32 to int16 */ \ + "sqxtn v11.4h, v2.4s\n" /* 02, cvt int32 to int16 */ \ + "sqxtn v12.4h, v3.4s\n" /* 10, cvt int32 to int16 */ \ + "sqxtn2 v12.8h, v4.4s\n" /* 11, cvt int32 to int16 */ \ + "sqxtn v13.4h, v5.4s\n" /* 12, cvt int32 to int16 */ \ + "sqxtn v14.4h, v6.4s\n" /* 20, cvt int32 to int16 */ \ + "sqxtn2 v14.8h, v7.4s\n" /* 21, cvt int32 to int16 */ \ + "fcvtas v0.4s, v16.4s\n" /* 22, cvt to int */ \ + "fcvtas v1.4s, v17.4s\n" /* 30, cvt to int */ \ + "fcvtas v2.4s, v18.4s\n" /* 31, cvt to int */ \ + "fcvtas v3.4s, v19.4s\n" /* 32, cvt to int */ \ + "fcvtas v4.4s, v20.4s\n" /* 40, cvt to int */ \ + "fcvtas v5.4s, v21.4s\n" /* 41, cvt to int */ \ + "fcvtas v6.4s, v22.4s\n" /* 42, cvt to int */ \ + "fcvtas v7.4s, v23.4s\n" /* 50, cvt to int */ \ + "fcvtas v8.4s, v24.4s\n" /* 51, cvt to int */ \ + "fcvtas v9.4s, v25.4s\n" /* 52, cvt to int */ \ + "sqxtn v15.4h, v0.4s\n" /* 22, cvt int32 to int16 */ \ + "sqxtn v16.4h, v1.4s\n" /* 30, cvt int32 to int16 */ \ + "sqxtn2 v16.8h, v2.4s\n" /* 31, cvt int32 to int16 */ \ + "sqxtn v17.4h, v3.4s\n" /* 32, cvt int32 to int16 */ \ + "sqxtn v18.4h, v4.4s\n" /* 40, cvt int32 to int16 */ \ + "sqxtn2 v18.8h, v5.4s\n" /* 41, cvt int32 to int16 */ \ + "sqxtn v19.4h, v6.4s\n" /* 42, cvt int32 to int16 */ \ + "sqxtn v20.4h, v7.4s\n" /* 50, cvt int32 to int16 */ \ + "sqxtn2 v20.8h, v8.4s\n" /* 51, cvt int32 to int16 */ \ + "sqxtn v21.4h, v9.4s\n" /* 52, cvt int32 to int16 */ \ + "fcvtas v0.4s, v26.4s\n" /* 60, cvt to int */ \ + "fcvtas v1.4s, v27.4s\n" /* 61, cvt to int */ \ + "fcvtas v2.4s, v28.4s\n" /* 62, cvt to int */ \ + "fcvtas v3.4s, v29.4s\n" /* 70, cvt to int */ \ + "fcvtas v4.4s, v30.4s\n" /* 71, cvt to int */ \ + "fcvtas v5.4s, v31.4s\n" /* 72, cvt to int */ \ + "sqxtn v22.4h, v0.4s\n" /* 60, cvt int32 to int16 */ \ + "sqxtn2 v22.8h, v1.4s\n" /* 61, cvt int32 to int16 */ \ + "sqxtn v23.4h, v2.4s\n" /* 62, cvt int32 to int16 */ \ + "sqxtn v24.4h, v3.4s\n" /* 70, cvt int32 to int16 */ \ + "sqxtn2 v24.8h, v4.4s\n" /* 71, cvt int32 to int16 */ \ + "sqxtn v25.4h, v5.4s\n" /* 72, cvt int32 to int16 */ \ + "sqxtn v0.8b, v10.8h\n" /* 00, cvt int16 to int8 */ \ + "sqxtn v1.8b, v12.8h\n" /* 10, cvt int16 to int8 */ \ + "sqxtn v2.8b, v14.8h\n" /* 20, cvt int16 to int8 */ \ + "sqxtn v3.8b, v16.8h\n" /* 30, cvt int16 to int8 */ \ + "sqxtn v4.8b, v18.8h\n" /* 40, cvt int16 to int8 */ \ + "sqxtn v5.8b, v20.8h\n" /* 50, cvt int16 to int8 */ \ + "sqxtn v6.8b, v22.8h\n" /* 60, cvt int16 to int8 */ \ + "sqxtn v7.8b, v24.8h\n" /* 70, cvt int16 to int8 */ \ + "st1 {v0.8b},[%[c_ptr0]], #8\n" /* store r0 */ \ + "sqxtn v8.8b, v11.8h\n" /* 0, cvt int16 to int8 */ \ + "st1 {v1.8b},[%[c_ptr1]], #8\n" /* store r1 */ \ + "sqxtn v9.8b, v13.8h\n" /* 1, cvt int16 to int8 */ \ + "st1 {v2.8b},[%[c_ptr2]], #8\n" /* store r2 */ \ + "sqxtn v10.8b, v15.8h\n" /* 2, cvt int16 to int8 */ \ + "st1 {v3.8b},[%[c_ptr3]], #8\n" /* store r3 */ \ + "sqxtn v11.8b, v17.8h\n" /* 3, cvt int16 to int8 */ \ + "st1 {v4.8b},[%[c_ptr4]], #8\n" /* store r4 */ \ + "sqxtn v12.8b, v19.8h\n" /* 4, cvt int16 to int8 */ \ + "st1 {v5.8b},[%[c_ptr5]], #8\n" /* store r5 */ \ + "sqxtn v13.8b, v21.8h\n" /* 5, cvt int16 to int8 */ \ + "st1 {v6.8b},[%[c_ptr6]], #8\n" /* store r6 */ \ + "sqxtn v14.8b, v23.8h\n" /* 6, cvt int16 to int8 */ \ + "st1 {v7.8b},[%[c_ptr7]], #8\n" /* store r7 */ \ + "sqxtn v15.8b, v25.8h\n" /* 7, cvt int16 to int8 */ \ + "str s8,[%[c_ptr0]], #4\n" /* store r0 */ \ + "str s9,[%[c_ptr1]], #4\n" /* store r1 */ \ + "str s10,[%[c_ptr2]], #4\n" /* store r2 */ \ + "str s11,[%[c_ptr3]], #4\n" /* store r3 */ \ + "str s12,[%[c_ptr4]], #4\n" /* store r4 */ \ + "str s13,[%[c_ptr5]], #4\n" /* store r5 */ \ + "str s14,[%[c_ptr6]], #4\n" /* store r6 */ \ + "str s15,[%[c_ptr7]], #4\n" /* store r7 */ -#define GEMM_SDOT_INT8_OUT \ - "ldp q0, q1, [%[scale]]\n" /* load scale */ \ - "scvtf v2.4s , v8.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v9.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v10.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v11.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v12.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v13.4s\n" /* 00, convert to fp32 */ \ - "fmul v8.4s, v2.4s, v0.s[0]\n" /* 00, mul scale to get final*/ \ - "fmul v9.4s, v3.4s, v0.s[0]\n" /* 00, mul scale to get final*/ \ - "fmul v10.4s, v4.4s, v0.s[0]\n" /* 00, mul scale to get final*/ \ - "fmul v11.4s, v5.4s, v0.s[1]\n" /* 00, mul scale to get final*/ \ - "fmul v12.4s, v6.4s, v0.s[1]\n" /* 00, mul scale to get final*/ \ - "fmul v13.4s, v7.4s, v0.s[1]\n" /* 00, mul scale to get final*/ \ - "scvtf v2.4s , v14.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v15.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v16.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v17.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v18.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v19.4s\n" /* 00, convert to fp32 */ \ - "fmul v14.4s, v2.4s, v0.s[2]\n" /* 00, mul scale to get final*/ \ - "fmul v15.4s, v3.4s, v0.s[2]\n" /* 00, mul scale to get final*/ \ - "fmul v16.4s, v4.4s, v0.s[2]\n" /* 00, mul scale to get final*/ \ - "fmul v17.4s, v5.4s, v0.s[3]\n" /* 00, mul scale to get final*/ \ - "fmul v18.4s, v6.4s, v0.s[3]\n" /* 00, mul scale to get final*/ \ - "fmul v19.4s, v7.4s, v0.s[3]\n" /* 00, mul scale to get final*/ \ - "scvtf v2.4s , v20.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v21.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v22.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v23.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v24.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v25.4s\n" /* 00, convert to fp32 */ \ - "fmul v20.4s, v2.4s, v1.s[0]\n" /* 00, mul scale to get final*/ \ - "fmul v21.4s, v3.4s, v1.s[0]\n" /* 00, mul scale to get final*/ \ - "fmul v22.4s, v4.4s, v1.s[0]\n" /* 00, mul scale to get final*/ \ - "fmul v23.4s, v5.4s, v1.s[1]\n" /* 00, mul scale to get final*/ \ - "fmul v24.4s, v6.4s, v1.s[1]\n" /* 00, mul scale to get final*/ \ - "fmul v25.4s, v7.4s, v1.s[1]\n" /* 00, mul scale to get final*/ \ - "scvtf v2.4s , v26.4s\n" /* 00, convert to fp32 */ \ - "scvtf v3.4s , v27.4s\n" /* 01, convert to fp32 */ \ - "scvtf v4.4s , v28.4s\n" /* 02, convert to fp32 */ \ - "scvtf v5.4s , v29.4s\n" /* 03, convert to fp32 */ \ - "scvtf v6.4s , v30.4s\n" /* 00, convert to fp32 */ \ - "scvtf v7.4s , v31.4s\n" /* 00, convert to fp32 */ \ - "fmul v26.4s, v2.4s, v1.s[2]\n" /* 00, mul scale to get final*/ \ - "fmul v27.4s, v3.4s, v1.s[2]\n" /* 00, mul scale to get final*/ \ - "fmul v28.4s, v4.4s, v1.s[2]\n" /* 00, mul scale to get final*/ \ - "fmul v29.4s, v5.4s, v1.s[3]\n" /* 00, mul scale to get final*/ \ - "fmul v30.4s, v6.4s, v1.s[3]\n" /* 00, mul scale to get final*/ \ - "fmul v31.4s, v7.4s, v1.s[3]\n" /* 00, mul scale to get final*/ \ - "fcvtas v0.4s, v8.4s\n" /* 00, cvt to int */ \ - "fcvtas v1.4s, v9.4s\n" /* 00, cvt to int */ \ - "fcvtas v2.4s, v10.4s\n" /* 00, cvt to int */ \ - "fcvtas v3.4s, v11.4s\n" /* 00, cvt to int */ \ - "fcvtas v4.4s, v12.4s\n" /* 00, cvt to int */ \ - "fcvtas v5.4s, v13.4s\n" /* 00, cvt to int */ \ - "sqxtn v8.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v8.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v9.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v0.4s, v14.4s\n" /* 00, cvt to int */ \ - "fcvtas v1.4s, v15.4s\n" /* 00, cvt to int */ \ - "fcvtas v2.4s, v16.4s\n" /* 00, cvt to int */ \ - "sqxtn v11.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v11.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v12.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v3.4s, v17.4s\n" /* 00, cvt to int */ \ - "fcvtas v4.4s, v18.4s\n" /* 00, cvt to int */ \ - "fcvtas v5.4s, v19.4s\n" /* 00, cvt to int */ \ - "sqxtn v14.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v14.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v15.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v0.4s, v20.4s\n" /* 00, cvt to int */ \ - "fcvtas v1.4s, v21.4s\n" /* 00, cvt to int */ \ - "fcvtas v2.4s, v22.4s\n" /* 00, cvt to int */ \ - "sqxtn v17.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v17.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v18.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v3.4s, v23.4s\n" /* 00, cvt to int */ \ - "fcvtas v4.4s, v24.4s\n" /* 00, cvt to int */ \ - "fcvtas v5.4s, v25.4s\n" /* 00, cvt to int */ \ - "sqxtn v20.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v20.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v21.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v0.4s, v26.4s\n" /* 00, cvt to int */ \ - "fcvtas v1.4s, v27.4s\n" /* 00, cvt to int */ \ - "fcvtas v2.4s, v28.4s\n" /* 00, cvt to int */ \ - "sqxtn v23.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v23.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v24.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ - "fcvtas v3.4s, v29.4s\n" /* 00, cvt to int */ \ - "fcvtas v4.4s, v30.4s\n" /* 00, cvt to int */ \ - "fcvtas v5.4s, v31.4s\n" /* 00, cvt to int */ \ - "sqxtn v26.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v26.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v27.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v29.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn2 v29.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v30.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ - "sqxtn v4.8b, v8.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v0.8b, v9.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v5.8b, v11.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v1.8b, v12.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v6.8b, v14.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v2.8b, v15.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v7.8b, v17.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v3.8b, v18.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v16.8b, v20.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v15.8b, v21.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v20.8b, v23.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v17.8b, v24.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v24.8b, v26.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v18.8b, v27.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v28.8b, v29.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "sqxtn v19.8b, v30.8h\n" /* 00, 01, cvt int16 to int8 */ \ - "st1 {v4.8b},[%[c_ptr0]], #8\n" /* store r0 */ \ - "st1 {v5.8b},[%[c_ptr1]], #8\n" /* store r0 */ \ - "st1 {v6.8b},[%[c_ptr2]], #8\n" /* store r0 */ \ - "st1 {v7.8b},[%[c_ptr3]], #8\n" /* store r0 */ \ - "st1 {v16.8b},[%[c_ptr4]], #8\n" /* store r0 */ \ - "st1 {v20.8b},[%[c_ptr5]], #8\n" /* store r0 */ \ - "st1 {v24.8b},[%[c_ptr6]], #8\n" /* store r0 */ \ - "st1 {v28.8b},[%[c_ptr7]], #8\n" /* store r0 */ \ - "str s0,[%[c_ptr0]], #4\n" /* store r0 */ \ - "str s1,[%[c_ptr1]], #4\n" /* store r0 */ \ - "str s2,[%[c_ptr2]], #4\n" /* store r0 */ \ - "str s3,[%[c_ptr3]], #4\n" /* store r0 */ \ - "str s15,[%[c_ptr4]], #4\n" /* store r0 */ \ - "str s17,[%[c_ptr5]], #4\n" /* store r0 */ \ - "str s18,[%[c_ptr6]], #4\n" /* store r0 */ \ - "str s19,[%[c_ptr7]], #4\n" /* store r0 */ +// clang-format on template <> -inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, - const int8_t*& b_ptr, // NOLINT - const int32_t* bias, - int32_t*& c_ptr0, // NOLINT - int32_t*& c_ptr1, // NOLINT - int32_t*& c_ptr2, // NOLINT - int32_t*& c_ptr3, // NOLINT - int32_t*& c_ptr4, // NOLINT - int32_t*& c_ptr5, // NOLINT - int32_t*& c_ptr6, // NOLINT - int32_t*& c_ptr7, // NOLINT - const float32_t* scale, - bool is_relu, - int k, - int tail) { - asm volatile(_DECLARE_SDOT_ELEMENT GEMM_SDOT_INT8_KERNEL GEMM_SDOT_INT32_OUT +inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const float* bias, + float32_t*& c_ptr0, // NOLINT + float32_t*& c_ptr1, // NOLINT + float32_t*& c_ptr2, // NOLINT + float32_t*& c_ptr3, // NOLINT + float32_t*& c_ptr4, // NOLINT + float32_t*& c_ptr5, // NOLINT + float32_t*& c_ptr6, // NOLINT + float32_t*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int tail) { + // clang-format off + asm volatile(GEMM_SDOT_INT8_KERNEL + GEMM_SDOT_FP32_OUT : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), @@ -1415,122 +1294,30 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, [c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7) : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); + : "cc","memory","v0","v1","v2", + "v3","v4","v5","v6","v7","v8","v9","v10", + "v11","v12","v13","v14","v15","v16","v17", + "v18","v19","v20","v21","v22","v23","v24", + "v25","v26","v27","v28","v29","v30","v31"); + // clang-format on } template <> -inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, - const int8_t*& b_ptr, // NOLINT - const int32_t* bias, - float32_t*& c_ptr0, // NOLINT - float32_t*& c_ptr1, // NOLINT - float32_t*& c_ptr2, // NOLINT - float32_t*& c_ptr3, // NOLINT - float32_t*& c_ptr4, // NOLINT - float32_t*& c_ptr5, // NOLINT - float32_t*& c_ptr6, // NOLINT - float32_t*& c_ptr7, // NOLINT - const float32_t* scale, - bool is_relu, - int k, - int tail) { - asm volatile(GEMM_SDOT_INT8_KERNEL GEMM_SDOT_FP32_OUT - : [a_ptr] "+r"(a_ptr), - [b_ptr] "+r"(b_ptr), - [k] "+r"(k), - [tail] "+r"(tail), - [c_ptr0] "+r"(c_ptr0), - [c_ptr1] "+r"(c_ptr1), - [c_ptr2] "+r"(c_ptr2), - [c_ptr3] "+r"(c_ptr3), - [c_ptr4] "+r"(c_ptr4), - [c_ptr5] "+r"(c_ptr5), - [c_ptr6] "+r"(c_ptr6), - [c_ptr7] "+r"(c_ptr7) - : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); -} -template <> -inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, - const int8_t*& b_ptr, // NOLINT - const int32_t* bias, - int8_t*& c_ptr0, // NOLINT - int8_t*& c_ptr1, // NOLINT - int8_t*& c_ptr2, // NOLINT - int8_t*& c_ptr3, // NOLINT - int8_t*& c_ptr4, // NOLINT - int8_t*& c_ptr5, // NOLINT - int8_t*& c_ptr6, // NOLINT - int8_t*& c_ptr7, // NOLINT - const float32_t* scale, - bool is_relu, - int k, - int tail) { +inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const float* bias, + int8_t*& c_ptr0, // NOLINT + int8_t*& c_ptr1, // NOLINT + int8_t*& c_ptr2, // NOLINT + int8_t*& c_ptr3, // NOLINT + int8_t*& c_ptr4, // NOLINT + int8_t*& c_ptr5, // NOLINT + int8_t*& c_ptr6, // NOLINT + int8_t*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int tail) { + // clang-format off asm volatile(GEMM_SDOT_INT8_KERNEL GEMM_SDOT_INT8_OUT : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), @@ -1545,352 +1332,409 @@ inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, [c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7) : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); + : "cc","memory","v0","v1","v2","v3", + "v4","v5","v6","v7","v8","v9","v10", + "v11","v12","v13","v14","v15","v16","v17", + "v18","v19","v20","v21","v22","v23","v24", + "v25","v26","v27","v28","v29","v30","v31"); + // clang-format on } #endif -#else // armv7 -// clang-format off -#define GEMM_INT8_KERNEL \ - "vld1.8 {d0-d1}, [%[a_ptr]: 128]!\n" /* load 4x2x2 int8, A, k2x2 */ \ - "vld1.8 {d4-d7}, [%[b_ptr]: 128]!\n" /* load 8x2x2 int8, B, k2x2 */ \ - "vld1.8 {d8-d9}, [%[bias]]\n" /* load int32x4 bias */ \ - "vext.8 q5, q4, q4, #4\n" /* bias shift 1 int32 */ \ - "vext.8 q6, q4, q4, #8\n" /* bias shift 2 int32 */ \ - "vext.8 q7, q4, q4, #12\n" /* bias shift 3 int32 */ \ - "pld [%[a_ptr]]\n" /* preload A */ \ - "vand q8, q4, q4\n" /* set bias to out00 */ \ - "vand q9, q4, q4\n" /* set bias to out01 */ \ - "pld [%[b_ptr]]\n" /* preload B */ \ - "vand q10, q5, q5\n" /* set bias to out10 */ \ - "vand q11, q5, q5\n" /* set bias to out11 */ \ - "pld [%[b_ptr], #64]\n" /* preload B */ \ - "vand q12, q6, q6\n" /* set bias to out20 */ \ - "vand q13, q6, q6\n" /* set bias to out21 */ \ - "pld [%[b_ptr], #128]\n" /* preload B */ \ - "vand q14, q7, q7\n" /* set bias to out30 */ \ - "vand q15, q7, q7\n" /* set bias to out31 */ \ - "pld [%[a_ptr], #64]\n" /* preload A */ \ - "vext.8 d2, d0, d0, #2\n" /* shift left circular by 2byte */ \ - "vext.8 d3, d1, d1, #2\n" /* shift left circular by 2byte */ \ - "pld [%[b_ptr], #192]\n" /* preload b */ \ - "pld [%[b_ptr], #256]\n" /* preload b */ \ - "pld [%[a_ptr], #128]\n" /* preload a */ \ - "cmp %[k], #0\n" /* check main loop count */ \ - "beq 3f\n" /* if k = 0, jump to remains */ /* 1st r0, r1 */ \ - "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ - "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ - "vmull.s8 q6, d2, d4\n" /* a1 * b0 = c10 */ \ - "vmull.s8 q7, d2, d5\n" /* a1 * b1 = c11 */ \ - "subs %[k], %[k], #1\n" /* loop count -1 */ /* 2nd r0, r1 */ \ - "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ - "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ - "vrev64.32 q0, q0\n" /* shift left circular by 4byte */ \ - "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ - "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ - "vrev64.32 q1, q1\n" /* shift left circular by 4byte */ \ - "beq 8f\n" /* skip main loop */ /* main loop*/ \ - "0:\n" /* main loop */ /* 1st r2, r3 */ \ - "vpadal.s16 q8, q4\n" /* pair add and accumulate to int32, c00 */ \ - "vmull.s8 q4, d0, d4\n" /* a2 * b0 = c20 */ \ - "vpadal.s16 q9, q5\n" /* pair add and accumulate to int32, c01 */ \ - "vmull.s8 q5, d0, d5\n" /* a2 * b1 = c21 */ \ - "vpadal.s16 q10,q6\n" /* pair add and accumulate to int32, c10 */ \ - "vmull.s8 q6, d2, d4\n" /* a3 * b0 = c30 */ \ - "vpadal.s16 q11,q7\n" /* pair add and accumulate to int32, c11 */ \ - "vmull.s8 q7, d2, d5\n" /* a3 * b1 = c31 */ \ - "vld1.8 {d4-d5}, [%[b_ptr]: 128]!\n" /* load 4x2x2 int8, B, k2x2 */ \ - "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ - "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ - "vld1.8 {d0-d1}, [%[a_ptr]: 128]!\n" /* load 4x2x2 int8, A, k2x2 */ \ - "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ - "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ - "vld1.8 {d6-d7}, [%[b_ptr]: 128]!\n" /* load 4x2x2 int8, B, k2x2 */ \ - "vext.8 d2, d0, d0, #2\n" /* shift left circular by 2byte */ \ - "vext.8 d3, d1, d1, #2\n" /* shift left circular by 2byte */ \ - "vpadal.s16 q12,q4\n" /* pair add and accumulate to int32, c20 */ \ - "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ - "vpadal.s16 q13,q5\n" /* pair add and accumulate to int32, c21 */ \ - "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ - "vpadal.s16 q14,q6\n" /* pair add and accumulate to int32, c30 */ \ - "vmull.s8 q6, d2, d4\n" /* a1 * b0 = c10 */ \ - "vpadal.s16 q15,q7\n" /* pair add and accumulate to int32, c31 */ \ - "vmull.s8 q7, d2, d5\n" /* a1 * b1 = c11 */ \ - "subs %[k], %[k], #1\n" /* loop count -1 */ /* 2nd r0, r1 */ \ - "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ - "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ - "vrev64.32 q0, q0\n" /* shift left circular by 2 */ \ - "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ - "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ - "vrev64.32 q1, q1\n" /* shift left circular by 2 */ \ - "bgt 0b\n" /* jump to main loop */ \ - "8:\n" /* end of main loop */ /* 1st r2, r3 */ \ - "vpadal.s16 q8, q4\n" /* pair add and accumulate to int32, c00 */ \ - "vmull.s8 q4, d0, d4\n" /* a2 * b0 = c20 */ \ - "vpadal.s16 q9, q5\n" /* pair add and accumulate to int32, c01 */ \ - "vmull.s8 q5, d0, d5\n" /* a2 * b1 = c21 */ \ - "vpadal.s16 q10,q6\n" /* pair add and accumulate to int32, c10 */ \ - "vmull.s8 q6, d2, d4\n" /* a3 * b0 = c30 */ \ - "vpadal.s16 q11,q7\n" /* pair add and accumulate to int32, c11 */ \ - "vmull.s8 q7, d2, d5\n" /* a3 * b1 = c31 */ /* 2nd r2, r3 */ \ - "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c20 */ \ - "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c21 */ \ - "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c30 */ \ - "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c31 */ \ - "cmp %[rem], #0\n" /* skip remain */ \ - "beq 5f\n" \ - "mov r0, #32\n" /* address offset */ \ - "vld1.8 {d0}, [%[a_ptr]]\n" /* load a to d0, final */ \ - "vld1.8 {d4-d5}, [%[b_ptr]], r0\n" /* load b to d4, d5 */ \ - "5:\n" /* skip rem */ \ - "vpadal.s16 q12, q4\n" /* pair add and accumulate to int32, c20 */ \ - "vpadal.s16 q13, q5\n" /* pair add and accumulate to int32, c21 */ \ - "vpadal.s16 q14, q6\n" /* pair add and accumulate to int32, c30 */ \ - "vpadal.s16 q15, q7\n" /* pair add and accumulate to int32, c31 */ \ - "3:\n" /* process remain k */ \ - "cmp %[rem], #0\n" /* skip remain */ \ - "beq 7f\n" /* process remain k */ \ - "vext.8 d1, d0, d0, #2\n" /* shift left 2bytes */ \ - "vext.8 d2, d0, d0, #4\n" /* shift left 4bytes */ \ - "vext.8 d3, d0, d0, #6\n" /* shift left 6bytes */ /* 1st r0, r1 */ \ - "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ - "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ - "vmull.s8 q6, d1, d4\n" /* a1 * b0 = c10 */ \ - "vmull.s8 q7, d1, d5\n" /* a1 * b1 = c11 */ /* 1st r2, r3 */ \ - "vpadal.s16 q8, q4\n" /* pair add and accumulate to int32, c00 */ \ - "vmull.s8 q4, d2, d4\n" /* a2 * b0 = c20 */ \ - "vpadal.s16 q9, q5\n" /* pair add and accumulate to int32, c01 */ \ - "vmull.s8 q5, d2, d5\n" /* a2 * b1 = c21 */ \ - "vpadal.s16 q10,q6\n" /* pair add and accumulate to int32, c10 */ \ - "vmull.s8 q6, d3, d4\n" /* a3 * b0 = c30 */ \ - "vpadal.s16 q11,q7\n" /* pair add and accumulate to int32, c11 */ \ - "vmull.s8 q7, d3, d5\n" /* a3 * b1 = c31 */ \ - "vpadal.s16 q12, q4\n" /* pair add and accumulate to int32, c20 */ \ - "vpadal.s16 q13, q5\n" /* pair add and accumulate to int32, c21 */ \ - "vpadal.s16 q14, q6\n" /* pair add and accumulate to int32, c30 */ \ - "vpadal.s16 q15, q7\n" /* pair add and accumulate to int32, c31 */ \ - "7: \n" /* do relu */ /* do relu */ \ - "cmp %[is_relu], #0\n" /* skip relu */ \ - "beq 9f\n" /* skip relu */ \ - "vmov.i32 q0, #0\n" /* for relu */ \ - "vmax.s32 q8, q8, q0\n" /* relu */ \ - "vmax.s32 q9, q9, q0\n" /* relu */ \ - "vmax.s32 q10,q10, q0\n" /* relu */ \ - "vmax.s32 q11,q11, q0\n" /* relu */ \ - "vmax.s32 q12,q12, q0\n" /* relu */ \ - "vmax.s32 q13,q13, q0\n" /* relu */ \ - "vmax.s32 q14,q14, q0\n" /* relu */ \ - "vmax.s32 q15,q15, q0\n" /* relu */ /* unpack the result */ \ - "9:\n" /* unpack */ /* trans 1 */ \ - "vtrn.32 q8, q10\n" /* get q8 */ \ - "vtrn.32 q12, q14\n" /* get q12 */ \ - "vtrn.32 q9, q11\n" /* get q9 */ \ - "vtrn.32 q13, q15\n" /* get q13*/ \ - "vswp d17, d24\n" /* get q8*/ \ - "vswp d21, d28\n" /* get q10 */ \ - "vswp d19, d26\n" /* get q9 */ \ - "vswp d23, d30\n" /* get q11 */ \ - "vext.8 q0, q10, q10, #12\n" /* circular shift left 1 q0 */ \ - "vext.8 q2, q12, q12, #8\n" /* circular shift left 2 q2 */ \ - "vext.8 q4, q14, q14, #4\n" /* circular shift left 3 q4 */ \ - "vext.8 q1, q11, q11, #12\n" /* circular shift left 1 q1 */ \ - "vext.8 q3, q13, q13, #8\n" /* circular shift left 2 q3 */ \ - "vext.8 q5, q15, q15, #4\n" /* circular shift left 3 q5 */ \ - "vtrn.32 q8, q0\n" /* get q8 */ \ - "vtrn.32 q2, q4\n" /* get q2 */ \ - "vtrn.32 q9, q1\n" /* get q9 */ \ - "vtrn.32 q3, q5\n" /* get q3 */ /* trans 2 */ \ - "vswp d17, d4\n" /* get q8 */ \ - "vswp d1, d8\n" /* get q0: a1*/ \ - "vswp d19, d6\n" /* get q9: */ \ - "vswp d3, d10\n" /* get q1: a3b3 */ - +#else // armv7 // clang-format off - -#define GEMM_INT8_INT32_OUT \ - /* write output */ \ - "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write outr0 */ \ - "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write outr1 */ \ - "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write outr2 */ \ - "vst1.32 {d8-d11}, [%[c_ptr3]]!\n" /* write outr3 */ - -#define GEMM_INT8_FP32_OUT \ - /* write output */ \ - "vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \ - "vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \ - "vmul.f32 q8, q10, d12[0]\n" /* r00, mul scale to get final result */ \ - "vmul.f32 q9, q11, d12[0]\n" /* r01, mul scale to get final result */ \ - "vmul.f32 q0, q12, d12[1]\n" /* r10, mul scale to get final result */ \ - "vmul.f32 q1, q13, d12[1]\n" /* r11, mul scale to get final result */ \ - "vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \ - "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \ - "vmul.f32 q2, q10, d13[0]\n" /* r20, mul scale to get final result */ \ - "vmul.f32 q3, q11, d13[0]\n" /* r21, mul scale to get final result */ \ - "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \ - "vmul.f32 q4, q12, d13[1]\n" /* r30, mul scale to get final result */ \ - "vmul.f32 q5, q13, d13[1]\n" /* r31, mul scale to get final result */ \ - "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \ +#define GEMM_INT8_KERNEL \ + "vld1.8 {d0-d1}, [%[a_ptr]: 128]!\n" /* load 4x2x2 int8, A, k2x2 */ \ + "vld1.8 {d4-d7}, [%[b_ptr]: 128]!\n" /* load 8x2x2 int8, B, k2x2 */ \ + "pld [%[a_ptr]]\n" /* preload A */ \ + "veor q8, q4, q4\n" /* set bias to out00 */ \ + "veor q9, q4, q4\n" /* set bias to out01 */ \ + "pld [%[b_ptr]]\n" /* preload B */ \ + "veor q10, q5, q5\n" /* set bias to out10 */ \ + "veor q11, q5, q5\n" /* set bias to out11 */ \ + "pld [%[b_ptr], #64]\n" /* preload B */ \ + "veor q12, q6, q6\n" /* set bias to out20 */ \ + "veor q13, q6, q6\n" /* set bias to out21 */ \ + "pld [%[b_ptr], #128]\n" /* preload B */ \ + "veor q14, q7, q7\n" /* set bias to out30 */ \ + "veor q15, q7, q7\n" /* set bias to out31 */ \ + "pld [%[a_ptr], #64]\n" /* preload A */ \ + "vext.8 d2, d0, d0, #2\n" /* shift left circular by 2byte */ \ + "vext.8 d3, d1, d1, #2\n" /* shift left circular by 2byte */ \ + "pld [%[b_ptr], #192]\n" /* preload b */ \ + "pld [%[b_ptr], #256]\n" /* preload b */ \ + "pld [%[a_ptr], #128]\n" /* preload a */ \ + "cmp %[k], #0\n" /* check main loop count */ \ + "beq 3f\n" /* if k = 0, jump to remains */ \ + /* 1st r0, r1 */ \ + "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ + "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ + "vmull.s8 q6, d2, d4\n" /* a1 * b0 = c10 */ \ + "vmull.s8 q7, d2, d5\n" /* a1 * b1 = c11 */ \ + "subs %[k], %[k], #1\n" /* loop count -1 */ \ + /* 2nd r0, r1 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ + "vrev64.32 q0, q0\n" /* shift left circular by 4byte */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ + "vrev64.32 q1, q1\n" /* shift left circular by 4byte */ \ + "beq 8f\n" /* skip main loop */ \ + /* main loop*/ \ + "0:\n" /* main loop */ \ + /* 1st r2, r3 */ \ + "vpadal.s16 q8, q4\n" /* pair add and accumulate, c00 */ \ + "vmull.s8 q4, d0, d4\n" /* a2 * b0 = c20 */ \ + "vpadal.s16 q9, q5\n" /* pair add and accumulate, c01 */ \ + "vmull.s8 q5, d0, d5\n" /* a2 * b1 = c21 */ \ + "vpadal.s16 q10,q6\n" /* pair add and accumulate, c10 */ \ + "vmull.s8 q6, d2, d4\n" /* a3 * b0 = c30 */ \ + "vpadal.s16 q11,q7\n" /* pair add and accumulate, c11 */ \ + "vmull.s8 q7, d2, d5\n" /* a3 * b1 = c31 */ \ + "vld1.8 {d4-d5}, [%[b_ptr]: 128]!\n" /* load 4x2x2 int8, B, k2x2 */ \ + /* 2nd r2, r3 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ + "vld1.8 {d0-d1}, [%[a_ptr]: 128]!\n" /* load 4x2x2 int8, A, k2x2 */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ + "vld1.8 {d6-d7}, [%[b_ptr]: 128]!\n" /* load 4x2x2 int8, B, k2x2 */ \ + /* pre process A */ \ + "vext.8 d2, d0, d0, #2\n" /* shift left circular by 2byte */ \ + "vext.8 d3, d1, d1, #2\n" /* shift left circular by 2byte */ \ + /* 1st r0, r1 */ \ + "vpadal.s16 q12,q4\n" /* pair add and accumulate, c20 */ \ + "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ + "vpadal.s16 q13,q5\n" /* pair add and accumulate, c21 */ \ + "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ + "vpadal.s16 q14,q6\n" /* pair add and accumulate, c30 */ \ + "vmull.s8 q6, d2, d4\n" /* a1 * b0 = c10 */ \ + "vpadal.s16 q15,q7\n" /* pair add and accumulate, c31 */ \ + "vmull.s8 q7, d2, d5\n" /* a1 * b1 = c11 */ \ + "subs %[k], %[k], #1\n" /* loop count -1 */ \ + /* 2nd r0, r1 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ + "vrev64.32 q0, q0\n" /* shift left circular by 2 */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ + "vrev64.32 q1, q1\n" /* shift left circular by 2 */ \ + "bgt 0b\n" /* jump to main loop */ \ + "8:\n" /* end of main loop */ \ + /* 1st r2, r3 */ \ + "vpadal.s16 q8, q4\n" /* pair add and accumulate, c00 */ \ + "vmull.s8 q4, d0, d4\n" /* a2 * b0 = c20 */ \ + "vpadal.s16 q9, q5\n" /* pair add and accumulate, c01 */ \ + "vmull.s8 q5, d0, d5\n" /* a2 * b1 = c21 */ \ + "vpadal.s16 q10,q6\n" /* pair add and accumulate, c10 */ \ + "vmull.s8 q6, d2, d4\n" /* a3 * b0 = c30 */ \ + "vpadal.s16 q11,q7\n" /* pair add and accumulate, c11 */ \ + "vmull.s8 q7, d2, d5\n" /* a3 * b1 = c31 */ \ + /* 2nd r2, r3 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c20 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c21 */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c30 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c31 */ \ + "cmp %[rem], #0\n" /* skip remain */ \ + "beq 5f\n" \ + "mov r0, #32\n" /* address offset */ \ + "vld1.8 {d0}, [%[a_ptr]]\n" /* load a to d0, final */ \ + "vld1.8 {d4-d5}, [%[b_ptr]], r0\n" /* load b to d4, d5 */ \ + "5:\n" /* skip rem */ \ + "vpadal.s16 q12, q4\n" /* pair add and accumulate, c20 */ \ + "vpadal.s16 q13, q5\n" /* pair add and accumulate, c21 */ \ + "vpadal.s16 q14, q6\n" /* pair add and accumulate, c30 */ \ + "vpadal.s16 q15, q7\n" /* pair add and accumulate, c31 */ \ + /* process remain k */ \ + "3:\n" /* process remain k */ \ + "cmp %[rem], #0\n" /* skip remain */ \ + "beq 7f\n" \ + /* process remain k */ \ + "vext.8 d1, d0, d0, #2\n" /* shift left 2bytes */ \ + "vext.8 d2, d0, d0, #4\n" /* shift left 4bytes */ \ + "vext.8 d3, d0, d0, #6\n" /* shift left 6bytes */ \ + /* 1st r0, r1 */ \ + "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ + "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ + "vmull.s8 q6, d1, d4\n" /* a1 * b0 = c10 */ \ + "vmull.s8 q7, d1, d5\n" /* a1 * b1 = c11 */ \ + /* 1st r2, r3 */ \ + "vpadal.s16 q8, q4\n" /* pair add and accumulate, c00 */ \ + "vmull.s8 q4, d2, d4\n" /* a2 * b0 = c20 */ \ + "vpadal.s16 q9, q5\n" /* pair add and accumulate, c01 */ \ + "vmull.s8 q5, d2, d5\n" /* a2 * b1 = c21 */ \ + "vpadal.s16 q10,q6\n" /* pair add and accumulate, c10 */ \ + "vmull.s8 q6, d3, d4\n" /* a3 * b0 = c30 */ \ + "vpadal.s16 q11,q7\n" /* pair add and accumulate, c11 */ \ + "vmull.s8 q7, d3, d5\n" /* a3 * b1 = c31 */ \ + "vpadal.s16 q12, q4\n" /* pair add and accumulate, c20 */ \ + "vpadal.s16 q13, q5\n" /* pair add and accumulate, c21 */ \ + "vpadal.s16 q14, q6\n" /* pair add and accumulate, c30 */ \ + "vpadal.s16 q15, q7\n" /* pair add and accumulate, c31 */ \ + "7: \n" /* do relu */ \ + /* unpack the result */ \ + /* trans 1 */ \ + "vtrn.32 q8, q10\n" /* get q8: a0b0, a1b0, a2b2, a3b2;*/ \ + /* q10: a1b1, a2b1, a3b3, a0b3 */ \ + "vtrn.32 q12, q14\n" /* get q12: a2b0, a3b0, a0b2, a1b2;*/ \ + /* q14: a3b1, a0b1, a1b3, a2b3 */ \ + "vtrn.32 q9, q11\n" /* get q9: a0b0, a1b0, a2b2, a3b2;*/ \ + /* q11: a1b1, a2b1, a3b3, a0b3 */ \ + "vtrn.32 q13, q15\n" /* get q13: a2b0, a3b0, a0b2, a1b2;*/ \ + /* q15: a3b1, a0b1, a1b3, a2b3 */ \ + /* trans 2 */ \ + "vswp d17, d24\n" /* get q8: a0b0, a1b0, a2b0, a3b0;*/ \ + /* q12: a2b2, a3b2, a0b2, a1b2 */ \ + "vswp d21, d28\n" /* get q10: a1b1, a2b1, a3b1, a0b1;*/ \ + /* q14: a3b3, a0b3, a1b3, a2b3 */ \ + "vswp d19, d26\n" /* get q9: a0b0, a1b0, a2b0, a3b0;*/ \ + /* q13: a2b2, a3b2, a0b2, a1b2 */ \ + "vswp d23, d30\n" /* get q11: a1b1, a2b1, a3b1, a0b1;*/ \ + /* q15: a3b3, a0b3, a1b3, a2b3 */ \ + /* shift */ \ + "vext.8 q0, q10, q10, #12\n" /* circular shift left 1 */ \ + /* q0: a0b1, a1b1, a2b1, a3b1 */ \ + "vext.8 q2, q12, q12, #8\n" /* circular shift left 2 */ \ + /* q2: a0b2, a1b2, a2b2, a3b2 */ \ + "vext.8 q4, q14, q14, #4\n" /* circular shift left 3 */ \ + /* q4: a0b3, a1b3, a2b3, a3b3 */ \ + "vext.8 q1, q11, q11, #12\n" /* circular shift left 1 */ \ + /* q1: a0b1, a1b1, a2b1, a3b1 */ \ + "vext.8 q3, q13, q13, #8\n" /* circular shift left 2 */ \ + /* q3: a0b2, a1b2, a2b2, a3b2 */ \ + "vext.8 q5, q15, q15, #4\n" /* circular shift left 3 */ \ + /* q5: a0b3, a1b3, a2b3, a3b3 */ \ + /* trans 1 */ \ + "vtrn.32 q8, q0\n" /* get q8: a0b0, a0b1, a2b0, a2b1; */ \ + /* q0: a1b0, a1b1, a3b0, a3b1 */ \ + "vtrn.32 q2, q4\n" /* get q2: a0b2, a0b3, a2b2, a2b3; */ \ + /* q4: a1b2, a1b3, a3b2, a3b3 */ \ + "vtrn.32 q9, q1\n" /* get q9: a0b0, a0b1, a2b0, a2b1; */ \ + /* q1: a1b0, a1b1, a3b0, a3b1 */ \ + "vtrn.32 q3, q5\n" /* get q3: a0b2, a0b3, a2b2, a2b3; */ \ + /* q5: a1b2, a1b3, a3b2, a3b3 */ \ + /* trans 2 */ \ + "vswp d17, d4\n" /* get q8: a0b0, a0b1, a0b2, a0b3; */ \ + /* q2: a2b0, a2b1, a2b2, a2b3 */ \ + "vswp d1, d8\n" /* get q0: a1b0, a1b1, a1b2, a1b3; */ \ + /* q4: a3b0, a3b1, a3b2, a3b3 */ \ + "vswp d19, d6\n" /* get q9: a0b0, a0b1, a0b2, a0b3; */ \ + /* q3: a2b0, a2b1, a2b2, a2b3 */ \ + "vswp d3, d10\n" /* get q1: a1b0, a1b1, a1b2, a1b3; */ \ + /* q5: a3b0, a3b1, a3b2, a3b3 */ + +#define GEMM_INT8_TRANS_INT32_TO_FP32 \ + /* write output */ \ + "vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \ + "vld1.32 {d14-d15}, [%[bias]]\n" /* load bias */ \ + "vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \ + "vdup.32 q8, d14[0]\n" \ + "vdup.32 q9, d14[0]\n" \ + "vdup.32 q0, d14[1]\n" \ + "vdup.32 q1, d14[1]\n" \ + "vmla.f32 q8, q10, d12[0]\n" /* r00, mul scale */ \ + "vmla.f32 q9, q11, d12[0]\n" /* r01, mul scale */ \ + "vmla.f32 q0, q12, d12[1]\n" /* r10, mul scale */ \ + "vmla.f32 q1, q13, d12[1]\n" /* r11, mul scale */ \ + "vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \ + "vdup.32 q2, d15[0]\n" \ + "vdup.32 q3, d15[0]\n" \ + "vdup.32 q4, d15[1]\n" \ + "vdup.32 q5, d15[1]\n" \ + "vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale */ \ + "vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale */ \ + "vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale */ \ + "vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */ + + +#define GEMM_INT8_RELU \ + /* do relu */ \ + "cmp %[is_relu], #0\n" /* skip relu */ \ + "beq 9f\n" /* skip relu */ \ + "vmov.i32 q15, #0\n" /* for relu */ \ + "vmax.f32 q8, q8, q15\n" /* relu */ \ + "vmax.f32 q9, q9, q15\n" /* relu */ \ + "vmax.f32 q0,q0, q15\n" /* relu */ \ + "vmax.f32 q1,q1, q15\n" /* relu */ \ + "vmax.f32 q2,q2, q15\n" /* relu */ \ + "vmax.f32 q3,q3, q15\n" /* relu */ \ + "vmax.f32 q4,q4, q15\n" /* relu */ \ + "vmax.f32 q5,q5, q15\n" /* relu */ \ + "9:\n" + + +#define GEMM_INT8_FP32_OUT \ + GEMM_INT8_TRANS_INT32_TO_FP32 \ + GEMM_INT8_RELU \ + "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \ + "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \ + "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \ "vst1.32 {d8-d11}, [%[c_ptr3]]!\n" /* write r3, float32x4 x2 */ -#define GEMM_INT8_INT8_OUT \ - /* write output */ \ - "vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \ - "vmov.f32 q7, #-0.5\n" /* neg offset */ \ - "vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \ - "vmov.f32 q8, #0.5\n" /* pos offset */ \ - "vmov.f32 q9, #0.5\n" /* pos offset */ \ - "vmov.f32 q0, #0.5\n" /* pos offset */ \ - "vmov.f32 q1, #0.5\n" /* pos offset */ \ - "vcgt.f32 q14, q10, #0\n" /* get pos mask */ \ - "vcgt.f32 q15, q11, #0\n" /* get pos mask */ \ - "vbif.f32 q8, q7, q14\n" /* get right offset */ \ - "vbif.f32 q9, q7, q15\n" /* get right offset */ \ - "vcgt.f32 q14, q12, #0\n" /* get pos mask */ \ - "vcgt.f32 q15, q13, #0\n" /* get pos mask */ \ - "vbif.f32 q0, q7, q14\n" /* get right offset */ \ - "vbif.f32 q1, q7, q15\n" /* get right offset */ \ - "vmla.f32 q8, q10, d12[0]\n" /* r00, mul scale to get final result */ \ - "vmla.f32 q9, q11, d12[0]\n" /* r01, mul scale to get final result */ \ - "vmla.f32 q0, q12, d12[1]\n" /* r10, mul scale to get final result */ \ - "vmla.f32 q1, q13, d12[1]\n" /* r11, mul scale to get final result */ \ - "vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \ - "vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \ - "vmov.f32 q2, #0.5\n" /* pos offset */ \ - "vmov.f32 q3, #0.5\n" /* pos offset */ \ - "vmov.f32 q4, #0.5\n" /* pos offset */ \ - "vmov.f32 q5, #0.5\n" /* pos offset */ \ - "vcgt.f32 q14, q10, #0\n" /* get pos mask */ \ - "vcgt.f32 q15, q11, #0\n" /* get pos mask */ \ - "vbif.f32 q2, q7, q14\n" /* get right offset */ \ - "vbif.f32 q3, q7, q15\n" /* get right offset */ \ - "vcgt.f32 q14, q12, #0\n" /* get pos mask */ \ - "vcgt.f32 q15, q13, #0\n" /* get pos mask */ \ - "vbif.f32 q4, q7, q14\n" /* get right offset */ \ - "vbif.f32 q5, q7, q15\n" /* get right offset */ \ - "vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale to get final result */ \ - "vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale to get final result */ \ - "vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale to get final result */ \ - "vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale to get final result */ \ - "vcvt.s32.f32 q6, q8\n" /* r00, fp32->int32 */ \ - "vcvt.s32.f32 q7, q9\n" /* r01, fp32->int32 */ \ - "vcvt.s32.f32 q10, q0\n" /* r10, fp32->int32 */ \ - "vcvt.s32.f32 q11, q1\n" /* r11, fp32->int32 */ \ - "vcvt.s32.f32 q12, q2\n" /* r20, fp32->int32 */ \ - "vcvt.s32.f32 q13, q3\n" /* r21, fp32->int32 */ \ - "vcvt.s32.f32 q14, q4\n" /* r30, fp32->int32 */ \ - "vcvt.s32.f32 q15, q5\n" /* r31, fp32->int32 */ \ - "vqmovn.s32 d0, q6\n" /* r00, int32 -> int16 */ \ - "vqmovn.s32 d1, q7\n" /* r01, int32 -> int16 */ \ - "vqmovn.s32 d2, q10\n" /* r10, int32 -> int16 */ \ - "vqmovn.s32 d3, q11\n" /* r11, int32 -> int16 */ \ - "vqmovn.s32 d4, q12\n" /* r00, int32 -> int16 */ \ - "vqmovn.s32 d5, q13\n" /* r01, int32 -> int16 */ \ - "vqmovn.s32 d6, q14\n" /* r10, int32 -> int16 */ \ - "vqmovn.s32 d7, q15\n" /* r11, int32 -> int16 */ \ - "vqmovn.s16 d8, q0\n" /* 0, int16 -> int8 */ \ - "vqmovn.s16 d9, q1\n" /* 1, int16 -> int8 */ \ - "vqmovn.s16 d10, q2\n" /* 2, int16 -> int8 */ \ - "vqmovn.s16 d11, q3\n" /* 3, int16 -> int8 */ \ - "vst1.32 {d8}, [%[c_ptr0]]!\n" /* write r0*/ \ - "vst1.32 {d9}, [%[c_ptr1]]!\n" /* write r1*/ \ - "vst1.32 {d10}, [%[c_ptr2]]!\n" /* write r2*/ \ + +#define GEMM_INT8_INT8_OUT \ + GEMM_INT8_TRANS_INT32_TO_FP32 \ + GEMM_INT8_RELU \ + "vmov.f32 q7, #-0.5\n" /* neg offset */ \ + "vmov.f32 q10, #0.5\n" /* pos offset */ \ + "vmov.f32 q11, #0.5\n" /* pos offset */ \ + "vmov.f32 q12, #0.5\n" /* pos offset */ \ + "vmov.f32 q13, #0.5\n" /* pos offset */ \ + "vcgt.f32 q14, q8, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q9, #0\n" /* get pos mask */ \ + "vbif.f32 q10, q7, q14\n" /* get right offset */ \ + "vbif.f32 q11, q7, q15\n" /* get right offset */ \ + "vcgt.f32 q14, q0, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q1, #0\n" /* get pos mask */ \ + "vbif.f32 q12, q7, q14\n" /* get right offset */ \ + "vbif.f32 q13, q7, q15\n" /* get right offset */ \ + "vadd.f32 q8, q10, q8\n" /* r00, add offset */ \ + "vadd.f32 q9, q11, q9\n" /* r01, add offset */ \ + "vadd.f32 q0, q12, q0\n" /* r10, add offset */ \ + "vadd.f32 q1, q13, q1\n" /* r11, add offset */ \ + "vmov.f32 q10, #0.5\n" /* pos offset */ \ + "vmov.f32 q11, #0.5\n" /* pos offset */ \ + "vmov.f32 q12, #0.5\n" /* pos offset */ \ + "vmov.f32 q13, #0.5\n" /* pos offset */ \ + "vcgt.f32 q14, q2, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q3, #0\n" /* get pos mask */ \ + "vbif.f32 q10, q7, q14\n" /* get right offset */ \ + "vbif.f32 q11, q7, q15\n" /* get right offset */ \ + "vcgt.f32 q14, q4, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q5, #0\n" /* get pos mask */ \ + "vbif.f32 q12, q7, q14\n" /* get right offset */ \ + "vbif.f32 q13, q7, q15\n" /* get right offset */ \ + "vadd.f32 q2, q10, q2\n" /* r20, add offset */ \ + "vadd.f32 q3, q11, q3\n" /* r21, add offset */ \ + "vadd.f32 q4, q12, q4\n" /* r30, add offset */ \ + "vadd.f32 q5, q13, q5\n" /* r31, add offset */ \ + "vcvt.s32.f32 q6, q8\n" /* r00, fp32->int32 */ \ + "vcvt.s32.f32 q7, q9\n" /* r01, fp32->int32 */ \ + "vcvt.s32.f32 q10, q0\n" /* r10, fp32->int32 */ \ + "vcvt.s32.f32 q11, q1\n" /* r11, fp32->int32 */ \ + "vcvt.s32.f32 q12, q2\n" /* r20, fp32->int32 */ \ + "vcvt.s32.f32 q13, q3\n" /* r21, fp32->int32 */ \ + "vcvt.s32.f32 q14, q4\n" /* r30, fp32->int32 */ \ + "vcvt.s32.f32 q15, q5\n" /* r31, fp32->int32 */ \ + "vqmovn.s32 d0, q6\n" /* r00, int32 -> int16 */ \ + "vqmovn.s32 d1, q7\n" /* r01, int32 -> int16 */ \ + "vqmovn.s32 d2, q10\n" /* r10, int32 -> int16 */ \ + "vqmovn.s32 d3, q11\n" /* r11, int32 -> int16 */ \ + "vqmovn.s32 d4, q12\n" /* r00, int32 -> int16 */ \ + "vqmovn.s32 d5, q13\n" /* r01, int32 -> int16 */ \ + "vqmovn.s32 d6, q14\n" /* r10, int32 -> int16 */ \ + "vqmovn.s32 d7, q15\n" /* r11, int32 -> int16 */ \ + "vqmovn.s16 d8, q0\n" /* 0, int16 -> int8 */ \ + "vqmovn.s16 d9, q1\n" /* 1, int16 -> int8 */ \ + "vqmovn.s16 d10, q2\n" /* 2, int16 -> int8 */ \ + "vqmovn.s16 d11, q3\n" /* 3, int16 -> int8 */ \ + "vst1.32 {d8}, [%[c_ptr0]]!\n" /* write r0*/ \ + "vst1.32 {d9}, [%[c_ptr1]]!\n" /* write r1*/ \ + "vst1.32 {d10}, [%[c_ptr2]]!\n" /* write r2*/ \ "vst1.32 {d11}, [%[c_ptr3]]!\n" /* write r3*/ -template <> -inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT - const int32_t* bias, int32_t*& c_ptr0, // NOLINT - int32_t*& c_ptr1, int32_t*& c_ptr2, // NOLINT - int32_t*& c_ptr3, const float* scale, bool is_relu, // NOLINT - int k, int rem) { - asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), - [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), [k] "+r"(k) - : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13", "q14", "q15", "r0", "cc"); -} +// clang-format on template <> -inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT - const int32_t* bias, float*& c_ptr0, // NOLINT - float*& c_ptr1, float*& c_ptr2, float*& c_ptr3, // NOLINT - const float* scale, bool is_relu, int k, int rem) { +inline void gemm_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const float* bias, + float32_t*& c_ptr0, // NOLINT + float32_t*& c_ptr1, // NOLINT + float32_t*& c_ptr2, // NOLINT + float32_t*& c_ptr3, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int rem) { asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), - [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), [k] "+r"(k) - : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem), + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k) + : [is_relu] "r"(is_relu), + [bias] "r"(bias), + [rem] "r"(rem), [scale] "r"(scale) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13", "q14", "q15", "r0", "cc"); + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "cc"); } template <> -inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT - const int32_t* bias, int8_t*& c_ptr0, // NOLINT - int8_t*& c_ptr1, int8_t*& c_ptr2, int8_t*& c_ptr3, // NOLINT - const float* scale, bool is_relu, int k, int rem) { +inline void gemm_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const float* bias, + int8_t*& c_ptr0, // NOLINT + int8_t*& c_ptr1, // NOLINT + int8_t*& c_ptr2, // NOLINT + int8_t*& c_ptr3, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int rem) { asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), - [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), [k] "+r"(k) - : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem), + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k) + : [is_relu] "r"(is_relu), + [bias] "r"(bias), + [rem] "r"(rem), [scale] "r"(scale) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13", "q14", "q15", "r0", "cc"); + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "cc"); } -#endif //__aarch64__ // NOLINT +#endif // __aarch64__ // NOLINT // gemm wrapper template void gemm_prepack_oth_int8(const int8_t* A_packed, const int8_t* B, - const int* bias, + const float* bias, Dtype* C, int M, int N, @@ -1921,11 +1765,11 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, auto* b_tmp = static_cast(workspace); - auto* zerobuf = static_cast(malloc(x_block * \ - (sizeof(int8_t) + sizeof(Dtype)))); + auto* zerobuf = + static_cast(malloc(x_block * (sizeof(int8_t) + sizeof(Dtype)))); memset(zerobuf, 0, x_block * sizeof(int8_t)); - auto* trash_ptr = reinterpret_cast(zerobuf + \ - x_block * sizeof(int8_t)); + auto* trash_ptr = + reinterpret_cast(zerobuf + x_block * sizeof(int8_t)); //! apanel is pre_compute outside gemm @@ -1960,7 +1804,7 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, Dtype* tmp2 = nullptr; Dtype* tmp3 = nullptr; float32_t scale_local[4]; - int32_t bias_local[4] = {0, 0, 0, 0}; + float32_t bias_local[4] = {0, 0, 0, 0}; if (is_bias) { bias_local[0] = bias[y]; bias_local[1] = bias[y + 1]; @@ -1998,9 +1842,17 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, c_ptr2 = out2; c_ptr3 = out3; } - gemm_int8_kernel(a_ptr_l, b_ptr, bias_local, - c_ptr0, c_ptr1, c_ptr2, c_ptr3, - scale_local, is_relu, k, k_rem); + gemm_int8_kernel(a_ptr_l, + b_ptr, + bias_local, + c_ptr0, + c_ptr1, + c_ptr2, + c_ptr3, + scale_local, + is_relu, + k, + k_rem); if (flag_rem && (xb == bblocks - 1)) { for (int i = 0; i < n_rem; ++i) { *(tmp0++) = out0[i]; @@ -2030,8 +1882,12 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, // e0,f0, e1,f1, e2,f2, e3,f3; // g0,h0, g1,h1, g2,h2, g3,h3; /***********************************************************************/ -void prepackA_m4k2x2_int8(int8_t* out, const int8_t* in, const int ldin, - const int m0, const int mmax, const int k0, +void prepackA_m4k2x2_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int m0, + const int mmax, + const int k0, const int kmax) { int y_len = mmax - m0; int x_len = kmax - k0; @@ -2064,6 +1920,7 @@ void prepackA_m4k2x2_int8(int8_t* out, const int8_t* in, const int ldin, int8_t* ptr_out = out + y * x_len_roundup; int i = 0; for (; i < x_len + 1 - 2 * KBLOCK_INT8; i += 2 * KBLOCK_INT8) { +// clang-format off #ifdef __aarch64__ asm volatile( "ld1 {v0.8b}, [%[ptr0]], #8\n" /* load r0, 8 int8 */ @@ -2107,7 +1964,8 @@ void prepackA_m4k2x2_int8(int8_t* out, const int8_t* in, const int ldin, [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) : : "q0", "q1", "cc", "memory"); -#endif //__aarch64 // NOLINT +#endif // __aarch64__ + // clang-format on } if (i + KBLOCK_INT8 <= x_len) { ptr_out[0] = ptr0[0]; @@ -2222,8 +2080,12 @@ void prepackA_m4k2x2_int8(int8_t* out, const int8_t* in, const int ldin, // a12,b12, a13,b13, a14,b14, a15,b15; // c12,d12, c13,d13, c14,d14, c15,d15;----block3 /***************************************************************************/ -void prepackA_m4k2x2_trans_int8(int8_t* out, const int8_t* in, const int ldin, - const int m0, const int mmax, const int k0, +void prepackA_m4k2x2_trans_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int m0, + const int mmax, + const int k0, const int kmax) { int xlen = mmax - m0; int ylen = kmax - k0; @@ -2235,8 +2097,8 @@ void prepackA_m4k2x2_trans_int8(int8_t* out, const int8_t* in, const int ldin, int x_rem = xlen & (MUNROLL * MBLOCK_INT8_OTH - 1); int m_rem = (x_rem + MBLOCK_INT8_OTH - 1) / MBLOCK_INT8_OTH; - const uint8_t mask_buffer[16] = {0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15}; + const uint8_t mask_buffer[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; int8x16_t vzero = vdupq_n_s8(0); uint8x16_t vmask = vcltq_u8(vld1q_u8(mask_buffer), vdupq_n_u8(x_rem)); @@ -2267,6 +2129,7 @@ void prepackA_m4k2x2_trans_int8(int8_t* out, const int8_t* in, const int ldin, } int k = mcnt; int rem = m_rem; +// clang-format off #ifdef __aarch64__ asm volatile( "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load r0 */ @@ -2504,7 +2367,8 @@ void prepackA_m4k2x2_trans_int8(int8_t* out, const int8_t* in, const int ldin, [ptr_out] "+r"(ptr_out) : [mask] "w"(vmask), [vzero] "w"(vzero), [stride] "r"(stride_out) : "q0", "q1", "q2", "q3", "cc"); -#endif //__aarch64__ // NOLINT +#endif // __aarch64__ + // clang-format on } free(zerobuf); } @@ -2542,12 +2406,17 @@ void prepackA_m4k2x2_trans_int8(int8_t* out, const int8_t* in, const int ldin, // c0,d0, c1,d1, c2,d2, c3,d3; // c4,d4, c5,d5, c6,d6, c7,d7; /***************************************************************************/ -void packb_int8(int8_t* out, const int8_t* in, const int ldin, const int k0, - const int kmax, const int n0, const int nmax, +void packb_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int k0, + const int kmax, + const int n0, + const int nmax, const int8_t* zerobuf) { const int8_t* inptr = in + k0 * ldin + n0; - const uint8_t mask_buffer[16] = {0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15}; + const uint8_t mask_buffer[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; int x_len = nmax - n0; int y_len = kmax - k0; int kup = ROUNDUP(y_len, KBLOCK_INT8); @@ -2577,6 +2446,7 @@ void packb_int8(int8_t* out, const int8_t* in, const int ldin, const int k0, } int8_t* outptr_row_col = out + y * NBLOCK_INT8_OTH; int k = kcnt; +// clang-format off #ifdef __aarch64__ asm volatile( "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load r0 */ @@ -2748,7 +2618,8 @@ void packb_int8(int8_t* out, const int8_t* in, const int ldin, const int k0, : [rem] "r"(rem), [mask] "w"(vmask), [vzero] "w"(vzero), [stride] "r"(stride_out) : "q0", "q1", "cc"); -#endif //__aarch64__ // NOLINT +#endif // __aarch64__ + // clang-format on } } @@ -2786,15 +2657,20 @@ void packb_int8(int8_t* out, const int8_t* in, const int ldin, const int k0, // a6,b6, c6,d6, e6,f6, g6,h6; // a7,b7, c7,d7, e7,f7, g7,h7;--block3, address+32 /*******************************************************************/ -void packb_trans_int8(int8_t* out, const int8_t* in, const int ldin, - const int k0, const int kmax, const int n0, - const int nmax, const int8_t* zerobuf) { +void packb_trans_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int k0, + const int kmax, + const int n0, + const int nmax, + const int8_t* zerobuf) { const int KUNROLL = 4; const int NUNROLL = 8; const int RATIO = NBLOCK_INT8_OTH / NUNROLL; const int8_t* inptr = in + n0 * ldin + k0; - const uint8_t mask_buffer[16] = {0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15}; + const uint8_t mask_buffer[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; int y_len = nmax - n0; int x_len = kmax - k0; int yup = ROUNDUP(y_len, NBLOCK_INT8_OTH); @@ -2848,6 +2724,7 @@ void packb_trans_int8(int8_t* out, const int8_t* in, const int ldin, } int k = kcnt; int rem = k_rem; +// clang-format off #ifdef __aarch64__ asm volatile( "cbz %w[k], 1f\n" /* skip main loop */ @@ -3078,7 +2955,8 @@ void packb_trans_int8(int8_t* out, const int8_t* in, const int ldin, [k] "+r"(k), [rem] "+r"(rem) : [mask] "w"(vmask), [vzero] "w"(vzero) : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc"); -#endif //__aarch64__ // NOLINT +#endif // __aarch64__ + // clang-format on } } @@ -3087,7 +2965,7 @@ void packb_trans_int8(int8_t* out, const int8_t* in, const int ldin, template void gemm_prepack_sdot_int8(const int8_t* A_packed, const int8_t* B, - const int* bias, + const float* bias, Dtype* C, int M, int N, @@ -3097,169 +2975,180 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed, bool is_transB, const float* scale, ARMContext* ctx) { - size_t llc_size = ctx->llc_size() / 4; - auto workspace = ctx->workspace_data(); - //! MBLOCK_INT8_DOT * x (result) + MBLOCK_INT8_DOT * k (A) + x * k (B) = l2 - int x_block = (llc_size - (MBLOCK_INT8_DOT * K)) / \ - (sizeof(int8_t) * (K + MBLOCK_INT8_DOT)); - x_block /= NBLOCK_INT8_DOT; - x_block *= NBLOCK_INT8_DOT; - int x_num = (N + (x_block - 1)) / x_block; - x_block = (N + x_num - 1) / x_num; - x_block = (x_block + NBLOCK_INT8_DOT - 1) / NBLOCK_INT8_DOT; - x_block *= NBLOCK_INT8_DOT; - x_block = x_block < NBLOCK_INT8_DOT ? NBLOCK_INT8_DOT : x_block; - - int kup = ROUNDUP(K, KBLOCK_INT8); - // unroll 2 loop - int tail_pre = ((kup / 4) & (KBLOCK_INT8 - 1)); - int k_pre = (((kup / 4) + KBLOCK_INT8 - 1) / KBLOCK_INT8) - 1; - - bool flag_p_remain = false; - int remain = 0; - - //! apanel is pre_compute outside gemm - for (unsigned int x0 = 0; x0 < N; x0 += x_block) { - unsigned int xmax = x0 + x_block; - if (xmax > N) { - xmax = N; - } - int bblocks = (xmax - x0 + NBLOCK_INT8_DOT - 1) / NBLOCK_INT8_DOT; - remain = xmax - x0 - (bblocks - 1) * NBLOCK_INT8_DOT; - if (remain > 0) { - flag_p_remain = true; + size_t llc_size = ctx->llc_size() / 4; + auto workspace = ctx->workspace_data(); + //! MBLOCK_INT8_DOT * x (result) + MBLOCK_INT8_DOT * k (A) + x * k (B) = l2 + int x_block = (llc_size - (MBLOCK_INT8_DOT * K)) / + (sizeof(int8_t) * (K + MBLOCK_INT8_DOT)); + x_block /= NBLOCK_INT8_DOT; + x_block *= NBLOCK_INT8_DOT; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK_INT8_DOT - 1) / NBLOCK_INT8_DOT; + x_block *= NBLOCK_INT8_DOT; + x_block = x_block < NBLOCK_INT8_DOT ? NBLOCK_INT8_DOT : x_block; + + int kup = ROUNDUP(K, KBLOCK_INT8); + // unroll 2 loop + int tail_pre = ((kup / 4) & (KBLOCK_INT8 - 1)); + int k_pre = (((kup / 4) + KBLOCK_INT8 - 1) / KBLOCK_INT8) - 1; + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK_INT8_DOT - 1) / NBLOCK_INT8_DOT; + remain = xmax - x0 - (bblocks - 1) * NBLOCK_INT8_DOT; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + auto b_pannel = static_cast(workspace); + if (!is_transB) { + // K * N + packb_sdot_int8(b_pannel, B, N, 0, K, x0, xmax); + } else { + // N X K + packb_sdot_trans_int8(b_pannel, B, K, 0, K, x0, xmax); + } +#pragma omp parallel for + for (unsigned int y = 0; y < M; y += MBLOCK_INT8_DOT) { + unsigned int ymax = y + MBLOCK_INT8_DOT; + if (ymax > M) { + ymax = M; + } + + float32_t bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + bias_local[6] = bias[y + 6]; + bias_local[7] = bias[y + 7]; + } + float32_t scale_local[8]; + if (scale) { + scale_local[0] = scale[y]; + scale_local[1] = scale[y + 1]; + scale_local[2] = scale[y + 2]; + scale_local[3] = scale[y + 3]; + scale_local[4] = scale[y + 4]; + scale_local[5] = scale[y + 5]; + scale_local[6] = scale[y + 6]; + scale_local[7] = scale[y + 7]; + } + + Dtype cout0[NBLOCK_INT8_DOT]; + Dtype cout1[NBLOCK_INT8_DOT]; + Dtype cout2[NBLOCK_INT8_DOT]; + Dtype cout3[NBLOCK_INT8_DOT]; + Dtype cout4[NBLOCK_INT8_DOT]; + Dtype cout5[NBLOCK_INT8_DOT]; + Dtype cout6[NBLOCK_INT8_DOT]; + Dtype cout7[NBLOCK_INT8_DOT]; + + Dtype* c_ptr0 = C + y * N + x0; + Dtype* c_ptr1 = c_ptr0 + N; + Dtype* c_ptr2 = c_ptr1 + N; + Dtype* c_ptr3 = c_ptr2 + N; + Dtype* c_ptr4 = c_ptr3 + N; + Dtype* c_ptr5 = c_ptr4 + N; + Dtype* c_ptr6 = c_ptr5 + N; + Dtype* c_ptr7 = c_ptr6 + N; + + Dtype* pout0 = c_ptr0; + Dtype* pout1 = c_ptr1; + Dtype* pout2 = c_ptr2; + Dtype* pout3 = c_ptr3; + Dtype* pout4 = c_ptr4; + Dtype* pout5 = c_ptr5; + Dtype* pout6 = c_ptr6; + Dtype* pout7 = c_ptr7; + + // const int8_t *a_ptr_l = A_packed + y * K; + const int8_t* a_ptr_l = A_packed + y * kup; + const int8_t* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + c_ptr1 = cout1; + case 5: + c_ptr2 = cout2; + case 4: + c_ptr3 = cout3; + case 3: + c_ptr4 = cout4; + case 2: + c_ptr5 = cout5; + case 1: + c_ptr6 = cout6; + case 0: + c_ptr7 = cout7; + default: + break; + } } - //! load bpanel - auto b_pannel = static_cast(workspace); - if (!is_transB) { - // K * N - packb_sdot_int8(b_pannel, B, N, 0, K, x0, xmax); - } else { - // N X K - packb_sdot_trans_int8(b_pannel, B, K, 0, K, x0, xmax); + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + pout6 = c_ptr6; + pout7 = c_ptr7; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + c_ptr6 = cout6; + c_ptr7 = cout7; } -#pragma omp parallel for - for (unsigned int y = 0; y < M; y += MBLOCK_INT8_DOT) { - unsigned int ymax = y + MBLOCK_INT8_DOT; - if (ymax > M) { - ymax = M; - } - - int32_t bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - if (is_bias) { - bias_local[0] = bias[y]; - bias_local[1] = bias[y + 1]; - bias_local[2] = bias[y + 2]; - bias_local[3] = bias[y + 3]; - bias_local[4] = bias[y + 4]; - bias_local[5] = bias[y + 5]; - bias_local[6] = bias[y + 6]; - bias_local[7] = bias[y + 7]; - } - float32_t scale_local[8]; - if (scale) { - scale_local[0] = scale[y]; - scale_local[1] = scale[y + 1]; - scale_local[2] = scale[y + 2]; - scale_local[3] = scale[y + 3]; - scale_local[4] = scale[y + 4]; - scale_local[5] = scale[y + 5]; - scale_local[6] = scale[y + 6]; - scale_local[7] = scale[y + 7]; - } - - Dtype cout0[NBLOCK_INT8_DOT]; - Dtype cout1[NBLOCK_INT8_DOT]; - Dtype cout2[NBLOCK_INT8_DOT]; - Dtype cout3[NBLOCK_INT8_DOT]; - Dtype cout4[NBLOCK_INT8_DOT]; - Dtype cout5[NBLOCK_INT8_DOT]; - Dtype cout6[NBLOCK_INT8_DOT]; - Dtype cout7[NBLOCK_INT8_DOT]; - - Dtype *c_ptr0 = C + y * N + x0; - Dtype *c_ptr1 = c_ptr0 + N; - Dtype *c_ptr2 = c_ptr1 + N; - Dtype *c_ptr3 = c_ptr2 + N; - Dtype *c_ptr4 = c_ptr3 + N; - Dtype *c_ptr5 = c_ptr4 + N; - Dtype *c_ptr6 = c_ptr5 + N; - Dtype *c_ptr7 = c_ptr6 + N; - - Dtype *pout0 = c_ptr0; - Dtype *pout1 = c_ptr1; - Dtype *pout2 = c_ptr2; - Dtype *pout3 = c_ptr3; - Dtype *pout4 = c_ptr4; - Dtype *pout5 = c_ptr5; - Dtype *pout6 = c_ptr6; - Dtype *pout7 = c_ptr7; - - // const int8_t *a_ptr_l = A_packed + y * K; - const int8_t *a_ptr_l = A_packed + y * kup; - const int8_t *b_ptr = b_pannel; - for (int xb = 0; xb < bblocks; xb++) { - if ((y + 7) >= ymax) { - switch ((y + 7) - ymax) { - case 6: - c_ptr1 = cout1; - case 5: - c_ptr2 = cout2; - case 4: - c_ptr3 = cout3; - case 3: - c_ptr4 = cout4; - case 2: - c_ptr5 = cout5; - case 1: - c_ptr6 = cout6; - case 0: - c_ptr7 = cout7; - default: - break; - } - } - if (flag_p_remain && (xb == bblocks - 1)) { - pout0 = c_ptr0; - pout1 = c_ptr1; - pout2 = c_ptr2; - pout3 = c_ptr3; - pout4 = c_ptr4; - pout5 = c_ptr5; - pout6 = c_ptr6; - pout7 = c_ptr7; - - c_ptr0 = cout0; - c_ptr1 = cout1; - c_ptr2 = cout2; - c_ptr3 = cout3; - c_ptr4 = cout4; - c_ptr5 = cout5; - c_ptr6 = cout6; - c_ptr7 = cout7; - } - const int8_t *a_ptr = a_ptr_l; - int tail = tail_pre; - int k = k_pre; - sgemm_sdot_int8_kernel(a_ptr, b_ptr, - bias_local, c_ptr0, c_ptr1, c_ptr2, c_ptr3, \ - c_ptr4, c_ptr5, c_ptr6, c_ptr7, scale_local, \ - is_relu, k, tail); - if (flag_p_remain && (xb == bblocks - 1)) { - for (int i = 0; i < remain; ++i) { - *pout0++ = cout0[i]; - *pout1++ = cout1[i]; - *pout2++ = cout2[i]; - *pout3++ = cout3[i]; - *pout4++ = cout4[i]; - *pout5++ = cout5[i]; - *pout6++ = cout6[i]; - *pout7++ = cout7[i]; - } - } - } + const int8_t* a_ptr = a_ptr_l; + int tail = tail_pre; + int k = k_pre; + gemm_sdot_int8_kernel(a_ptr, + b_ptr, + bias_local, + c_ptr0, + c_ptr1, + c_ptr2, + c_ptr3, + c_ptr4, + c_ptr5, + c_ptr6, + c_ptr7, + scale_local, + is_relu, + k, + tail); + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + *pout6++ = cout6[i]; + *pout7++ = cout7[i]; + } } + } } + } } void prepackA_m8k4_int8(int8_t* out, @@ -3269,46 +3158,47 @@ void prepackA_m8k4_int8(int8_t* out, const int mmax, const int k0, const int kmax) { - int x_len = (kmax - k0); - int8_t zerobuff[x_len]; //NOLINT - memset(zerobuff, 0, sizeof(int8_t) * x_len); - - int8_t *dout = out; - const int8_t *inptr = in; - int kup = ROUNDUP(x_len, KBLOCK_INT8); - int stride = kup * 8; - int remain = x_len % 4; + int x_len = (kmax - k0); + int8_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(int8_t) * x_len); + + int8_t* dout = out; + const int8_t* inptr = in; + int kup = ROUNDUP(x_len, KBLOCK_INT8); + int stride = kup * 8; + int remain = x_len % 4; #pragma omp parallel for - for (int y = m0; y < mmax; y += 8) { - int8_t* outptr = dout + stride * (y - m0) / 8; - const int8_t * inptr_row[8]; - inptr_row[0] = inptr + y * ldin + k0; - for (int i = 1; i < 8; i++) { - inptr_row[i] = inptr_row[i - 1] + ldin; - } - //! cope with row index exceed real size, set to zero buffer - if ((y + 7) >= mmax) { - switch ((y + 7) - mmax) { - case 6: - inptr_row[1] = zerobuff; - case 5: - inptr_row[2] = zerobuff; - case 4: - inptr_row[3] = zerobuff; - case 3: - inptr_row[4] = zerobuff; - case 2: - inptr_row[5] = zerobuff; - case 1: - inptr_row[6] = zerobuff; - case 0: - inptr_row[7] = zerobuff; - default: - break; - } - } + for (int y = m0; y < mmax; y += 8) { + int8_t* outptr = dout + stride * (y - m0) / 8; + const int8_t* inptr_row[8]; + inptr_row[0] = inptr + y * ldin + k0; + for (int i = 1; i < 8; i++) { + inptr_row[i] = inptr_row[i - 1] + ldin; + } + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= mmax) { + switch ((y + 7) - mmax) { + case 6: + inptr_row[1] = zerobuff; + case 5: + inptr_row[2] = zerobuff; + case 4: + inptr_row[3] = zerobuff; + case 3: + inptr_row[4] = zerobuff; + case 2: + inptr_row[5] = zerobuff; + case 1: + inptr_row[6] = zerobuff; + case 0: + inptr_row[7] = zerobuff; + default: + break; + } + } + // clang-format off asm volatile( - "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0]] \n" "prfm pldl1keep, [%[ptr0], #64] \n" "prfm pldl1keep, [%[ptr1]] \n" "prfm pldl1keep, [%[ptr1], #64] \n" @@ -3410,17 +3300,18 @@ void prepackA_m8k4_int8(int8_t* out, ); x -= 4; } - if (x > 0) { - for (int i = 0; i < 8; i++) { - for (int j = x; j > 0; j--) { - *outptr++ = *inptr_row[i]++; - } - for (int j = 0; j < 4 - remain; j++) { - *outptr++ = 0; - } - } + // clang-format on + if (x > 0) { + for (int i = 0; i < 8; i++) { + for (int j = x; j > 0; j--) { + *outptr++ = *inptr_row[i]++; } + for (int j = 0; j < 4 - remain; j++) { + *outptr++ = 0; + } + } } + } } void prepackA_m8k4_trans_int8(int8_t* out, @@ -3430,37 +3321,37 @@ void prepackA_m8k4_trans_int8(int8_t* out, const int mmax, const int k0, const int kmax) { - int8_t *outptr = out; - const int8_t *inptr = in + k0 * ldin + m0; - int x_len = mmax - m0; - int y_len = kmax - k0; - int right_remain = x_len % 8; - int kup = ROUNDUP(y_len, KBLOCK_INT8); - - int stride_out = 8 * kup; - int8_t zerobuff[x_len]; //NOLINT - memset(zerobuff, 0, sizeof(int8_t) * x_len); - printf("right_remain: %d \n", right_remain); + int8_t* outptr = out; + const int8_t* inptr = in + k0 * ldin + m0; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len % 8; + int kup = ROUNDUP(y_len, KBLOCK_INT8); + + int stride_out = 8 * kup; + int8_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(int8_t) * x_len); #pragma omp parallel for - for (int y = 0; y < y_len; y += 4) { - const int8_t* inptr0 = inptr + y * ldin; - const int8_t* inptr1 = inptr0 + ldin; - const int8_t* inptr2 = inptr1 + ldin; - const int8_t* inptr3 = inptr2 + ldin; - - if (y + 4 > y_len) { - switch (y + 4 - y_len) { - case 3: - inptr1 = zerobuff; - case 2: - inptr2 = zerobuff; - case 1: - inptr3 = zerobuff; - default: - break; - } - } + for (int y = 0; y < y_len; y += 4) { + const int8_t* inptr0 = inptr + y * ldin; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + if (y + 4 > y_len) { + switch (y + 4 - y_len) { + case 3: + inptr1 = zerobuff; + case 2: + inptr2 = zerobuff; + case 1: + inptr3 = zerobuff; + default: + break; + } + } + // clang-format off asm volatile( "prfm pldl1keep, [%[ptr0]] \n" "prfm pldl1keep, [%[ptr0], #64] \n" @@ -3515,63 +3406,65 @@ void prepackA_m8k4_trans_int8(int8_t* out, ); outptr_row += stride_out; } - if (right_remain > 0) { - int8_t *out0 = outptr_row; - for (; x < x_len; x++) { - *out0++ = *inptr0++; - *out0++ = *inptr1++; - *out0++ = *inptr2++; - *out0++ = *inptr3++; - } - for (int i = 0; i < 8 - right_remain; i++) { - *out0++ = 0; - *out0++ = 0; - *out0++ = 0; - *out0++ = 0; - } - } + // clang-format on + if (right_remain > 0) { + int8_t* out0 = outptr_row; + for (; x < x_len; x++) { + *out0++ = *inptr0++; + *out0++ = *inptr1++; + *out0++ = *inptr2++; + *out0++ = *inptr3++; + } + for (int i = 0; i < 8 - right_remain; i++) { + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + } } + } } void packb_sdot_int8(int8_t* out, - const int8_t* in, - const int ldin, - const int k0, - const int kmax, - const int n0, - const int nmax) { - int y_len = kmax - k0; - int x_len = nmax - n0; - int kup = ROUNDUP(y_len, KBLOCK_INT8); // 4k - int8_t zerobuff[x_len]; //NOLINT - memset(zerobuff, 0, sizeof(int8_t) * x_len); - int8_t *outptr = out; - const int8_t *inptr = in + k0 * ldin + n0; - - int stride_out = 12 * kup; - // int stride_y = 48; - int remain = x_len % 12; - - // data B is not transposed, transpose B to k * 12 + const int8_t* in, + const int ldin, + const int k0, + const int kmax, + const int n0, + const int nmax) { + int y_len = kmax - k0; + int x_len = nmax - n0; + int kup = ROUNDUP(y_len, KBLOCK_INT8); // 4k + int8_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(int8_t) * x_len); + int8_t* outptr = out; + const int8_t* inptr = in + k0 * ldin + n0; + + int stride_out = 12 * kup; + // int stride_y = 48; + int remain = x_len % 12; + +// data B is not transposed, transpose B to k * 12 #pragma omp parallel for - for (int y = 0; y < y_len; y += 4) { - // cope with row index exceed real size, set to zero - const int8_t *inptr0 = inptr + y * ldin; - const int8_t *inptr1 = inptr0 + ldin; - const int8_t *inptr2 = inptr1 + ldin; - const int8_t *inptr3 = inptr2 + ldin; - if (y + 4 > y_len) { - switch (y + 4 - y_len) { - case 3: - inptr1 = zerobuff; - case 2: - inptr2 = zerobuff; - case 1: - inptr3 = zerobuff; - default: - break; - } - } + for (int y = 0; y < y_len; y += 4) { + // cope with row index exceed real size, set to zero + const int8_t* inptr0 = inptr + y * ldin; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + if (y + 4 > y_len) { + switch (y + 4 - y_len) { + case 3: + inptr1 = zerobuff; + case 2: + inptr2 = zerobuff; + case 1: + inptr3 = zerobuff; + default: + break; + } + } + // clang-format off asm volatile( "prfm pldl1keep, [%[inptr0]] \n" "prfm pldl1keep, [%[inptr0], #64] \n" @@ -3650,20 +3543,21 @@ void packb_sdot_int8(int8_t* out, ); outptr_row += stride_out; } - int8_t* out0 = outptr_row; // outptr + stride_out + y * remain; - for (; x < x_len; x++) { - *out0++ = *inptr0++; - *out0++ = *inptr1++; - *out0++ = *inptr2++; - *out0++ = *inptr3++; - } - for (int i = 0; i < 12 - remain; i++) { - *out0++ = 0; - *out0++ = 0; - *out0++ = 0; - *out0++ = 0; - } + // clang-format on + int8_t* out0 = outptr_row; // outptr + stride_out + y * remain; + for (; x < x_len; x++) { + *out0++ = *inptr0++; + *out0++ = *inptr1++; + *out0++ = *inptr2++; + *out0++ = *inptr3++; + } + for (int i = 0; i < 12 - remain; i++) { + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; } + } } void packb_sdot_trans_int8(int8_t* out, @@ -3673,34 +3567,35 @@ void packb_sdot_trans_int8(int8_t* out, const int kmax, const int n0, const int nmax) { - int8_t *outptr = out; - const int8_t *inptr = in + n0 * ldin + k0; - int y_len = nmax - n0; - int x_len = kmax - k0; + int8_t* outptr = out; + const int8_t* inptr = in + n0 * ldin + k0; + int y_len = nmax - n0; + int x_len = kmax - k0; - int kup = ROUNDUP(x_len, KBLOCK_INT8); // 4 + int kup = ROUNDUP(x_len, KBLOCK_INT8); // 4 - int8_t zerobuff[kup]; //NOLINT - memset(zerobuff, 0, sizeof(int8_t) * kup); + int8_t zerobuff[kup]; // NOLINT + memset(zerobuff, 0, sizeof(int8_t) * kup); - int stride_y = 48; - int stride_out = kup; + int stride_y = 48; + int stride_out = kup; - int remain = x_len % 8; + int remain = x_len % 8; #pragma omp parallel for - for (int y = 0; y < y_len; y += 12) { - const int8_t *inptr_row[12]; - inptr_row[0] = inptr + y * ldin; - for (int i = 1; i < 12; i++) { - inptr_row[i] = inptr_row[i - 1] + ldin; - } - if (y + 12 > y_len) { - for (int i = y + 12 - y_len; i > 0; i--) { - // inptr_row[12 - i] = zero_ptr[12 - i - 1]; - inptr_row[12 - i] = zerobuff; - } - } + for (int y = 0; y < y_len; y += 12) { + const int8_t* inptr_row[12]; + inptr_row[0] = inptr + y * ldin; + for (int i = 1; i < 12; i++) { + inptr_row[i] = inptr_row[i - 1] + ldin; + } + if (y + 12 > y_len) { + for (int i = y + 12 - y_len; i > 0; i--) { + // inptr_row[12 - i] = zero_ptr[12 - i - 1]; + inptr_row[12 - i] = zerobuff; + } + } + // clang-format off asm volatile( "prfm pldl1keep, [%[ptr0]] \n" "prfm pldl1keep, [%[ptr1]] \n" @@ -3833,24 +3728,25 @@ void packb_sdot_trans_int8(int8_t* out, ); right_remain -= 4; } - if (right_remain > 0) { - for (int i = 0; i < 12; i++) { - for (int x = 0; x < right_remain; x++) { - *out0++ = *inptr_row[i]++; - } - for (int x = 0; x < 4 - right_remain; x++) { - *out0++ = 0; - } - } + // clang-format on + if (right_remain > 0) { + for (int i = 0; i < 12; i++) { + for (int x = 0; x < right_remain; x++) { + *out0++ = *inptr_row[i]++; + } + for (int x = 0; x < 4 - right_remain; x++) { + *out0++ = 0; } + } } + } } -#endif //dotprod //NOLINT +#endif // dotprod //NOLINT template <> void gemm_prepack_int8(const int8_t* A_packed, const int8_t* B, - const int* bias, + const float* bias, float32_t* C, int M, int N, @@ -3862,25 +3758,22 @@ void gemm_prepack_int8(const int8_t* A_packed, ARMContext* ctx) { #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) if (ctx->has_dot()) { - gemm_prepack_sdot_int8(A_packed, - B, bias, C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); - } else { - gemm_prepack_oth_int8(A_packed, B, - bias, C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); - } + gemm_prepack_sdot_int8( + A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); + } else { + gemm_prepack_oth_int8( + A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); + } #else - gemm_prepack_oth_int8(A_packed, B, - bias, C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); + gemm_prepack_oth_int8( + A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); #endif } template <> void gemm_prepack_int8(const int8_t* A_packed, const int8_t* B, - const int* bias, + const float* bias, int8_t* C, int M, int N, @@ -3892,47 +3785,15 @@ void gemm_prepack_int8(const int8_t* A_packed, ARMContext* ctx) { #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) if (ctx->has_dot()) { - gemm_prepack_sdot_int8(A_packed, B, bias, - C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); - } else { - gemm_prepack_oth_int8(A_packed, B, bias, - C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); - } -#else - gemm_prepack_oth_int8(A_packed, B, bias, - C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); -#endif -} - -template <> -void gemm_prepack_int8(const int8_t* A_packed, - const int8_t* B, - const int* bias, - int32_t* C, - int M, - int N, - int K, - bool is_bias, - bool is_relu, - bool is_transB, - const float* scale, - ARMContext* ctx) { -#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) - if (ctx->has_dot()) { - gemm_prepack_sdot_int8(A_packed, B, - bias, C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); - } else { - gemm_prepack_oth_int8(A_packed, B, - bias, C, M, N, K, is_bias, is_relu, - is_transB, scale, ctx); - } + gemm_prepack_sdot_int8( + A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); + } else { + gemm_prepack_oth_int8( + A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); + } #else - gemm_prepack_oth_int8(A_packed, B, bias, - C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); + gemm_prepack_oth_int8( + A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); #endif } diff --git a/lite/backends/arm/math/gemm_prepacked_int8.h b/lite/backends/arm/math/gemm_prepacked_int8.h index 7f54eea3988c6bb22a88c28f1f5ddb9faf0a98f8..c0c8ea6c35b905e29a52c114148a952558a6cae2 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.h +++ b/lite/backends/arm/math/gemm_prepacked_int8.h @@ -15,7 +15,6 @@ #pragma once #include #include "lite/core/context.h" -#include "lite/core/device_info.h" #include "lite/core/tensor.h" namespace paddle { @@ -34,7 +33,7 @@ const int NBLOCK_INT8_OTH = 16; const int MBLOCK_INT8_DOT = 8; const int NBLOCK_INT8_DOT = 12; -inline int get_hblock_int8(const ARMContext* ctx) { +inline int get_hblock_int8(ARMContext* ctx) { #ifdef WITH_ARM_DOTPROD if (ctx->has_dot()) { return MBLOCK_INT8_DOT; @@ -51,7 +50,7 @@ inline int get_hblock_int8(const ARMContext* ctx) { const int MBLOCK_INT8_OTH = 4; const int NBLOCK_INT8_OTH = 8; -inline int get_hblock_int8(const ARMContext* ctx) { return 4; } +inline int get_hblock_int8(ARMContext* ctx) { return 4; } #endif // __aarch64__ void prepackA_int8(void* out, @@ -75,7 +74,7 @@ void prepackA_int8(TensorLite* tout, template void gemm_prepack_int8(const int8_t* A_packed, const int8_t* B, - const int* bias, + const float* bias, dtype* C, int M, int N, @@ -87,7 +86,6 @@ void gemm_prepack_int8(const int8_t* A_packed, ARMContext* ctx); #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) - } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/gemm_s8.cc b/lite/backends/arm/math/gemm_s8.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bc3f5f4647ea0cc78131ff07837f1ff0ae39d56 --- /dev/null +++ b/lite/backends/arm/math/gemm_s8.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/gemm_s8.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void gemm_s8(bool is_transA, + bool is_transB, + int M, + int N, + int K, + const int8_t* A, + const int8_t* B, + Dtype* C, + const float* bias, + bool is_bias, + bool is_relu, + const float* scale, + ARMContext* ctx) { + int hblock = get_hblock_int8(ctx); + int m_roundup = hblock * ((M + hblock - 1) / hblock); + auto packed_A = static_cast( + TargetMalloc(TargetType::kARM, m_roundup * K * sizeof(int8_t))); + + int lda = is_transA ? M : K; + prepackA_int8(packed_A, A, lda, 0, M, 0, K, is_transA, ctx); + + gemm_prepack_int8( + packed_A, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); + TargetFree(TargetType::kARM, packed_A); +} + +template void gemm_s8(bool is_transA, + bool is_transB, + int M, + int N, + int K, + const int8_t* A, + const int8_t* B, + float* C, + const float* bias, + bool is_bias, + bool is_relu, + const float* scale, + ARMContext* ctx); + +template void gemm_s8(bool is_transA, + bool is_transB, + int M, + int N, + int K, + const int8_t* A, + const int8_t* B, + int8_t* C, + const float* bias, + bool is_bias, + bool is_relu, + const float* scale, + ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/gemm_s8.h b/lite/backends/arm/math/gemm_s8.h new file mode 100644 index 0000000000000000000000000000000000000000..0a37c5e3a488e491a3bf4a7277775681c657feb2 --- /dev/null +++ b/lite/backends/arm/math/gemm_s8.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/arm/math/gemm_prepacked_int8.h" +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void gemm_s8(bool is_transA, + bool is_transB, + int M, + int N, + int K, + const int8_t* A, + const int8_t* B, + Dtype* C, + const float* bias, + bool is_bias, + bool is_relu, + const float* scale, + ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/gemv_arm_int8.cc b/lite/backends/arm/math/gemv_arm_int8.cc index dff3024ba452b507e30578b417911353bef097d8..dab42cdeca28d40622590632985603ce8eab1fb9 100644 --- a/lite/backends/arm/math/gemv_arm_int8.cc +++ b/lite/backends/arm/math/gemv_arm_int8.cc @@ -22,36 +22,90 @@ namespace arm { namespace math { template -inline void write_gemv_out(const int* in, dtype* out, const float* scale); - -template <> -inline void write_gemv_out(const int* in, int* out, const float* scale) { - out[0] = in[0]; -} +inline void write_gemv_out(const int* in, + dtype* out, + const float* scale, + const float* bias, + int size, + bool is_relu); template <> -inline void write_gemv_out(const int* in, float* out, const float* scale) { - out[0] = in[0] * scale[0]; +inline void write_gemv_out(const int* in, + float* out, + const float* scale, + const float* bias, + int size, + bool is_relu) { + int i = 0; + float32x4_t vzero = vdupq_n_f32(0.f); + for (; i < size - 7; i += 8) { + float32x4_t vout0 = bias ? vld1q_f32(bias) : vdupq_n_f32(0.f); + float32x4_t vout1 = bias ? vld1q_f32(bias + 4) : vdupq_n_f32(0.f); + int32x4_t vin0 = vld1q_s32(in); + int32x4_t vin1 = vld1q_s32(in + 4); + float32x4_t vscale0 = vld1q_f32(scale); + float32x4_t vscale1 = vld1q_f32(scale + 4); + float32x4_t vinf0 = vcvtq_f32_s32(vin0); + float32x4_t vinf1 = vcvtq_f32_s32(vin1); + vout0 = vmlaq_f32(vout0, vinf0, vscale0); + vout1 = vmlaq_f32(vout1, vinf1, vscale1); + if (is_relu) { + vout0 = vmaxq_f32(vout0, vzero); + vout1 = vmaxq_f32(vout1, vzero); + } + vst1q_f32(out, vout0); + vst1q_f32(out + 4, vout1); + bias += 8; + in += 8; + out += 8; + scale += 8; + } + for (; i < size; ++i) { + out[0] = *(in++) * *(scale)++; + out[0] += bias ? *(bias++) : 0.f; + out[0] = is_relu ? (out[0] > 0.f ? out[0] : 0.f) : out[0]; + out++; + } } template <> inline void write_gemv_out(const int* in, signed char* out, - const float* scale) { - out[0] = saturate_cast(roundf(in[0] * scale[0])); + const float* scale, + const float* bias, + int size, + bool flag_relu) { + if (bias) { + for (int i = 0; i < size; ++i) { + out[0] = + saturate_cast(roundf(*(in++) * *(scale++) + *(bias++))); + if (flag_relu) { + out[0] = out[0] > 0 ? out[0] : 0; + } + out++; + } + } else { + for (int i = 0; i < size; ++i) { + out[0] = saturate_cast(roundf(*(in++) * *(scale++))); + if (flag_relu) { + out[0] = out[0] > 0 ? out[0] : 0; + } + out++; + } + } } template -bool gemv_int8(const int8_t* A, - const int8_t* x, - dtype* y, - bool transA, - int M, - int N, - const float* scale, - bool is_bias, - const int* bias, - bool is_relu) { +bool gemv_int8_oth(const int8_t* A, + const int8_t* x, + dtype* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const float* bias, + bool is_relu) { if (transA) { LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; return false; @@ -61,7 +115,6 @@ bool gemv_int8(const int8_t* A, const int8_t* weights_ptr = A; int cnt = N >> 4; int tail = N & 15; - int flag_bias = is_bias ? 1 : 0; #ifdef __aarch64__ int out_cnt = M >> 3; @@ -80,7 +133,7 @@ bool gemv_int8(const int8_t* A, const int8_t* ptr_w5 = ptr_w4 + N; const int8_t* ptr_w6 = ptr_w5 + N; const int8_t* ptr_w7 = ptr_w6 + N; - const int* bias_ptr = is_bias ? (bias + out_idx) : nullptr; + auto bias_ptr = is_bias ? bias + out_idx : nullptr; int cnt_loop = cnt; asm volatile( "prfm pldl1keep, [%[in]] \n" /* preload din */ @@ -153,13 +206,6 @@ bool gemv_int8(const int8_t* A, "addp v12.4s, v8.4s , v9.4s \n" /* pair add to 4 int32*/ "addp v13.4s, v10.4s, v11.4s \n" /* pair add to 4 int32*/ - "cmp %w[bias], #1 \n" /* check whether has bias */ - "blt 0f \n" /* jump to tail */ - "ldp q8, q9, [%[bias_ptr]]\n" /* load bias to q8, q9*/ - "add v12.4s, v12.4s, v8.4s \n" /* add bias */ - "add v13.4s, v13.4s, v9.4s \n" /* add bias */ - "0: \n" /* end of add bias */ - /* write to output */ "stp q12, q13, [%[out]] \n" /* save result */ : [in] "+r"(ptr_in), @@ -172,7 +218,7 @@ bool gemv_int8(const int8_t* A, [w6] "+r"(ptr_w6), [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop) - : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr), [bias] "r"(flag_bias) + : [out] "r"(ptr_out) : "cc", "memory", "v0", @@ -211,25 +257,8 @@ bool gemv_int8(const int8_t* A, ptr_out[6] += ptr_in[i] * ptr_w6[i]; ptr_out[7] += ptr_in[i] * ptr_w7[i]; } - if (is_relu) { - ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; - ptr_out[1] = ptr_out[1] > 0 ? ptr_out[1] : 0; - ptr_out[2] = ptr_out[2] > 0 ? ptr_out[2] : 0; - ptr_out[3] = ptr_out[3] > 0 ? ptr_out[3] : 0; - ptr_out[4] = ptr_out[4] > 0 ? ptr_out[4] : 0; - ptr_out[5] = ptr_out[5] > 0 ? ptr_out[5] : 0; - ptr_out[6] = ptr_out[6] > 0 ? ptr_out[6] : 0; - ptr_out[7] = ptr_out[7] > 0 ? ptr_out[7] : 0; - } - write_gemv_out(ptr_out, out_ptr, scale_ptr); - write_gemv_out(ptr_out + 1, out_ptr + 1, scale_ptr + 1); - write_gemv_out(ptr_out + 2, out_ptr + 2, scale_ptr + 2); - write_gemv_out(ptr_out + 3, out_ptr + 3, scale_ptr + 3); - write_gemv_out(ptr_out + 4, out_ptr + 4, scale_ptr + 4); - write_gemv_out(ptr_out + 5, out_ptr + 5, scale_ptr + 5); - write_gemv_out(ptr_out + 6, out_ptr + 6, scale_ptr + 6); - write_gemv_out(ptr_out + 7, out_ptr + 7, scale_ptr + 7); + write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu); } //! deal with remains @@ -242,12 +271,11 @@ bool gemv_int8(const int8_t* A, const int8_t* ptr_in = data_in; const int8_t* ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; - int bias0 = is_bias ? bias[j] : 0; + auto bias_ptr = is_bias ? bias + j : nullptr; asm volatile( "prfm pldl1keep, [%[in]] \n" /* preload din */ "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ "movi v0.4s, #0 \n" /* set out0 to 0 */ - "fmov s0, %w[bias0] \n" /* set bias */ /* check main loop */ "cmp %w[cnt], #1 \n" /* check whether has main loop */ "blt 2f \n" /* jump to tail */ @@ -269,17 +297,14 @@ bool gemv_int8(const int8_t* A, /* write to output */ "str s8, [%[out]] \n" /* save result */ : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out) : "cc", "memory", "v0", "v8", "v9", "v18"); for (int i = 0; i < tail; ++i) { ptr_out[0] += ptr_in[i] * ptr_w0[i]; } - if (is_relu) { - ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; - } - write_gemv_out(ptr_out, out_ptr, scale_ptr); + write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu); } -#else //__aarch64__ // NOLINT +#else // __aarch64__ int out_cnt = M >> 2; #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { @@ -293,10 +318,7 @@ bool gemv_int8(const int8_t* A, const int8_t* ptr_w2 = ptr_w1 + N; const int8_t* ptr_w3 = ptr_w2 + N; int cnt_loop = cnt; - int bias0 = is_bias ? bias[out_idx] : 0; - int bias1 = is_bias ? bias[out_idx + 1] : 0; - int bias2 = is_bias ? bias[out_idx + 2] : 0; - int bias3 = is_bias ? bias[out_idx + 3] : 0; + auto bias_ptr = is_bias ? bias + out_idx : nullptr; asm volatile( "pld [%[in]] @ preload cache line, input\n" "pld [%[w0]] @ preload cache line, weights r0\n" @@ -307,10 +329,6 @@ bool gemv_int8(const int8_t* A, "vmov.u32 q1, #0 @ set q1 to 0\n" "vmov.u32 q2, #0 @ set q2 to 0\n" "vmov.u32 q3, #0 @ set q3 to 0\n" - "vmov s0, %[bias0] @ set q0 to bias0\n" - "vmov s4, %[bias1] @ set q1 to bias1\n" - "vmov s8, %[bias2] @ set q2 to bias2\n" - "vmov s12,%[bias3] @ set q3 to bias3\n" // "vld1.32 {d20-d21}, %[bias] @ load bias data" "cmp %[cnt], #1 @ check whether has main loop\n" "blt 2f @ jump to pair add\n" @@ -355,11 +373,7 @@ bool gemv_int8(const int8_t* A, [w2] "+r"(ptr_w2), [w3] "+r"(ptr_w3), [cnt] "+r"(cnt_loop) - : [bias0] "r"(bias0), - [bias1] "r"(bias1), - [bias2] "r"(bias2), - [bias3] "r"(bias3), - [out] "r"(ptr_out) + : [out] "r"(ptr_out) : "cc", "memory", "q0", @@ -382,16 +396,7 @@ bool gemv_int8(const int8_t* A, ptr_out[2] += ptr_in[i] * ptr_w2[i]; ptr_out[3] += ptr_in[i] * ptr_w3[i]; } - if (is_relu) { - ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; - ptr_out[1] = ptr_out[1] > 0 ? ptr_out[1] : 0; - ptr_out[2] = ptr_out[2] > 0 ? ptr_out[2] : 0; - ptr_out[3] = ptr_out[3] > 0 ? ptr_out[3] : 0; - } - write_gemv_out(ptr_out, out_ptr, scale_ptr); - write_gemv_out(ptr_out + 1, out_ptr + 1, scale_ptr + 1); - write_gemv_out(ptr_out + 2, out_ptr + 2, scale_ptr + 2); - write_gemv_out(ptr_out + 3, out_ptr + 3, scale_ptr + 3); + write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, is_relu); } //! deal with remains #pragma omp parallel for @@ -402,13 +407,11 @@ bool gemv_int8(const int8_t* A, const int8_t* ptr_in = data_in; const int8_t* ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; - int bias0 = is_bias ? bias[j] : 0; + auto bias_ptr = is_bias ? bias + j : nullptr; asm volatile( - "pld [%[in]] @ preload cache line, " - "input\n" + "pld [%[in]] @ preload cache line, input\n" "pld [%[w0]] @ preload cache line, weights r0\n" "vmov.u32 q0, #0 @ set q0 to 0\n" - "vmov s0, %[bias0] @ set q0 to bias0\n" "cmp %[cnt], #1 @ check whether has main loop\n" "blt 2f @ jump to tail\n" /* main loop */ @@ -429,50 +432,263 @@ bool gemv_int8(const int8_t* A, /* write output */ "vst1.32 {d0[0]}, [%[out]] @ save result\n" : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) - : [bias0] "r"(bias0), [out] "r"(ptr_out) + : [out] "r"(ptr_out) : "cc", "memory", "q0", "q1", "q12", "q13"); for (int i = 0; i < tail; ++i) { ptr_out[0] += ptr_in[i] * ptr_w0[i]; } - if (is_relu) { - ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; + write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu); + } +#endif // __aarch64__ + return true; +} + +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) +template +bool gemv_int8_sdot(const int8_t* A, + const int8_t* x, + dtype* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const float* bias, + bool is_relu) { + if (transA) { + LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; + return false; + } + dtype* data_out = y; + const int8_t* data_in = x; + const int8_t* weights_ptr = A; + int cnt = N >> 4; + int tail = N & 15; + int size_m = (M >> 3) << 3; +#pragma omp parallel for + for (int j = 0; j < M - 7; j += 8) { + dtype* out_ptr = data_out + j; + const float* scale_ptr = scale + j; + auto bias_ptr = is_bias ? bias + j : nullptr; + int ptr_out[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const int8_t* ptr_in = data_in; + const int8_t* ptr_w0 = weights_ptr + (N * j); + const int8_t* ptr_w1 = ptr_w0 + N; + const int8_t* ptr_w2 = ptr_w1 + N; + const int8_t* ptr_w3 = ptr_w2 + N; + const int8_t* ptr_w4 = ptr_w3 + N; + const int8_t* ptr_w5 = ptr_w4 + N; + const int8_t* ptr_w6 = ptr_w5 + N; + const int8_t* ptr_w7 = ptr_w6 + N; + int cnt_loop = cnt; + if (cnt > 0) { + asm volatile( + "prfm pldl1keep, [%[in]] \n" /* preload din */ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ + "prfm pldl1keep, [%[w1]] \n" /* preload w1 */ + "prfm pldl1keep, [%[w2]] \n" /* preload w2 */ + "prfm pldl1keep, [%[w3]] \n" /* preload w3 */ + "prfm pldl1keep, [%[w4]] \n" /* preload w4 */ + "prfm pldl1keep, [%[w5]] \n" /* preload w5 */ + "prfm pldl1keep, [%[w6]] \n" /* preload w6 */ + "prfm pldl1keep, [%[w7]] \n" /* preload w7 */ + "movi v0.4s, #0 \n" /* set out0 to 0 */ + "movi v1.4s, #0 \n" /* set out1 to 0 */ + "movi v2.4s, #0 \n" /* set out2 to 0 */ + "movi v3.4s, #0 \n" /* set out3 to 0 */ + "movi v4.4s, #0 \n" /* set out4 to 0 */ + "movi v5.4s, #0 \n" /* set out5 to 0 */ + "movi v6.4s, #0 \n" /* set out6 to 0 */ + "movi v7.4s, #0 \n" /* set out7 to 0 */ + /* main loop */ + "1: \n" /* main loop */ + "ldr q8, [%[in]], #16 \n" /* load input, 16 int8 */ + "ldr q9, [%[w0]], #16 \n" /* load w0, 16 int8 */ + "ldr q10, [%[w1]], #16 \n" /* load w0, 16 int8 */ + "ldr q11, [%[w2]], #16 \n" /* load w0, 16 int8 */ + "ldr q12, [%[w3]], #16 \n" /* load w0, 16 int8 */ + "ldr q13, [%[w4]], #16 \n" /* load w0, 16 int8 */ + "ldr q14, [%[w5]], #16 \n" /* load w0, 16 int8 */ + "ldr q15, [%[w6]], #16 \n" /* load w0, 16 int8 */ + "ldr q16, [%[w7]], #16 \n" /* load w0, 16 int8 */ + + ".word 0x4e899500 // sdot v0.4s, v8.16b, v9.16b \n" /* out0, out1, + out2, out3 + */ + ".word 0x4e8a9501 // sdot v1.4s, v8.16b, v10.16b \n" /* out4, out5, + out6, out7 + */ + ".word 0x4e8b9502 // sdot v2.4s, v8.16b, v11.16b \n" /* out0, out1, + out2, out3 + */ + ".word 0x4e8c9503 // sdot v3.4s, v8.16b, v12.16b \n" /* out4, out5, + out6, out7 + */ + "subs %w[cnt], %w[cnt], #1 \n" + ".word 0x4e8d9504 // sdot v4.4s, v8.16b, v13.16b \n" /* out0, out1, + out2, out3 + */ + ".word 0x4e8e9505 // sdot v5.4s, v8.16b, v14.16b \n" /* out4, out5, + out6, out7 + */ + ".word 0x4e8f9506 // sdot v6.4s, v8.16b, v15.16b \n" /* out0, out1, + out2, out3 + */ + ".word 0x4e909507 // sdot v7.4s, v8.16b, v16.16b \n" /* out4, out5, + out6, out7 + */ + "bne 1b \n" /* jump to main loop */ + /* pair add to final result */ + "2: \n" /* reduce to scale */ + "addp v10.4s , v0.4s , v1.4s \n" /* pair add to 4 int32*/ + "addp v11.4s , v2.4s , v3.4s \n" /* pair add to 4 int32*/ + "addp v12.4s , v4.4s , v5.4s \n" /* pair add to 4 int32*/ + "addp v13.4s , v6.4s , v7.4s \n" /* pair add to 4 int32*/ + + "addp v0.4s , v10.4s , v11.4s \n" /* pair add to 4 int32*/ + "addp v1.4s , v12.4s , v13.4s \n" /* pair add to 4 int32*/ + /* write to output */ + "stp q0, q1, [%[out]] \n" /* save result */ + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [w4] "+r"(ptr_w4), + [w5] "+r"(ptr_w5), + [w6] "+r"(ptr_w6), + [w7] "+r"(ptr_w7), + [cnt] "+r"(cnt_loop) + : [out] "r"(ptr_out) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); + } + for (int i = 0; i < tail; ++i) { + ptr_out[0] += ptr_in[i] * ptr_w0[i]; + ptr_out[1] += ptr_in[i] * ptr_w1[i]; + ptr_out[2] += ptr_in[i] * ptr_w2[i]; + ptr_out[3] += ptr_in[i] * ptr_w3[i]; + ptr_out[4] += ptr_in[i] * ptr_w4[i]; + ptr_out[5] += ptr_in[i] * ptr_w5[i]; + ptr_out[6] += ptr_in[i] * ptr_w6[i]; + ptr_out[7] += ptr_in[i] * ptr_w7[i]; + } + write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu); + } +//! deal with remains +#pragma omp parallel for + for (int j = size_m; j < M; j++) { + // int *ptr_out = data_out + j; + dtype* out_ptr = data_out + j; + const float* scale_ptr = scale + j; + int ptr_out[1] = {0}; + const int8_t* ptr_in = data_in; + const int8_t* ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + auto bias_ptr = is_bias ? bias + j : nullptr; + asm volatile( + "prfm pldl1keep, [%[in]] \n" /* preload din */ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ + "cmp %w[cnt], #1 \n" /* check whether has main loop */ + "movi v0.4s, #0 \n" /* set out0 to 0 */ + /* check main loop */ + "blt 2f \n" /* jump to tail */ + /* main loop */ + "1: \n" /* main loop */ + "ldr q8, [%[in]], #16 \n" /* load input, 16 int8 */ + "ldr q9, [%[w0]], #16 \n" /* load w0, 16 int8 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ + /* mul, lower 8 int8 * int8 = int16 */ + ".word 0x4e899500 // sdot v0.4s, v8.16b, v9.16b \n" + "bne 1b \n" /* jump to main loop */ + /* pair add to final result */ + "2: \n" /* reduce to scale */ + "addp v1.4s, v0.4s, v0.4s \n" /* reduction to out0 */ + "addp v2.4s, v1.4s, v1.4s \n" /* reduction to out0 */ + /* write to output */ + "str s2, [%[out]] \n" /* save result */ + : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) + : [out] "r"(ptr_out) + : "cc", "memory", "v0", "v1", "v2", "v8", "v9", "v18"); + for (int i = 0; i < tail; ++i) { + ptr_out[0] += ptr_in[i] * ptr_w0[i]; } - write_gemv_out(ptr_out, out_ptr, scale_ptr); + write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu); } -#endif //__aarch64__ // NOLINT return true; } +#endif // __aarch64__ && sdot + +template <> +bool gemv_int8(const int8_t* A, + const int8_t* x, + float* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const float* bias, + bool is_relu, + const ARMContext* ctx) { +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + if (ctx->has_dot()) { + return gemv_int8_sdot( + A, x, y, transA, M, N, scale, is_bias, bias, is_relu); + } else { + return gemv_int8_oth( + A, x, y, transA, M, N, scale, is_bias, bias, is_relu); + } +#else + return gemv_int8_oth( + A, x, y, transA, M, N, scale, is_bias, bias, is_relu); +#endif +} -template bool gemv_int8(const int8_t* A, - const int8_t* x, - float* y, - bool transA, - int M, - int N, - const float* scale, - bool is_bias, - const int* bias, - bool is_relu); -template bool gemv_int8(const int8_t* A, - const int8_t* x, - int* y, - bool transA, - int M, - int N, - const float* scale, - bool is_bias, - const int* bias, - bool is_relu); -template bool gemv_int8(const int8_t* A, - const int8_t* x, - signed char* y, - bool transA, - int M, - int N, - const float* scale, - bool is_bias, - const int* bias, - bool is_relu); +template <> +bool gemv_int8(const int8_t* A, + const int8_t* x, + int8_t* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const float* bias, + bool is_relu, + const ARMContext* ctx) { +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + if (ctx->has_dot()) { + return gemv_int8_sdot( + A, x, y, transA, M, N, scale, is_bias, bias, is_relu); + } else { + return gemv_int8_oth( + A, x, y, transA, M, N, scale, is_bias, bias, is_relu); + } +#else + return gemv_int8_oth( + A, x, y, transA, M, N, scale, is_bias, bias, is_relu); +#endif +} } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/gemv_arm_int8.h b/lite/backends/arm/math/gemv_arm_int8.h index 302112069507f2dcf973cae0b4a90c6c69196a58..51c10ea18fe398091cf86fe4319eb03e2564fd93 100644 --- a/lite/backends/arm/math/gemv_arm_int8.h +++ b/lite/backends/arm/math/gemv_arm_int8.h @@ -14,7 +14,7 @@ #pragma once #include -#include "lite/core/device_info.h" +#include "lite/core/context.h" namespace paddle { namespace lite { @@ -30,9 +30,10 @@ bool gemv_int8(const int8_t* A, int M, int N, const float* scale, - bool is_bias = false, - const int* bias = nullptr, - bool is_relu = false); + bool is_bias, + const float* bias, + bool is_relu, + const ARMContext* ctx); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/increment.cc b/lite/backends/arm/math/increment.cc index 094fe78de9cbb66445dc2e486e246d5503b06869..583ff52077e720510e66fcdb9604d1dc8992a90d 100644 --- a/lite/backends/arm/math/increment.cc +++ b/lite/backends/arm/math/increment.cc @@ -21,10 +21,10 @@ namespace paddle { namespace lite { namespace arm { namespace math { -void increment(const int* input, +void increment(const float* input, const int n, const float step, - int* out, + float* out, Context* ctx) { for (int i = 0; i < n; i++) { out[i] = input[i] + step; diff --git a/lite/backends/arm/math/increment.h b/lite/backends/arm/math/increment.h index 80aec628854d37f40c5167268e12749ddd0c4974..028db0fd55e9507aa4f161339e4a8b0cd2e59ffe 100644 --- a/lite/backends/arm/math/increment.h +++ b/lite/backends/arm/math/increment.h @@ -21,10 +21,10 @@ namespace paddle { namespace lite { namespace arm { namespace math { -void increment(const int* input, +void increment(const float* input, const int n, const float step, - int* out, + float* out, Context* ctx); } // namespace math diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index 9770ee3d90ae2c13fb0fc0a48f9bd3aca320eb34..f89410ad11590c60bf5542702b60fa883298d3e6 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -83,8 +83,8 @@ void bilinear_interp(const float* src, beta[dy * 2 + 1] = fy; } } else { - scale_x = static_cast(w_in / w_out); - scale_y = static_cast(h_in / h_out); + scale_x = static_cast(w_in) / w_out; + scale_y = static_cast(h_in) / h_out; // calculate x axis coordinate for (int dx = 0; dx < w_out; dx++) { fx = scale_x * (dx + 0.5f) - 0.5f; @@ -459,8 +459,10 @@ void nearest_interp(const float* src, #pragma omp parallel for collapse(2) schedule(static) for (int h = 0; h < h_out; ++h) { for (int w = 0; w < w_out; ++w) { - int near_x = static_cast(scale_w_new * w + 0.5); - int near_y = static_cast(scale_h_new * h + 0.5); + int near_x = (with_align) ? static_cast(scale_w_new * w + 0.5) + : static_cast(scale_w_new * w); + int near_y = (with_align) ? static_cast(scale_h_new * h + 0.5) + : static_cast(scale_h_new * h); near_x = near_x < 0 ? 0 : near_x; near_y = near_y < 0 ? 0 : near_y; dst[h * w_out + w] = src[near_y * w_in + near_x]; diff --git a/lite/backends/arm/math/norm.cc b/lite/backends/arm/math/norm.cc index 4780ef68c131ab1de2fcd028006dd5707ebd2e60..6114c919cc686d55713ec9ad34e2183480c65e32 100644 --- a/lite/backends/arm/math/norm.cc +++ b/lite/backends/arm/math/norm.cc @@ -15,6 +15,7 @@ #include "lite/backends/arm/math/norm.h" #include #include +#include "lite/backends/arm/math/funcs.h" #include "lite/utils/cp_logging.h" namespace paddle { @@ -43,7 +44,143 @@ void norm(const float* input, } } } - LOG(INFO) << "norm math finished"; +} + +void matrix_norm_row(const float* x_data, + const float* scale_data, + const float* bias_data, + float* out_data, + float* mean_out, + float* var_out, + float epsilon, + int batch_size, + int feature_size) { + int cnt = feature_size >> 4; + int remain = feature_size & 0xf; +#pragma omp parallel for + + for (int bi = 0; bi < batch_size; ++bi) { + int offset = bi * feature_size; + const float* x_ptr = x_data + offset; + float mean = 0.f; + float variance = 0.f; + + // get mean and variance + float32x4_t mean_v = vdupq_n_f32(0); + float32x4_t var_v = vdupq_n_f32(0); + for (int oi = 0; oi < cnt; ++oi) { + float32x4_t odim1 = vld1q_f32(x_ptr); + float32x4_t odim2 = vld1q_f32(x_ptr + 4); + float32x4_t odim3 = vld1q_f32(x_ptr + 8); + float32x4_t odim4 = vld1q_f32(x_ptr + 12); + + mean_v = vaddq_f32(mean_v, odim1); + mean_v = vaddq_f32(mean_v, odim2); + mean_v = vaddq_f32(mean_v, odim3); + mean_v = vaddq_f32(mean_v, odim4); + + var_v = vmlaq_f32(var_v, odim1, odim1); + var_v = vmlaq_f32(var_v, odim2, odim2); + var_v = vmlaq_f32(var_v, odim3, odim3); + var_v = vmlaq_f32(var_v, odim4, odim4); + + x_ptr += 16; + } + mean = vgetq_lane_f32(mean_v, 0) + vgetq_lane_f32(mean_v, 1) + + vgetq_lane_f32(mean_v, 2) + vgetq_lane_f32(mean_v, 3); + variance = vgetq_lane_f32(var_v, 0) + vgetq_lane_f32(var_v, 1) + + vgetq_lane_f32(var_v, 2) + vgetq_lane_f32(var_v, 3); + for (int i = 0; i < remain; ++i) { + mean += *x_ptr; + variance += (*x_ptr) * (*x_ptr); + ++x_ptr; + } + mean /= feature_size; + variance = variance / feature_size - mean * mean; + mean_out[bi] = mean; + var_out[bi] = variance; + + variance = sqrtf(variance + epsilon); + float rvar = 1 / variance; + // compute norm_out + float* out_ptr = out_data + offset; + x_ptr = x_data + offset; + + auto* scale_ptr = scale_data; + auto* bias_ptr = bias_data; + + float32x4_t vneg = vdupq_n_f32(-1); + + float32x4_t scale1 = vdupq_n_f32(1); + float32x4_t scale2 = vdupq_n_f32(1); + float32x4_t scale3 = vdupq_n_f32(1); + float32x4_t scale4 = vdupq_n_f32(1); + + float32x4_t bias1 = vdupq_n_f32(0); + float32x4_t bias2 = vdupq_n_f32(0); + float32x4_t bias3 = vdupq_n_f32(0); + float32x4_t bias4 = vdupq_n_f32(0); + + for (int oi = 0; oi < cnt; ++oi) { + float32x4_t odim1 = vld1q_f32(x_ptr); + float32x4_t odim2 = vld1q_f32(x_ptr + 4); + float32x4_t odim3 = vld1q_f32(x_ptr + 8); + float32x4_t odim4 = vld1q_f32(x_ptr + 12); + + odim1 = vmlaq_n_f32(odim1, vneg, mean); + odim2 = vmlaq_n_f32(odim2, vneg, mean); + odim3 = vmlaq_n_f32(odim3, vneg, mean); + odim4 = vmlaq_n_f32(odim4, vneg, mean); + + if (scale_data) { + scale1 = vld1q_f32(scale_ptr); + scale2 = vld1q_f32(scale_ptr + 4); + scale3 = vld1q_f32(scale_ptr + 8); + scale4 = vld1q_f32(scale_ptr + 12); + scale_ptr += 16; + } + if (bias_data) { + bias1 = vld1q_f32(bias_ptr); + bias2 = vld1q_f32(bias_ptr + 4); + bias3 = vld1q_f32(bias_ptr + 8); + bias4 = vld1q_f32(bias_ptr + 12); + bias_ptr += 16; + } + + float32x4_t os1 = vmulq_n_f32(scale1, rvar); + float32x4_t os2 = vmulq_n_f32(scale2, rvar); + float32x4_t os3 = vmulq_n_f32(scale3, rvar); + float32x4_t os4 = vmulq_n_f32(scale4, rvar); + + odim1 = vmlaq_f32(bias1, odim1, os1); + odim2 = vmlaq_f32(bias2, odim2, os2); + odim3 = vmlaq_f32(bias3, odim3, os3); + odim4 = vmlaq_f32(bias4, odim4, os4); + + vst1q_f32(out_ptr, odim1); + vst1q_f32(out_ptr + 4, odim2); + vst1q_f32(out_ptr + 8, odim3); + vst1q_f32(out_ptr + 12, odim4); + + x_ptr += 16; + out_ptr += 16; + } + for (int i = 0; i < remain; ++i) { + auto out_value = (*x_ptr - mean) / variance; + if (scale_data) { + out_value = out_value * (*scale_ptr); + ++scale_ptr; + } + if (bias_data) { + out_value = out_value + *bias_ptr; + ++bias_ptr; + } + *out_ptr = out_value; + + ++out_ptr; + ++x_ptr; + } + } // for bi } } // namespace math diff --git a/lite/backends/arm/math/norm.h b/lite/backends/arm/math/norm.h index 503d2c5af4840d21f4c7fc19ce9ad8c006499fd4..63d28b301e48f47cc85f3f4dfa7e2b23a55a6eec 100644 --- a/lite/backends/arm/math/norm.h +++ b/lite/backends/arm/math/norm.h @@ -29,6 +29,15 @@ void norm(const float* input, float* out, Context* ctx); +void matrix_norm_row(const float* x_data, + const float* scale_data, + const float* bias_data, + float* out_data, + float* mean_out, + float* var_out, + float epsilon, + int batch_size, + int feature_size); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/packed_sgemm.cc b/lite/backends/arm/math/packed_sgemm.cc index 77b3beae80b92d000c257c637e6211c5da646fc1..0d6eed9904902aa9539caf95172b0e4109e11f7d 100644 --- a/lite/backends/arm/math/packed_sgemm.cc +++ b/lite/backends/arm/math/packed_sgemm.cc @@ -169,7 +169,7 @@ void prepackA(TensorLite *tout, int group, bool is_trans, ARMContext *ctx) { - int hblock = get_hblock(ctx->arch()); + int hblock = get_hblock(ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; if (tout->numel() < group_size_round_up * group) { @@ -1516,6 +1516,7 @@ void loadb_trans( } } for (; x > 7; x -= 8) { + // clang-format off asm volatile( "ldp q0, q1, [%[inptr0]], #32\n" /* r0, a0~a7 */ "ldp q2, q3, [%[inptr1]], #32\n" /* r1, b0~b7 */ @@ -1638,40 +1639,12 @@ void loadb_trans( [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31", - "cc", - "memory"); + : "v0","v1","v2","v3","v4","v5", + "v6","v7","v8","v9","v10","v11","v12", + "v13","v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25","v26", + "v27","v28","v29","v30","v31","cc","memory"); + // clang-format on } for (; x > 0; x--) { @@ -2135,7 +2108,7 @@ void sgemm_prepacked_8x12(bool is_transB, const float *a_ptr = a_ptr_l; int tail = tail_pre; int k = k_pre; - + // clang-format off asm volatile( "prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/ "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ @@ -2596,40 +2569,13 @@ void sgemm_prepacked_8x12(bool is_transB, [relu] "r"(has_relu), [has_beta] "r"(has_beta), [beta] "r"(beta) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25", + "v26","v27","v28","v29","v30","v31"); + // clang-format on if (flag_p_remain && (xb == bblocks - 1)) { for (int i = 0; i < remain; ++i) { *pout0++ = cout0[i]; @@ -2799,6 +2745,7 @@ void sgemm_prepacked_6x8(bool is_transB, const float* a_ptr = a_ptr_l; int tails = tail_pre; int k = k_pre; + // clang-format off asm volatile( // sgemm 6x8 "vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n" @@ -2826,7 +2773,7 @@ void sgemm_prepacked_6x8(bool is_transB, "pld [%[b_ptr], #320] @ preload b\n" "vdup.i32 q11,d3[1] @ out31=0\n" "pld [%[b_ptr], #384] @ preload b\n" - "cmp %[has_beta], #0\n" + "cmp %[beta], #0\n" "beq 11f\n" /* check beta == 0? */ /* process beta */ "vdup.32 q3, %[beta]\n" /* beta to vector */ @@ -3082,26 +3029,11 @@ void sgemm_prepacked_6x8(bool is_transB, [tails] "+r"(tails) : [bias_ptr] "r"(bias_local), [relu] "r"(has_relu), - [has_beta] "r"(has_beta), [beta] "r"(beta) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "cc", - "memory"); + : "q0","q1","q2","q3","q4", + "q5","q6","q7","q8","q9","q10","q11", + "q12","q13","q14","q15","cc","memory"); + // clang-format on if (flag_p_remain && (xb == bblocks - 1)) { for (int i = 0; i < remain; ++i) { @@ -3243,6 +3175,7 @@ void sgemm_prepacked_4x8(bool is_transB, const float* a_ptr = a_ptr_l; int tails = tail_pre; int k = k_pre; + // clang-format off asm volatile( "vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n" "vdup.32 q8, d4[0] @ add bias to out00\n" @@ -3260,7 +3193,7 @@ void sgemm_prepacked_4x8(bool is_transB, "pld [%[b_ptr], #128] @ preload b\n" "vdup.32 q15, d5[1] @ add bias to out31\n" "pld [%[b_ptr], #192] @ preload b\n" - "cmp %[has_beta], #0\n" + "cmp %[beta], #0\n" "beq 11f\n" /* check beta == 0? */ /* process beta */ "vdup.32 q4, %[beta]\n" /* beta to vector */ @@ -3440,27 +3373,11 @@ void sgemm_prepacked_4x8(bool is_transB, [tails] "+r"(tails) : [bias_ptr] "r"(bias_local), [relu] "r"(has_relu), - [has_beta] "r"(has_beta), [beta] "r"(beta) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "cc", - "memory"); - + : "q0","q1","q2","q3", + "q4","q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15","cc","memory"); + // clang-format on if (flag_p_remain && (xb == bblocks - 1)) { for (int i = 0; i < remain; ++i) { *pout0++ = cout0[i]; diff --git a/lite/backends/arm/math/packed_sgemm.h b/lite/backends/arm/math/packed_sgemm.h index 396ca7beb9cd3e8310c11f69946bd0f3ce6b7017..6c14cdb2ef62558a53c765719107d68da678246b 100644 --- a/lite/backends/arm/math/packed_sgemm.h +++ b/lite/backends/arm/math/packed_sgemm.h @@ -16,7 +16,6 @@ #include #include "lite/core/context.h" -#include "lite/core/device_info.h" #include "lite/core/tensor.h" namespace paddle { @@ -28,14 +27,14 @@ namespace math { constexpr int MBLOCK = 8; constexpr int NBLOCK = 12; constexpr int KBLOCK = 4; -inline int get_hblock(ARMArch arch) { return MBLOCK; } +inline int get_hblock(ARMContext* ctx) { return MBLOCK; } #else constexpr int MBLOCK_A73 = 4; constexpr int MBLOCK_OTH = 6; constexpr int NBLOCK = 8; constexpr int KBLOCK = 4; -inline int get_hblock(ARMArch arch) { - if (arch == kA73) { +inline int get_hblock(ARMContext* ctx) { + if (ctx->arch() == kA73) { return MBLOCK_A73; } else { return MBLOCK_OTH; diff --git a/lite/backends/arm/math/prior_box.cc b/lite/backends/arm/math/prior_box.cc index f262e6e1d7318e6c42e8b666239be4ae500788fe..6daab69ebf00da24d67132afba4b9abef0afbd39 100644 --- a/lite/backends/arm/math/prior_box.cc +++ b/lite/backends/arm/math/prior_box.cc @@ -63,7 +63,8 @@ void density_prior_box(const lite::Tensor* input, int prior_num_, bool is_flip_, bool is_clip_, - const std::vector& order_) { + const std::vector& order_, + bool min_max_aspect_ratios_order) { // compute output shape int win1 = input->dims()[3]; int hin1 = input->dims()[2]; @@ -284,12 +285,21 @@ void density_prior_box(const lite::Tensor* input, //! ymax com_buf[com_idx++] = (center_y + box_height / 2.f) / img_height; } - memcpy(_cpu_data + idx, min_buf, sizeof(float) * min_idx); - idx += min_idx; - memcpy(_cpu_data + idx, com_buf, sizeof(float) * com_idx); - idx += com_idx; - memcpy(_cpu_data + idx, max_buf, sizeof(float) * max_idx); - idx += max_idx; + if (min_max_aspect_ratios_order) { + memcpy(_cpu_data + idx, min_buf, sizeof(float) * min_idx); + idx += min_idx; + memcpy(_cpu_data + idx, max_buf, sizeof(float) * max_idx); + idx += max_idx; + memcpy(_cpu_data + idx, com_buf, sizeof(float) * com_idx); + idx += com_idx; + } else { + memcpy(_cpu_data + idx, min_buf, sizeof(float) * min_idx); + idx += min_idx; + memcpy(_cpu_data + idx, com_buf, sizeof(float) * com_idx); + idx += com_idx; + memcpy(_cpu_data + idx, max_buf, sizeof(float) * max_idx); + idx += max_idx; + } } fast_free(min_buf); fast_free(max_buf); @@ -333,7 +343,8 @@ void prior_box(const lite::Tensor* input, int prior_num, bool is_flip, bool is_clip, - const std::vector& order) { + const std::vector& order, + bool min_max_aspect_ratios_order) { density_prior_box(input, image, boxes, @@ -353,7 +364,8 @@ void prior_box(const lite::Tensor* input, prior_num, is_flip, is_clip, - order); + order, + min_max_aspect_ratios_order); } } // namespace math diff --git a/lite/backends/arm/math/prior_box.h b/lite/backends/arm/math/prior_box.h index ffa821b75e54ee3e2329e4dcced8ddee2a003802..03fd62751081e491ddfb23f196d52153db5d3a5f 100644 --- a/lite/backends/arm/math/prior_box.h +++ b/lite/backends/arm/math/prior_box.h @@ -42,7 +42,8 @@ void density_prior_box(const lite::Tensor* input, int prior_num_, bool is_flip_, bool is_clip_, - const std::vector& order_); + const std::vector& order_, + bool min_max_aspect_ratios_order); void prior_box(const lite::Tensor* input, const lite::Tensor* image, @@ -60,7 +61,8 @@ void prior_box(const lite::Tensor* input, int prior_num, bool is_flip, bool is_clip, - const std::vector& order); + const std::vector& order, + bool min_max_aspect_ratios_order); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/sgemm.cc b/lite/backends/arm/math/sgemm.cc index 93f64445e289dcfd17bf2de48e89fcb2c907a7a9..f3123ddd718ee61b6430d2b7f14480b79435291a 100644 --- a/lite/backends/arm/math/sgemm.cc +++ b/lite/backends/arm/math/sgemm.cc @@ -36,8 +36,7 @@ void sgemm(bool is_transA, bool is_bias, bool is_relu, ARMContext* ctx) { - auto arch = ctx->arch(); - int hblock = get_hblock(arch); + int hblock = get_hblock(ctx); int m_roundup = hblock * ((M + hblock - 1) / hblock); auto packed_A = static_cast( diff --git a/lite/backends/arm/math/yolo_box.cc b/lite/backends/arm/math/yolo_box.cc index 72e67cf69331ac2e0fa6edc7c8cd4a99ee763071..7ddb262480bbc427cda68b199a39fdef50a214c3 100644 --- a/lite/backends/arm/math/yolo_box.cc +++ b/lite/backends/arm/math/yolo_box.cc @@ -108,7 +108,7 @@ void yolobox(lite::Tensor* X, auto anchors_data = anchors.data(); const float* X_data = X->data(); - float* ImgSize_data = ImgSize->mutable_data(); + int* ImgSize_data = ImgSize->mutable_data(); float* Boxes_data = Boxes->mutable_data(); @@ -116,8 +116,8 @@ void yolobox(lite::Tensor* X, float box[4]; for (int i = 0; i < n; i++) { - int img_height = static_cast(ImgSize_data[2 * i]); - int img_width = static_cast(ImgSize_data[2 * i + 1]); + int img_height = ImgSize_data[2 * i]; + int img_width = ImgSize_data[2 * i + 1]; for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { diff --git a/lite/backends/cuda/CMakeLists.txt b/lite/backends/cuda/CMakeLists.txt index 03ca30ecf04dcaa415dccdecfe5e34e19087ec7d..a6c3fcc66a789f159cd3a756ed893627b393e1fe 100644 --- a/lite/backends/cuda/CMakeLists.txt +++ b/lite/backends/cuda/CMakeLists.txt @@ -1,7 +1,10 @@ if(NOT LITE_WITH_CUDA) return() endif() +set(cuda_static_deps cudnn_static cublas_static curand_static + culibos_static cudart_static) -nv_library(target_wrapper_cuda SRCS target_wrapper.cc) -nv_library(cuda_blas SRCS blas.cc) +nv_library(target_wrapper_cuda SRCS target_wrapper.cc DEPS ${cuda_static_deps}) +nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_static_deps}) +add_subdirectory(math) diff --git a/lite/backends/cuda/blas.h b/lite/backends/cuda/blas.h index f73bb576b8dd5ecad178ba69a9208b2286c050ab..058b961f3678197a3a6719a3337e0decac78564f 100644 --- a/lite/backends/cuda/blas.h +++ b/lite/backends/cuda/blas.h @@ -30,10 +30,8 @@ namespace cuda { * Some basic methods. */ struct BlasBase { - /* BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); } ~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); } - */ void SetStream(cudaStream_t stream) { CUBLAS_CHECK(cublasSetStream(handle_, stream)); diff --git a/lite/backends/cuda/cuda_utils.h b/lite/backends/cuda/cuda_utils.h index 0db3c4b179d3e395dd4379ca1603733fba0a55db..13bf8190efe1592e7509039a569d31f6bddc5b66 100644 --- a/lite/backends/cuda/cuda_utils.h +++ b/lite/backends/cuda/cuda_utils.h @@ -17,6 +17,7 @@ #include #include #include +#include #include "lite/utils/cp_logging.h" /* @@ -46,6 +47,15 @@ << "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \ } +#define CUDNN_VERSION_MIN(major, minor, patch) \ + (CUDNN_VERSION >= (major * 1000 + minor * 100 + patch)) + +#define CUDNN_CHECK(condition) \ + { \ + cudnnStatus_t status = condition; \ + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << CudnnGetErrorInfo(status); \ + } + namespace paddle { namespace lite { namespace cuda { @@ -71,6 +81,44 @@ static const char* CublasErrorInfo(int error) { } } +static const char* CudnnGetErrorInfo(cudnnStatus_t status) { + switch (status) { + case CUDNN_STATUS_SUCCESS: + return "CUDNN_STATUS_SUCCESS"; + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDNN_STATUS_NOT_INITIALIZED"; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDNN_STATUS_ALLOC_FAILED"; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_INTERNAL_ERROR: + return "CUDNN_STATUS_INTERNAL_ERROR"; + case CUDNN_STATUS_INVALID_VALUE: + return "CUDNN_STATUS_INVALID_VALUE"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH"; + case CUDNN_STATUS_MAPPING_ERROR: + return "CUDNN_STATUS_MAPPING_ERROR"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_LICENSE_ERROR: + return "CUDNN_STATUS_LICENSE_ERROR"; +#if CUDNN_VERSION_MIN(6, 0, 0) + case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING: + return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING"; +#endif +#if CUDNN_VERSION_MIN(7, 0, 0) + case CUDNN_STATUS_RUNTIME_IN_PROGRESS: + return "CUDNN_STATUS_RUNTIME_IN_PROGRESS"; + case CUDNN_STATUS_RUNTIME_FP_OVERFLOW: + return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW"; +#endif + } + return "Unknown cudnn status"; +} + } // namespace cuda } // namespace lite } // namespace paddle diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5ee25643b4c87c9488df5b2acaead26773855a9 --- /dev/null +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -0,0 +1,26 @@ +if(NOT LITE_WITH_CUDA) + return() +endif() + +set(cuda_static_deps cudnn_static cublas_static curand_static + culibos_static cudart_static) + +nv_library(cuda_activation SRCS activation.cu DEPS ${cuda_static_deps}) +nv_library(cuda_scale SRCS scale.cu DEPS ${cuda_static_deps}) +nv_library(cuda_type_trans SRCS type_trans.cu DEPS ${cuda_static_deps}) +nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps}) +nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale +cuda_type_trans ${cuda_static_deps}) +nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) + +set ( + math_cuda + cudnn_conv + cuda_activation + cuda_scale + cuda_type_trans + cuda_transpose + cuda_elementwise +) + +set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda") diff --git a/lite/backends/cuda/math/activation.cu b/lite/backends/cuda/math/activation.cu new file mode 100644 index 0000000000000000000000000000000000000000..508da6a2b470ad346063eb35e6d5b9cfdcf0f6e6 --- /dev/null +++ b/lite/backends/cuda/math/activation.cu @@ -0,0 +1,442 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/cuda/math/activation.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__global__ void relu_kernel(const int num, + const T alpha, + const T* input, + T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { +#if __CUDA_ARCH__ >= 350 + output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) + : __ldg(input + index) * alpha; +#else + output[index] = input[index] >= 0 ? input[index] : input[index] * alpha; +#endif + } +} + +template +__global__ void bias_relu_kernel(const int num, + const T alpha, + const T* input, + T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { +#if __CUDA_ARCH__ >= 350 + output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) + : __ldg(input + index) * alpha; +#else + output[index] = input[index] >= 0 ? input[index] : input[index] * alpha; +#endif + } +} + +template +__global__ void bias_relu_int8_nhwc_kernel(int num, + const float* in, + const float* bias, + Dtype* out, + int N, + int C, + int H, + int W, + const float* scale, + float alpha) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int idx = tid % C; +#if __CUDA_ARCH__ >= 350 + float temp = __ldg(in + tid) * __ldg(scale + idx) + __ldg(bias + idx); + out[tid] = + temp > 0 ? from_float(temp) : from_float(temp * alpha); +#else + float temp = in[tid] * scale[idx] + bias[idx]; + out[tid] = + temp > 0 ? from_float(temp) : from_float(temp * alpha); +#endif + } +} + +__global__ void bias_relu_int8_nhwc4_kernel(int num, + const float4* in, + const float4* bias, + float4* out, + int N, + int K, + int H, + int W, + const float4* scale, + float alpha) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int bias_idx = tid % K; + const float4 bias_ptr = bias[bias_idx]; + const float4 scale_ptr = scale[bias_idx]; + const float4 in_ptr = in[tid]; + + float4 packed_val; + packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x; + packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x); + packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y; + packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y); + packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z; + packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z); + packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w; + packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w); + out[tid] = packed_val; + } +} + +__global__ void bias_relu_int8_nhwc4_kernel(int num, + const float4* in, + const float4* bias, + char4* out, + int N, + int K, + int H, + int W, + const float4* scale, + float alpha) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int bias_idx = tid % K; + const float4 bias_ptr = bias[bias_idx]; + const float4 scale_ptr = scale[bias_idx]; + const float4 in_ptr = in[tid]; + + float4 packed_val; + char4 result_val; + packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x; + result_val.x = + from_float(fmaxf(packed_val.x * alpha, packed_val.x)); + packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y; + result_val.y = + from_float(fmaxf(packed_val.y * alpha, packed_val.y)); + packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z; + result_val.z = + from_float(fmaxf(packed_val.z * alpha, packed_val.z)); + packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w; + result_val.w = + from_float(fmaxf(packed_val.w * alpha, packed_val.w)); + + out[tid] = result_val; + } +} + +template +__global__ void bias_int8_nhwc_kernel(int num, + const float* in, + const float* bias, + Dtype* out, + int N, + int C, + int H, + int W, + const float* scale) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int idx = tid % C; +#if __CUDA_ARCH__ >= 350 + float temp = __ldg(in + tid) * __ldg(scale + idx) + __ldg(bias + idx); + out[tid] = from_float(temp); +#else + float temp = in[tid] * scale[idx] + bias[idx]; + out[tid] = from_float(temp); +#endif + } +} + +__global__ void relu_int8_nhwc4_kernel(int num, + const float4* in, + float4* out, + int N, + int K, + int H, + int W, + const float4* scale, + float alpha) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int scale_idx = tid % K; + const float4 scale_ptr = scale[scale_idx]; + const float4 in_ptr = in[tid]; + + float4 packed_val; + packed_val.x = in_ptr.x * scale_ptr.x; + packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x); + packed_val.y = in_ptr.y * scale_ptr.y; + packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y); + packed_val.z = in_ptr.z * scale_ptr.z; + packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z); + packed_val.w = in_ptr.w * scale_ptr.w; + packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w); + out[tid] = packed_val; + } +} + +__global__ void relu_int8_nhwc4_kernel(int num, + const float4* in, + char4* out, + int N, + int K, + int H, + int W, + const float4* scale, + float alpha) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int scale_idx = tid % K; + const float4 scale_ptr = scale[scale_idx]; + const float4 in_ptr = in[tid]; + + float4 packed_val; + char4 result_val; + packed_val.x = in_ptr.x * scale_ptr.x; + result_val.x = + from_float(fmaxf(packed_val.x * alpha, packed_val.x)); + packed_val.y = in_ptr.y * scale_ptr.y; + result_val.y = + from_float(fmaxf(packed_val.y * alpha, packed_val.y)); + packed_val.z = in_ptr.z * scale_ptr.z; + result_val.z = + from_float(fmaxf(packed_val.z * alpha, packed_val.z)); + packed_val.w = in_ptr.w * scale_ptr.w; + result_val.w = + from_float(fmaxf(packed_val.w * alpha, packed_val.w)); + + out[tid] = result_val; + } +} + +template <> +void bias_relu_int8_nhwc(int num, + const void* in, + const void* bias, + void* out, + int N, + int C, + int H, + int W, + const void* scale, + float alpha, + cudaStream_t stream) { + int thread = 256; + if (C % 4 == 0) { + int block = (num / 4 + thread - 1) / thread; + bias_relu_int8_nhwc4_kernel<<>>( + num / 4, + static_cast(in), + static_cast(bias), + static_cast(out), + N, + C / 4, + H, + W, + static_cast(scale), + alpha); + } else { + int block = (num + thread - 1) / thread; + bias_relu_int8_nhwc_kernel<<>>( + num, + static_cast(in), + static_cast(bias), + static_cast(out), + N, + C, + H, + W, + static_cast(scale), + alpha); + } +} + +template <> +void bias_relu_int8_nhwc(int num, + const void* in, + const void* bias, + void* out, + int N, + int C, + int H, + int W, + const void* scale, + float alpha, + cudaStream_t stream) { + int thread = 256; + if (C % 4 == 0) { + int block = (num / 4 + thread - 1) / thread; + bias_relu_int8_nhwc4_kernel<<>>( + num / 4, + static_cast(in), + static_cast(bias), + static_cast(out), + N, + C / 4, + H, + W, + static_cast(scale), + alpha); + } else { + int block = (num + thread - 1) / thread; + bias_relu_int8_nhwc_kernel<<>>( + num, + static_cast(in), + static_cast(bias), + static_cast(out), + N, + C, + H, + W, + static_cast(scale), + alpha); + } +} + +template +void bias_int8_nhwc(int num, + const void* in, + const void* bias, + void* out, + int N, + int C, + int H, + int W, + const void* scale, + cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + bias_int8_nhwc_kernel<<>>( + num, + static_cast(in), + static_cast(bias), + static_cast(out), + N, + C, + H, + W, + static_cast(scale)); +} + +template void bias_int8_nhwc(int, + const void*, + const void* bias, + void*, + int, + int, + int, + int, + const void*, + cudaStream_t); +template void bias_int8_nhwc(int, + const void*, + const void* bias, + void*, + int, + int, + int, + int, + const void*, + cudaStream_t); + +template <> +void relu_int8_nhwc4(int num, + const void* in, + void* out, + int N, + int K, + int H, + int W, + const void* scale, + float alpha, + cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + relu_int8_nhwc4_kernel<<>>( + num, + static_cast(in), + static_cast(out), + N, + K, + H, + W, + static_cast(scale), + alpha); +} + +template <> +void relu_int8_nhwc4(int num, + const void* in, + void* out, + int N, + int K, + int H, + int W, + const void* scale, + float alpha, + cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + relu_int8_nhwc4_kernel<<>>( + num, + static_cast(in), + static_cast(out), + N, + K, + H, + W, + static_cast(scale), + alpha); +} + +template +void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + relu_kernel<<>>(num, alpha, din, dout); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); +} + +template +void bias_relu(int num, + const T* din, + const float* bias, + T* dout, + float alpha, + cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + relu_kernel<<>>(num, alpha, din, dout); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); +} +template void relu(int, const float*, float*, float, cudaStream_t); +template void bias_relu( + int, const float*, const float* bias, float*, float, cudaStream_t); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/activation.h b/lite/backends/cuda/math/activation.h new file mode 100644 index 0000000000000000000000000000000000000000..273374a4ccddd6927010014d5e5544b97ee5e23c --- /dev/null +++ b/lite/backends/cuda/math/activation.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +// fp32 +template +void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream); + +template +void relu_int8_nhwc4(int num, + const void* in, + void* out, + int N, + int K, + int H, + int W, + const void* scale, + float alpha, + cudaStream_t stream); + +template +void bias_relu(int num, + const T* din, + const float* bias, + T* dout, + float alpha, + cudaStream_t stream); + +// For int8 +template +void bias_relu_int8_nhwc(int num, + const void* in, + const void* bias, + void* out, + int N, + int C, + int H, + int W, + const void* scale, + float alpha, + cudaStream_t stream); + +template +void bias_int8_nhwc(int num, + const void* in, + const void* bias, + void* out, + int N, + int C, + int H, + int W, + const void* scale, + cudaStream_t stream); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/conv_op_cache_cudnn.h b/lite/backends/cuda/math/conv_op_cache_cudnn.h new file mode 100644 index 0000000000000000000000000000000000000000..e1428ef00a00cceea45d6ea37e629b44d74e3c14 --- /dev/null +++ b/lite/backends/cuda/math/conv_op_cache_cudnn.h @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +// Not thread-safe. Should be owned per-kernel. +template +class AlgorithmsCache { + public: + AlgorithmsCache() : search_times_(0) { hash_.clear(); } + // Caches the best algorithm for a given + // combination of tensor dimensions & compute data type. + TAlgorithm GetAlgorithm( + const std::vector& dims1, + const std::vector& dims2, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int algorithmFlags, // can set for different data type + std::function gen_func); + + TAlgorithm GetAlgorithm(int64_t area, + int search_times, + int algorithmFlags, + std::function gen_func); + + private: + std::unordered_map hash_; + int search_times_; +}; + +template +TAlgorithm AlgorithmsCache::GetAlgorithm( + const std::vector& dims1, + const std::vector& dims2, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int algorithmFlags, + std::function gen_func) { + int64_t seed = 0; + // Hash all of the inputs, use to try and look up a previously + // discovered algorithm, or fall back to generating a new one. + std::hash hashFn; + // do hash like boost + // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x + for (const auto num : dims1) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + for (const auto num : dims2) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1; + } + + for (const auto num : strides) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 2; + } + + for (const auto num : paddings) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 3; + } + + for (const auto num : dilations) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 4; + } + + seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + + (seed << 6) + (seed >> 2) + 5; + + VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size(); + + if (seed == 0) return gen_func(); + + if (hash_.find(seed) == hash_.end()) { + TAlgorithm value = gen_func(); + hash_[seed] = value; + } + return hash_[seed]; +} + +template +TAlgorithm AlgorithmsCache::GetAlgorithm( + int64_t area, + int search_times, + int algorithmFlags, + std::function gen_func) { + if (hash_.find(area) != hash_.end()) { + return hash_[area]; + } + if (search_times_ < search_times) { + auto algo = gen_func(); + hash_[area] = algo; + ++search_times_; + return algo; + } + TAlgorithm algo{}; + int64_t min = static_cast(INT_MAX); + for (const auto& m : hash_) { + if (m.first < min) { + min = m.first; + algo = m.second; + } + } + return algo; +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..72ed3951f6b9b22a5ae1ee6caef8c69708102885 --- /dev/null +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -0,0 +1,565 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/cudnn_conv.h" +#include "lite/backends/cuda/math/activation.h" +#include "lite/backends/cuda/math/conv_op_cache_cudnn.h" +#include "lite/backends/cuda/math/scale.h" +#include "lite/backends/cuda/math/type_trans.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template <> +bool CudnnConv2D::create(const operators::ConvParam& param, + Context* ctx) { + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + int batch = x_dims[0]; + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int kh = w_dims[2]; + int sw = param.strides[1]; + int sh = param.strides[0]; + int pw = param.paddings[1]; + int ph = param.paddings[0]; + int dw = param.dilations[1]; + int dh = param.dilations[0]; + + CHECK(ic % param.groups == 0) + << "The conv input channel shoud be divide group number."; + + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_, + CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, + batch, + ic, + ih, + iw)); + CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_, + CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NCHW, + oc, + ic / param.groups, + kh, + kw)); + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_, + ph, + pw, + sh, + sw, + dh, + dw, + CUDNN_CROSS_CORRELATION, + CUDNN_DATA_FLOAT)); + CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_, + CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, + batch, + oc, + oh, + ow)); + + if (param.activation_param.has_active && with_relu_act_) { + CUDNN_CHECK(cudnnSetActivationDescriptor( + this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0)); + } + + if (ic == param.groups && ic == oc && ic != 1) { + this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + } else if (1) { + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + auto* o_data = param.output->mutable_data(TARGET(kCUDA)); + int workspace_size_limit = 256 * 1024 * 1024; + + auto search_func = [&]() { + int returned_algo_count; + std::array + fwd_perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( + this->handle_, + this->input_desc_, + i_data, + this->filter_desc_, + w_data, + this->conv_desc_, + this->output_desc_, + o_data, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT, + &returned_algo_count, + fwd_perf_stat.data(), + cudnn_workspace, + workspace_size_limit)); + }; + + ResetWorkSpace(); + CUDA_CALL(cudaMalloc(&this->workspace_data_, workspace_size_limit)); + cudnn_find_func(this->workspace_data_); + ResetWorkSpace(); + + VLOG(2) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = fwd_perf_stat[i]; + VLOG(2) << stat.algo << ": " << stat.status << " " << stat.time << " " + << stat.memory; + } + return fwd_perf_stat[0].algo; + }; + AlgorithmsCache algo_cache; + this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(), + w_dims.Vectorize(), + param.strides, + param.paddings, + param.dilations, + 0, + search_func); + + } else { + CUDNN_CHECK( + cudnnGetConvolutionForwardAlgorithm(this->handle_, + this->input_desc_, + this->filter_desc_, + this->conv_desc_, + this->output_desc_, + this->preference_, + this->workspace_limit_bytes_, + &this->fwd_algo_)); + } + CUDNN_CHECK( + cudnnGetConvolutionForwardWorkspaceSize(this->handle_, + this->input_desc_, + this->filter_desc_, + this->conv_desc_, + this->output_desc_, + this->fwd_algo_, + &this->workspace_fwd_sizes_)); + if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) { + this->workspace_size_inbytes_ = this->workspace_fwd_sizes_; + ResetWorkSpace(); + cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_); + this->workspace_ = reinterpret_cast(this->workspace_data_); + } + if (param.bias) { + int dim_bias[] = {1, oc, 1, 1}; + int stride_bias[] = {oc, 1, 1, 1}; + cudnnSetTensorNdDescriptor( + this->bias_desc_, CUDNN_DATA_FLOAT, 4, dim_bias, stride_bias); + } + return true; +} + +template <> +bool CudnnConv2D::init(const operators::ConvParam& param, + Context* ctx) { + this->workspace_size_inbytes_ = 0; + this->workspace_data_ = NULL; + this->workspace_fwd_sizes_ = 0; + + this->stream_ = ctx->exec_stream(); + CUDNN_CHECK(cudnnCreate(&this->handle_)); + CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_)); + + this->workspace_ = NULL; + + cudnnCreateTensorDescriptor(&this->input_desc_); + cudnnCreateTensorDescriptor(&this->output_desc_); + cudnnCreateFilterDescriptor(&this->filter_desc_); + cudnnCreateConvolutionDescriptor(&this->conv_desc_); + cudnnCreateTensorDescriptor(&this->bias_desc_); + + if (param.activation_param.has_active) { + if (param.activation_param.active_type == lite_api::ActivationType::kRelu) { + cudnnCreateActivationDescriptor(&this->act_desc_); + } else { + this->with_relu_act_ = false; + } + } + return create(param, ctx); +} + +template <> +bool CudnnConv2D::run(const operators::ConvParam& param) { + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(TARGET(kCUDA)); + + if (param.activation_param.has_active && with_relu_act_) { + if (b_data) { + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle_, + &alpha, + input_desc_, + i_data, + filter_desc_, + w_data, + conv_desc_, + fwd_algo_, + workspace_, + workspace_fwd_sizes_, + &beta, + output_desc_, + o_data, + bias_desc_, + b_data, + act_desc_, + output_desc_, + o_data)); + } else { + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CHECK(cudnnConvolutionForward(handle_, + &alpha, + input_desc_, + i_data, + filter_desc_, + w_data, + conv_desc_, + fwd_algo_, + workspace_, + workspace_fwd_sizes_, + &beta, + output_desc_, + o_data)); + + CUDNN_CHECK(cudnnActivationForward(handle_, + act_desc_, + &alpha, + output_desc_, + o_data, + &beta, + output_desc_, + o_data)); + } + } else { + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CHECK(cudnnConvolutionForward(handle_, + &alpha, + input_desc_, + i_data, + filter_desc_, + w_data, + conv_desc_, + fwd_algo_, + workspace_, + workspace_fwd_sizes_, + &beta, + output_desc_, + o_data)); + if (b_data) { + CUDNN_CHECK(cudnnAddTensor( + handle_, &alpha, bias_desc_, b_data, &alpha, output_desc_, o_data)); + } + } + + if (!with_relu_act_) { + CHECK(param.activation_param.active_type == + lite_api::ActivationType::kLeakyRelu) + << "Only support leaky relu now."; + auto out_dims = param.output->dims(); + int n = out_dims[0], c = out_dims[1], h = out_dims[2], w = out_dims[3]; + int num = n * h * w * c; + float alpha = param.activation_param.Leaky_relu_alpha; + + relu(num, o_data, o_data, alpha, this->stream_); + } + return true; +} + +template +bool CudnnConv2DInt8::create(const operators::ConvParam& param, + Context* ctx) { + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int batch = x_dims[0]; + + int iw = x_dims[2]; // nchw + int ih = x_dims[1]; + int ic = x_dims[3]; + int ow = o_dims[2]; + int oh = o_dims[1]; + int oc = o_dims[3]; + + int kw = w_dims[2]; + int kh = w_dims[1]; + + int sw = param.strides[1]; + int sh = param.strides[0]; + int pw = param.paddings[1]; + int ph = param.paddings[0]; + int dw = param.dilations[1]; + int dh = param.dilations[0]; + + std::vector weight_scale = param.weight_scale; + float input_scale = param.input_scale; + float output_scale = param.output_scale; + CHECK(weight_scale.size() == static_cast(oc)) + << "the num of the weight_scale should be equals to the output channel."; + if (Ptype_out == PRECISION(kInt8)) { + this->temp_tensor_.Resize(o_dims); + this->temp_tensor_.template mutable_data(TARGET(kCUDA)); + for (size_t i = 0; i < weight_scale.size(); i++) { + weight_scale[i] = (weight_scale[i] * input_scale) / output_scale; + } + + auto* b_data = param.bias ? param.bias->mutable_data() : nullptr; + if (b_data) { + scale(param.bias->numel(), b_data, b_data, 1.f / output_scale); + } + } else { + for (size_t i = 0; i < weight_scale.size(); i++) { + weight_scale[i] = (weight_scale[i] * input_scale); + } + } + this->scale_.Resize({oc}); + this->scale_.template Assign( + weight_scale.data(), this->scale_.dims()); + + CHECK(ic % param.groups == 0) + << "The conv input channel shoud be divide group number."; + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_, + CUDNN_TENSOR_NHWC, + CUDNN_DATA_INT8, + batch, + ic, + ih, + iw)); + CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_, + CUDNN_DATA_INT8, + CUDNN_TENSOR_NHWC, + oc, + ic / param.groups, + kh, + kw)); + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_, + ph, + pw, + sh, + sw, + dh, + dw, + CUDNN_CROSS_CORRELATION, + CUDNN_DATA_INT32)); + + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_, + CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, + batch, + oc, + oh, + ow)); + if (ic % 4 == 0 && oc % 4 == 0) { + this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + } else { + this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + } + CUDNN_CHECK( + cudnnGetConvolutionForwardWorkspaceSize(this->handle_, + this->input_desc_, + this->filter_desc_, + this->conv_desc_, + this->output_desc_, + this->fwd_algo_, + &this->workspace_fwd_sizes_)); + + if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) { + this->workspace_size_inbytes_ = this->workspace_fwd_sizes_; + if (this->workspace_data_ != NULL) { + CUDA_CALL(cudaFree(this->workspace_data_)); + } + CUDA_CALL( + cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_)); + this->workspace_ = reinterpret_cast(this->workspace_data_); + } + + return true; +} + +template +bool CudnnConv2DInt8::init(const operators::ConvParam& param, + Context* ctx) { + this->workspace_size_inbytes_ = 0; // 64Mb + this->workspace_data_ = NULL; + this->workspace_fwd_sizes_ = 0; + + this->stream_ = ctx->exec_stream(); + CUDNN_CHECK(cudnnCreate(&this->handle_)); + CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_)); + + this->workspace_ = NULL; + + cudnnCreateTensorDescriptor(&this->input_desc_); + cudnnCreateTensorDescriptor(&this->output_desc_); + cudnnCreateFilterDescriptor(&this->filter_desc_); + cudnnCreateConvolutionDescriptor(&this->conv_desc_); + cudnnCreateTensorDescriptor(&this->bias_desc_); + + if (param.activation_param.has_active) { + if (!(param.activation_param.active_type == + lite_api::ActivationType::kRelu)) { + this->with_relu_act_ = false; + } + } + return create(param, ctx); +} + +template +bool CudnnConv2DInt8::run(const operators::ConvParam& param) { + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + float* temp_out; + float* scale = this->scale_.template mutable_data(TARGET(kCUDA)); + if (Ptype_out == PRECISION(kInt8)) { + temp_out = this->temp_tensor_.template mutable_data(TARGET(kCUDA)); + } else { + // LOG(INFO) << param.output->dims().repr(); + temp_out = param.output->mutable_data(TARGET(kCUDA)); + } + + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CHECK(cudnnConvolutionForward(this->handle_, + &alpha, + this->input_desc_, + i_data, + this->filter_desc_, + w_data, + this->conv_desc_, + this->fwd_algo_, + this->workspace_, + this->workspace_fwd_sizes_, + &beta, + this->output_desc_, + temp_out)); + + auto out_dims = param.output->dims(); + int n = out_dims[0], h = out_dims[1], w = out_dims[2], c = out_dims[3]; + int num = n * h * w * c; + + if (!param.activation_param.has_active && !b_data) { + if (Ptype_out == PRECISION(kInt8)) { + auto* out = param.output->mutable_data(TARGET(kCUDA)); + fp32_to_int8_nhwc(num, + static_cast(temp_out), + static_cast(out), + static_cast(scale), + n, + c, + h, + w, + this->stream_); + } else { + fp32_scale_nhwc(num, + static_cast(temp_out), + static_cast(temp_out), + static_cast(scale), + n, + c, + h, + w, + this->stream_); + } + return true; + } + + if (b_data) { + if (param.activation_param.has_active) { + float alpha = 0.0; + if (!this->with_relu_act_) + alpha = param.activation_param.Leaky_relu_alpha; + if (Ptype_out == PRECISION(kInt8)) { + auto* out = param.output->mutable_data(TARGET(kCUDA)); + bias_relu_int8_nhwc(num, + static_cast(temp_out), + static_cast(b_data), + static_cast(out), + n, + c, + h, + w, + static_cast(scale), + alpha, + this->stream_); + } else { + bias_relu_int8_nhwc(num, + static_cast(temp_out), + static_cast(b_data), + static_cast(temp_out), + n, + c, + h, + w, + static_cast(scale), + alpha, + this->stream_); + } + return true; + } else { + if (Ptype_out == PRECISION(kInt8)) { + auto* out = param.output->mutable_data(TARGET(kCUDA)); + bias_int8_nhwc(num, + static_cast(temp_out), + static_cast(b_data), + static_cast(out), + n, + c, + h, + w, + static_cast(scale), + this->stream_); + } else { + bias_int8_nhwc(num, + static_cast(temp_out), + static_cast(b_data), + static_cast(temp_out), + n, + c, + h, + w, + static_cast(scale), + this->stream_); + } + return true; + } + } + + CHECK(false) + << "Conv Int8 support Conv, Conv + bias + relu, Conv + bias + leaky_relu"; +} + +template class CudnnConv2DInt8; +template class CudnnConv2DInt8; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_conv.h b/lite/backends/cuda/math/cudnn_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..5800d13c19677e624d9d52216fd44fee29813909 --- /dev/null +++ b/lite/backends/cuda/math/cudnn_conv.h @@ -0,0 +1,140 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class CudnnConv2DBase { + public: + CudnnConv2DBase() + : handle_(NULL), + fwd_algo_((cudnnConvolutionFwdAlgo_t)0), + input_desc_(NULL), + output_desc_(NULL), + bias_desc_(NULL), + filter_desc_(NULL), + conv_desc_(NULL), + act_desc_(NULL), + workspace_data_(NULL), + workspace_(NULL), + workspace_fwd_sizes_(0), + workspace_size_inbytes_(0) {} + + ~CudnnConv2DBase() { + if (conv_desc_) { + CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc_)); + } + if (input_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_)); + } + if (output_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_)); + } + if (act_desc_) { + CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_)); + } + if (bias_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc_)); + } + if (filter_desc_) { + CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc_)); + } + if (handle_ != NULL) { + CUDNN_CHECK(cudnnDestroy(handle_)); + } + ResetWorkSpace(); + } + + protected: + void ResetWorkSpace() { + if (workspace_data_ != NULL) { + CUDA_CALL(cudaFree(workspace_data_)); + } + workspace_data_ = NULL; + } + + protected: + cudaStream_t stream_; + cudnnHandle_t handle_; + cudnnConvolutionFwdAlgo_t fwd_algo_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnTensorDescriptor_t bias_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + + // activation descriptor + cudnnActivationDescriptor_t act_desc_; + bool with_relu_act_{true}; + + void* workspace_data_; // underlying storage + void* workspace_; // aliases into _workspaceData + size_t workspace_fwd_sizes_; + size_t workspace_size_inbytes_; // size of underlying storage + + const bool use_tensor_core_ = true; + const size_t workspace_limit_bytes_ = 4 * 1024 * 1024; + const cudnnConvolutionFwdPreference_t preference_ = + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; + + // For int8 + Tensor temp_tensor_; + Tensor scale_; +}; + +template +class CudnnConv2D : public CudnnConv2DBase { + public: + CudnnConv2D() : CudnnConv2DBase() {} + virtual ~CudnnConv2D() = default; + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); +}; + +template +class CudnnConv2DInt8 : CudnnConv2DBase { + public: + CudnnConv2DInt8() : CudnnConv2DBase() {} + virtual ~CudnnConv2DInt8() = default; + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_helper.h b/lite/backends/cuda/math/cudnn_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..b7f9b2cf69cadf5abfacd244ee07788f4a3ce525 --- /dev/null +++ b/lite/backends/cuda/math/cudnn_helper.h @@ -0,0 +1,24 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math {} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/elementwise.cu b/lite/backends/cuda/math/elementwise.cu new file mode 100644 index 0000000000000000000000000000000000000000..57c9ec022a6e49551fd2d56a9b2036de13bf5a2c --- /dev/null +++ b/lite/backends/cuda/math/elementwise.cu @@ -0,0 +1,129 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/elementwise.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__global__ void elementwise_add_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { +#if __CUDA_ARCH__ >= 350 + out_data[tid] = __ldg(x_data + tid) + __ldg(y_data + tid); +#else + out_data[tid] = x_data[tid] + y_data[tid]; +#endif + } +} + +__global__ void elementwise_add_int8_kernel(const size_t total, + const float* x_data, + const float* y_data, + const float alpha, + int8_t* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + float temp_d; +#if __CUDA_ARCH__ >= 350 + temp_d = __ldg(x_data + tid) + __ldg(y_data + tid); +#else + temp_d = x_data[tid] + y_data[tid]; +#endif + out_data[tid] = from_float(temp_d * alpha); + } +} + +__global__ void elementwise_add_nhwc4_int8_kernel(const size_t total, + const float4* x_data, + const float4* y_data, + const float alpha, + char4* out_data) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + const float4 x_d = x_data[tid]; + const float4 y_d = y_data[tid]; + + float4 packed_val; + char4 result_val; + packed_val.x = (x_d.x + y_d.x) * alpha; + result_val.x = from_float(packed_val.x); + packed_val.y = (x_d.y + y_d.y) * alpha; + result_val.y = from_float(packed_val.y); + packed_val.z = (x_d.z + y_d.z) * alpha; + result_val.z = from_float(packed_val.z); + packed_val.w = (x_d.w + y_d.w) * alpha; + result_val.w = from_float(packed_val.w); + out_data[tid] = result_val; + } +} + +template +void elementwise_add(int num, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + elementwise_add_kernel<<>>( + num, x_data, y_data, out_data); +} + +template void elementwise_add( + int, const float*, const float*, float*, cudaStream_t); + +// input type is float32 +// output type is int8 +void elementwise_add_int8(int num, + const float* x_data, + const float* y_data, + const float alpha, + int8_t* out_data, + cudaStream_t stream) { + int thread = 256; + int block = (num + thread - 1) / thread; + // elementwise_add_int8_kernel<<>>( + elementwise_add_int8_kernel<<>>( + num, x_data, y_data, alpha, out_data); +} + +void elementwise_add_nhwc4_int8(int num, + const void* x_data, + const void* y_data, + const float alpha, + void* out_data, + cudaStream_t stream) { + int thread = 512; + int block = (num + thread - 1) / thread; + // elementwise_add_nhwc4_int8_kernel<<>>( + elementwise_add_nhwc4_int8_kernel<<>>( + num, + static_cast(x_data), + static_cast(y_data), + alpha, + static_cast(out_data)); +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/elementwise.h b/lite/backends/cuda/math/elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..7fcdf95021ff21379bf94298ed06328dd6d2db09 --- /dev/null +++ b/lite/backends/cuda/math/elementwise.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +void elementwise_add(int num, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + cudaStream_t stream); + +void elementwise_add_int8(int num, + const float* x_data, + const float* y_data, + const float alpha, + int8_t* out_data, + cudaStream_t stream); +// input type is float32 +// output type is int8 +void elementwise_add_nhwc4_int8(int num, + const void* x_data, + const void* y_data, + const float alpha, + void* out_data, + cudaStream_t stream); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/scale.cu b/lite/backends/cuda/math/scale.cu new file mode 100644 index 0000000000000000000000000000000000000000..806a3697a2eb19354a81056f0a7ab6272ed991a1 --- /dev/null +++ b/lite/backends/cuda/math/scale.cu @@ -0,0 +1,159 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iostream" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/scale.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void scale_kernel(int count, + const T* in_data, + T* out_data, + const T* scale_data, + const T* bias_data, + const int scale_dim, + const int inner_dim) { + CUDA_KERNEL_LOOP(tid, count) { + int scale_id = (tid / inner_dim) % scale_dim; + T scale = scale_data[scale_id]; + if (bias_data == nullptr) { + out_data[tid] = scale * in_data[tid]; + } else { + out_data[tid] = scale * in_data[tid] + bias_data[scale_id]; + } + } +} + +template +__global__ void scale_kernel( + int count, const T* in_data, T* out_data, const T scale, const T bias) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; } +} + +__global__ void fp32_scale_nhwc4_kernel(int num, + const float4* in, + float4* out, + const float4* scale, + int N, + int K, + int H, + int W) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int scale_idx = tid % K; + const float4 scale_ptr = scale[scale_idx]; + const float4 in_ptr = in[tid]; + float4 packed_val; + + packed_val.x = in_ptr.x * scale_ptr.x; + packed_val.y = in_ptr.y * scale_ptr.y; + packed_val.z = in_ptr.z * scale_ptr.z; + packed_val.w = in_ptr.w * scale_ptr.w; + out[tid] = packed_val; + } +} + +__global__ void fp32_scale_nhwc_kernel(int num, + const float* in, + float* out, + const float* scale, + int N, + int C, + int H, + int W) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int idx = tid % C; +#if __CUDA_ARCH__ >= 350 + out[tid] = __ldg(in + tid) * __ldg(scale + idx); +#else + out[tid] = in[tid] * scale[idx]; +#endif + } +} + +void fp32_scale_nhwc(int num, + const void* in, + void* out, + const void* scale, + int N, + int C, + int H, + int W, + cudaStream_t stream) { + int thread = 256; + if (C % 4 == 0) { + int block = (num / 4 + thread - 1) / thread; + fp32_scale_nhwc4_kernel<<>>( + num / 4, + static_cast(in), + static_cast(out), + static_cast(scale), + N, + C / 4, + H, + W); + } else { + int block = (num + thread - 1) / thread; + fp32_scale_nhwc_kernel<<>>( + num, + static_cast(in), + static_cast(out), + static_cast(scale), + N, + C, + H, + W); + } + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); +} + +template +void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) { + int thread = 256; + int block = (num + thread - 1) / thread; + scale_kernel<<>>(num, in, out, scale, bias); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); +} + +template +void scale(int num, const T* in, T* out, T scale, T bias) { + int thread = 256; + int block = (num + thread - 1) / thread; + scale_kernel<<>>(num, in, out, scale, bias); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); +} + +template void scale(int num, const float*, float*, float, cudaStream_t, float); +template void scale(int num, const float*, float*, float, float); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/scale.h b/lite/backends/cuda/math/scale.h new file mode 100644 index 0000000000000000000000000000000000000000..52ed1d38ae79ce11cac50a9abef0f57e6de1352c --- /dev/null +++ b/lite/backends/cuda/math/scale.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +void fp32_scale_nhwc(int num, + const void* din, + void* dout, + const void* scale, + int N, + int K, + int H, + int W, + cudaStream_t stream); + +template +void scale( + int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0); + +template +void scale(int num, const T* in, T* out, T scale, T bias = 0); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/transpose.cu b/lite/backends/cuda/math/transpose.cu new file mode 100644 index 0000000000000000000000000000000000000000..cebcece812dc584d0921edea2fef8f129e430b56 --- /dev/null +++ b/lite/backends/cuda/math/transpose.cu @@ -0,0 +1,194 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/transpose.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +constexpr int kTileDim = 32; +constexpr int kBlockRows = 8; +constexpr int CUDA_NUM_THREADS = 128; + +// Splits the original matrix into submatrices with size 32 * 32. +// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/ +template +__global__ void BatchTranspose2DCUDAKernel(const int N, + const int H, + const int W, + const int dh, + const int dw, + const T* input, + T* out) { + __shared__ T tile[kTileDim][kTileDim + 1]; // plus 1 to prevent bank confict. + const int n = blockIdx.x / (dh * dw); + const int k = blockIdx.x % (dh * dw); + const int r = k / dw; + const int c = k % dw; + const int offset = n * H * W; + int x = c * kTileDim + threadIdx.x; + int y = r * kTileDim + threadIdx.y; + if (x < W) { + for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + tile[threadIdx.y + i][threadIdx.x] = + __ldg(input + offset + (y + i) * W + x); +#else + tile[threadIdx.y + i][threadIdx.x] = input[offset + (y + i) * W + x]; +#endif + } + } + __syncthreads(); + x = r * kTileDim + threadIdx.x; + y = c * kTileDim + threadIdx.y; + if (x < H) { + for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) { + out[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; + } + } +} + +template +void BatchTranspose2DCUDAImpl(const int N, + const int H, + const int W, + const T* input, + T* out, + CUDAContext* ctx) { + const int dh = (H + kTileDim - 1) / kTileDim; + const int dw = (W + kTileDim - 1) / kTileDim; + BatchTranspose2DCUDAKernel< + T><<exec_stream()>>>( + N, H, W, dh, dw, input, out); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \ + template <> \ + void NCHW2NHWC(const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + T* Y, \ + CUDAContext* ctx) { \ + BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, ctx); \ + } +TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float) +TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t) +#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC + +#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \ + template <> \ + void NHWC2NCHW(const int N, \ + const int C, \ + const int HxW, \ + const T* X, \ + T* Y, \ + CUDAContext* ctx) { \ + BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, ctx); \ + } +TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float) +TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t) +#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW + +template +__global__ void TransposeCUDAKernel(const int size, + const int ndim, + const int* X_strides, + const int* Y_dims, + const T* X, + T* Y) { + const int Y_index = blockIdx.x * CUDA_NUM_THREADS + threadIdx.x; + if (Y_index < size) { + int X_index = 0; + int v = Y_index; +#pragma unroll + for (int i = ndim - 1; i >= 0; --i) { + X_index += v % Y_dims[i] * X_strides[i]; + v /= Y_dims[i]; + } +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + Y[Y_index] = __ldg(X + X_index); +#else + Y[Y_index] = X[X_index]; +#endif + } +} + +template +void TransposeCUDAImpl(const std::vector& X_dims, + const std::vector& axes, + const T* X, + T* Y, + CUDAContext* ctx) { + CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal"; + int ndim = X_dims.size(); + std::vector strides(ndim, 0); + std::vector Y_dims(ndim, 0); + std::vector buf(ndim, 0); + int cur_stride = 1; + for (int i = ndim - 1; i >= 0; --i) { + buf[i] = cur_stride; + cur_stride *= X_dims[i]; + } + for (int i = 0; i < ndim; ++i) { + strides[i] = buf[axes[i]]; + } + int size = 1; + for (int i = 0; i < ndim; ++i) { + Y_dims[i] = static_cast(X_dims[axes[i]]); + size *= X_dims[i]; + } + + lite::Tensor Y_dims_, strides_; + Y_dims_.Resize(std::vector({ndim})); + int* d_y_dims = Y_dims_.mutable_data(TARGET(kCUDA)); + CopySync( + d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD); + + strides_.Resize(std::vector({ndim})); + int* d_strides = strides_.mutable_data(TARGET(kCUDA)); + CopySync(d_strides, + strides.data(), + sizeof(int) * strides.size(), + IoDirection::HtoD); + + const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; + TransposeCUDAKernel<<exec_stream()>>>( + size, ndim, d_strides, d_y_dims, X, Y); + auto e = cudaGetLastError(); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); +} + +#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \ + template <> \ + void Transpose(const std::vector& X_dims, \ + const std::vector& axes, \ + const T* X, \ + T* Y, \ + CUDAContext* ctx) { \ + TransposeCUDAImpl(X_dims, axes, X, Y, ctx); \ + } +TYPE_SPECIALIZED_CUDA_TRANSPOSE(float) +#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/transpose.h b/lite/backends/cuda/math/transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..ba2464547b587f44cd9b0ce287a0d40d37d46411 --- /dev/null +++ b/lite/backends/cuda/math/transpose.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/context.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); + +template +void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); + +template +void Transpose(const std::vector& X_dims, + const std::vector& axes, + const T* X, + T* Y, + CUDAContext* ctx); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/type_trans.cu b/lite/backends/cuda/math/type_trans.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d884e5cb5ec9a86fdfb5bbc0d6752396a6e026a --- /dev/null +++ b/lite/backends/cuda/math/type_trans.cu @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/type_trans.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +__global__ void fp32_to_int8_nhwc_kernel(int num, + const float* in, + int8_t* out, + const float* scale, + int N, + int C, + int H, + int W) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int idx = tid % C; +#if __CUDA_ARCH__ >= 350 + out[tid] = from_float(__ldg(in + tid) * __ldg(scale + idx)); +#else + out[tid] = from_float(in[tid] * scale[idx]); +#endif + } +} + +__global__ void fp32_to_int8_nhwc4_kernel(int num, + const float4* in, + char4* out, + const float4* scale, + int N, + int K, + int H, + int W) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int scale_idx = tid % K; + const float4 scale_ptr = scale[scale_idx]; + const float4 in_ptr = in[tid]; + char4 result_val; + + result_val.x = from_float(in_ptr.x * scale_ptr.x); + result_val.y = from_float(in_ptr.y * scale_ptr.y); + result_val.z = from_float(in_ptr.z * scale_ptr.z); + result_val.w = from_float(in_ptr.w * scale_ptr.w); + out[tid] = result_val; + } +} + +void fp32_to_int8_nhwc(int num, + const void* in, + void* out, + const void* scale, + int N, + int C, + int H, + int W, + cudaStream_t stream) { + int thread = 256; + if (C % 4 == 0) { + int block = (num / 4 + thread - 1) / thread; + fp32_to_int8_nhwc4_kernel<<>>( + num / 4, + static_cast(in), + static_cast(out), + static_cast(scale), + N, + C / 4, + H, + W); + } else { + int block = (num + thread - 1) / thread; + fp32_to_int8_nhwc_kernel<<>>( + num, + static_cast(in), + static_cast(out), + static_cast(scale), + N, + C, + H, + W); + } +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/type_trans.h b/lite/backends/cuda/math/type_trans.h new file mode 100644 index 0000000000000000000000000000000000000000..87c0a191e011c370bbfe110631f9c2f20bf277fe --- /dev/null +++ b/lite/backends/cuda/math/type_trans.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +void fp32_to_int8_nhwc(int num, + const void* din, + void* dout, + const void* scale, + int N, + int C, + int H, + int W, + cudaStream_t stream); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/utils.h b/lite/backends/cuda/math/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b4cd82fd8df6df063d92df709311f3c90e7cf4b6 --- /dev/null +++ b/lite/backends/cuda/math/utils.h @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__device__ T from_float(float x); + +template <> +__device__ __forceinline__ float from_float(float x) { + return x; +} + +template <> +__device__ __forceinline__ half from_float(float x) { + return __float2half(x); +} + +template <> +__device__ __forceinline__ int8_t from_float(float x) { + x = fmaxf(x, std::numeric_limits::min()); + x = fminf(x, std::numeric_limits::max()); + return __float2int_rn(x); +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/target_wrapper.cc b/lite/backends/cuda/target_wrapper.cc index b1aaadf027a56d485286a68638520f45d78d9468..a79eb7539318b52e21683bdf97bd534f7cc75fb5 100644 --- a/lite/backends/cuda/target_wrapper.cc +++ b/lite/backends/cuda/target_wrapper.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/backends/cuda/target_wrapper.h" +#include "lite/backends/cuda/cuda_utils.h" namespace paddle { namespace lite { @@ -25,13 +26,11 @@ size_t TargetWrapperCuda::num_devices() { void* TargetWrapperCuda::Malloc(size_t size) { void* ptr{}; - CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size)); + CUDA_CALL(cudaMalloc(&ptr, size)); return ptr; } -void TargetWrapperCuda::Free(void* ptr) { - CHECK_EQ(cudaSuccess, cudaFree(ptr)); -} +void TargetWrapperCuda::Free(void* ptr) { CUDA_CALL(cudaFree(ptr)); } void TargetWrapperCuda::MemcpySync(void* dst, const void* src, @@ -39,14 +38,13 @@ void TargetWrapperCuda::MemcpySync(void* dst, IoDirection dir) { switch (dir) { case IoDirection::DtoD: - CHECK(cudaSuccess == - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice)); + CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice)); break; case IoDirection::HtoD: - CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice)); break; case IoDirection::DtoH: - CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); + CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); break; default: LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); @@ -60,21 +58,32 @@ void TargetWrapperCuda::MemcpyAsync(void* dst, const stream_t& stream) { switch (dir) { case IoDirection::DtoD: - CHECK(cudaSuccess == - cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL( + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream)); break; case IoDirection::HtoD: - CHECK(cudaSuccess == - cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream)); + CUDA_CALL( + cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream)); break; case IoDirection::DtoH: - CHECK(cudaSuccess == - cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream)); + CUDA_CALL( + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream)); break; default: LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); } } +void TargetWrapperCuda::MemsetSync(void* devPtr, int value, size_t count) { + CUDA_CALL(cudaMemset(devPtr, value, count)); +} + +void TargetWrapperCuda::MemsetAsync(void* devPtr, + int value, + size_t count, + const stream_t& stream) { + CUDA_CALL(cudaMemsetAsync(devPtr, value, count, stream)); +} + } // namespace lite } // namespace paddle diff --git a/lite/backends/cuda/target_wrapper.h b/lite/backends/cuda/target_wrapper.h index 50063007ce30cca7642a668f6c315903daf026bc..5b57ddf0043c59219aded9836cc0b1ad982eec2d 100644 --- a/lite/backends/cuda/target_wrapper.h +++ b/lite/backends/cuda/target_wrapper.h @@ -59,6 +59,13 @@ class TargetWrapper { size_t size, IoDirection dir, const stream_t& stream); + + static void MemsetSync(void* devPtr, int value, size_t count); + + static void MemsetAsync(void* devPtr, + int value, + size_t count, + const stream_t& stream); }; } // namespace lite } // namespace paddle diff --git a/lite/backends/fpga/KD/fpga_cv.cpp b/lite/backends/fpga/KD/fpga_cv.cpp index e3a1eed1ff8ca9c08e41b8479f48346e1a8f6fed..15a20e368b09f193e3f43b574ff3682ce96782ad 100644 --- a/lite/backends/fpga/KD/fpga_cv.cpp +++ b/lite/backends/fpga/KD/fpga_cv.cpp @@ -23,9 +23,7 @@ void fpga_resize(float* input, uint8_t* output, int output_width, int output_height) { - paddle::zynqmp::InplaceArgs inplace_args = { - .relu_enable = 0, .power_enable = 0, - }; + paddle::zynqmp::InplaceArgs inplace_args = {0, 0, 0}; paddle::zynqmp::config_inplace(inplace_args); paddle::zynqmp::ImageInputArgs input_args = {nullptr}; diff --git a/lite/backends/fpga/KD/llapi/zynqmp_api.cpp b/lite/backends/fpga/KD/llapi/zynqmp_api.cpp index 6e7c1cd03027ae28b2977fcbe217b6cfb06378a0..1f1226ead3d4e9b50100f4de574104a5d6f777b2 100644 --- a/lite/backends/fpga/KD/llapi/zynqmp_api.cpp +++ b/lite/backends/fpga/KD/llapi/zynqmp_api.cpp @@ -39,10 +39,14 @@ static size_t memory_size_max = 0; static size_t memory_size = 0; static inline int do_ioctl(uint64_t req, const void *arg) { + int ret = -1; #ifdef PADDLE_LITE_OS_LINUX - return ioctl(fd, req, arg); + ret = ioctl(fd, req, arg); + if (ret != 0) { + throw - 1; + } #else - return -1; + return ret; #endif } diff --git a/lite/backends/fpga/KD/llapi/zynqmp_api.h b/lite/backends/fpga/KD/llapi/zynqmp_api.h index 3dd7f1e981ac37fa687fdafa883409e6ad8439c9..7d22de95a2272862c6fe781295bdaab7177a92fe 100644 --- a/lite/backends/fpga/KD/llapi/zynqmp_api.h +++ b/lite/backends/fpga/KD/llapi/zynqmp_api.h @@ -46,6 +46,15 @@ struct VersionArgs { struct DeviceInfo { uint32_t filter_cap; + uint32_t version; + uint16_t device_type; + uint32_t reserved0; + uint32_t reserved1; + uint32_t reserved2; + uint32_t reserved3; + uint32_t reserved4; + uint32_t reserved5; + uint32_t reserved6; }; struct MemoryCopyArgs { @@ -191,6 +200,7 @@ struct NormalizeParameterArgs { }; struct InplaceArgs { + bool leaky_relu_enable; bool relu_enable; bool power_enable; bool normalize_enable; diff --git a/lite/backends/fpga/lite_tensor.h b/lite/backends/fpga/lite_tensor.h index 77f6a7ad822a0071539f54d4cb29d69983a99f7e..2f9df3abb08dd15641323f4a3c59d6175f2e481b 100644 --- a/lite/backends/fpga/lite_tensor.h +++ b/lite/backends/fpga/lite_tensor.h @@ -57,7 +57,7 @@ class DDimLite { DDimLite Slice(int start, int end) const; - DDimLite Flattern2D(int col) const { + DDimLite Flatten2D(int col) const { return DDimLite(std::vector( {Slice(0, col).production(), Slice(col, size()).production()})); } @@ -118,6 +118,13 @@ class TensorLite { const LoD &lod() const { return lod_; } LoD *mutable_lod() { return &lod_; } + void set_lod(const LoD &lod) { lod_ = lod; } + + PrecisionType precision() const { return precision_; } + void set_precision(PrecisionType precision) { precision_ = precision; } + + bool persistable() const { return persistable_; } + void set_persistable(bool persistable) { persistable_ = persistable; } // T is the data type and R is the return type // For OpenCL, the return type can be cl::Buffer // and the data type can be float/int8_t. @@ -147,6 +154,9 @@ class TensorLite { void CopyDataFrom(const TensorLite &other); + template + TensorLite Slice(int64_t begin, int64_t end) const; + TargetType target() const { return target_; } zynqmp::Tensor *ZynqTensor() const { return zynq_tensor_; } @@ -168,6 +178,11 @@ class TensorLite { LoD lod_; size_t memory_size_{}; + size_t offset_{0}; + + PrecisionType precision_{PrecisionType::kUnk}; + bool persistable_{false}; + zynqmp::Tensor *zynq_tensor_ = new zynqmp::Tensor(); template @@ -219,6 +234,18 @@ bool TensorCompareWith(const TensorT &a, const TensorT &b) { if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; return true; } - +template +TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { + int64_t base = numel() / dims_[0]; + + TensorLite dst; + dst.buffer_ = buffer_; + dst.target_ = target_; + auto dst_dims = dims_; + dst_dims[0] = end - begin; + dst.Resize(dst_dims); + dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); + return dst; +} } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/CMakeLists.txt b/lite/backends/npu/CMakeLists.txt index abe567566b942ce85f149915e7eb8dfe4771351e..426ff5698146c773c818b2bfd598d6bbbdf7867f 100644 --- a/lite/backends/npu/CMakeLists.txt +++ b/lite/backends/npu/CMakeLists.txt @@ -2,5 +2,5 @@ if(NOT LITE_WITH_NPU) return() endif() -lite_cc_library(npu_helper SRCS npu_helper.cc DEPS ${npu_ddk_libs}) -add_subdirectory(bridge) +lite_cc_library(npu_runtime SRCS runtime.cc DEPS ${npu_runtime_libs}) +lite_cc_library(npu_builder SRCS builder.cc DEPS ${npu_builder_libs} npu_runtime tensor op scope) diff --git a/lite/backends/npu/bridge/fc_op.cc b/lite/backends/npu/bridge/fc_op.cc deleted file mode 100644 index 1321498db68eadd85b596b79c88201ad6bbe979b..0000000000000000000000000000000000000000 --- a/lite/backends/npu/bridge/fc_op.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/operators/fc_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" - -namespace paddle { -namespace lite { -namespace npu { -namespace bridge { - -node_map_type FCConverter(const std::shared_ptr fc_op, - const node_map_type& inputs_map) { - LOG(INFO) << "Converting fc..."; - lite::Scope* scope = fc_op->scope(); - const lite::OpInfo* op_info = fc_op->op_info(); - auto output_node = std::make_shared(UniqueName("fc")); - - auto x_var_name = op_info->Input("Input").front(); - auto w_var_name = op_info->Input("W").front(); - - int in_num_col_dims = op_info->GetAttr("in_num_col_dims"); - auto* xtensor = scope->FindVar(x_var_name)->GetMutable(); - auto* wtensor = scope->FindVar(w_var_name)->GetMutable(); - auto x_dims = xtensor->dims(); - auto w_dims = wtensor->dims(); - - CHECK_GE(x_dims.size(), 2UL); - CHECK_EQ(w_dims.size(), 2UL); - - int m = x_dims.Slice(0, in_num_col_dims).production(); - int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production(); - int n = w_dims[1]; - - CHECK(inputs_map.count(x_var_name)); - CHECK(!inputs_map.count(w_var_name)); - - LOG(INFO) << "m:" << m << ",n:" << n << ",k:" << k; - LOG(INFO) << "x_var_name:" << x_var_name - << ", is data: " << inputs_map.count(x_var_name); - LOG(INFO) << "w_var_name:" << w_var_name - << ", is data: " << inputs_map.count(w_var_name); - - auto xsrc = inputs_map.at(x_var_name); - auto reshapex = std::make_shared(x_var_name + "_reshape"); - reshapex->set_input_tensor(*xsrc); - reshapex->set_attr_shape({m, k}); - reshapex->set_attr_axis(0); - OpList::Global().add(xsrc); - OpList::Global().add(reshapex); - output_node->set_input_x(*reshapex); - - auto wconst = std::make_shared(w_var_name); - ge::TensorDesc wdesc(ge::Shape({k, n}), ge::FORMAT_NCHW, ge::DT_FLOAT); - auto size = wdesc.GetShape().GetShapeSize(); - CHECK_EQ(size, w_dims.production()); - ge::TensorPtr ptensor = std::make_shared(); - ptensor->SetTensorDesc(wdesc); - auto* pdata = reinterpret_cast(wtensor->mutable_data()); - ptensor->SetData(pdata, size * sizeof(float)); - wconst->set_attr_value(ptensor); - OpList::Global().add(wconst); - output_node->set_input_w(*wconst); - - if (HasInputArg(op_info, scope, "Bias")) { - auto b_var_name = op_info->Input("Bias").front(); - auto* btensor = scope->FindVar(b_var_name)->GetMutable(); - - LOG(INFO) << "b_var_name:" << b_var_name - << ", is data: " << inputs_map.count(b_var_name); - CHECK(!inputs_map.count(b_var_name)); - CHECK_EQ(btensor->numel(), n); - - auto bconst = std::make_shared(b_var_name); - ge::TensorDesc bdesc( - ge::Shape({1, n, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - auto size = bdesc.GetShape().GetShapeSize(); - CHECK_EQ(size, n); - ge::TensorPtr ptensor = std::make_shared(); - ptensor->SetTensorDesc(bdesc); - auto* pdata = reinterpret_cast(btensor->mutable_data()); - ptensor->SetData(pdata, size * sizeof(float)); - bconst->set_attr_value(ptensor); - OpList::Global().add(bconst); - output_node->set_input_bias(*bconst); - output_node->set_attr_has_bias(ge::AttrValue::BOOL{true}); - } - - OpList::Global().add(output_node); - - node_map_type outputs_map; - outputs_map[op_info->Output("Out").front()] = output_node; - return outputs_map; -} - -} // namespace bridge -} // namespace npu -} // namespace lite -} // namespace paddle - -REGISTER_NPU_BRIDGE(fc, paddle::lite::npu::bridge::FCConverter); diff --git a/lite/backends/npu/bridge/utils.h b/lite/backends/npu/bridge/utils.h deleted file mode 100644 index 169b7ca80c254704f968cf70725ac4ca45fe7b8f..0000000000000000000000000000000000000000 --- a/lite/backends/npu/bridge/utils.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/core/mir/node.h" -#include "lite/core/op_lite.h" -#include "lite/core/target_wrapper.h" -#include "lite/core/tensor.h" - -namespace paddle { -namespace lite { -namespace npu { -namespace bridge { - -std::string UniqueName(const std::string& prefix); - -ge::DataType PrecisionConverter(PrecisionType itype); - -ge::Format DataLayoutConverter(DataLayoutType itype); - -ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor, - std::vector out_shape = {}, - PrecisionType in_ptype = PRECISION(kFloat), - DataLayoutType in_ltype = DATALAYOUT(kNCHW)); - -template -ge::TensorPtr CreateTensorAndFillData(std::vector data, - std::vector shape = {}, - ge::Format format = ge::FORMAT_NCHW) { - const std::type_info& info = typeid(T); - ge::DataType type = ge::DT_FLOAT; - if (info == typeid(float)) { - type = ge::DT_FLOAT; - } else if (info == typeid(int8_t)) { - type = ge::DT_INT8; - } else if (info == typeid(int32_t)) { - type = ge::DT_INT32; - } else { - LOG(FATAL) << "Unknow value type " << info.name(); - } - if (shape.empty()) { - shape = {static_cast(data.size())}; - } else { - int size = 1; - for (auto i : shape) { - size *= i; - } - CHECK_EQ(data.size(), size); - } - ge::TensorDesc desc(ge::Shape(shape), format, type); - ge::TensorPtr tensor = std::make_shared(); - tensor->SetTensorDesc(desc); - tensor->SetData(reinterpret_cast(data.data()), - data.size() * sizeof(T)); - return tensor; -} - -template -ge::TensorPtr CreateTensorAndFillData(T value, - std::vector shape = {1}, - ge::Format format = ge::FORMAT_NCHW) { - int64_t size = 1; - for (auto i : shape) { - size *= i; - } - std::vector data(size, value); - return CreateTensorAndFillData(data, shape, format); -} - -bool HasInputArg(const OpInfo* op_info, - const Scope* scope, - const std::string& argname); - -} // namespace bridge -} // namespace npu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/npu/bridge/utils.cc b/lite/backends/npu/builder.cc similarity index 76% rename from lite/backends/npu/bridge/utils.cc rename to lite/backends/npu/builder.cc index 8abd7dbda45b2f2ac493ccfb4928252b01ab63e4..80ab6e486b6cd9a67f4162ffb11d7bdac959eca9 100644 --- a/lite/backends/npu/bridge/utils.cc +++ b/lite/backends/npu/builder.cc @@ -12,19 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/npu/bridge/utils.h" -#include +#include "lite/backends/npu/builder.h" #include // NOLINT -#include -#include -#include "ai_ddk_lib/include/graph/op/all_ops.h" // for ge::op::Data -#include "ai_ddk_lib/include/graph/tensor.h" // for ge::TensorUtils -#include "lite/core/op_lite.h" +#include +#include "lite/backends/npu/runtime.h" namespace paddle { namespace lite { namespace npu { -namespace bridge { + +// Build HIAI IR graph to om model, and store om model data into lite tensor +bool BuildModel(std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + lite::Tensor* model_data) { + LOG(INFO) << "[NPU] Build model."; + CHECK_GT(inputs.size(), 0); + CHECK_GT(outputs.size(), 0); + CHECK_NE(model_data, 0); + // build IR graph to om model + ge::Graph ir_graph("graph"); + ir_graph.SetInputs(inputs).SetOutputs(outputs); + ge::Model om_model("model", "model"); + om_model.SetGraph(ir_graph); + domi::HiaiIrBuild ir_build; + domi::ModelBufferData om_model_buf; + if (!ir_build.CreateModelBuff(om_model, om_model_buf)) { + LOG(WARNING) << "[NPU] CreateModelBuff failed!"; + return false; + } + if (!ir_build.BuildIRModel(om_model, om_model_buf)) { + LOG(WARNING) << "[NPU] BuildIRModel failed!"; + return false; + } + // store om model into tensor + model_data->Resize({om_model_buf.length}); + memcpy(model_data->mutable_data(), + om_model_buf.data, + om_model_buf.length); + ir_build.ReleaseModelBuff(om_model_buf); + return true; +} std::string UniqueName(const std::string& prefix) { static std::mutex counter_mtx; @@ -131,7 +158,6 @@ bool HasInputArg(const OpInfo* op_info, } } -} // namespace bridge } // namespace npu } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/builder.h b/lite/backends/npu/builder.h new file mode 100644 index 0000000000000000000000000000000000000000..a245a8517b1c8e20a4630d370da5ca0b203adb71 --- /dev/null +++ b/lite/backends/npu/builder.h @@ -0,0 +1,254 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "ai_ddk_lib/include/hiai_ir_build.h" +#include "lite/core/op_lite.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/tensor.h" + +// Extended Ops of HIAI DDK +namespace ge { +/** + * Multiply the matrix x1 by the matrix x2 to generate x1 * x2. + * The inputs must be two-dimensional matrices and the inner dimension of "x1" + * (after being transposed if transpose_x1 is true) must match the outer + * dimension of "x2" (after being transposed if transposed_x2 is true). + * x : the first input tensor, must be non const op. + * w : the second input tensor, must be const op. + * bias: the optional bias tensor, must be const op. + * + * y : the output tensor. + * + * has_bias: If true, enable input bias. + */ +REG_OP(MatMul) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(w, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT})) // bias must be const input + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(has_bias, AttrValue::BOOL{false}) // when has input::bias,set true + .OP_END(); + +/** + * Computes the gradients of convolution with respect to the input. + * + * input_sizes : An integer vector representing the shape of input, + * where input is a 4-D [batch, height, width, channels] tensor. + * filter : the filter tensor, with shape [H , W, filter_channel, + * filter_number], filter_channel must be same as x channel. + * x : The input tensor. + * + * y : The output tensor. + * + * format: 0: NCHW. 1: NHWC + * group : 1: default + * num_output : 0: default, num_output must be equal to + * (filter_channel * group) + * pad : Padding for the beginning and ending along each axis + * stride : Stride along each axis. + * dilation : dilation value along each axis of the filter. + * pad_mode : 0:NOTSET, 5:VALID 6:SAME. defaul value is 0:NOTSET + * bias_term : 0: default + * kernel : The shape of the convolution kernel + */ +REG_OP(Deconvolution) + .INPUT(input_sizes, TensorType({DT_UINT8})) + .INPUT(filter, TensorType({DT_FLOAT})) + .INPUT(x, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(b, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(mode, AttrValue::INT{1}) + .ATTR(format, AttrValue::INT{1}) + .ATTR(group, AttrValue::INT{1}) + .ATTR(num_output, AttrValue::INT{0}) + .ATTR(pad, AttrValue::LIST_INT({0, 0, 0, 0})) + .ATTR(stride, AttrValue::LIST_INT({1, 1})) + .ATTR(dilation, AttrValue::LIST_INT({1, 1})) + .ATTR(pad_mode, AttrValue::INT{0}) + .ATTR(bias_term, AttrValue::INT{0}) + .ATTR(kernel, AttrValue::LIST_INT({0, 0})) + .OP_END(); + +/** + * Resize images to size using bilinear interpolation. + * + * x : The tensor of 4-D + * w : A int32 Tensor of 2 elements: [height, width]. + * + * y : the output tensor + * + * align_corners : If true, the centers of the 4 corner pixels of the + * input and output tensors are aligned, preserving the values at the corner + * pixels. + * output_dim_mode : Defaults 2, including 0: zoom_factor , 1: + * shrink_factor, 2: height/width. when output_dim_mode=2, the output-dim is + * controled by the [height, width] of w. + * shrink_factor : shrink factor. + * zoom_factor : zoom factor. + * pad_begin : begin of pad. + * pad_end : end of pad. + */ +REG_OP(ResizeBilinear) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32})) + .INPUT(w, TensorType({DT_FLOAT, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32})) + .ATTR(align_corners, AttrValue::BOOL{false}) + .ATTR(output_dim_mode, AttrValue::INT{2}) + .ATTR(shrink_factor, AttrValue::INT{1}) + .ATTR(zoom_factor, AttrValue::INT{1}) + .ATTR(pad_begin, AttrValue::INT{0}) + .ATTR(pad_end, AttrValue::INT{0}) + .OP_END(); + +/** + * Resize images to size using nearest neighbor interpolation. + * + * image : Resize images to size using nearest neighbor interpolation. + * size : Must be one dimension and two elements + * + * output : the output tensor + * + * align_corners : If true, the centers of the 4 corner pixels of the + * input and output tensors are aligned, preserving the values at the corner + * pixels. Defaults to false + */ +REG_OP(ResizeNearestNeighbor) + .INPUT(image, TensorType({DT_FLOAT, DT_INT32, DT_UINT8, DT_BOOL})) + .INPUT(size, TensorType({DT_INT32})) + .OUTPUT(output, TensorType({DT_FLOAT, DT_INT32, DT_UINT8, DT_BOOL})) + .ATTR(align_corners, AttrValue::BOOL{false}) + .OP_END(); + +/** + * Pads a tensor. + * + * x : the input tensor + * padding : the input tensor must be 2-D + * constant_values : constant values must be a scalar + * + * output : the output tensor + * + * t_paddings : Default DT_INT32 , t_paddings must be the same with + * datatype of the padding + * mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC + * T : datatype of constant_values DT_INT32:3 DT_FLOAT:0 + */ +REG_OP(Pad) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32})) + .INPUT(padding, TensorType({DT_INT32})) + .OPTIONAL_INPUT(constant_values, TensorType({DT_INT32, DT_FLOAT})) + .OUTPUT(output, TensorType({DT_FLOAT, DT_INT32})) + .ATTR(t_paddings, AttrValue::INT{3}) + .ATTR(mode, AttrValue::INT{0}) + .REQUIRED_ATTR(T, AttrValue::INT) + .OP_END(); + +} // namespace ge + +namespace paddle { +namespace lite { +namespace npu { + +class OpList { + public: + static OpList& Global() { + static thread_local OpList x; + return x; + } + void clear() { lists_.clear(); } + void add(std::shared_ptr p) { lists_.push_back(p); } + + private: + std::vector> lists_; +}; + +// Build HIAI IR graph to om model, and store om model data into lite tensor +bool BuildModel(std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + lite::Tensor* model_data); + +std::string UniqueName(const std::string& prefix); + +ge::DataType PrecisionConverter(PrecisionType itype); + +ge::Format DataLayoutConverter(DataLayoutType itype); + +ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor, + std::vector out_shape = {}, + PrecisionType in_ptype = PRECISION(kFloat), + DataLayoutType in_ltype = DATALAYOUT(kNCHW)); + +template +ge::TensorPtr CreateTensorAndFillData(std::vector data, + std::vector shape = {}, + ge::Format format = ge::FORMAT_NCHW) { + const std::type_info& info = typeid(T); + ge::DataType type = ge::DT_FLOAT; + if (info == typeid(float)) { + type = ge::DT_FLOAT; + } else if (info == typeid(int8_t)) { + type = ge::DT_INT8; + } else if (info == typeid(int32_t)) { + type = ge::DT_INT32; + } else { + LOG(FATAL) << "Unknow value type " << info.name(); + } + if (shape.empty()) { + shape = {static_cast(data.size())}; + } else { + int size = 1; + for (auto i : shape) { + size *= i; + } + CHECK_EQ(data.size(), size); + } + ge::TensorDesc desc(ge::Shape(shape), format, type); + ge::TensorPtr tensor = std::make_shared(); + tensor->SetTensorDesc(desc); + tensor->SetData(reinterpret_cast(data.data()), + data.size() * sizeof(T)); + return tensor; +} + +template +ge::TensorPtr CreateTensorAndFillData(T value, + std::vector shape = {1}, + ge::Format format = ge::FORMAT_NCHW) { + int64_t size = 1; + for (auto i : shape) { + size *= i; + } + std::vector data(size, value); + return CreateTensorAndFillData(data, shape, format); +} + +bool HasInputArg(const OpInfo* op_info, + const Scope* scope, + const std::string& argname); + +} // namespace npu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/npu/npu_helper.cc b/lite/backends/npu/npu_helper.cc deleted file mode 100644 index 688c62c7f65b50386612077ef2633bb1ac880254..0000000000000000000000000000000000000000 --- a/lite/backends/npu/npu_helper.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/npu/npu_helper.h" -#include -#include -#include -#include -#include -#include "ai_ddk_lib/include/HiAiModelManagerService.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/hiai_ir_build.h" - -namespace paddle { -namespace lite { -namespace npu { - -bool SaveNPUModel(const void* om_model_data, - const size_t om_model_size, - const std::string& om_file_path) { - std::FILE* fp; - fp = std::fopen(om_file_path.c_str(), "wb"); - if (fp == NULL) { - LOG(WARNING) << "[NPU] " << om_file_path << " open failed!"; - return false; - } - - size_t write_size = std::fwrite(om_model_data, 1, om_model_size, fp); - if (write_size != om_model_size) { - std::fclose(fp); - LOG(WARNING) << "[NPU] Write NPU model failed: " << om_file_path; - return false; - } - std::fclose(fp); - return true; -} - -bool BuildNPUClient(const void* om_model_data, - const size_t om_model_size, - const std::string& name) { - std::unique_ptr client( - new hiai::AiModelMngerClient); - int ret = client->Init(nullptr); - if (ret != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] Failed building NPU client " << name - << ", ret: " << ret; - throw std::runtime_error(""); - return false; - } - - auto desc = std::make_shared( - name, - DeviceInfo::Global().freq_level(), - DeviceInfo::Global().framework_type(), - DeviceInfo::Global().model_type(), - DeviceInfo::Global().device_type()); - desc->SetModelBuffer(om_model_data, om_model_size); - - std::vector> model_desc; - model_desc.push_back(desc); - if (client->Load(model_desc) != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] Model Load Failed: " << desc->GetName(); - throw std::runtime_error(""); - return false; - } - - DeviceInfo::Global().Insert(name, std::move(client)); - return true; -} - -// If build from inputs and outputs will save the npu offline model -bool BuildNPUClient(std::vector& inputs, // NOLINT - std::vector& outputs, // NOLINT - const std::string& name) { - LOG(INFO) << "[NPU] Building Client"; - ge::Graph npu_subgraph("npu_subgraph" + name); - npu_subgraph.SetInputs(inputs).SetOutputs(outputs); - - ge::Model npu_model("model", "npu_model" + name); - npu_model.SetGraph(npu_subgraph); - - // compile IR graph and output om model to memory - domi::HiaiIrBuild ir_build; - domi::ModelBufferData om_model_buffer; - if (!ir_build.CreateModelBuff(npu_model, om_model_buffer)) { - LOG(WARNING) << "[NPU] Failed CreateModelBuff: " << npu_model.GetName(); - return false; - } - if (!ir_build.BuildIRModel(npu_model, om_model_buffer)) { - LOG(WARNING) << "[NPU] Failed BuildIRModel: " << npu_model.GetName(); - return false; - } - - if (BuildNPUClient(om_model_buffer.data, om_model_buffer.length, name)) { - // save npu offline model - if (!SaveNPUModel(om_model_buffer.data, om_model_buffer.length, name)) { - LOG(WARNING) << "[NPU] Save model " << name << " failed."; - } - ir_build.ReleaseModelBuff(om_model_buffer); - return true; - } - return false; -} - -// If build from path will not save the npu offline model -bool BuildNPUClient(const std::string& om_model_file_path, - const std::string& name) { - // load om model from file - std::ifstream file(om_model_file_path, std::ios::binary); - CHECK(file.is_open()) << "[NPU] Unable to open om model file: " - << om_model_file_path; - const auto fbegin = file.tellg(); - file.seekg(0, std::ios::end); - const auto fend = file.tellg(); - size_t om_model_size = fend - fbegin; - VLOG(5) << "[NPU] om model file size: " << om_model_size; - file.seekg(0, std::ios::beg); - std::vector om_model_data(om_model_size); - file.read(om_model_data.data(), om_model_size); - - return BuildNPUClient( - reinterpret_cast(om_model_data.data()), om_model_size, name); -} - -} // namespace npu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/npu/npu_helper.h b/lite/backends/npu/npu_helper.h deleted file mode 100644 index 95c290315b6f43e76eb39c5cb630d4dc16d1427b..0000000000000000000000000000000000000000 --- a/lite/backends/npu/npu_helper.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#include -#include "ai_ddk_lib/include/HiAiModelManagerService.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/utils/cp_logging.h" - -namespace paddle { -namespace lite { -namespace npu { - -class DeviceInfo { - public: - static DeviceInfo& Global() { - static DeviceInfo x; - return x; - } - DeviceInfo() {} - void Insert(const std::string& name, - std::unique_ptr client) { - if (clients_.find(name) != clients_.end()) { - LOG(WARNING) << "[NPU] Already insert " << name; - return; - } - clients_.emplace(std::make_pair(name, std::move(client))); - } - - void Clear() { clients_.clear(); } - - hiai::AiModelMngerClient* client(const std::string& model_name) const { - if (clients_.find(model_name) != clients_.end()) { - return clients_.at(model_name).get(); - } else { - return nullptr; - } - } - std::vector AllClientNames() { - std::vector names; - for (auto& i : clients_) { - names.push_back(i.first); - } - return names; - } - - int freq_level() { return freq_level_; } - int framework_type() { return framework_type_; } - int model_type() { return model_type_; } - int device_type() { return device_type_; } - - private: - int freq_level_{3}; - int framework_type_{0}; - int model_type_{0}; - int device_type_{0}; - // TODO(TJ): find better place - std::unordered_map> - clients_; -}; - -class OpList { - public: - static OpList& Global() { - static thread_local OpList x; - return x; - } - void clear() { lists_.clear(); } - void add(std::shared_ptr p) { lists_.push_back(p); } - - private: - std::vector> lists_; -}; - -bool SaveNPUModel(const void* om_model_data, - const size_t om_model_size, - const std::string& om_file_path); - -// If build from inputs and outputs will save the npu offline model -bool BuildNPUClient(std::vector& inputs, // NOLINT - std::vector& outputs, // NOLINT - const std::string& name); - -// If build from path will not save the npu offline model -bool BuildNPUClient(const std::string& om_model_file_path, - const std::string& name); - -bool BuildNPUClient(const void* om_model_data, - const size_t om_model_size, - const std::string& name); - -} // namespace npu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/npu/runtime.cc b/lite/backends/npu/runtime.cc new file mode 100644 index 0000000000000000000000000000000000000000..3485f63c7c8bb91081fd1969d0d41733417149d9 --- /dev/null +++ b/lite/backends/npu/runtime.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/npu/runtime.h" +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace npu { + +// Create hiai model manager to load om model from lite tensor, and return the +// manager and an unique model name +bool LoadModel(const lite::Tensor &model_data, + std::shared_ptr *model_client, + std::string *model_name) { + LOG(INFO) << "[NPU] Load model."; + auto model_data_ptr = model_data.data(); + auto model_data_size = model_data.numel() * sizeof(int8_t); + if (model_data_ptr == nullptr || model_data_size == 0) { + return false; + } + *model_client = std::make_shared(); + int ret = (*model_client)->Init(nullptr); + if (ret != hiai::AI_SUCCESS) { + LOG(WARNING) << "[NPU] AiModelMngerClient init failed(" << ret << ")!"; + return false; + } + *model_name = "model.om"; + auto model_desc = std::make_shared( + *model_name, + DeviceInfo::Global().freq_level(), + DeviceInfo::Global().framework_type(), + DeviceInfo::Global().model_type(), + DeviceInfo::Global().device_type()); + model_desc->SetModelBuffer(model_data_ptr, model_data_size); + std::vector> model_descs; + model_descs.push_back(model_desc); + if ((*model_client)->Load(model_descs) != hiai::AI_SUCCESS) { + LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; + return false; + } + return true; +} + +} // namespace npu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/npu/runtime.h b/lite/backends/npu/runtime.h new file mode 100644 index 0000000000000000000000000000000000000000..8b1ad51518d8626d9a6ecd6203a70b2637bb6004 --- /dev/null +++ b/lite/backends/npu/runtime.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "ai_ddk_lib/include/HiAiModelManagerService.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace npu { + +class DeviceInfo { + public: + static DeviceInfo &Global() { + static DeviceInfo x; + return x; + } + DeviceInfo() {} + + int freq_level() { return freq_level_; } + int framework_type() { return framework_type_; } + int model_type() { return model_type_; } + int device_type() { return device_type_; } + + private: + int freq_level_{3}; + int framework_type_{0}; + int model_type_{0}; + int device_type_{0}; +}; + +bool LoadModel(const lite::Tensor &model_data, + std::shared_ptr *model_client, + std::string *model_name); +} // namespace npu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/opencl/cl_caller.cc b/lite/backends/opencl/cl_caller.cc index ae755b756d2c1d7eaf619469367f46550ec36e14..4926a53c43d54b4e2b4d802a7d8ef289c7e87fc5 100644 --- a/lite/backends/opencl/cl_caller.cc +++ b/lite/backends/opencl/cl_caller.cc @@ -31,8 +31,8 @@ static void CopyImageData(CLContext* context, float* image_data = new float[height * width * 4]; cl::Image* image = cl_image.cl_image(); - const std::array origin{0, 0, 0}; - const std::array region{ + cl::array origin = {0, 0, 0}; + cl::array region = { static_cast(width), static_cast(height), 1}; cl_int err = context->GetCommandQueue().enqueueReadImage( *image, CL_TRUE, origin, region, 0, 0, image_data, nullptr, nullptr); diff --git a/lite/backends/opencl/cl_functions_test.cc b/lite/backends/opencl/cl_functions_test.cc index b041952b34c43ca98237ee33e9dceccdd58a431b..b9f6648c9956e1952b65f66abfa40d912a99ee67 100644 --- a/lite/backends/opencl/cl_functions_test.cc +++ b/lite/backends/opencl/cl_functions_test.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include #include #include -#include #include #include #include @@ -395,51 +394,74 @@ TEST(cl_test, target_wrapper_buffer_test) { } TEST(cl_test, target_wrapper_image_test) { - const std::array image_shape{28, 32}; + const size_t cl_image2d_width = 28; + const size_t cl_image2d_height = 32; + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; auto *d_image = static_cast( - TargetWrapperCL::MallocImage(image_shape, PRECISION(kFloat))); - std::array image_pitch; + TargetWrapperCL::MallocImage(cl_image2d_width, cl_image2d_height)); // Map/Unmap test - auto *h_image = static_cast( - TargetWrapperCL::MapImage(d_image, image_shape, &image_pitch)); - // row_pitch = 448 = 28 * 4 (RGBA: 4 floats) * 4 (float in bytes) - // slice_pitch = 0 - size_t row_pitch = image_pitch[0]; - size_t slice_pitch = image_pitch[1]; - CHECK_EQ(row_pitch, 448); - CHECK_EQ(slice_pitch, 0); - LOG(INFO) << "row_pitch = " << row_pitch << ", slice_pitch " << slice_pitch; + auto *h_image = + static_cast(TargetWrapperCL::MapImage(d_image, + cl_image2d_width, + cl_image2d_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + CHECK_EQ( + cl_image2d_row_pitch, + cl_image2d_width * 4 * + 4); // row_pitch = 448 = 28 * 4 (RGBA: 4 floats) * 4 (float in bytes) + CHECK_EQ(cl_image2d_slice_pitch, 0); // slice_pitch = 0 + LOG(INFO) << "cl_image2d_row_pitch = " << cl_image2d_row_pitch + << ", cl_image2d_slice_pitch " << cl_image2d_slice_pitch; for (int i = 0; i < 10; i++) { h_image[i] = 3.14f * i; } TargetWrapperCL::Unmap(d_image, h_image); - auto *h_ptr = static_cast( - TargetWrapperCL::MapImage(d_image, image_shape, &image_pitch)); + auto *h_ptr = + static_cast(TargetWrapperCL::MapImage(d_image, + cl_image2d_width, + cl_image2d_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); for (int i = 0; i < 10; i++) { EXPECT_NEAR(h_ptr[i], 3.14f * i, 1e-6); } TargetWrapperCL::Unmap(d_image, h_ptr); // Imagecpy test - std::vector h_image_cpy(28 * 4 * 32); - for (int i = 0; i < 28 * 4 * 32; i++) { + std::vector h_image_cpy(cl_image2d_width * 4 * + cl_image2d_height); // 4 for RGBA channels + for (int i = 0; i < cl_image2d_width * 4 * cl_image2d_height; i++) { h_image_cpy[i] = 3.14f; } - TargetWrapperCL::ImgcpySync( - d_image, h_image_cpy.data(), image_shape, image_pitch, IoDirection::HtoD); + TargetWrapperCL::ImgcpySync(d_image, + h_image_cpy.data(), + cl_image2d_width, + cl_image2d_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::HtoD); auto *d_image_cpy = static_cast( - TargetWrapperCL::MallocImage(image_shape, PRECISION(kFloat))); - TargetWrapperCL::ImgcpySync( - d_image_cpy, d_image, image_shape, image_pitch, IoDirection::DtoD); + TargetWrapperCL::MallocImage(cl_image2d_width, cl_image2d_height)); + TargetWrapperCL::ImgcpySync(d_image_cpy, + d_image, + cl_image2d_width, + cl_image2d_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoD); std::fill(h_image_cpy.begin(), h_image_cpy.end(), 0); TargetWrapperCL::ImgcpySync(h_image_cpy.data(), d_image_cpy, - image_shape, - image_pitch, + cl_image2d_width, + cl_image2d_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, IoDirection::DtoH); - for (int i = 0; i < 28 * 4 * 32; i++) { + for (int i = 0; i < cl_image2d_width * 4 * cl_image2d_height; i++) { EXPECT_NEAR(h_image_cpy[i], 3.14f, 1e-6); } diff --git a/lite/backends/opencl/cl_image.cc b/lite/backends/opencl/cl_image.cc index f6dcd4bbefc7735e0b18df5f536c2236fbda1809..b67f4040bff4cac15624c1440ca741d2b9dfa6ba 100644 --- a/lite/backends/opencl/cl_image.cc +++ b/lite/backends/opencl/cl_image.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "lite/backends/opencl/cl_image.h" -#include #include "lite/backends/opencl/cl_runtime.h" #include "lite/backends/opencl/cl_utility.h" #include "lite/utils/cp_logging.h" @@ -27,8 +26,9 @@ std::ostream& operator<<(std::ostream& os, const CLImage& cl_image) { float* image_data = new float[height * width * 4]; cl::Image* image = cl_image.cl_image(); - const std::array origin{0, 0, 0}; - const std::array region{ + + cl::array origin = {0, 0, 0}; + cl::array region = { static_cast(width), static_cast(height), 1}; cl_int err = CLRuntime::Global()->command_queue().enqueueReadImage( *image, CL_TRUE, origin, region, 0, 0, image_data, nullptr, nullptr); diff --git a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..c9c16581d67db0c9143e91e13249edfd5901ddb8 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +// buffer -> image2d +__kernel void buffer_to_image2d(__global CL_DTYPE *in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_C, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_H; + const int out_h = out_nh % out_H; + + const int in_n = out_n; + const int in_c0 = out_c * 4 + 0; + const int in_c1 = out_c * 4 + 1; + const int in_c2 = out_c * 4 + 2; + const int in_c3 = out_c * 4 + 3; + const int in_h = out_h; + const int in_w = out_w; + + int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n * Stride2 + in_c3 * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + CL_DTYPE4 output = (CL_DTYPE4)0.0f; + output.x = convert_float(in[input_pos0]); + if(out_C - 4 * out_c >= 2){ + output.y = convert_float(in[input_pos1]); + } + if(out_C - 4 * out_c >= 3){ + output.z = convert_float(in[input_pos2]); + } + if(out_C - 4 * out_c >= 4){ + output.w = convert_float(in[input_pos3]); + } + write_imagef(output_image, output_pos, output); +} + +// image2d -> buffer +__kernel void image2d_to_buffer(__read_only image2d_t input, + __private const int in_width, + __private const int in_height, + __global CL_DTYPE* out, + __private const int size_ch, + __private const int size_block, + __private const int size_batch, + __private const int C) { + const int in_c = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + const int in_n = in_nh / in_height; + const int in_h = in_nh % in_height; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + const int pos_x = mad24(in_c, in_width, in_w); + CL_DTYPE4 in = read_imagef(input, sampler, (int2)(pos_x, in_nh)); + + const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; + out[index] = convert_float(in.x); + if (C - 4 * in_c >= 2) { + out[index + size_ch] = convert_float(in.y); + } + if(C - 4 * in_c >= 3) { + out[index + size_ch * 2] = convert_float(in.z); + } + if(C - 4 * in_c >= 4) { + out[index + size_ch * 3] = convert_float(in.w); + } +} + +// image2d -> buffer +__kernel void image2d_to_buffer_2d(__private const int in_height, + __private const int in_width, + __read_only image2d_t input, + __global CL_DTYPE* out) { + const int in_w = get_global_id(1); + const int in_h = get_global_id(2); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 in = read_imagef(input, sampler, (int2)(in_w, in_h)); + + const int index = (in_h * in_width + in_w) * 4; + out[index] = convert_float(in.x); + out[index + 1] = convert_float(in.y); + out[index + 2] = convert_float(in.z); + out[index + 3] = convert_float(in.w); +} diff --git a/lite/backends/opencl/cl_kernel/cl_common.h b/lite/backends/opencl/cl_kernel/cl_common.h index ec67aa676d40cae421309661b36eddc8f56485d5..7f901fc994ffd82ccfe99f59614a3422260d0dc5 100644 --- a/lite/backends/opencl/cl_kernel/cl_common.h +++ b/lite/backends/opencl/cl_kernel/cl_common.h @@ -16,10 +16,35 @@ limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable +// Data type: pass one of macros on host: [CL_DTYPE_float, CL_DYPE_half] +#ifdef CL_DTYPE_float +#define CL_DTYPE float +#define CL_DTYPE_CHAR f +#endif + +#ifdef CL_DTYPE_half +#define CL_DTYPE half +#define CL_DTYPE_CHAR h +#endif + +// Note: macro name replacement need twice parser #define GET_VEC_TYPE(type__, size__) type__##size__ #define VECTORIZED_TYPE(type__, size__) GET_VEC_TYPE(type__, size__) #define CL_DTYPE4 VECTORIZED_TYPE(CL_DTYPE, 4) +#define _CONVERT_TYPE_TO(value, type) convert_##type(value) +#define CONVERT_TYPE_TO(value, type) _CONVERT_TYPE_TO(value, type) + +#define _WRITE_IMG_TYPE(type_char, img, pos, value) \ + write_image##type_char(img, pos, value) +#define WRITE_IMG_TYPE(type_char, img, pos, value) \ + _WRITE_IMG_TYPE(type_char, img, pos, value) + +#define _READ_IMG_TYPE(type_char, img, pos, sampler) \ + read_image##type_char(img, sampler, pos) +#define READ_IMG_TYPE(type_char, img, pos, sampler) \ + _READ_IMG_TYPE(type_char, img, pos, sampler) + inline CL_DTYPE activation(CL_DTYPE in #ifdef PRELU , diff --git a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..a99ac79d32bcedb48354d2e179ef6c8c1ff7f997 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void relu(__read_only image2d_t input, + __write_only image2d_t output) { + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = read_imagef(input, sampler, (int2)(x, y)); + in = max((CL_DTYPE4)(0.0f), in); + write_imagef(output, (int2)(x, y), in); +} diff --git a/lite/backends/opencl/target_wrapper.cc b/lite/backends/opencl/target_wrapper.cc index eb324fcb0f0872ede6f4476ce5921e2094d45be7..575f87d0f8d0192345c6ab111db46715a809a976 100644 --- a/lite/backends/opencl/target_wrapper.cc +++ b/lite/backends/opencl/target_wrapper.cc @@ -14,11 +14,9 @@ #include "lite/backends/opencl/target_wrapper.h" #include -#include #include "lite/backends/opencl/cl_include.h" #include "lite/backends/opencl/cl_runtime.h" #include "lite/backends/opencl/cl_utility.h" - namespace paddle { namespace lite { @@ -58,18 +56,61 @@ void TargetWrapperCL::Free(void *ptr) { } } -void *TargetWrapperCL::MallocImage(const std::array &image_shape, - PrecisionType data_type) { - cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(data_type)); +template <> +void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFloat))); + cl_int status; + cl::Image2D *cl_image = + new cl::Image2D(CLRuntime::Global()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + img_format, + cl_image2d_width, + cl_image2d_height, + 0, + nullptr, + &status); + if (status != CL_SUCCESS) { + delete cl_image; + cl_image = nullptr; + } + CL_CHECK_FATAL(status); + return cl_image; +} + +template <> +void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt8))); + cl_int status; + cl::Image2D *cl_image = + new cl::Image2D(CLRuntime::Global()->context(), + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + img_format, + cl_image2d_width, + cl_image2d_height, + 0, + nullptr, + &status); + if (status != CL_SUCCESS) { + delete cl_image; + cl_image = nullptr; + } + CL_CHECK_FATAL(status); + return cl_image; +} + +template <> +void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt32))); cl_int status; - size_t width = image_shape[0]; - size_t height = image_shape[1]; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, img_format, - width, - height, + cl_image2d_width, + cl_image2d_height, 0, nullptr, &status); @@ -108,15 +149,13 @@ void *TargetWrapperCL::Map(void *buffer, size_t offset, size_t size) { } void *TargetWrapperCL::MapImage(void *image, - const std::array &image_shape, - std::array *image_pitch) { + const size_t cl_image2d_width, + const size_t cl_image2d_height, + size_t cl_image2d_row_pitch, + size_t cl_image2d_slice_pitch) { cl::Image2D *cl_image = static_cast(image); - size_t width = image_shape[0]; - size_t height = image_shape[1]; - size_t *row_pitch = image_pitch->data(); - size_t *slice_pitch = image_pitch->data() + 1; - std::array origin{{0, 0, 0}}; - std::array region{{width, height, 1}}; + cl::array origin = {0, 0, 0}; + cl::array region = {cl_image2d_width, cl_image2d_height, 1}; cl_int status; void *mapped_ptr = CLRuntime::Global()->command_queue().enqueueMapImage( *cl_image, @@ -124,8 +163,8 @@ void *TargetWrapperCL::MapImage(void *image, CL_MAP_READ | CL_MAP_WRITE, origin, region, - row_pitch, - slice_pitch, + &cl_image2d_row_pitch, + &cl_image2d_slice_pitch, nullptr, nullptr, &status); @@ -231,15 +270,13 @@ void TargetWrapperCL::MemcpyAsync(void *dst, void TargetWrapperCL::ImgcpySync(void *dst, const void *src, - const std::array &image_shape, - const std::array &image_pitch, + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch, IoDirection dir) { - size_t width = image_shape[0]; - size_t height = image_shape[1]; - size_t row_pitch = image_pitch[0]; - size_t slice_pitch = image_pitch[1]; - std::array origin{{0, 0, 0}}; - std::array region{{width, height, 1}}; + cl::array origin = {0, 0, 0}; + cl::array region = {cl_image2d_width, cl_image2d_height, 1}; cl_int status; cl::Event event; auto stream = CLRuntime::Global()->command_queue(); @@ -260,8 +297,8 @@ void TargetWrapperCL::ImgcpySync(void *dst, CL_TRUE, origin, region, - row_pitch, - slice_pitch, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, src, nullptr, nullptr); @@ -272,8 +309,8 @@ void TargetWrapperCL::ImgcpySync(void *dst, CL_TRUE, origin, region, - row_pitch, - slice_pitch, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, dst, nullptr, nullptr); @@ -286,16 +323,14 @@ void TargetWrapperCL::ImgcpySync(void *dst, void TargetWrapperCL::ImgcpyAsync(void *dst, const void *src, - const std::array &image_shape, - const std::array &image_pitch, + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch, IoDirection dir, const stream_t &stream) { - size_t width = image_shape[0]; - size_t height = image_shape[1]; - size_t row_pitch = image_pitch[0]; - size_t slice_pitch = image_pitch[1]; - std::array origin{{0, 0, 0}}; - std::array region{{width, height, 1}}; + cl::array origin = {0, 0, 0}; + cl::array region = {cl_image2d_width, cl_image2d_height, 1}; cl_int status; switch (dir) { case IoDirection::DtoD: @@ -313,8 +348,8 @@ void TargetWrapperCL::ImgcpyAsync(void *dst, CL_FALSE, origin, region, - row_pitch, - slice_pitch, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, src, nullptr, nullptr); @@ -325,8 +360,8 @@ void TargetWrapperCL::ImgcpyAsync(void *dst, CL_FALSE, origin, region, - row_pitch, - slice_pitch, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, dst, nullptr, nullptr); diff --git a/lite/backends/opencl/target_wrapper.h b/lite/backends/opencl/target_wrapper.h index 8ff8e6fd4027bd543bdfc221b1fdf5af9d7c4000..7753448052e17ac739f730c9fabcaf9533e0045e 100644 --- a/lite/backends/opencl/target_wrapper.h +++ b/lite/backends/opencl/target_wrapper.h @@ -14,7 +14,6 @@ #pragma once -#include #include "lite/backends/opencl/cl_include.h" #include "lite/core/target_wrapper.h" @@ -47,14 +46,17 @@ class TargetWrapper { static void* Malloc(size_t size); static void Free(void* ptr); - static void* MallocImage(const std::array& image_shape, - PrecisionType data_type); + template + static void* MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height); static void FreeImage(void* image); static void* Map(void* buffer, size_t offset, size_t size); static void* MapImage(void* image, - const std::array& image_shape, - std::array* image_pitch); + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch); static void Unmap(void* cl_obj, void* mapped_ptr); static void MemcpySync(void* dst, @@ -68,13 +70,17 @@ class TargetWrapper { const stream_t& stream); static void ImgcpySync(void* dst, const void* src, - const std::array& image_shape, - const std::array& image_pitch, + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch, IoDirection dir); static void ImgcpyAsync(void* dst, const void* src, - const std::array& image_shape, - const std::array& image_pitch, + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch, IoDirection dir, const stream_t& stream); }; diff --git a/lite/backends/x86/CMakeLists.txt b/lite/backends/x86/CMakeLists.txt index 992bf5536eb7cfae392e2f62313ec9df1286e1c0..63b41ae77d0f3949e3d1de13f9db5ca99b4f1c41 100644 --- a/lite/backends/x86/CMakeLists.txt +++ b/lite/backends/x86/CMakeLists.txt @@ -4,10 +4,12 @@ endif() configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) configure_file(warpctc_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/warpctc_lib_path.h) - -lite_cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -#lite_cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) lite_cc_library(target_wrapper_x86 SRCS target_wrapper.cc) +if (LITE_ON_MODEL_OPTIMIZE_TOOL) + return() +endif(LITE_ON_MODEL_OPTIMIZE_TOOL) +lite_cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) +lite_cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) lite_cc_library(x86_cpu_info SRCS cpu_info.cc DEPS xbyak) add_subdirectory(jit) diff --git a/lite/backends/x86/dynamic_loader.cc b/lite/backends/x86/dynamic_loader.cc index 3a3e0e1dd422093d7ea8e1c04dcaa2733890d51b..75bb528f38664fc1061653e1036b73eed74daae9 100644 --- a/lite/backends/x86/dynamic_loader.cc +++ b/lite/backends/x86/dynamic_loader.cc @@ -54,8 +54,8 @@ DEFINE_string( DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); namespace paddle { -namespace platform { -namespace dynload { +namespace lite { +namespace x86 { static constexpr char cupti_lib_path[] = CUPTI_LIB_PATH; static constexpr char warpctc_lib_path[] = WARPCTC_LIB_PATH; @@ -153,16 +153,18 @@ static inline void* GetDsoHandleFromSearchPath(const std::string& search_root, dso_handle = GetDsoHandleFromDefaultPath(dlPath, dynload_flags); } } - auto error_msg = - "Failed to find dynamic library: %s ( %s ) \n Please specify " - "its path correctly using following ways: \n Method. set " - "environment variable LD_LIBRARY_PATH on Linux or " - "DYLD_LIBRARY_PATH on Mac OS. \n For instance, issue command: " - "export LD_LIBRARY_PATH=... \n Note: After Mac OS 10.11, " - "using the DYLD_LIBRARY_PATH is impossible unless System " - "Integrity Protection (SIP) is disabled."; +/* +auto error_msg = + "Failed to find dynamic library: %s ( %s ) \n Please specify " + "its path correctly using following ways: \n Method. set " + "environment variable LD_LIBRARY_PATH on Linux or " + "DYLD_LIBRARY_PATH on Mac OS. \n For instance, issue command: " + "export LD_LIBRARY_PATH=... \n Note: After Mac OS 10.11, " + "using the DYLD_LIBRARY_PATH is impossible unless System " + "Integrity Protection (SIP) is disabled."; +*/ #if !defined(_WIN32) - auto errorno = dlerror(); +// auto errorno = dlerror(); #else auto errorno = GetLastError(); #endif // !_WIN32 @@ -258,6 +260,6 @@ void* GetMKLMLDsoHandle() { #endif } -} // namespace dynload -} // namespace platform +} // namespace x86 +} // namespace lite } // namespace paddle diff --git a/lite/backends/x86/jit/more/CMakeLists.txt b/lite/backends/x86/jit/more/CMakeLists.txt index 94927dd66bc5f9cf5e319772c8d4debb20cf029a..2ddbbcd16a3ffef560581592e3a009c61844d4d5 100644 --- a/lite/backends/x86/jit/more/CMakeLists.txt +++ b/lite/backends/x86/jit/more/CMakeLists.txt @@ -4,9 +4,9 @@ function(USE_JITKERNEL_MORE TARGET TYPE) endfunction() # enable it latter -# if(WITH_MKLML) -# add_subdirectory(mkl) -# endif() + if(WITH_MKLML) + add_subdirectory(mkl) + endif() if(WITH_AVX) add_subdirectory(intrinsic) diff --git a/lite/backends/x86/math/CMakeLists.txt b/lite/backends/x86/math/CMakeLists.txt index 35208c2c7aefe996247dc83d3055223f931c61ee..2dea4364d5ee2d11d6d266935fad2a1180954369 100644 --- a/lite/backends/x86/math/CMakeLists.txt +++ b/lite/backends/x86/math/CMakeLists.txt @@ -16,7 +16,7 @@ function(math_library TARGET) endif() list(LENGTH cc_srcs cc_srcs_len) - lite_cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps} eigen3) + lite_cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps} eigen3 dynload_mklml) endfunction() # please add new math_library in alphabetical order @@ -30,13 +30,13 @@ math_library(sample_prob) math_library(sampler) math_library(gru_compute DEPS activation_functions math_function) -## math_library(lstm_compute DEPS activation_functions) +math_library(lstm_compute DEPS activation_functions) -lite_cc_library(blas SRCS blas.cc DEPS cblas framework_proto eigen3) -math_library(math_function DEPS blas) +lite_cc_library(blas SRCS blas.cc DEPS cblas framework_proto eigen3 dynload_mklml) +math_library(math_function DEPS blas dynload_mklml) math_library(maxouting) math_library(pooling) -# math_library(selected_rows_functor DEPS selected_rows math_function blas) +math_library(selected_rows_functor DEPS selected_rows math_function blas) math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function jit_kernel_helper) diff --git a/lite/backends/x86/math/beam_search.cc b/lite/backends/x86/math/beam_search.cc index 93726afcc22446526952f3d7d9641f4abcfc10ee..bbe35b4de5508c70496e5c8566c8d1b982a7155c 100644 --- a/lite/backends/x86/math/beam_search.cc +++ b/lite/backends/x86/math/beam_search.cc @@ -49,6 +49,7 @@ class BeamSearchFunctor { end_id, is_accumulated); auto selected_items = ToMap(items, high_level.back()); + /* if (FLAGS_v == 3) { VLOG(3) << "selected_items:"; for (size_t i = 0; i < selected_items.size(); ++i) { @@ -58,6 +59,7 @@ class BeamSearchFunctor { } } } + */ PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id); // calculate the output tensor's height @@ -69,7 +71,8 @@ class BeamSearchFunctor { // the output tensor shape should be [num_instances, 1] // auto dims = framework::make_ddim( // std::vector({static_cast(num_instances), 1})); - lite::DDim dims(std::vector({num_instances, 1L})); + lite::DDim dims( + std::vector({static_cast(num_instances), 1L})); selected_ids->Resize(dims); auto *selected_ids_data = selected_ids->mutable_data(TARGET(kX86)); @@ -296,7 +299,7 @@ class BeamSearchFunctor { result.emplace_back(top_beam); } - + /* if (FLAGS_v == 3) { VLOG(3) << "SelectTopBeamSizeItems result size " << result.size(); for (auto &items : result) { @@ -306,7 +309,7 @@ class BeamSearchFunctor { } } } - + */ return result; } }; diff --git a/lite/backends/x86/math/blas_impl.h b/lite/backends/x86/math/blas_impl.h index 36d76c783cfe06d38b65d548e5dd4dbb16304521..72d0736268f342187f0be8c6348f5bed75df30ea 100644 --- a/lite/backends/x86/math/blas_impl.h +++ b/lite/backends/x86/math/blas_impl.h @@ -463,9 +463,9 @@ void Blas::MatMul(const lite::Tensor &mat_a, auto dim_out = mat_out->dims(); PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, "The input and output of matmul be matrix"); - PADDLE_ENFORCE( - mat_a.target() == mat_b.target() && mat_a.target() == mat_out->target(), - "The targets of matrices must be same"); + // PADDLE_ENFORCE( + // mat_a.target() == mat_b.target() && mat_a.target() == mat_out->target(), + // "The targets of matrices must be same"); int M = dim_out[0]; int N = dim_out[1]; @@ -483,7 +483,7 @@ void Blas::MatMul(const lite::Tensor &mat_a, mat_a.data(), mat_b.data(), beta, - mat_out->data()); + mat_out->mutable_data()); } template <> @@ -759,7 +759,7 @@ void Blas::MatMul(const lite::Tensor &mat_a, mat_a.data(), mat_b.data(), beta, - mat_out->data()); + mat_out->mutable_data()); } else { PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); @@ -773,7 +773,7 @@ void Blas::MatMul(const lite::Tensor &mat_a, mat_a.data(), mat_b.data(), beta, - mat_out->data(), + mat_out->mutable_data(), dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.stride_, dim_b.stride_); diff --git a/lite/backends/x86/math/detail/activation_functions.h b/lite/backends/x86/math/detail/activation_functions.h index cb215df72205ed59c22698fc1fb914cd3736ce22..6a13a3d471e10970b36120a12b21a36256350803 100644 --- a/lite/backends/x86/math/detail/activation_functions.h +++ b/lite/backends/x86/math/detail/activation_functions.h @@ -45,8 +45,10 @@ inline ActivationType GetActivationType(const std::string &type) { } else if (type == "identity" || type == "") { return ActivationType::kIdentity; } - PADDLE_ENFORCE(false, "Not support type %s", type); + LOG(ERROR) << "Not support type " << type; + // PADDLE_ENFORCE(false, "Not support type %s", type); // PADDLE_THROW("Not support type %s.", type); + return ActivationType(); } namespace forward { diff --git a/lite/backends/x86/math/detail/lstm_cpu_kernel.h b/lite/backends/x86/math/detail/lstm_cpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3091cc5679e053d0d93855822b14abc1f412b753 --- /dev/null +++ b/lite/backends/x86/math/detail/lstm_cpu_kernel.h @@ -0,0 +1,431 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/backends/x86/math/detail/activation_functions.h" +#include "lite/backends/x86/math/lstm_compute.h" + +#if defined(_WIN32) +#if defined(__AVX2__) || defined(__AVX__) +inline __m256 operator+=(__m256 a, __m256 b) { return _mm256_add_ps(a, b); } +#endif +#endif + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { +namespace detail { + +#ifndef __NVCC__ + +template +void naive_lstm_forward_one_sequence(Op op, + LstmMetaValue value, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_checkI; + T r_checkF; + T r_checkO; + T r_state; + T r_prev_state = 0; + T r_state_atv; + T r_out; + + T *value_in = value.gate_value; + T *value_ig = value.gate_value + frame_size; + T *value_fg = value.gate_value + frame_size * 2; + T *value_og = value.gate_value + frame_size * 3; + + for (int i = 0; i < frame_size; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + r_checkI = value.check_ig ? value.check_ig[i] : 0; + r_checkF = value.check_fg ? value.check_fg[i] : 0; + r_checkO = value.check_og ? value.check_og[i] : 0; + + if (value.prev_state_value) { + r_prev_state = value.prev_state_value[i]; + } + + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_prev_state, + &r_state, + &r_state_atv, + &r_out, + &r_checkI, + &r_checkF, + &r_checkO, + &cell_clip, + active_node, + active_gate, + active_state); + + value_in[i] = r_value_in; + value_ig[i] = r_value_ig; + value_fg[i] = r_value_fg; + value_og[i] = r_value_og; + value.state_value[i] = r_state; + value.state_active_value[i] = r_state_atv; + value.output_value[i] = r_out; + } +} + +template +void naive_lstm_backward_one_sequence(Op op, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_grad_in; + T r_grad_ig; + T r_grad_fg; + T r_grad_og; + T r_prev_state = 0; + T r_prev_state_grad; + T r_state; + T r_state_grad; + T r_state_atv; + T r_output_grad; + T r_checkI; + T r_checkF; + T r_checkO; + T r_checkIGrad; + T r_checkFGrad; + T r_checkOGrad; + + T *value_in = value.gate_value; + T *value_ig = value.gate_value + frame_size; + T *value_fg = value.gate_value + frame_size * 2; + T *value_og = value.gate_value + frame_size * 3; + T *grad_in = grad.gate_grad; + T *grad_ig = grad.gate_grad + frame_size; + T *grad_fg = grad.gate_grad + frame_size * 2; + T *grad_og = grad.gate_grad + frame_size * 3; + + for (int i = 0; i < frame_size; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + r_checkI = value.check_ig ? value.check_ig[i] : 0; + r_checkF = value.check_fg ? value.check_fg[i] : 0; + r_checkO = value.check_og ? value.check_og[i] : 0; + r_state = value.state_value[i]; + r_state_atv = value.state_active_value[i]; + r_output_grad = grad.output_grad[i]; + r_state_grad = grad.state_grad[i]; + if (value.prev_state_value) { + r_prev_state = value.prev_state_value[i]; + } + + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_grad_in, + &r_grad_ig, + &r_grad_fg, + &r_grad_og, + &r_prev_state, + &r_prev_state_grad, + &r_state, + &r_state_grad, + &r_state_atv, + &r_output_grad, + &r_checkI, + &r_checkF, + &r_checkO, + &r_checkIGrad, + &r_checkFGrad, + &r_checkOGrad, + &cell_clip, + active_node, + active_gate, + active_state); + + grad_in[i] = r_grad_in; + grad_ig[i] = r_grad_ig; + grad_fg[i] = r_grad_fg; + grad_og[i] = r_grad_og; + grad.state_grad[i] = r_state_grad; + + if (grad.prev_state_grad) grad.prev_state_grad[i] = r_prev_state_grad; + if (value.prev_state_value) { + if (grad.check_ig_grad) grad.check_ig_grad[i] += r_checkIGrad; + if (grad.check_fg_grad) grad.check_fg_grad[i] += r_checkFGrad; + } + if (grad.check_og_grad) grad.check_og_grad[i] += r_checkOGrad; + } +} + +template +void avx_lstm_forward_one_sequence(Op op, + LstmMetaValue value, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { +#ifdef __AVX__ + __m256 r_value_in; + __m256 r_value_ig; + __m256 r_value_fg; + __m256 r_value_og; + __m256 r_checkI = _mm256_set1_ps(0.0f); + __m256 r_checkF = _mm256_set1_ps(0.0f); + __m256 r_checkO = _mm256_set1_ps(0.0f); + __m256 r_state; + __m256 r_prev_state = _mm256_set1_ps(0.0f); + __m256 r_state_atv; + __m256 r_out; + + __m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); + __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); + __m256 *value_fg = + reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); + __m256 *value_og = + reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + if (value.check_ig) { + r_checkI = (reinterpret_cast<__m256 *>(value.check_ig))[i]; + r_checkF = (reinterpret_cast<__m256 *>(value.check_fg))[i]; + r_checkO = (reinterpret_cast<__m256 *>(value.check_og))[i]; + } + + if (value.prev_state_value) { + r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; + } + + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_prev_state, + &r_state, + &r_state_atv, + &r_out, + &r_checkI, + &r_checkF, + &r_checkO, + &cell_clip, + active_node, + active_gate, + active_state); + + value_in[i] = r_value_in; + value_ig[i] = r_value_ig; + value_fg[i] = r_value_fg; + value_og[i] = r_value_og; + (reinterpret_cast<__m256 *>(value.state_value))[i] = r_state; + (reinterpret_cast<__m256 *>(value.state_active_value))[i] = r_state_atv; + (reinterpret_cast<__m256 *>(value.output_value))[i] = r_out; + } +#endif +} + +template +void avx_lstm_backward_one_sequence(Op op, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { +#ifdef __AVX__ + __m256 r_value_in; + __m256 r_value_ig; + __m256 r_value_fg; + __m256 r_value_og; + __m256 r_grad_in; + __m256 r_grad_ig; + __m256 r_grad_fg; + __m256 r_grad_og; + __m256 r_prev_state = _mm256_set1_ps(0.0f); + __m256 r_prev_state_grad; + __m256 r_state_grad; + __m256 r_state; + __m256 r_state_atv; + __m256 r_output_grad; + __m256 r_checkI = _mm256_set1_ps(0.0f); + __m256 r_checkF = _mm256_set1_ps(0.0f); + __m256 r_checkO = _mm256_set1_ps(0.0f); + __m256 r_checkIGrad; + __m256 r_checkFGrad; + __m256 r_checkOGrad; + + __m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); + __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); + __m256 *value_fg = + reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); + __m256 *value_og = + reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); + __m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad); + __m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size); + __m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2); + __m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + if (value.check_ig) { + r_checkI = (reinterpret_cast<__m256 *>(value.check_ig))[i]; + r_checkF = (reinterpret_cast<__m256 *>(value.check_fg))[i]; + r_checkO = (reinterpret_cast<__m256 *>(value.check_og))[i]; + } + r_state = (reinterpret_cast<__m256 *>(value.state_value))[i]; + r_state_atv = (reinterpret_cast<__m256 *>(value.state_active_value))[i]; + r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i]; + r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i]; + if (value.prev_state_value) { + r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; + } + + op(&r_value_in, + &r_value_ig, + &r_value_fg, + &r_value_og, + &r_grad_in, + &r_grad_ig, + &r_grad_fg, + &r_grad_og, + &r_prev_state, + &r_prev_state_grad, + &r_state, + &r_state_grad, + &r_state_atv, + &r_output_grad, + &r_checkI, + &r_checkF, + &r_checkO, + &r_checkIGrad, + &r_checkFGrad, + &r_checkOGrad, + &cell_clip, + active_node, + active_gate, + active_state); + + grad_in[i] = r_grad_in; + grad_ig[i] = r_grad_ig; + grad_fg[i] = r_grad_fg; + grad_og[i] = r_grad_og; + (reinterpret_cast<__m256 *>(grad.state_grad))[i] = r_state_grad; + + if (grad.prev_state_grad) + (reinterpret_cast<__m256 *>(grad.prev_state_grad))[i] = r_prev_state_grad; + if (value.prev_state_value) { + if (grad.check_ig_grad) + (reinterpret_cast<__m256 *>(grad.check_ig_grad))[i] += r_checkIGrad; + if (grad.check_fg_grad) + (reinterpret_cast<__m256 *>(grad.check_fg_grad))[i] += r_checkFGrad; + } + if (grad.check_og_grad) + (reinterpret_cast<__m256 *>(grad.check_og_grad))[i] += r_checkOGrad; + } +#endif +} + +template +void cpu_lstm_forward(Op op, + LstmMetaValue value, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, + value, + frame_size, + cell_clip, + active_node, + active_gate, + active_state); + } else { + naive_lstm_forward_one_sequence(op, + value, + frame_size, + cell_clip, + active_node, + active_gate, + active_state); + } +} + +template +void cpu_lstm_backward(Op op, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + T cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, + value, + grad, + frame_size, + cell_clip, + active_node, + active_gate, + active_state); + } else { + naive_lstm_backward_one_sequence(op, + value, + grad, + frame_size, + cell_clip, + active_node, + active_gate, + active_state); + } +} + +#endif + +} // namespace detail +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/detail/lstm_kernel.h b/lite/backends/x86/math/detail/lstm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1286f2e8b70b85c0a5aea709b99c854771fea72f --- /dev/null +++ b/lite/backends/x86/math/detail/lstm_kernel.h @@ -0,0 +1,236 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/backends/x86/math/detail/activation_functions.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { +namespace detail { + +namespace forward { + +template +class lstm { + public: + HOSTDEVICE void operator()(T *value_in, + T *value_ig, + T *value_fg, + T *value_og, + T *prev_state, + T *state, + T *state_atv, + T *output, + T *checkI, + T *checkF, + T *checkO, + T *cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + *value_in = activation(*value_in, active_node); + *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); + *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); + *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); + + if (*cell_clip > 0.0) { + if (*state < -1.0 * (*cell_clip)) { + *state = -1.0 * (*cell_clip); + } + if (*state > *cell_clip) { + *state = *cell_clip; + } + } + *value_og = activation(*value_og + (*state) * (*checkO), active_gate); + *state_atv = activation(*state, active_state); + *output = (*value_og) * (*state_atv); + } +#ifndef __NVCC__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default + static const bool avx = false; +#else + // Only float support AVX optimization + static const bool avx = std::is_same::value; + + HOSTDEVICE void operator()(__m256 *value_in, + __m256 *value_ig, + __m256 *value_fg, + __m256 *value_og, + __m256 *prev_state, + __m256 *state, + __m256 *state_atv, + __m256 *output, + __m256 *checkI, + __m256 *checkF, + __m256 *checkO, + T *cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + *value_in = activation(*value_in, active_node); + *value_ig = activation( + _mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)), + active_gate); + *value_fg = activation( + _mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)), + active_gate); + *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), + _mm256_mul_ps(*prev_state, *value_fg)); + + if (*cell_clip > 0.0f) { + __m256 min = _mm256_set1_ps(0.0f - *cell_clip); + __m256 max = _mm256_set1_ps(*cell_clip); + *state = _mm256_min_ps(max, *state); + *state = _mm256_max_ps(min, *state); + } + *value_og = activation( + _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate); + *state_atv = activation(*state, active_state); + *output = _mm256_mul_ps(*value_og, *state_atv); + } +#endif +#endif +}; + +} // namespace forward + +namespace backward { + +template +class lstm { + public: + HOSTDEVICE void operator()(T *value_in, + T *value_ig, + T *value_fg, + T *value_og, + T *grad_in, + T *grad_ig, + T *grad_fg, + T *grad_og, + T *prev_state, + T *prev_state_grad, + T *state, + T *state_grad, + T *state_atv, + T *output_grad, + T *checkI, + T *checkF, + T *checkO, + T *checkIGrad, + T *checkFGrad, + T *checkOGrad, + T *cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + *grad_og = + activation((*output_grad) * (*state_atv), *value_og, active_gate); + if (*cell_clip > 0.0f) { + if (*state >= (*cell_clip) || *state <= (0.0f - (*cell_clip))) { + *state_grad = 0.0f; + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + } + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + } + + *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); + *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); + *grad_fg = + activation((*state_grad) * (*prev_state), *value_fg, active_gate); + *prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) + + (*state_grad) * (*value_fg); + *checkIGrad = (*grad_ig) * (*prev_state); + *checkFGrad = (*grad_fg) * (*prev_state); + *checkOGrad = (*grad_og) * (*state); + } +#ifndef __NVCC__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default + static const bool avx = false; +#else + // Only float support AVX optimization + static const bool avx = std::is_same::value; + HOSTDEVICE void operator()(__m256 *value_in, + __m256 *value_ig, + __m256 *value_fg, + __m256 *value_og, + __m256 *grad_in, + __m256 *grad_ig, + __m256 *grad_fg, + __m256 *grad_og, + __m256 *prev_state, + __m256 *prev_state_grad, + __m256 *state, + __m256 *state_grad, + __m256 *state_atv, + __m256 *output_grad, + __m256 *checkI, + __m256 *checkF, + __m256 *checkO, + __m256 *checkIGrad, + __m256 *checkFGrad, + __m256 *checkOGrad, + T *cell_clip, + ActivationType active_node, + ActivationType active_gate, + ActivationType active_state) { + *grad_og = activation( + _mm256_mul_ps(*output_grad, *state_atv), *value_og, active_gate); + if (*cell_clip > 0.0f) { + T *state_ = reinterpret_cast(state); + if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) { + *state_grad = _mm256_set1_ps(0.0f); + } else { + *state_grad = + _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), + *state_atv, + active_state), + *state_grad); + *state_grad = + _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); + } + } + *grad_in = activation( + _mm256_mul_ps(*state_grad, *value_ig), *value_in, active_node); + *grad_ig = activation( + _mm256_mul_ps(*state_grad, *value_in), *value_ig, active_gate); + *grad_fg = activation( + _mm256_mul_ps(*state_grad, *prev_state), *value_fg, active_gate); + *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI), + _mm256_mul_ps(*grad_fg, *checkF)); + *prev_state_grad = + _mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad); + *checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state); + *checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state); + *checkOGrad = _mm256_mul_ps(*grad_og, *state); + } +#endif +#endif +}; + +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/lstm_compute.cc b/lite/backends/x86/math/lstm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..639aff02fa84f2d3a3acc726915a2349365bb0f2 --- /dev/null +++ b/lite/backends/x86/math/lstm_compute.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/x86/math/lstm_compute.h" +#include "lite/backends/x86/math/detail/lstm_cpu_kernel.h" +#include "lite/backends/x86/math/detail/lstm_kernel.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(const lite::X86Context& context, + LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, + const detail::ActivationType& gate_act, + const detail::ActivationType& cell_act, + const detail::ActivationType& cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_forward(detail::forward::lstm(), + value, + frame_size, + cell_clip, + cand_act, + gate_act, + cell_act); + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const lite::X86Context& context, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, + const detail::ActivationType& gate_act, + const detail::ActivationType& cell_act, + const detail::ActivationType& cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_backward(detail::backward::lstm(), + value, + grad, + frame_size, + cell_clip, + cand_act, + gate_act, + cell_act); + + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; + } + + grad.gate_grad += frame_size * 4; + grad.state_grad += frame_size; + grad.state_active_grad += frame_size; + grad.output_grad += frame_size; + if (grad.prev_state_grad) { + grad.prev_state_grad += frame_size; + } + } + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/lstm_compute.h b/lite/backends/x86/math/lstm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..ddb7bea9995ebcca978be97f8295eb07b0e4e17e --- /dev/null +++ b/lite/backends/x86/math/lstm_compute.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/backends/x86/math/detail/activation_functions.h" +#include "lite/core/context.h" +#include "lite/utils/paddle_enforce.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +template +struct LstmMetaValue { + T *gate_value; + T *prev_state_value; + T *state_value; + T *state_active_value; + T *output_value; + T *check_ig; + T *check_fg; + T *check_og; +}; + +template +struct LstmMetaGrad { + T *gate_grad; + T *prev_state_grad; + T *state_grad; + T *state_active_grad; + T *output_grad; + T *check_ig_grad; + T *check_fg_grad; + T *check_og_grad; +}; + +template +class LstmUnitFunctor { + public: + static void compute(const lite::Context &context, + LstmMetaValue value, + int frame_size, + int batch_size, + T cell_clip, + const detail::ActivationType &gate_act, + const detail::ActivationType &cell_act, + const detail::ActivationType &cand_act); +}; + +template +class LstmUnitGradFunctor { + public: + static void compute(const lite::Context &context, + LstmMetaValue value, + LstmMetaGrad grad, + int frame_size, + int batch_size, + T cell_clip, + const detail::ActivationType &gate_act, + const detail::ActivationType &cell_act, + const detail::ActivationType &cand_act); +}; + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/pooling.cc b/lite/backends/x86/math/pooling.cc index e700c5f7c794ae07dc9ddc218f732a3eff049acf..9da239f9c63371350403cc0bd0eecc94eab87590 100644 --- a/lite/backends/x86/math/pooling.cc +++ b/lite/backends/x86/math/pooling.cc @@ -30,7 +30,7 @@ template class Pool2dFunctor { public: void operator()(const lite::X86Context& context, - const lite::Tensor& input, + const lite::Tensor* input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, @@ -38,9 +38,9 @@ class Pool2dFunctor { bool exclusive, bool adaptive, lite::Tensor* output) { - const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; + const int batch_size = input->dims()[0]; + const int input_height = input->dims()[2]; + const int input_width = input->dims()[3]; const int output_channels = output->dims()[1]; const int output_height = output->dims()[2]; const int output_width = output->dims()[3]; @@ -54,7 +54,7 @@ class Pool2dFunctor { const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; - const T* input_data = input.data(); + const T* input_data = input->data(); T* output_data = output->mutable_data(lite::TargetType::kX86); int hstart, hend; diff --git a/lite/backends/x86/math/pooling.h b/lite/backends/x86/math/pooling.h index 64015e32c883a95ba8b3c1419035f04d325bb1e0..394522559b6859cac4a036717e3632a9e7b3090e 100644 --- a/lite/backends/x86/math/pooling.h +++ b/lite/backends/x86/math/pooling.h @@ -94,24 +94,12 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { * This is different from average pooling. So we rewrite the max_pool_grad: * MaxPool2dGradFunctor, MaxPool3dGradFunctor. */ -//#ifdef PADDLE_WITH_CUDA -// template -// class Pool2dDirectCUDAFunctor { -// public: -// void operator()(const T* input, const std::vector& input_shape, -// const std::vector& output_shape, -// const std::vector& ksize, -// const std::vector& strides, -// const std::vector& paddings, PoolProcess pool_compute, -// bool exclusive, T* output, cudaStream_t stream); -//}; -//#endif template class Pool2dFunctor { public: void operator()(const lite::Context& context, - const lite::Tensor& input, + const lite::Tensor* input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, diff --git a/lite/backends/x86/math/selected_rows_functor.cc b/lite/backends/x86/math/selected_rows_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8f1b42361832771ba04d1bdc8b3e2e05f954e29 --- /dev/null +++ b/lite/backends/x86/math/selected_rows_functor.cc @@ -0,0 +1,437 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/selected_rows_functor.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +template +struct SelectedRowsAdd { + void operator()(const lite::X86Context& context, + const fluid::SelectedRows& input1, + const fluid::SelectedRows& input2, + fluid::SelectedRows* output) { + auto in1_height = input1.height(); + PADDLE_ENFORCE_EQ(in1_height, input2.height()); + output->set_height(in1_height); + + auto& in1_rows = input1.rows(); + auto& in2_rows = input2.rows(); + std::vector out_rows; + out_rows.reserve(in1_rows.size() + in2_rows.size()); + + // concat rows + out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); + out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); + output->set_rows(out_rows); + + auto* out_value = output->mutable_value(); + auto& in1_value = input1.value(); + auto& in2_value = input2.value(); + + auto in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); + PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); + + auto* out_data = out_value->mutable_data(); + auto* in1_data = in1_value.data(); + std::copy_n(in1_data, in1_value.numel(), out_data); + + auto* in2_data = in2_value.data(); + std::copy_n(in2_data, in2_value.numel(), out_data + in1_value.numel()); + } +}; + +template struct SelectedRowsAdd; +template struct SelectedRowsAdd; + +template +struct SelectedRowsAddTensor { + void operator()(const lite::X86Context& context, + const fluid::SelectedRows& input1, + const lite::Tensor& input2, + lite::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); + PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); + + SetConstant functor; + functor(context, output, 0.0); + + auto* in1_data = in1_value.data(); + auto* out_data = output->mutable_data(); + + for (size_t i = 0; i < in1_rows.size(); i++) { + for (int64_t j = 0; j < in1_row_numel; j++) { + out_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + } + } + + auto out_eigen = fluid::EigenVector::Flatten(*output); + auto in2_eigen = fluid::EigenVector::Flatten(input2); + out_eigen.device(lite::fluid::EigenDeviceType()) = + out_eigen + in2_eigen; + } +}; + +template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; + +template +struct SelectedRowsAddTo { + void operator()(const lite::X86Context& context, + const fluid::SelectedRows& input1, + const int64_t input2_offset, + fluid::SelectedRows* input2) { + auto in1_height = input1.height(); + PADDLE_ENFORCE_EQ(in1_height, input2->height()); + + auto& in1_rows = input1.rows(); + auto& in2_rows = *(input2->mutable_rows()); + + auto& in1_value = input1.value(); + auto* in2_value = input2->mutable_value(); + + // concat rows + in2_rows.reserve(in2_rows.size() + + size_t(in1_rows.end() - in1_rows.begin())); + in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); + + auto* in1_data = in1_value.data(); + auto* in2_data = in2_value->mutable_data(); + std::copy_n(in1_data, in1_value.numel(), in2_data + input2_offset); + } +}; + +template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; + +template +struct SelectedRowsSumTo { + void operator()(const lite::X86Context& context, + const std::vector& input1, + const std::vector& input2_offsets, + fluid::SelectedRows* input2) { + // Ensure all selected rows have the same height + size_t size = 0u; + for (auto iter = input1.begin(); iter != input1.end(); ++iter) { + auto& in_rows = (*iter)->rows(); + size += in_rows.end() - in_rows.begin(); + auto in1_height = (*iter)->height(); + PADDLE_ENFORCE_EQ(in1_height, input2->height()); + } + // concat rows + std::vector in2_rows; + in2_rows.reserve(in2_rows.size() + size); + for (auto iter = input1.begin(); iter != input1.end(); ++iter) { + const std::vector& in_rows = (*iter)->rows(); + in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end()); + } + input2->set_rows(in2_rows); + + auto* in2_value = input2->mutable_value(); + T* in2_data = in2_value->mutable_data(); + auto blas = math::GetBlas(context); + size_t offset = 0u; + for (size_t i = 0u; i != input1.size(); ++i) { + auto& in_value = input1[i]->value(); + const T* in_data = in_value.data(); + offset += input2_offsets[i]; + blas.VCOPY(in_value.numel(), in_data, in2_data + offset); + } + } +}; + +template struct SelectedRowsSumTo; +template struct SelectedRowsSumTo; + +template +struct SelectedRowsAddToTensor { + void operator()(const lite::X86Context& context, + const fluid::SelectedRows& input1, + lite::Tensor* input2) { + CHECK(input1.rows().size() != 0) << "input selected rows is empty!"; + + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->mutable_data(); + + for (size_t i = 0; i < in1_rows.size(); i++) { + for (int64_t j = 0; j < in1_row_numel; j++) { + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + } + } + } +}; + +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; + +// This is a separated namespace for manipulate SelectedRows typed +// data. Like merge duplicated rows, adding two SelectedRows etc. +// +// Another group of functors is called "scatter updates", which means +// use SelectedRows to update a dense tensor with different Ops, like +// add or mul. +namespace scatter { + +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add_to(const DeviceContext& ctx, + BlasT* blas, + size_t data_len, + const T* in, + T* out) { + blas->AXPY(data_len, 1., in, out); +} + +template +typename std::enable_if< + !std::is_floating_point::value && + std::is_same::value>::type +elementwise_add_to(const DeviceContext& ctx, + BlasT* blas, + size_t data_len, + const T* in, + T* out) { + for (size_t i = 0; i < data_len; i++) { + out[i] += in[i]; + } +} + +template +struct MergeAdd { + fluid::SelectedRows operator()(const lite::X86Context& context, + const fluid::SelectedRows& input, + const bool sorted_result = false) { + fluid::SelectedRows out; + (*this)(context, input, &out, sorted_result); + return out; + } + + void operator()(const lite::X86Context& context, + const fluid::SelectedRows& input, + fluid::SelectedRows* output, + const bool sorted_result = false) { + std::vector inputs; + inputs.push_back(&input); + (*this)(context, inputs, output, sorted_result); + } + + void operator()(const lite::X86Context& context, + const std::vector& inputs, + fluid::SelectedRows* output, + const bool sorted_result = false) { + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; + } + const fluid::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (in->rows().size() > 0) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); + fluid::SelectedRows& out = *output; + std::set merged_row_set; + size_t row_num = 0; + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + PADDLE_ENFORCE_EQ(input_width, + input->value().dims()[1], + "all input should have same " + "dimension except for the first one"); + PADDLE_ENFORCE_EQ( + input_height, input->height(), "all input should have same height"); + row_num += input->rows().size(); + merged_row_set.insert(input->rows().begin(), input->rows().end()); + } + + out.set_height(input_height); + lite::DDim dims(std::vector( + {static_cast(merged_row_set.size()), input_width})); + out.mutable_value()->Resize(dims); + auto* out_data = out.mutable_value()->mutable_data(); + + if (merged_row_set.size() == row_num && !sorted_result) { + // no duplicated ids, just concat the result together + std::vector merge_rows; + merge_rows.reserve(row_num); + // concat rows + for (auto* in : inputs) { + merge_rows.insert( + merge_rows.end(), in->rows().begin(), in->rows().end()); + } + out.set_rows(merge_rows); + int64_t copied_numel = 0; + for (auto* in : inputs) { + auto* in_data = in->value().data(); + auto in_numel = in->value().numel(); + std::copy_n(in_data, in_numel, out_data + copied_numel); + copied_numel += in_numel; + } + } else { + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); + + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } + + out.set_rows(merge_rows); + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; + } + + auto blas = math::GetBlas(context); + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + elementwise_add_to( + context, + &blas, + static_cast(input_width), + &input_data[i * input_width], + &out_data[out_i * input_width]); + } + } + } + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +struct UpdateToTensor { + void operator()(const lite::X86Context& context, + const ScatterOps& op, + const fluid::SelectedRows& input1, + lite::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::ADD: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUB: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] -= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUBBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] - + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + case ScatterOps::MUL: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] *= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIV: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] /= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIVBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] / + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + } + } +}; + +} // namespace scatter +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/selected_rows_functor.h b/lite/backends/x86/math/selected_rows_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..fc3636e1e6adb4aaf812deba4131a9cbff5cbdc4 --- /dev/null +++ b/lite/backends/x86/math/selected_rows_functor.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include + +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/core/context.h" +#include "lite/fluid/eigen.h" +#include "lite/fluid/selected_rows.h" + +#define INLINE_FOR2(sizei, sizej) \ + for (int64_t i = 0; i < sizei; i++) \ + for (int64_t j = 0; j < sizej; j++) + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +template +struct SelectedRowsAdd { + void operator()(const lite::Context& context, + const fluid::SelectedRows& input1, + const fluid::SelectedRows& input2, + fluid::SelectedRows* output); +}; + +template +struct SelectedRowsAddTensor { + void operator()(const lite::Context& context, + const fluid::SelectedRows& input1, + const lite::Tensor& input2, + lite::Tensor* output); +}; + +// input2 = input1 + input2 +template +struct SelectedRowsAddTo { + void operator()(const lite::Context& context, + const fluid::SelectedRows& input1, + const int64_t input2_offset, + fluid::SelectedRows* input2); +}; + +// input2 = [all input in input1] + input2 +template +struct SelectedRowsSumTo { + void operator()(const lite::Context& context, + const std::vector& input1, + const std::vector& input2_offsets, + fluid::SelectedRows* input2); +}; + +// FIXME: The result of SelectedRowsAddToTensor maybe non deterministic, +// because it uses CudaAtomicAdd. +// input2 = input1 + input2 +template +struct SelectedRowsAddToTensor { + void operator()(const lite::Context& context, + const fluid::SelectedRows& input1, + lite::Tensor* input2); +}; + +namespace scatter { +// functors for manuplating SelectedRows data +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + fluid::SelectedRows operator()(const lite::Context& context, + const fluid::SelectedRows& input, + const bool sorted_result = false); + void operator()(const lite::Context& context, + const fluid::SelectedRows& input, + fluid::SelectedRows* output, + const bool sorted_result = false); + void operator()(const lite::Context& context, + const std::vector& inputs, + fluid::SelectedRows* output, + const bool sorted_result = false); +}; + +enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; + +// out = selected_rows_in / tensor +template +struct UpdateToTensor { + void operator()(const lite::Context& context, + const ScatterOps& op, + const fluid::SelectedRows& input1, + lite::Tensor* input2); +}; + +} // namespace scatter +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/sequence2batch.h b/lite/backends/x86/math/sequence2batch.h index 807558e9d82add3dffaa34ee880390c73e1e8112..a97bfaf66607e5ea2efbd6f26f311fb4cd9dab67 100644 --- a/lite/backends/x86/math/sequence2batch.h +++ b/lite/backends/x86/math/sequence2batch.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "lite/core/context.h" #include "lite/core/tensor.h" #include "lite/fluid/eigen.h" -#include "lite/fluid/lod.h" +// #include "lite/fluid/lod.h" #include "lite/utils/paddle_enforce.h" namespace paddle { diff --git a/lite/backends/x86/math/softmax_impl.h b/lite/backends/x86/math/softmax_impl.h index c8432c42cc52c7b656397f8ac09bed99b1957732..ae997a8680b9012435d80b4aa9f592a775e62e85 100644 --- a/lite/backends/x86/math/softmax_impl.h +++ b/lite/backends/x86/math/softmax_impl.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "lite/backends/x86/cpu_info.h" #include "lite/backends/x86/jit/helper.h" #include "lite/backends/x86/jit/kernel_base.h" +#include "lite/backends/x86/jit/kernels.h" #include "lite/backends/x86/math/cpu_vec.h" #include "lite/core/tensor.h" #include "lite/fluid/eigen.h" diff --git a/lite/backends/x86/math/tree2col.cc b/lite/backends/x86/math/tree2col.cc index 8a34bebef05e35880548ccc30f7519e2a84dbd1a..20b913331308c8b8c95d190b6b0b3d76ccac354b 100644 --- a/lite/backends/x86/math/tree2col.cc +++ b/lite/backends/x86/math/tree2col.cc @@ -107,7 +107,8 @@ class Tree2ColFunctor { // patch->mutable_data({static_cast(patch_size), // static_cast(patch_elem_size)}, // cpu_place); - patch->Resize({static_cast(patch_size, patch_elem_size)}); + patch->Resize({static_cast(patch_size), + static_cast(patch_elem_size)}); auto *patch_data = patch->mutable_data(lite::TargetType::kX86); constant(context, patch, 0); const T *features = node_features.data(); diff --git a/lite/backends/xpu/CMakeLists.txt b/lite/backends/xpu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f911f8e0e7c61481e1d4e309bc0635718be11206 --- /dev/null +++ b/lite/backends/xpu/CMakeLists.txt @@ -0,0 +1,6 @@ +if(NOT LITE_WITH_XPU) + return() +endif() + +lite_cc_library(xpu_runtime SRCS runtime.cc DEPS ${xpu_runtime_libs}) +lite_cc_library(xpu_builder SRCS builder.cc DEPS ${xpu_builder_libs} xpu_runtime tensor op scope) diff --git a/lite/backends/xpu/builder.cc b/lite/backends/xpu/builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..796eaf9c46ceb3d29f1ffdc4c86ac45509f07ba1 --- /dev/null +++ b/lite/backends/xpu/builder.cc @@ -0,0 +1,189 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include // NOLINT +#include +#include "lite/backends/xpu/runtime.h" + +namespace paddle { +namespace lite { +namespace xpu { + +bool HasInputArg(const OpInfo* op_info, + const Scope* scope, + const std::string& argname) { + auto iarg_names = op_info->input_argnames(); + if (std::find(iarg_names.begin(), iarg_names.end(), argname) != + iarg_names.end()) { + auto inputs = op_info->Input(argname); + if (inputs.empty()) { + return false; + } + auto var_name = inputs.front(); + auto var = scope->FindVar(var_name); + return var != nullptr; + } else { + return false; + } +} + +std::string UniqueName(const std::string& prefix) { + static std::mutex counter_mtx; + static std::unordered_map counter_map; + std::unique_lock counter_lck(counter_mtx); + int counter = 1; + auto it = counter_map.find(prefix); + if (it == counter_map.end()) { + counter_map[prefix] = counter; + } else { + counter = ++(it->second); + } + return prefix + "_" + std::to_string(counter); +} + +xtcl::DataType CvtPrecisionType(PrecisionType in_type) { + xtcl::DataType out_type = ::xtcl::Float(32); + switch (in_type) { + case PRECISION(kFloat): + out_type = ::xtcl::Float(32); + break; + case PRECISION(kInt8): + out_type = ::xtcl::Int(8); + break; + case PRECISION(kInt32): + out_type = ::xtcl::Int(32); + break; + default: + LOG(FATAL) << "Can not convert precision type(" << PrecisionToStr(in_type) + << ") from Lite to XPU"; + break; + } + return out_type; +} + +DLDataType CvtDataType(PrecisionType in_type) { + DLDataType out_type = {kDLFloat, 32, 1}; + switch (in_type) { + case PRECISION(kFloat): + out_type = {kDLFloat, 32, 1}; + break; + case PRECISION(kInt8): + out_type = {kDLInt, 8, 1}; + break; + case PRECISION(kInt32): + out_type = {kDLInt, 32, 1}; + break; + default: + LOG(FATAL) << "Can not convert data type(" << PrecisionToStr(in_type) + << ") from Lite to XPU"; + break; + } + return out_type; +} + +xtcl::Array CvtShape(const std::vector& in_shape) { + xtcl::Array out_shape; + for (auto dim : in_shape) { + out_shape.push_back(dim); + } + return out_shape; +} + +xtcl::Array CvtShape(const std::vector& in_shape) { + return CvtShape(std::vector(in_shape.begin(), in_shape.end())); +} + +xtcl::Array CvtShape(const DDim& in_dims) { + return CvtShape(in_dims.Vectorize()); +} + +std::shared_ptr CvtTensor(lite::Tensor* in_tensor, + std::vector out_shape, + PrecisionType in_ptype, + DataLayoutType in_ltype) { + uint8_t* in_data = nullptr; + auto in_size = in_tensor->dims().production(); + auto in_shape = in_tensor->dims().Vectorize(); + if (out_shape.empty()) { + out_shape = in_shape; + } + int in_bytes; + if (in_ptype == PRECISION(kFloat)) { + in_data = reinterpret_cast(in_tensor->mutable_data()); + in_bytes = in_size * sizeof(float); + } else if (in_ptype == PRECISION(kInt32)) { + in_data = reinterpret_cast(in_tensor->mutable_data()); + in_bytes = in_size * sizeof(int32_t); + } else if (in_ptype == PRECISION(kInt8)) { + in_data = reinterpret_cast(in_tensor->mutable_data()); + in_bytes = in_size * sizeof(int8_t); + } else { + LOG(FATAL) << "Unknow precision type " << PrecisionToStr(in_ptype); + } + auto out_tensor = std::make_shared( + xtcl::xNDArray::Empty(out_shape, CvtDataType(in_ptype), {kDLCPU, 0})); + auto out_data = + reinterpret_cast(out_tensor->ToDLPack()->dl_tensor.data); + std::memcpy(out_data, in_data, in_bytes); + return out_tensor; +} + +// Build the XPU subgraph to the XPU model, store the model data into the +// weight tensor of the graph op, and the model data will be loaded again +// by the graph computing kernel when the graph op is executed for inference. +// Due to the lack of XPU APIs for building and outputing the model data, +// the compiled XPU runtime object will be managed by the global variable +// 'DeviceInfo' and the key name for finding the runtime object will be +// stored in the weight tensor of graph op. +// TODO(hong19860320) Compile the XPU subgraph and output the compiled model +// data to the weight tensor of graph op. +bool BuildModel( + std::shared_ptr builder, + std::shared_ptr params, + std::vector>* outputs, + lite::Tensor* model) { + LOG(INFO) << "[XPU] Build Model."; + CHECK(builder != nullptr); + CHECK(outputs != nullptr); + CHECK_GT(outputs->size(), 0); + CHECK(model != nullptr); + + // build graph and fill all of constant params + xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0])); + auto target = xtcl::Target::Create("llvm"); + auto compiler = xtcl::network::xTensorCompiler(network, target); + compiler.SetParams(*params); // set the data of constant tensors + compiler.Build(); + + // create and register runtime + auto runtime = std::make_shared( + compiler.CreateRuntimeInstance()); + if (runtime == nullptr) { + LOG(WARNING) << "[XPU] Build Model failed!"; + return false; + } + std::string name = UniqueName("xpu"); + LOG(INFO) << "[XPU] Model Name: " << name; + DeviceInfo::Global().Insert(name, runtime); + model->Resize({static_cast(name.length() + 1)}); + memcpy(model->mutable_data(), + reinterpret_cast(name.c_str()), + name.length() + 1); + return true; +} + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/builder.h b/lite/backends/xpu/builder.h new file mode 100644 index 0000000000000000000000000000000000000000..f0ac2b303aac7fa7f827e6e2f8f0fdf614b604b5 --- /dev/null +++ b/lite/backends/xpu/builder.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace xpu { + +bool HasInputArg(const OpInfo* op_info, + const Scope* scope, + const std::string& argname); + +std::string UniqueName(const std::string& prefix); + +xtcl::DataType CvtPrecisionType(PrecisionType in_type); + +DLDataType CvtDataType(PrecisionType in_type); + +xtcl::Array CvtShape(const std::vector& in_shape); + +xtcl::Array CvtShape(const std::vector& in_shape); + +xtcl::Array CvtShape(const DDim& in_dims); + +std::shared_ptr CvtTensor( + Tensor* in_tensor, + std::vector out_shape = {}, + PrecisionType in_ptype = PRECISION(kFloat), + DataLayoutType in_ltype = DATALAYOUT(kNCHW)); + +bool BuildModel( + std::shared_ptr builder, + std::shared_ptr params, + std::vector>* outputs, + lite::Tensor* model); + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/runtime.cc b/lite/backends/xpu/runtime.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2c34b95758e8abf81c8294507d0ca60aad7c021 --- /dev/null +++ b/lite/backends/xpu/runtime.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/runtime.h" +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace xpu { + +// Extract the model data and recover the XPU model for inference, the function +// is called by the graph computing kernel when the graph op is executed. +// Due to the lack of XPU APIs for loading and recovering the XPU model from +// memory, the key name is obtained from the weight tensor of graph op, to get +// the runtime object for inference from the global variable 'DeviceInfo'. +// TODO(hong19860320) Recover the XPU model from the weight tensor of graph op. +bool LoadModel(const lite::Tensor &model, + std::shared_ptr *runtime) { + LOG(INFO) << "[XPU] Load Model."; + CHECK_GT(model.dims().production(), 0); + std::string name(reinterpret_cast(model.data())); + LOG(INFO) << "[XPU] Model Name: " << name; + CHECK(runtime != nullptr); + *runtime = DeviceInfo::Global().Find(name); + if (*runtime == nullptr) { + LOG(WARNING) << "[XPU] Load Model failed!"; + return false; + } + return true; +} + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/runtime.h b/lite/backends/xpu/runtime.h new file mode 100644 index 0000000000000000000000000000000000000000..4ff8d75bce6156d51a4988d427058da34460443f --- /dev/null +++ b/lite/backends/xpu/runtime.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace xpu { + +class DeviceInfo { + public: + static DeviceInfo& Global() { + static DeviceInfo x; + return x; + } + DeviceInfo() {} + + void Insert(const std::string& name, + std::shared_ptr runtime) { + if (runtimes_.find(name) != runtimes_.end()) { + LOG(WARNING) << "[XPU] Model " << name << " already exists."; + return; + } + runtimes_.emplace(std::make_pair(name, runtime)); + } + + void Clear() { runtimes_.clear(); } + + std::shared_ptr Find( + const std::string& name) const { + if (runtimes_.find(name) != runtimes_.end()) { + return runtimes_.at(name); + } else { + return nullptr; + } + } + + private: + int device_id_{0}; + std::string device_name_{"default"}; + std::unordered_map> + runtimes_; +}; + +bool LoadModel(const lite::Tensor& model, + std::shared_ptr* runtime); + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index 19d973fc1e3c19f94c32e4c8f5390b8a4916f1c0..5eecf1d815d30fe0ef10a55c6b6b351795fe63ae 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -33,31 +33,73 @@ lite_cc_library(scope SRCS scope.cc DEPS tensor) lite_cc_library(device_info SRCS device_info.cc DEPS tensor) if (LITE_WITH_ARM) -lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS ${npu_ddk_libs}) +lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS npu_runtime) else() -lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags) +lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags XPU_DEPS xpu_runtime) endif() +#-------------------------------------------- GET CODE META INFO ------------------------------------------ +execute_process( + COMMAND git describe --tags --exact-match + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE PADDLE_LITE_TAG + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +execute_process( + COMMAND git rev-parse --abbrev-ref HEAD + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE PADDLE_LITE_BRANCH + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +execute_process( + COMMAND git log -1 --format=%h + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE PADDLE_LITE_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +message(STATUS "tag: ${PADDLE_LITE_TAG}") +message(STATUS "branch: ${PADDLE_LITE_BRANCH}") +message(STATUS "commit: ${PADDLE_LITE_COMMIT}") + +configure_file(version.h.in version.h) #----------------------------------------------- NOT CHANGE ----------------------------------------------- # A trick to generate the paddle_use_kernels.h add_custom_command( COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/parse_kernel_registry.py ${kernels_src_list} ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h - OUTPUT ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h + "${LITE_OPTMODEL_DIR}/.tailored_kernels_list" + LITE_BUILD_TAILOR + OUTPUT kernels.h # not a real path to the output to force it execute every time. ) # A trick to generate the paddle_use_ops.h add_custom_command( COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/parse_op_registry.py ${ops_src_list} ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h - OUTPUT ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h + "${LITE_OPTMODEL_DIR}/.tailored_ops_list" + LITE_BUILD_TAILOR + OUTPUT ops.h # not a real path to the output to force it execute every time. ) -add_custom_target(op_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h) -add_custom_target(kernel_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h) +# generate fake kernels for memory_optimize_tool +add_custom_command( + COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/create_fake_kernel_registry.py + ${kernels_src_list} + ${CMAKE_BINARY_DIR}/all_kernel_faked.cc + ${CMAKE_BINARY_DIR}/kernel_src_map.h + OUTPUT all_kernel_faked.cc # not a real path to the output to force it execute every time. + ) +add_custom_target(op_list_h DEPENDS ops.h) +add_custom_target(kernel_list_h DEPENDS kernels.h) +add_custom_target(all_kernel_faked_cc DEPENDS all_kernel_faked.cc) #----------------------------------------------- NOT CHANGE ----------------------------------------------- -lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor +lite_cc_library(kernel SRCS kernel.cc + DEPS context type_system target_wrapper any op_params tensor + PROFILE_DEPS basic_profiler ) lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel cpp_op_desc tensor diff --git a/lite/core/arena/CMakeLists.txt b/lite/core/arena/CMakeLists.txt index 854d2f4172544a4170d19cd8a77fd2ea8fc4753e..bc77afd81e0859b9492b2068ce681098a9393923 100644 --- a/lite/core/arena/CMakeLists.txt +++ b/lite/core/arena/CMakeLists.txt @@ -3,8 +3,8 @@ if(NOT WITH_TESTING) return() endif() -lite_cc_library(arena_framework SRCS framework.cc DEPS program) +lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest) -if(NOT LITE_WITH_OPENCL AND (LITE_WITH_X86 OR LITE_WITH_ARM)) +if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 48a8571a199292a850ec5e0a1b379f3fa1df5882..412ac0c167b8abe6d196dc25d1bc5b193d02965d 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -17,6 +17,7 @@ #include #include #include // NOLINT +#include #include #include #include @@ -41,6 +42,7 @@ class TestCase { : place_(place), scope_(new Scope), alias_(alias) { ctx_ = ContextScheduler::Global().NewContext(place_.target); } + virtual ~TestCase() {} void Prepare() { PrepareScopes(); @@ -137,20 +139,18 @@ class TestCase { } private: + Place place_; std::shared_ptr scope_; + std::string alias_; // The workspace for the Instruction. Scope* inst_scope_{}; // The workspace for the baseline implementation. Scope* base_scope_{}; std::unique_ptr op_desc_; std::unique_ptr instruction_; - Place place_; - std::string alias_; }; class Arena { - float abs_error_{}; - public: Arena(std::unique_ptr&& tester, const Place& place, @@ -202,12 +202,14 @@ class Arena { default: LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); + return false; } } private: std::unique_ptr tester_; Place place_; + float abs_error_; }; template diff --git a/lite/core/context.h b/lite/core/context.h index 4109c3333410604f03eaf3818adf183ff407a26f..68d8e29329a7ca8a8a257b46ddac6b13485879da 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -26,7 +26,10 @@ #include "lite/backends/opencl/cl_runtime.h" #endif #ifdef LITE_WITH_NPU -#include "lite/backends/npu/npu_helper.h" +#include "lite/backends/npu/runtime.h" +#endif +#ifdef LITE_WITH_XPU +#include "lite/backends/xpu/runtime.h" #endif #include @@ -55,6 +58,7 @@ using X86Context = Context; using CUDAContext = Context; using ARMContext = Context; using NPUContext = Context; +using XPUContext = Context; using OpenCLContext = Context; using FPGAContext = Context; @@ -81,9 +85,20 @@ class Context { NPUContext& operator=(const NPUContext& ctx) {} std::string name() const { return "NPUContext"; } - hiai::AiModelMngerClient* client(const std::string& model_name) const { - return npu::DeviceInfo::Global().client(model_name); - } +}; +#endif + +#ifdef LITE_WITH_XPU +template <> +class Context { + public: + Context() {} + explicit Context(const NPUContext& ctx); + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + void CopySharedTo(XPUContext* ctx) {} + + std::string name() const { return "XPUContext"; } }; #endif @@ -160,9 +175,9 @@ class Context { cublas_fp32_ = std::make_shared>(); } void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) { - CHECK_GT(devs.size(), 0) + CHECK_GT(devs.size(), 0UL) << "Env is not initialized or current target is not exit!"; - if (dev_id >= devs.size()) { + if (dev_id >= static_cast(devs.size())) { LOG(WARNING) << "device index exceeds the number of devices, set to " "default device(0)!"; device_id_ = 0; @@ -192,10 +207,10 @@ class Context { ctx->cublas_fp32_ = cublas_fp32_; } - const cudaStream_t exec_stream() { return exec_stream_; } + const cudaStream_t& exec_stream() const { return exec_stream_; } void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; } - const cudaStream_t io_stream() { return io_stream_; } + const cudaStream_t& io_stream() const { return io_stream_; } void SetIoStream(cudaStream_t stream) { io_stream_ = stream; } std::shared_ptr> cublas_fp32() { return cublas_fp32_; } @@ -240,8 +255,6 @@ class Context { public: Context() {} - Context(Context&& ctx) {} - // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} @@ -261,7 +274,7 @@ template <> class Context { std::shared_ptr cl_context_; using WaitListType = - std::unordered_map(nullptr)), + std::unordered_map(nullptr)), std::shared_ptr>; std::shared_ptr cl_wait_list_; @@ -343,6 +356,12 @@ class ContextScheduler { &ctx->As()); break; #endif +#ifdef LITE_WITH_XPU + case TARGET(kXPU): + kernel_contexts_[TargetType::kXPU].As().CopySharedTo( + &ctx->As()); + break; +#endif #ifdef LITE_WITH_OPENCL case TARGET(kOpenCL): kernel_contexts_[TargetType::kOpenCL].As().CopySharedTo( @@ -356,7 +375,10 @@ class ContextScheduler { break; #endif default: +#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL LOG(FATAL) << "unsupported target " << TargetToStr(target); +#endif + break; } return ctx; } @@ -386,6 +408,9 @@ class ContextScheduler { #endif #ifdef LITE_WITH_NPU InitContext(); +#endif +#ifdef LITE_WITH_XPU + InitContext(); #endif } diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index de53d9ba6735c622f041767579ed4a079139f828..896f6c8d33a8665c4c94786dd08af1a097942608 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -35,6 +35,9 @@ #include #include #endif +#ifdef LITE_WITH_ANDROID +#include +#endif #if __APPLE__ #include "TargetConditionals.h" #if LITE_WITH_IPHONE @@ -218,6 +221,7 @@ void get_cpu_arch(std::vector* archs, const int cpu_num) { #ifdef LITE_WITH_LINUX std::string get_cpu_name() { + std::string cpu_name; FILE* fp = fopen("/proc/cpuinfo", "rb"); if (!fp) { return ""; @@ -229,73 +233,93 @@ std::string get_cpu_name() { break; } if (strstr(line, "Hardware") != NULL) { - fclose(fp); - return std::string(line); + cpu_name = std::string(line); } } +#ifdef LITE_WITH_ANDROID + // cpu name concat board name, platform name and chip name + char board_name[128]; + char platform_name[128]; + char chip_name[128]; + __system_property_get("ro.product.board", board_name); + __system_property_get("ro.board.platform", platform_name); + __system_property_get("ro.chipname", chip_name); + cpu_name = + cpu_name + "_" + board_name + "_" + platform_name + "_" + chip_name; +#endif + std::transform(cpu_name.begin(), cpu_name.end(), cpu_name.begin(), ::toupper); fclose(fp); - return ""; + return cpu_name; } -void get_cpu_max_min_freq(int cpu_id, int* max_freq, int* min_freq) { - *max_freq = 0; - *min_freq = 0; +int get_min_freq_khz(int cpuid) { + // first try, for all possible cpu + char path[256]; + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", + cpuid); + FILE* fp = fopen(path, "rb"); + if (!fp) { + return -1; + } + + int min_freq_khz = -1; + fscanf(fp, "%d", &min_freq_khz); + fclose(fp); + return min_freq_khz; +} + +int get_max_freq_khz(int cpuid) { // first try, for all possible cpu char path[256]; snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", - cpu_id); + cpuid); + FILE* fp = fopen(path, "rb"); if (!fp) { // second try, for online cpu snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", - cpu_id); + cpuid); fp = fopen(path, "rb"); - if (!fp) { - // third try, for online cpu - // get max_freq - snprintf(path, - sizeof(path), - "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", - cpu_id); - fp = fopen(path, "rb"); - if (!fp) { - return; + } + + int max_freq_khz = 0; + if (fp) { + while (!feof(fp)) { + int freq_khz = 0; + int nscan = fscanf(fp, "%d %*d", &freq_khz); + if (nscan != 1) { + break; } - fscanf(fp, "%d", max_freq); - fclose(fp); - // get min_freq - snprintf(path, - sizeof(path), - "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_min_freq", - cpu_id); - fp = fopen(path, "rb"); - if (!fp) { - return; + + if (freq_khz > max_freq_khz) { + max_freq_khz = freq_khz; } - fscanf(fp, "%d", min_freq); - fclose(fp); - return; } } - *min_freq = std::numeric_limits::max(); - while (!feof(fp)) { - int freq = 0; - int nscan = fscanf(fp, "%d %*d", &freq); - if (nscan != 1) { - break; - } - if (freq > *max_freq) { - *max_freq = freq; - } - if (freq < *min_freq) { - *min_freq = freq; + if (max_freq_khz == 0 || !fp) { + // third try, for online cpu + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", + cpuid); + fp = fopen(path, "rb"); + if (!fp) { + return -1; } + int max_freq_khz = -1; + fscanf(fp, "%d", &max_freq_khz); + fclose(fp); + return max_freq_khz; } + fclose(fp); + return max_freq_khz; } void sort_cpuid_by_max_freq(const std::vector& max_freqs, @@ -771,7 +795,9 @@ bool DeviceInfo::SetCPUInfoByName() { cluster_ids_ = {0, 0, 0, 0}; SetArchInfo(1, kA53); return true; - } else if (dev_name_.find("KIRIN980") != std::string::npos) { // Kirin 980 + } else if (dev_name_.find("KIRIN980") != std::string::npos || + dev_name_.find("KIRIN990") != + std::string::npos) { // Kirin 980, Kirin 990 core_num_ = 8; core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; big_core_ids_ = {4, 5, 6, 7}; @@ -835,7 +861,7 @@ void DeviceInfo::RequestPowerHighMode(int thread_num) { active_ids_ = big_core_ids_; } else { for (int i = 0; i < thread_num; ++i) { - active_ids_.push_back(big_core_ids_[i]); + active_ids_.push_back(big_core_ids_[big_core_size - 1 - i]); } } } else { @@ -972,8 +998,8 @@ int DeviceInfo::Setup() { #ifdef LITE_WITH_LINUX // get max&min freq for (int i = 0; i < core_num_; ++i) { - int max_freq, min_freq; - get_cpu_max_min_freq(i, &max_freq, &min_freq); + int max_freq = get_max_freq_khz(i); + int min_freq = get_min_freq_khz(i); max_freqs_[i] = max_freq / 1000; min_freqs_[i] = min_freq / 1000; } @@ -982,13 +1008,6 @@ int DeviceInfo::Setup() { if (!SetCPUInfoByName()) { SetCPUInfoByProb(); } - core_ids_.resize(core_num_); - cluster_ids_.resize(core_num_); - for (int i = 0; i < core_num_; ++i) { - max_freqs_[i] = 1000000; - min_freqs_[i] = 1000000; - cluster_ids_[i] = 0; - } #else #ifdef TARGET_IOS dev_name_ = "Apple"; @@ -1102,13 +1121,14 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) { SetCacheInfo(0, 1, l1size); SetCacheInfo(1, 1, l2size); SetCacheInfo(2, 1, l3size); - workspace_.Resize({2 * (l1size + l2size)}); + workspace_.Resize({llc_size()}); + workspace_.mutable_data(); } -bool DeviceInfo::ExtendWorkspace(int size) { - workspace_.Resize({size + llc_size()}); - workspace_.mutable_data(); - return true; +bool DeviceInfo::ExtendWorkspace(size_t size) { + workspace_.Resize( + {static_cast(size + static_cast(llc_size()))}); + return workspace_.mutable_data() != nullptr; } #endif // LITE_WITH_ARM diff --git a/lite/core/device_info.h b/lite/core/device_info.h index 96f46801351aafe8eaf7388412bbc31200ee02e2..81c0ac4bf9a9a134de448efa92ac0cb2c1a06454 100644 --- a/lite/core/device_info.h +++ b/lite/core/device_info.h @@ -73,7 +73,7 @@ class DeviceInfo { T* workspace_data() { return reinterpret_cast(workspace_.mutable_data()); } - bool ExtendWorkspace(int size); + bool ExtendWorkspace(size_t size); private: int core_num_; @@ -167,7 +167,7 @@ class Device { int id() { return idx_; } int max_stream() { return max_stream_; } - int SetId(int idx) { idx_ = idx; } + void SetId(int idx) { idx_ = idx; } std::string name() { return device_prop_.name; } int core_num() { return device_prop_.multiProcessorCount; } float max_memory() { return device_prop_.totalGlobalMem / 1048576.; } @@ -186,8 +186,8 @@ class Device { void GetInfo(); private: - int max_stream_; int idx_{0}; + int max_stream_; cudaDeviceProp device_prop_; std::string device_name_; float max_memory_; diff --git a/lite/core/framework.proto b/lite/core/framework.proto index 6c60a041a191f1db4a755c1c5714724342053791..5adf2a18b98c2a2d3e2f6e8f7dd5688150674dc6 100644 --- a/lite/core/framework.proto +++ b/lite/core/framework.proto @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; -// option optimize_for = LITE_RUNTIME; +option optimize_for = LITE_RUNTIME; package paddle.framework.proto; // Any incompatible changes to ProgramDesc and its dependencies should @@ -166,6 +166,9 @@ message VarDesc { required string name = 1; required VarType type = 2; optional bool persistable = 3 [ default = false ]; + // True if the variable is an input data and + // have to check the feed data shape and dtype + optional bool need_check_feed = 4 [ default = false ]; } message BlockDesc { @@ -176,13 +179,39 @@ message BlockDesc { optional int32 forward_block_idx = 5 [ default = -1 ]; } +// CompatibleInfo is used to determine if a feature is compatible and +// provides the information. +message CompatibleInfo { + enum Type { + COMPATIBLE = 0; + DEFINITELY_NOT = 1; + POSSIBLE = 2; + BUG_FIX = 3; + PRECISION_CHANGE = 4; + } + required string version = 1; + required Type type = 2; +} + +// In some cases, Paddle Fluid may perform operator definition iterations, +// and the operator uses OpCompatibleMap for compatibility testing. +message OpCompatibleMap { + message OpCompatiblePair { + required string op_name = 1; + required CompatibleInfo compatible_info = 2; + } + repeated OpCompatiblePair pair = 1; + optional string default_required_version = 2; +} + // Please refer to // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md // for more details. // TODO(panyx0718): A model can have multiple programs. Need a // way to distinguish them. Maybe ID or name? message ProgramDesc { + reserved 2; // For backward compatibility. repeated BlockDesc blocks = 1; - - optional Version version = 2; + optional Version version = 4; + optional OpCompatibleMap op_compatible_map = 3; } diff --git a/lite/core/kernel.h b/lite/core/kernel.h index 92eca6af54991da45f81954399202d8627fb16a2..05d7a6b333810a8dc988d84a281f096babe8929f 100644 --- a/lite/core/kernel.h +++ b/lite/core/kernel.h @@ -30,6 +30,10 @@ #include "lite/utils/all.h" #include "lite/utils/replace_stl/stream.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/basic_profiler.h" +#endif // LITE_WITH_PROFILE + namespace paddle { namespace lite { @@ -43,20 +47,29 @@ class KernelBase { const std::map& input_types, const std::string& out_arg)>; - protected: /// Run some initialization before `Run`, it will invoke after `SetParam` and /// `SetContext`, that is both the param_ and context_ are valid. virtual void PrepareForRun() {} + /// Run kernel initialization if needed at every run (eg. input shape changed) + virtual void ReInitWhenNeeded() {} + /// Run the kernel. Before Run, both the param_ and context_ should be valid. virtual void Run() = 0; - public: +#ifdef LITE_WITH_PROFILE + void SetProfileID(uint32_t id) { profile_id_ = id; } +#endif + void Launch() { + /// First run, init kernel, do weights transform once if (is_first_epoch_) { PrepareForRun(); is_first_epoch_ = false; } + /// re-init the kernel if needed (input shape should be checked in conv + /// kernel) + ReInitWhenNeeded(); // Reset the workspace to make every kernel in the same thread to share the // temporary memory. @@ -67,7 +80,15 @@ class KernelBase { #if defined(LITE_WITH_CUDA) WorkSpace::Global_CUDA().AllocReset(); #endif + +#ifdef LITE_WITH_PROFILE + if (profile_id_ >= 0) { + profile::ProfileBlock x(profile_id_, "kernel"); + Run(); + } +#else Run(); +#endif } void SetContext(std::unique_ptr&& ctx) { @@ -152,6 +173,10 @@ class KernelBase { // is the unique ID for the kernel. std::string alias_{}; bool is_first_epoch_{true}; + +#ifdef LITE_WITH_PROFILE + int profile_id_{-1}; +#endif }; // Light-weight kernel implementation. diff --git a/lite/core/memory.cc b/lite/core/memory.cc index 463e10b9f9a4df2cb23cc176ccb7923e014eda60..b3cb18b33630de6615812471e1acaab59c8e99b0 100644 --- a/lite/core/memory.cc +++ b/lite/core/memory.cc @@ -105,5 +105,23 @@ void TargetCopy(TargetType target, void* dst, const void* src, size_t size) { } } +#ifdef LITE_WITH_OPENCL +void TargetCopyImage2D(TargetType target, + void* dst, + const void* src, + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch) { + TargetWrapperCL::ImgcpySync(dst, + src, + cl_image2d_width, + cl_image2d_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoD); +} +#endif + } // namespace lite } // namespace paddle diff --git a/lite/core/memory.h b/lite/core/memory.h index 31d7fd34e1cff5fed562996137be059172ff8151..cb4ac044e7af6994e5e404f379eeb12290e34778 100644 --- a/lite/core/memory.h +++ b/lite/core/memory.h @@ -38,6 +38,15 @@ void LITE_API TargetFree(TargetType target, void* data); // Copy a buffer from host to another target. void TargetCopy(TargetType target, void* dst, const void* src, size_t size); +#ifdef LITE_WITH_OPENCL +void TargetCopyImage2D(TargetType target, + void* dst, + const void* src, + const size_t cl_image2d_width, + const size_t cl_image2d_height, + const size_t cl_image2d_row_pitch, + const size_t cl_image2d_slice_pitch); +#endif // LITE_WITH_OPENCL template void CopySync(void* dst, const void* src, size_t size, IoDirection dir) { @@ -87,6 +96,25 @@ class Buffer { void ResizeLazy(size_t size) { ResetLazy(target_, size); } +#ifdef LITE_WITH_OPENCL + template + void ResetLazyImage2D(TargetType target, + const size_t img_w, + const size_t img_h) { + size_t size = sizeof(T) * img_w * img_h * + 4; // 4 for RGBA, un-used for opencl Image2D + if (target != target_ || cl_image2d_width_ < img_w || + cl_image2d_height_ < img_h) { + Free(); + data_ = TargetWrapperCL::MallocImage(img_w, img_h); + target_ = target; + space_ = size; // un-used for opencl Image2D + cl_image2d_width_ = img_w; + cl_image2d_height_ = img_h; + } + } +#endif + void Free() { if (space_ > 0) { TargetFree(target_, data_); @@ -107,6 +135,8 @@ class Buffer { private: // memory it actually malloced. size_t space_{0}; + size_t cl_image2d_width_{0}; // only used for OpenCL Image2D + size_t cl_image2d_height_{0}; // only used for OpenCL Image2D void* data_{nullptr}; TargetType target_{TargetType::kHost}; }; diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index d96a67f52e917962bd07a6207e4343014ca9f0c6..a44b8348716449519486d37f6784e31ecc39f554 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -1,6 +1,6 @@ lite_cc_library(mir_node SRCS node.cc DEPS kernel) lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program) -lite_cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) +lite_cc_library(mir_pass SRCS pass.cc pass_utils.cc DEPS mir_ssa_graph) lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) @@ -13,6 +13,7 @@ lite_cc_library(mir_passes fusion/fc_fuse_pass.cc fusion/shuffle_channel_fuse_pass.cc fusion/transpose_softmax_transpose_fuse_pass.cc + fusion/interpolate_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc fusion/conv_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc @@ -30,6 +31,7 @@ lite_cc_library(mir_passes argument_type_display_pass.cc demo_pass.cc runtime_context_assign_pass.cc + memory_optimize_pass.cc DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS diff --git a/lite/core/mir/argument_type_display_pass.cc b/lite/core/mir/argument_type_display_pass.cc index d53d705a2d772371a1cbd7a01db60ff90498c4be..2ed63b360c955b53eaa37af2f1e4832d0f88fd03 100644 --- a/lite/core/mir/argument_type_display_pass.cc +++ b/lite/core/mir/argument_type_display_pass.cc @@ -42,4 +42,5 @@ class ArgumentTypeDisplayPass : public DebugPass { } // namespace paddle REGISTER_MIR_PASS(argument_type_display_pass, - paddle::lite::mir::ArgumentTypeDisplayPass); + paddle::lite::mir::ArgumentTypeDisplayPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/demo_pass.cc b/lite/core/mir/demo_pass.cc index 837a5a1cbcce42bb2ce97767aa982263cd0228c6..0e0858332c9d10382d71fe7b50b3b2beb6ac257b 100644 --- a/lite/core/mir/demo_pass.cc +++ b/lite/core/mir/demo_pass.cc @@ -34,4 +34,5 @@ bool RegisterDemoPass() { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass); +REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc index 07d8dfd3f5b637bb8e1a49a4f538bc285496f30a..acea48c742522d5b6b5f1f3b570fcbfe0c4be08d 100644 --- a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -69,4 +69,5 @@ class IdentityScaleEliminatePass : public ProgramPass { } // namespace paddle REGISTER_MIR_PASS(identity_scale_eliminate_pass, - paddle::lite::mir::IdentityScaleEliminatePass); + paddle::lite::mir::IdentityScaleEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 92421a2cf8fdaa34134faa04f005260893f139a9..5ac52837551f0b78d67dfe1733fe354ee2cf7f01 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -22,6 +22,9 @@ lite_cc_library(fuse_quant_dequant lite_cc_library(fuse_transpose_softmax_transpose SRCS transpose_softmax_transpose_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_interpolate + SRCS interpolate_fuser.cc + DEPS pattern_matcher_high_api) set(mir_fusers fuse_fc @@ -32,6 +35,7 @@ set(mir_fusers fuse_quant_dequant fuse_elementwise_add_activation fuse_transpose_softmax_transpose + fuse_interpolate CACHE INTERNAL "fusers") if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index cad98cb26c2c7a16cfc8a02538f8e4cd9fbd6db3..ff064fb2ee93fc540e932da36fb07bb78eef989a 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -23,11 +23,21 @@ namespace lite { namespace mir { void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { - fusion::ConvActivationFuser fuser("conv2d", "relu"); - fuser(graph.get()); - - fusion::ConvActivationFuser depthwise_fuser("depthwise_conv2d", "relu"); - depthwise_fuser(graph.get()); + std::vector act_types{"relu"}; + for (auto& place : graph->valid_places()) { + if (place.target == TARGET(kCUDA)) { + act_types.push_back("leaky_relu"); + break; + } + } + for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { + for (auto act_type : act_types) { + for (auto has_bias : {true, false}) { + fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); + fuser(graph.get()); + } + } + } } } // namespace mir @@ -35,4 +45,6 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, - paddle::lite::mir::ConvActivationFusePass); + paddle::lite::mir::ConvActivationFusePass) + .BindTargets({TARGET(kAny)}) + .BindKernel("conv2d"); diff --git a/lite/core/mir/fusion/conv_activation_fuser.cc b/lite/core/mir/fusion/conv_activation_fuser.cc index c49a9ad4f0b375bd1f73cf72a8ff993af3f6c38d..6ba11a6a4e82416eb386ec3b34c71183cef5adcb 100644 --- a/lite/core/mir/fusion/conv_activation_fuser.cc +++ b/lite/core/mir/fusion/conv_activation_fuser.cc @@ -22,35 +22,33 @@ namespace mir { namespace fusion { void ConvActivationFuser::BuildPattern() { - // create input nodes. + // create nodes. auto* input = VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); auto* filter = VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); - auto* bias = - VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput(); - - // create op nodes - auto* conv2d = - OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + PMNode* bias = nullptr; + if (has_bias_) { + bias = VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput(); + } + auto* conv2d = OpNode("conv2d", conv_type_)->AsIntermediate(); - auto* act = - OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); + auto* act = OpNode("act", act_type_)->AsIntermediate(); - // create intermediate nodes auto* conv2d_out = VarNode("conv2d_out") ->assert_is_op_output(conv_type_, "Output") ->assert_is_op_input(act_type_, "X") ->AsIntermediate(); - // create output node auto* out = VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); // create topology. - std::vector conv2d_inputs{filter, input, bias}; - conv2d_inputs >> *conv2d >> *conv2d_out; - *conv2d_out >> *act >> *out; + std::vector conv2d_inputs{filter, input}; + conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; + if (has_bias_) { + *bias >> *conv2d; + } } void ConvActivationFuser::InsertNewNode(SSAGraph* graph, @@ -66,34 +64,25 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph, IR_NODE_LINK_TO(matched.at("input"), new_op_node); IR_NODE_LINK_TO(matched.at("filter"), new_op_node); - IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + if (has_bias_) { + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + } IR_NODE_LINK_TO(new_op_node, matched.at("output")); } cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { - auto* desc = matched.at("conv2d")->stmt()->op_info(); - - cpp::OpDesc op_desc = *desc; - op_desc.SetType(conv_type_); - op_desc.SetInput("Input", {matched.at("input")->arg()->name}); - op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); - op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info(); op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); - // Other inputs. See operators/conv_op.h - std::vector input_arg_names = desc->InputArgumentNames(); + cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info(); - if (std::find(input_arg_names.begin(), - input_arg_names.end(), - "ResidualData") != input_arg_names.end()) { - op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + op_desc.SetAttr("with_act", true); + op_desc.SetAttr("act_type", act_type_); + if (act_type_ == "relu") { + op_desc.SetAttr("fuse_relu", true); + } else if (act_type_ == "leaky_relu") { + float alpha = act_op_desc.GetAttr("alpha"); + op_desc.SetAttr("leaky_relu_alpha", alpha); } - // Only consider strides, padding, groups, dilations, fuse_relu for now - op_desc.SetAttr("strides", desc->GetAttr>("strides")); - op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); - op_desc.SetAttr("groups", desc->GetAttr("groups")); - op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); - // TODO(sangoly): support other activation types - op_desc.SetAttr("fuse_relu", true); return op_desc; } diff --git a/lite/core/mir/fusion/conv_activation_fuser.h b/lite/core/mir/fusion/conv_activation_fuser.h index 3377e28e29a22d32f4fa7ec1c5ad02b509b3e050..d352a32f9f8a7e232acee9a84dcaf23ae5676b55 100644 --- a/lite/core/mir/fusion/conv_activation_fuser.h +++ b/lite/core/mir/fusion/conv_activation_fuser.h @@ -26,10 +26,11 @@ namespace fusion { class ConvActivationFuser : public FuseBase { public: explicit ConvActivationFuser(const std::string& conv_type, - const std::string& act_type) { - CHECK(act_type == "relu") << "Only relu activation be supported now"; + const std::string& act_type, + bool has_bias) { conv_type_ = conv_type; act_type_ = act_type; + has_bias_ = has_bias; } void BuildPattern() override; @@ -39,6 +40,7 @@ class ConvActivationFuser : public FuseBase { cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; std::string conv_type_; std::string act_type_; + bool has_bias_; }; } // namespace fusion diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 954e007a850a1b9fd0afe1f0e157e8fe5aeff621..d9d9c1bbf55bd33c31aa9a22de934d4eae8657c6 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include #include "lite/core/mir/fusion/conv_bn_fuser.h" +#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" namespace paddle { @@ -23,15 +24,25 @@ namespace lite { namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { - fusion::ConvBNFuser fuser("conv2d"); - fuser(graph.get()); + // initialze fuser params + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; - fusion::ConvBNFuser fuser2("depthwise_conv2d"); - fuser2(graph.get()); + // start fuse using params + for (auto conv_has_bias : conv_has_bias_cases) { + for (auto conv_type : conv_type_cases) { + VLOG(4) << "conv_has_bias:" << conv_has_bias + << " conv_type:" << conv_type; + fusion::ConvBNFuser fuser(conv_type, conv_has_bias); + fuser(graph.get()); + } + } } } // namespace mir } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass); +REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) + .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kX86)}); diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 77ad8237fe8108c8b9d19d09bf45b724f6c0ca2d..ec07278eed1f259c45e225497f94d682b544c57c 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -14,6 +14,7 @@ #include "lite/core/mir/fusion/conv_bn_fuser.h" #include +#include #include namespace paddle { @@ -30,7 +31,8 @@ void ConvBNFuser::BuildPattern() { auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); auto* conv_out = VarNode("conv_out") ->assert_is_op_output(conv_type_, "Output") - ->assert_is_op_input("batch_norm", "X"); + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); auto* bn_scale = VarNode("bn_scale") ->assert_is_op_input("batch_norm", "Scale") @@ -61,34 +63,29 @@ void ConvBNFuser::BuildPattern() { ->assert_is_op_output("batch_norm", "SavedVariance") ->AsIntermediate(); - conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + if (conv_has_bias_) { + auto* conv_bias = VarNode("conv_bias") + ->assert_is_op_input(conv_type_, "Bias") + ->AsIntermediate(); + conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out}); + } else { + conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + } bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var}) .LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out}); } void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto op_desc = GenOpDesc(matched); - auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); - auto conv_instruct = matched.at("conv2d")->stmt(); + auto conv_op_desc = conv_instruct->mutable_op_info(); auto conv = conv_instruct->op(); auto* scope = conv->scope(); - auto& valid_places = conv->valid_places(); - - auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) - ->GetMutable(); - auto conv_weight_dims = conv_weight_t->dims(); - size_t weight_num = conv_weight_t->data_size(); + // bn auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name) ->GetMutable(); - size_t bias_size = bn_scale_t->data_size(); auto bn_scale_d = bn_scale_t->mutable_data(); - CHECK_EQ(bias_size, static_cast(conv_weight_dims[0])) - << "The BN bias's size should be equal to the size of the first " - << "dim size of the conv weights"; - auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) ->GetMutable(); auto bn_mean_d = bn_mean_t->mutable_data(); @@ -102,59 +99,103 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto bn_bias_d = bn_bias_t->mutable_data(); auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); - auto conv_op_desc = conv_instruct->mutable_op_info(); - + // conv + auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) + ->GetMutable(); + CHECK_EQ(static_cast(bn_scale_t->data_size()), + static_cast(conv_weight_t->dims()[0])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + size_t weight_num = conv_weight_t->data_size(); bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; + + // comupte BN alpha and beta Tensor alpha_tensor, beta_tensor; alpha_tensor.CopyDataFrom(*bn_bias_t); beta_tensor.CopyDataFrom(*bn_bias_t); auto alpha_data = alpha_tensor.mutable_data(); auto beta_data = beta_tensor.mutable_data(); - int h = bias_size; - int w = weight_num / bias_size; + int h = + bn_scale_t + ->data_size(); // h == bias_size == out channel num of conv weight + int w = weight_num / + (bn_scale_t->data_size()); // w = `conv_weight_num` / bias_size = in + // channel num of conv weight + ComputeAlphaAndBeta( bn_scale_d, bn_mean_d, bn_var_d, alpha_data, beta_data, eps, h, w); + /////////////////////////////////////////////////////////////////////////////// + // Compute ConvBNFuser + // Before fusion + // + // conv(x) = conv(x) = kx + z = y + // bn(y) = ay + b + // + // Note: `alpha_data` is a, `beta_data` is b from `ComputeAlphaAndBeta` + // + // After fusion: + // + // bn(conv(x)) = a(kx + z) + b = akx + az + b + // + // Note: h == bias_size == out channel num of conv weight + // w = `conv_weight_num` / bias_size = in channel num of conv weight + // little difference for int8 + /////////////////////////////////////////////////////////////////////////////// if (enable_int8) { PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"), "INT8 mode: Conv should has weight_scale attr"); + auto conv_weight_d = conv_weight_t->mutable_data(); + // compute new conv_weight for int8 auto weight_scale = conv_op_desc->GetAttr>("weight_scale"); - for (int i = 0; i < h; i++) { - weight_scale[i] *= alpha_data[i]; + for (unsigned int i = 0; i < h; ++i) { + weight_scale[i] *= fabsf(alpha_data[i]); + if (alpha_data[i] < 0.f) { + auto ptr_row = conv_weight_d + i * w; + for (unsigned int j = 0; j < w; ++j) { + ptr_row[j] *= -1; + } + } } - // Interface like this should be abandoned. conv_op_desc->SetAttr("weight_scale", weight_scale); - auto update_conv_desc = *conv_instruct->mutable_op_info(); - conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); } else { + // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); - for (int i = 0; i < h; i++) { - for (int j = 0; j < w; j++) { + for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels + for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels conv_weight_d[i * w + j] *= alpha_data[i]; } } } - for (int i = 0; i < bias_size; i++) { + + // compute new conv_bias + if (conv_has_bias_ && conv_op_desc->HasInput("Bias") && + conv_op_desc->Input("Bias").size() > 0) { + auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name) + ->GetMutable(); + auto conv_bias_d = conv_bias_t->data(); + for (unsigned int i = 0; i < bn_bias_t->data_size(); + ++i) { // bias_size == h == conv2d output channls + bn_bias_d[i] += alpha_data[i] * conv_bias_d[i]; + } + } + for (unsigned int i = 0; i < bn_bias_t->data_size(); ++i) { bn_bias_d[i] += beta_data[i]; } - eltwise_op->Attach(op_desc, scope); - auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places); - - IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node); - IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node); - IR_NODE_LINK_TO(new_op_node, matched.at("bn_out")); -} -cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) { - cpp::OpDesc op_desc; - op_desc.SetType("elementwise_add"); - op_desc.SetInput("X", {matched.at("conv_out")->arg()->name}); - op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name}); - op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name}); - op_desc.SetAttr("axis", 1); - return op_desc; + conv_op_desc->SetType(conv_type_); + conv_op_desc->SetInput("Input", {matched.at("conv_input")->arg()->name}); + conv_op_desc->SetInput("Filter", {matched.at("conv_weight")->arg()->name}); + conv_op_desc->SetOutput("Output", {matched.at("bn_out")->arg()->name}); + conv_op_desc->SetInput("Bias", + {matched.at("bn_bias")->arg()->name}); // conv_bias + auto update_conv_desc = *conv_instruct->mutable_op_info(); + conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("bn_bias"), matched.at("conv2d")); + IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("bn_out")); } } // namespace fusion diff --git a/lite/core/mir/fusion/conv_bn_fuser.h b/lite/core/mir/fusion/conv_bn_fuser.h index 9acf65f9e21be6b3f18b3512acbc1f80b12368d9..8bd8c0ce0600bb68667d96d07d43fa3028b5a856 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.h +++ b/lite/core/mir/fusion/conv_bn_fuser.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include "lite/core/mir/pattern_matcher_high_api.h" @@ -26,12 +27,12 @@ namespace fusion { class ConvBNFuser : public FuseBase { public: - explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {} + explicit ConvBNFuser(const std::string& conv_type, const bool conv_has_bias) + : conv_type_(conv_type), conv_has_bias_(conv_has_bias) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: - cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; void ComputeAlphaAndBeta(float* scale_d, float* mean_d, float* var_d, @@ -50,6 +51,7 @@ class ConvBNFuser : public FuseBase { private: std::string conv_type_{"conv2d"}; + bool conv_has_bias_{false}; }; } // namespace fusion diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc index 57ed845150a01cd3cb85af822ca09196fbcefe83..fd9aadc5d01c2cb3b6c7a3e888503072a0798725 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -23,14 +23,21 @@ namespace lite { namespace mir { void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { - fusion::ConvElementwiseFuser fuser("conv2d"); - fuser(graph.get()); + // initialze fuser params + // note: `true` of conv_has_bias must as first pattern to match + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; - fusion::ConvElementwiseFuser depthwise_fuser("depthwise_conv2d"); - depthwise_fuser(graph.get()); - - fusion::ConvElementwiseFuser conv2d_transpose_fuser("conv2d_transpose"); - conv2d_transpose_fuser(graph.get()); + // start fuse using params + for (auto conv_has_bias : conv_has_bias_cases) { + for (auto conv_type : conv_type_cases) { + VLOG(4) << "conv_has_bias:" << conv_has_bias + << " conv_type:" << conv_type; + fusion::ConvElementwiseFuser fuser(conv_type, conv_has_bias); + fuser(graph.get()); + } + } } } // namespace mir @@ -38,4 +45,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, - paddle::lite::mir::ConvElementwiseFusePass); + paddle::lite::mir::ConvElementwiseFusePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc index c3ab3e4c4ca9bd8d6a6eaaf82e40dcb06cf99ea9..22ec1fa0d22378adf3776c6bb391f50fde376b7a 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -27,12 +27,13 @@ void ConvElementwiseFuser::BuildPattern() { VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); auto* filter = VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); - auto* bias = - VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + auto* bias = VarNode("bias") + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput() + ->assert_is_persistable_var(); // create op nodes - auto* conv2d = - OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); auto* add = OpNode("add", "elementwise_add") ->assert_is_op("elementwise_add") ->AsIntermediate(); @@ -49,6 +50,13 @@ void ConvElementwiseFuser::BuildPattern() { // create topology. std::vector conv2d_inputs{filter, input}; + // consider a special case: conv with bias + if (conv_has_bias_) { + PMNode* conv_bias = VarNode("conv_bias") + ->assert_is_op_input(conv_type_, "Bias") + ->AsIntermediate(); + conv2d_inputs.emplace_back(conv_bias); + } std::vector add_inputs{conv2d_out, bias}; conv2d_inputs >> *conv2d >> *conv2d_out; add_inputs >> *add >> *add_out; @@ -56,44 +64,49 @@ void ConvElementwiseFuser::BuildPattern() { void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto op_desc = GenOpDesc(matched); - auto conv_op = LiteOpRegistry::Global().Create(conv_type_); - auto conv_old = matched.at("conv2d")->stmt()->op(); - auto* scope = conv_old->scope(); - auto& valid_places = conv_old->valid_places(); - conv_op->Attach(op_desc, scope); - - auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + auto conv_instruct = matched.at("conv2d")->stmt(); + auto conv_op_desc = conv_instruct->mutable_op_info(); + auto* scope = conv_instruct->op()->scope(); - IR_NODE_LINK_TO(matched.at("input"), new_op_node); - IR_NODE_LINK_TO(matched.at("filter"), new_op_node); - IR_NODE_LINK_TO(matched.at("bias"), new_op_node); - IR_NODE_LINK_TO(new_op_node, matched.at("output")); -} + ///////////////////////////////////////////////////////////////////////////////////// + // ConvElementwiseFuser + // if `conv_bias` existed, store previous old `conv_bias` to + // `elemwise_bias`, and add `elementwise_add_bias` to `new_conv_bias`. + // if `conv_bias` not existed, set `elementwise_add_bias` as + // `new_conv_bias`. + ///////////////////////////////////////////////////////////////////////////////////// -cpp::OpDesc ConvElementwiseFuser::GenOpDesc(const key2nodes_t& matched) { - auto* desc = matched.at("conv2d")->stmt()->op_info(); + if (conv_has_bias_ == true && conv_op_desc->HasInput("Bias") && + conv_op_desc->Input("Bias").size() > 0) { + auto conv_bias_var = scope->FindVar(conv_op_desc->Input("Bias").front()); + if (conv_bias_var != nullptr) { + // conv bias + auto conv_bias_t = &(conv_bias_var->Get()); + auto conv_bias_d = conv_bias_t->data(); - cpp::OpDesc op_desc = *desc; - op_desc.SetType(conv_type_); - op_desc.SetInput("Input", {matched.at("input")->arg()->name}); - op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); - op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); - op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); - // Other inputs. See operators/conv_op.h - std::vector input_arg_names = desc->InputArgumentNames(); + // elementwise_add bias + auto elementwise_add_bias_t = + scope->FindVar(matched.at("bias")->arg()->name) + ->GetMutable(); + auto elementwise_add_bias_d = + elementwise_add_bias_t->mutable_data(); - if (std::find(input_arg_names.begin(), - input_arg_names.end(), - "ResidualData") != input_arg_names.end()) { - op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) { + elementwise_add_bias_d[i] += conv_bias_d[i]; + } + } } - // Only consider strides, padding, groups, dilations for now - op_desc.SetAttr("strides", desc->GetAttr>("strides")); - op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); - op_desc.SetAttr("groups", desc->GetAttr("groups")); - op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); - return op_desc; + + conv_op_desc->SetType(conv_type_); + conv_op_desc->SetInput("Input", {matched.at("input")->arg()->name}); + conv_op_desc->SetInput("Filter", {matched.at("filter")->arg()->name}); + conv_op_desc->SetOutput("Output", {matched.at("output")->arg()->name}); + conv_op_desc->SetInput("Bias", {matched.at("bias")->arg()->name}); + auto update_conv_desc = *conv_instruct->mutable_op_info(); + conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("bias"), matched.at("conv2d")); + IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("output")); } } // namespace fusion diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.h b/lite/core/mir/fusion/conv_elementwise_fuser.h index 4514fc5010b5c40f31a69c4459f0a26f33d6046a..fdcb5d8912d87c61f13c47e5ef07b926a96d7272 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.h +++ b/lite/core/mir/fusion/conv_elementwise_fuser.h @@ -25,16 +25,18 @@ namespace fusion { class ConvElementwiseFuser : public FuseBase { public: - explicit ConvElementwiseFuser(const std::string& conv_type) { + explicit ConvElementwiseFuser(const std::string& conv_type, + const bool conv_has_bias) { conv_type_ = conv_type; + conv_has_bias_ = conv_has_bias; } void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: - cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - std::string conv_type_; + std::string conv_type_{"conv2d"}; + bool conv_has_bias_{false}; }; } // namespace fusion diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 33223cb140cb5a7f01dc70861b034cb24cd6e19a..af66f5ab66bd09907cb9d28f00f17d983e54c252 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -33,4 +33,6 @@ void ElementwiseAddActivationFusePass::Apply( } // namespace paddle REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, - paddle::lite::mir::ElementwiseAddActivationFusePass); + paddle::lite::mir::ElementwiseAddActivationFusePass) + .BindTargets({TARGET(kAny)}) + .BindKernel("fusion_elementwise_add_activation"); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index 0303ae06e6322940fd3f63da551b3c437e0bdfaa..ed10f06f5651f4000485279d682689101d80aa5a 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -31,4 +31,6 @@ void FcFusePass::Apply(const std::unique_ptr& graph) { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass); +REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) + .BindTargets({TARGET(kAny)}) + .BindKernel("fc"); diff --git a/lite/core/mir/fusion/fc_fuse_pass_test.cc b/lite/core/mir/fusion/fc_fuse_pass_test.cc index cbf77084dd6e6de77617931a077331c16e1f693a..f7aa4bb5adcb848531ecc3a8f63bace1c2e3e0ff 100644 --- a/lite/core/mir/fusion/fc_fuse_pass_test.cc +++ b/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -30,16 +30,12 @@ namespace mir { TEST(fc_fuse_pass, fuse_test) { lite::Predictor predictor; #ifndef LITE_WITH_CUDA - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kX86), PRECISION(kFloat)}}); #else std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, - Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, - Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, }); #endif @@ -72,8 +68,7 @@ TEST(fc_fuse_pass, fuse_test) { #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK TEST(fc_fuse_pass, save_model_test) { lite::Predictor predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kX86), PRECISION(kFloat)}}); predictor.Build(FLAGS_model_dir, "", "", diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc index 72e1a4684d6f0e1554d9fb385e21d31c12dcbb6c..460c0fdf7a4309638b9852a315ca0efda02801ab 100644 --- a/lite/core/mir/fusion/fc_fuser.cc +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -25,7 +25,7 @@ void FcFuser::BuildPattern() { // create nodes. auto* x = VarNode("x")->assert_is_op_input("mul", "X"); auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); - auto* b = VarNode("b"); + auto* b = VarNode("b")->assert_is_persistable_var(); auto* mul = OpNode("mul", "mul"); auto* mul_out = VarNode("mul_out"); auto* add = OpNode("add", "elementwise_add"); @@ -61,6 +61,8 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info(); + op_desc.mutable_inputs()->clear(); + op_desc.mutable_outputs()->clear(); op_desc.SetType("fc"); op_desc.SetInput("Input", {matched.at("x")->arg()->name}); op_desc.SetInput("W", {matched.at("W")->arg()->name}); diff --git a/lite/core/mir/fusion/interpolate_fuse_pass.cc b/lite/core/mir/fusion/interpolate_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..51c9868cf3ed76ee6f02ac954f74c330e9f1a8e1 --- /dev/null +++ b/lite/core/mir/fusion/interpolate_fuse_pass.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/interpolate_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/interpolate_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void InterpolateFusePass::Apply(const std::unique_ptr& graph) { + fusion::InterpolateFuser bilinear_interp_fuser("bilinear_interp"); + bilinear_interp_fuser(graph.get()); + + fusion::InterpolateFuser nearest_interp_fuser("nearest_interp"); + nearest_interp_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_interpolate_fuse_pass, + paddle::lite::mir::InterpolateFusePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/interpolate_fuse_pass.h b/lite/core/mir/fusion/interpolate_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..2beb4bb5b0d714e165ca7fc72227dbb325f66f9d --- /dev/null +++ b/lite/core/mir/fusion/interpolate_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class InterpolateFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/interpolate_fuser.cc b/lite/core/mir/fusion/interpolate_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..458ef76cb4432dd54678824b1a179e554bcbbf78 --- /dev/null +++ b/lite/core/mir/fusion/interpolate_fuser.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/interpolate_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void InterpolateFuser::BuildPattern() { + auto* x = VarNode("x"); + auto* shape = OpNode("shape", "shape")->AsIntermediate(); + auto* shape_out = VarNode("shape_out")->AsIntermediate(); + auto* slice = OpNode("slice", "slice") + ->assert_op_attr_satisfied>( + "axes", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 0; + }) + ->assert_op_attr_satisfied>( + "starts", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 2; + }) + ->assert_op_attr_satisfied>( + "ends", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 4; + }) + ->AsIntermediate(); + auto* slice_out = VarNode("slice_out")->AsIntermediate(); + auto* cast = OpNode("cast", "cast")->AsIntermediate(); + auto* cast_out = VarNode("cast_out")->AsIntermediate(); + auto* fill_constant = + OpNode("fill_constant", "fill_constant")->AsIntermediate(); + auto* fill_constant_out = VarNode("fill_constant_out")->AsIntermediate(); + auto* elementwise_mul = + OpNode("elementwise_mul", "elementwise_mul") + ->assert_op_attr_satisfied( + "axis", [](int attr) { return attr == -1 || attr == 0; }) + ->AsIntermediate(); + auto* elementwise_mul_out = VarNode("elementwise_mul_out")->AsIntermediate(); + auto* interpolate = OpNode("interpolate", interp_type_)->AsIntermediate(); + auto* interpolate_out = VarNode("interpolate_out"); + + // create topology. + *x >> *shape >> *shape_out >> *slice >> *slice_out >> *cast >> *cast_out >> + *elementwise_mul >> *elementwise_mul_out >> *interpolate >> + *interpolate_out; + *fill_constant >> *fill_constant_out >> *elementwise_mul; + *x >> *interpolate; +} + +void InterpolateFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto interp_op = LiteOpRegistry::Global().Create(interp_type_); + auto interp_old = matched.at("interpolate")->stmt()->op(); + auto* scope = interp_old->scope(); + auto& valid_places = interp_old->valid_places(); + interp_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(interp_op, valid_places); + + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("interpolate_out")); +} + +cpp::OpDesc InterpolateFuser::GenOpDesc(const key2nodes_t& matched) { + auto op_desc = *matched.at("interpolate")->stmt()->op_info(); + op_desc.SetInput("OutSize", {}); + op_desc.SetAttr( + "scale", + matched.at("fill_constant")->stmt()->op_info()->GetAttr("value")); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/interpolate_fuser.h b/lite/core/mir/fusion/interpolate_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..51f5655e76749ea4de6e1789f499862f2ac46437 --- /dev/null +++ b/lite/core/mir/fusion/interpolate_fuser.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class InterpolateFuser : public FuseBase { + public: + explicit InterpolateFuser(const std::string& interp_type) + : interp_type_(interp_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string interp_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 83b70c7828191d8f94d38da50ed0bbcc69694bc3..8ec50b8112b6b853e83abf5c491163fa4475f2f4 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" +#include #include #include +#include "lite/api/paddle_place.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include "lite/core/mir/pass_registry.h" @@ -23,17 +25,26 @@ namespace lite { namespace mir { void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { - std::unordered_set quant_types = { + // delete quant node + std::vector quant_op_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; - std::unordered_set quantized_op_types = { + for (auto& op_type : quant_op_types) { + fusion::DeleteQuantOpFuser fuser(op_type); + fuser(graph.get()); + } + + // fuse quantized node and dequant node + std::vector quantized_op_types = { "conv2d", "mul", "depthwise_conv2d"}; - for (auto& quant_type : quant_types) { - for (auto& op_type : quantized_op_types) { - for (int i = 6; i >= 1; i--) { - fusion::QuantDequantOpFuser fuser(op_type, quant_type, i); - fuser(graph.get()); - } - } + for (auto& op_type : quantized_op_types) { + fusion::DequantOpFuser fuser(op_type); + fuser(graph.get()); + } + + // delete quant_dequant_node + for (auto op_type : {"pool2d", "elementwise_add"}) { + fusion::DeleteQuantDequantOpFuser fuser(op_type); + fuser(graph.get()); } } @@ -42,4 +53,6 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, - paddle::lite::mir::QuantDequantFusePass); + paddle::lite::mir::QuantDequantFusePass) + .BindTargets({TARGET(kAny)}) + .BindKernel("calib"); diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 1c7cf866b90ad85f22a9a87d8bbc2db4de94a718..c8b32d46e20586bddc0c1c61fd03cf2a082137e7 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -14,6 +14,7 @@ #include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include +#include #include #include "lite/utils/string.h" @@ -22,174 +23,334 @@ namespace lite { namespace mir { namespace fusion { -void QuantDequantOpFuser::BuildPattern() { - const int kNumFields = 5; - const int kQuantizedWeightOffset = 0; - const int kQuantizedOpOffset = 1; - const int kQuantizedOpOutOffset = 2; - const int kDequantOpOffset = 3; - const int kDequantOpOutOffset = 4; +void DeleteQuantOpFuser::BuildPattern() { + auto* input_scale_node = VarNode("input_scale_node") + ->assert_is_op_input(quant_op_type_, "InScale"); + auto* input_act_node = + VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X"); + auto* quant_node = + OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_op_type_, "OutScale"); + auto* output_act_node = + VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out"); + quant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_node}); + output_act_node->LinksFrom({quant_node}); + VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_; +} + +void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto* input_scale_node = matched.at("input_scale_node"); + auto* input_act_node = matched.at("input_act_node"); + auto* quant_node = matched.at("quant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + + // obtain values, save values and relink node + int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto outlinks = output_act_node->outlinks; + for (auto* quantized_node : outlinks) { + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + op_desc->SetAttr("input_scale", scale_value); + IR_NODE_LINK_TO(input_act_node, quantized_node) + } + + // delete nodes and edges + std::unordered_set nodes2rm = { + input_scale_node, quant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); +} + +cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + +void DequantOpFuser::BuildPattern() { std::string weight_name = ""; if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { weight_name = "Filter"; } else { weight_name = "Y"; } - auto* quant_op_input = VarNode("quant_op_input") - ->assert_is_op_input(quant_type_, "X") - ->AsInput(); - auto* quant_op_in_scale = VarNode("quant_op_in_scale") - ->assert_is_op_input(quant_type_, "InScale") - ->AsIntermediate(); - auto* quant_op = OpNode("quant_op", quant_type_) - ->assert_is_op(quant_type_) - ->AsIntermediate(); - - auto* quant_op_out_scale = - VarNode("quant_op_out_scale") - ->assert_is_op_output(quant_type_, "OutScale") - ->assert_is_op_input("fake_dequantize_max_abs", "Scale") - ->AsIntermediate(); - auto* quant_op_out = VarNode("quant_op_out") - ->assert_is_op_output(quant_type_, "Out") - ->assert_is_op_input(op_type_) + auto* quantized_op_input = + VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput(); + auto* quantized_op_weight = VarNode("quantized_op_weight") + ->assert_is_op_input(op_type_, weight_name) + ->AsInput(); + auto* quantized_op = OpNode("quantized_op", op_type_) + ->assert_is_op(op_type_) ->AsIntermediate(); - std::vector nodes; - for (int i = 0; i < times_; i++) { - nodes.push_back(VarNode(string_format("quantized_op_weight%d", i)) - ->assert_is_op_input(op_type_, weight_name) - ->AsInput()); - - nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_) - ->assert_is_op(op_type_) - ->AsIntermediate()); - - nodes.push_back(VarNode(string_format("quantized_op_out%d", i)) - ->assert_is_op_output(op_type_) - ->assert_is_op_input("fake_dequantize_max_abs", "X") - ->AsIntermediate()); - - nodes.push_back( - OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs") - ->assert_is_op("fake_dequantize_max_abs") - ->AsIntermediate()); - nodes.push_back(VarNode(string_format("dequant_op_out%d", i)) - ->assert_is_op_output("fake_dequantize_max_abs", "Out") - ->AsOutput()); + auto* quantized_op_out = + VarNode("quantized_op_out") + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate(); + auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate(); + auto* dequant_op_out = + VarNode("dequant_op_out") + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput(); + + quantized_op->LinksFrom({quantized_op_input, quantized_op_weight}); + quantized_op_out->LinksFrom({quantized_op}); + dequant_op->LinksFrom({quantized_op_out}); + dequant_op_out->LinksFrom({dequant_op}); + VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_; +} + +void DequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto* quant_op_input = matched.at("quantized_op_input"); + auto* quantized_op_weight = matched.at("quantized_op_weight"); + auto* quantized_op = matched.at("quantized_op"); + auto* dequant_op = matched.at("dequant_op"); + auto* dequant_op_out = matched.at("dequant_op_out"); + + // obtain input_scale and weight_scale + auto* scope = quantized_op->stmt()->op()->scope(); + auto& valid_places = quantized_op->stmt()->op()->valid_places(); + int bit_length = quantized_op->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + float input_scale = + quantized_op->stmt()->op_info()->GetAttr("input_scale"); + float max_range = dequant_op->stmt()->op_info()->GetAttr("max_range"); + float whole_weight_scale = + static_cast(range * range) / max_range / range; + // max_range = range * range / max(abs(weight)) + // weight_scale = range * range / (range * range / max(abs(weight))) / range + // = max(abs(weight)) / range + + // set op desc + cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); + auto quantized_weight_var_name = quantized_op_weight->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + std::vector weight_scale; + int weight_scale_size; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {quant_op_input->arg()->name}); + op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); + // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should + // be Cout. + weight_scale_size = quantized_weight_t->dims()[0]; + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {quant_op_input->arg()->name}); + op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); + // Fc weight: Cin * Cout, the weight_scale_size should be Cout. + weight_scale_size = quantized_weight_t->dims()[1]; } + for (int i = 0; i < weight_scale_size; i++) { + weight_scale.push_back(whole_weight_scale); + } + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr("input_scale", input_scale); + op_desc.SetAttr("weight_scale", weight_scale); - quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); - quant_op_out->LinksFrom({quant_op}); - quant_op_out_scale->LinksFrom({quant_op}); - for (int i = 0; i < times_; i++) { - nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( - {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); - nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOffset]}); - nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); - nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( - {nodes[i * kNumFields + kDequantOpOffset]}); + // change the weight from the float type to int8 type. + Tensor temp_tensor; + temp_tensor.CopyDataFrom(*quantized_weight_t); + float* temp_data = temp_tensor.mutable_data(); + size_t weight_num = quantized_weight_t->data_size(); + int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = static_cast(temp_data[i]); } + quantized_weight_t->set_persistable(true); + quantized_weight_t->set_precision(PRECISION(kInt8)); + + // new op and relink nodes + auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_); + new_quantized_op->Attach(op_desc, scope); + auto* new_quantized_op_node = + graph->GraphCreateInstructNode(new_quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node); + IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node); + IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); +} + +cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; } -void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, - const key2nodes_t& matched) { - const int kNumFields = 5; - const int kQuantizedWeightOffset = 0; - const int kQuantizedOpOffset = 1; - const int kDequantOpOffset = 3; - const int kDequantOpOutOffset = 4; - - auto* quant_op_input = matched.at("quant_op_input"); - auto* quant_op_in_scale = matched.at("quant_op_in_scale"); - auto* quant_op = matched.at("quant_op"); - - std::vector nodes; - for (int i = 0; i < times_; i++) { - nodes.push_back(matched.at(string_format("quantized_op_weight%d", i))); - nodes.push_back(matched.at(string_format("quantized_op%d", i))); - nodes.push_back(matched.at(string_format("quantized_op_out%d", i))); - nodes.push_back(matched.at(string_format("dequant_op%d", i))); - nodes.push_back(matched.at(string_format("dequant_op_out%d", i))); +void DeleteQuantDequantOpFuser::BuildPattern() { + std::string quant_dequant_op_type = + "fake_quantize_dequantize_moving_average_abs_max"; + if (quantized_op_type_ == "pool2d") { + auto* input_scale_node = + VarNode("input_scale_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_node = VarNode("input_act_node") + ->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_node = + OpNode("quant_dequant_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_node = + VarNode("output_act_node") + ->assert_is_op_output(quant_dequant_op_type, "Out"); + auto* quantized_node = OpNode("quantized_node", quantized_op_type_) + ->assert_is_op(quantized_op_type_); + + quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_dequant_node}); + output_act_node->LinksFrom({quant_dequant_node}); + quantized_node->LinksFrom({output_act_node}); + } else if (quantized_op_type_ == "elementwise_add") { + auto* input_scale_left_node = + VarNode("input_scale_left_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_left_node = + VarNode("input_act_left_node") + ->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_left_node = + OpNode("quant_dequant_left_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_left_node = + VarNode("output_scale_left_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_left_node = + VarNode("output_act_left_node") + ->assert_is_op_output(quant_dequant_op_type, "Out") + ->assert_is_op_input(quantized_op_type_, "X"); + quant_dequant_left_node->LinksFrom( + {input_scale_left_node, input_act_left_node}); + output_scale_left_node->LinksFrom({quant_dequant_left_node}); + output_act_left_node->LinksFrom({quant_dequant_left_node}); + + auto* input_scale_right_node = + VarNode("input_scale_right_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_right_node = + VarNode("input_act_right_node") + ->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_right_node = + OpNode("quant_dequant_right_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_right_node = + VarNode("output_scale_right_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_right_node = + VarNode("output_act_right_node") + ->assert_is_op_output(quant_dequant_op_type, "Out") + ->assert_is_op_input(quantized_op_type_, "Y"); + quant_dequant_right_node->LinksFrom( + {input_scale_right_node, input_act_right_node}); + output_scale_right_node->LinksFrom({quant_dequant_right_node}); + output_act_right_node->LinksFrom({quant_dequant_right_node}); + + auto* quantized_node = OpNode("quantized_node", quantized_op_type_) + ->assert_is_op(quantized_op_type_); + quantized_node->LinksFrom({output_act_left_node, output_act_right_node}); + } else { + LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; } - int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); - auto* scope = quant_op->stmt()->op()->scope(); - auto& valid_places = quant_op->stmt()->op()->valid_places(); - int range = ((1 << (bit_length - 1)) - 1); - auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) - ->GetMutable(); - float input_scale = input_scale_t->data()[0] / range; - - VLOG(4) << "range: " << range << " input_scale: " << input_scale; - for (int i = 0; i < times_; i++) { - float max_range = nodes[i * kNumFields + kDequantOpOffset] - ->stmt() - ->op_info() - ->GetAttr("max_range"); - // weight_scale = max(abs(weight)) - float whole_weight_scale = - static_cast(range * range) / max_range / range; - - cpp::OpDesc op_desc = - *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); - - auto quantized_weight_var_name = - nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; - auto quantized_weight_t = - scope->FindVar(quantized_weight_var_name)->GetMutable(); - std::vector weight_scale; - int weight_scale_size; - - if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { - op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); - op_desc.SetOutput( - "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); - // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should - // be Cout. - weight_scale_size = quantized_weight_t->dims()[0]; - } else if (op_type_ == "mul") { - op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); - op_desc.SetOutput( - "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); - // Fc weight: Cin * Cout, the weight_scale_size should be Cout. - weight_scale_size = quantized_weight_t->dims()[1]; - } - for (int i = 0; i < weight_scale_size; i++) { - weight_scale.push_back(whole_weight_scale); - } - op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); - op_desc.SetAttr("weight_scale", weight_scale); - - Tensor temp_tensor; - temp_tensor.CopyDataFrom(*quantized_weight_t); - float* temp_data = temp_tensor.mutable_data(); - - size_t weight_num = quantized_weight_t->data_size(); - int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); - - // change the weight from the float type to int8 type. - for (size_t i = 0; i < weight_num; i++) { - quantized_weight_data[i] = static_cast(temp_data[i]); - } - quantized_weight_t->set_persistable(true); - quantized_weight_t->set_precision(PRECISION(kInt8)); - auto quantized_op = LiteOpRegistry::Global().Create(op_type_); - - quantized_op->Attach(op_desc, scope); - auto* new_op_node = - graph->GraphCreateInstructNode(quantized_op, valid_places); - IR_NODE_LINK_TO(quant_op_input, new_op_node); - IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], - new_op_node); - IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:" + << quantized_op_type_; +} + +void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + if (quantized_op_type_ == "pool2d") { + auto* input_scale_node = matched.at("input_scale_node"); + auto* input_act_node = matched.at("input_act_node"); + auto* quant_dequant_node = matched.at("quant_dequant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + auto* quantized_node = matched.at("quantized_node"); + + // obtain values, save values and relink node + int bit_length = + quant_dequant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_dequant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + op_desc->SetAttr("input_scale", scale_value); + op_desc->SetInput("X", {input_act_node->arg()->name}); + IR_NODE_LINK_TO(input_act_node, quantized_node) + + // delete nodes and edges + std::unordered_set nodes2rm = {input_scale_node, + quant_dequant_node, + output_scale_node, + output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); + } else if (quantized_op_type_ == "elementwise_add") { + auto* input_scale_left_node = matched.at("input_scale_left_node"); + auto* input_act_left_node = matched.at("input_act_left_node"); + auto* quant_dequant_left_node = matched.at("quant_dequant_left_node"); + auto* output_scale_left_node = matched.at("output_scale_left_node"); + auto* output_act_left_node = matched.at("output_act_left_node"); + + auto* input_scale_right_node = matched.at("input_scale_right_node"); + auto* input_act_right_node = matched.at("input_act_right_node"); + auto* quant_dequant_right_node = matched.at("quant_dequant_right_node"); + auto* output_scale_right_node = matched.at("output_scale_right_node"); + auto* output_act_right_node = matched.at("output_act_right_node"); + + auto* quantized_node = matched.at("quantized_node"); + + // obtain values, save values and relink node + int bit_length = + quant_dequant_left_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_dequant_left_node->stmt()->op()->scope(); + auto* left_scale_tensor = + scope->FindVar(output_scale_left_node->arg()->name) + ->GetMutable(); + float left_scale_value = left_scale_tensor->data()[0] / range; + auto* right_scale_tensor = + scope->FindVar(output_scale_right_node->arg()->name) + ->GetMutable(); + float right_scale_value = right_scale_tensor->data()[0] / range; + + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + op_desc->SetAttr("x_input_scale", left_scale_value); + op_desc->SetAttr("y_input_scale", right_scale_value); + op_desc->SetInput("X", {input_act_left_node->arg()->name}); + op_desc->SetInput("Y", {input_act_right_node->arg()->name}); + IR_NODE_LINK_TO(input_act_left_node, quantized_node) + IR_NODE_LINK_TO(input_act_right_node, quantized_node) + + // delete nodes and edges + std::unordered_set nodes2rm = {input_scale_left_node, + quant_dequant_left_node, + output_scale_left_node, + output_act_left_node, + input_scale_right_node, + quant_dequant_right_node, + output_scale_right_node, + output_act_right_node}; + GraphSafeRemoveNodes(graph, nodes2rm); + } else { + LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; } } -cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { +cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc op_desc; return op_desc; } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index 15833ad25805235a408950c9874fdb3566d26976..a56fb665770cb3d523c5666550e295ef51af8474 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -34,13 +34,46 @@ namespace fusion { * the quantized_op. * In addition, the fuser delete fake_quant and fake_dequant op in the graph at * the last. - */ -class QuantDequantOpFuser : public FuseBase { +*/ + +class DeleteQuantOpFuser : public FuseBase { + public: + explicit DeleteQuantOpFuser(const std::string& quant_op_type) + : quant_op_type_(quant_op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quant_op_type_{}; +}; + +class DequantOpFuser : public FuseBase { + public: + explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string op_type_{}; +}; + +/* The pattern like "fake_quantize_dequantize_moving_average_abs_max + + * pooled/elementwise_add" can be deteted by this fuser. The fuser + * extract the input_scale form fake_quant_dequant_op and save into + * the quantized_op. Besides, the fuser delete fake_quant_dequant_op in + * the graph. +*/ + +class DeleteQuantDequantOpFuser : public FuseBase { public: - explicit QuantDequantOpFuser(const std::string& op_type, - const std::string& quant_type, - int times) - : op_type_(op_type), quant_type_(quant_type), times_(times) {} + explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type) + : quantized_op_type_(quantized_op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; @@ -48,9 +81,7 @@ class QuantDequantOpFuser : public FuseBase { cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; private: - std::string op_type_{"conv2d"}; - std::string quant_type_; - int times_; + std::string quantized_op_type_{}; }; } // namespace fusion diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index fbeea88de1ba6ec27b9ae1d829a90931001e108d..2c289da82c69e9abac8cbc32a2efab47ebc05336 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -35,4 +35,6 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, - paddle::lite::mir::ShuffleChannelFusePass); + paddle::lite::mir::ShuffleChannelFusePass) + .BindTargets({TARGET(kAny)}) + .BindKernel("shuffle_channel"); diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc index 93bfef0ae517ce736a302f6bdd39b807991f80f6..c233d6473959d2cb2c7e15fe6074844db0ba5850 100644 --- a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc @@ -36,4 +36,5 @@ void TransposeSoftmaxTransposeFusePass::Apply( } // namespace paddle REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass, - paddle::lite::mir::TransposeSoftmaxTransposeFusePass); + paddle::lite::mir::TransposeSoftmaxTransposeFusePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/generate_program_pass.cc b/lite/core/mir/generate_program_pass.cc index b957e70f981b42f71ac5a253839042d5da307e35..76c97d2da6ed9e7c6fc1f1889d80095278b68ec0 100644 --- a/lite/core/mir/generate_program_pass.cc +++ b/lite/core/mir/generate_program_pass.cc @@ -38,5 +38,5 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(generate_program_pass, - paddle::lite::mir::GenerateProgramPass); +REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index 74245ad11d185fc02bf9ca3db87a1ee33c78a275..76ea9555c29a245aa9f20b158f0706557940bef8 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -90,7 +90,9 @@ std::string Visualize(mir::SSAGraph* graph) { } auto res = dot.Build(); - LOG(INFO) << "dot:\n" << res; + // If we use VLOG here, we can not type all graph out. + // So we change VLOG to std::cout. + std::cout << "dot:\n" << res << std::endl; return res; } @@ -98,4 +100,5 @@ std::string Visualize(mir::SSAGraph* graph) { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass); +REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/io_copy_kernel_pick_pass.cc b/lite/core/mir/io_copy_kernel_pick_pass.cc index 6c62ac9a1a093978996a1433be2e5198ef13b40e..df5ddffe8a17f0087da91491bca777748ad7aa9c 100644 --- a/lite/core/mir/io_copy_kernel_pick_pass.cc +++ b/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -71,4 +71,6 @@ class IoCopyKernelPickPass : public StmtPass { } // namespace paddle REGISTER_MIR_PASS(io_copy_kernel_pick_pass, - paddle::lite::mir::IoCopyKernelPickPass); + paddle::lite::mir::IoCopyKernelPickPass) + .BindTargets({TARGET(kAny)}) + .BindKernel("io_copy"); diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f2355e8a3205cce3410bd2cb6ac4a17d8fde602 --- /dev/null +++ b/lite/core/mir/memory_optimize_pass.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/memory_optimize_pass.h" +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace mir { + +typedef struct { + std::string name; + int cluster; + std::pair lifetime; + std::unordered_set adj; +} MemNode; + +void MemoryOptimizePass::CollectLifeCycleByDevice( + std::unordered_map* lifecycles, + SSAGraph* graph) { + max_lifecycle_ = 0; + + auto is_host = [](TargetType x) -> bool { + return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); + }; + // The vars which inputs or outputs are invalid op will not be reused. + auto valid_var = [&](Node* node) -> bool { + std::set invalid_op = {"while", + "conditional_block", + "conditional_block_infer", + "merge_lod_tensor_infer", + "merge_lod_tensor", + "equal", + "lod_reset", + "concat", + "yolo_box", + "graph_op", + "feed", + "fetch"}; + for (auto* tmp : node->inlinks) { + CHECK(tmp->IsStmt()); + std::string op_type = tmp->AsStmt().op_info()->Type(); + if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != + invalid_op.end()) { + return false; + } + } + for (auto* tmp : node->outlinks) { + CHECK(tmp->IsStmt()); + std::string op_type = tmp->AsStmt().op_info()->Type(); + if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != + invalid_op.end()) { + return false; + } + } + return true; + }; + + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (op_node->IsStmt()) { + auto inputs = op_node->inlinks; + auto outputs = op_node->outlinks; + std::vector requires(inputs.begin(), inputs.end()); + requires.insert(requires.end(), outputs.begin(), outputs.end()); + for (Node* node : requires) { + CHECK(node->IsArg()); + auto& arg = node->AsArg(); + if (arg.is_weight || arg.is_persist) continue; + if (!valid_var(node)) continue; + std::string var_name = arg.name; + TargetType target_type = node->AsArg().type->target(); + if (is_host(target_type)) target_type = TARGET(kHost); + + if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + } else { + int cur_life = + (*lifecycles)[TargetToStr(target_type)][var_name].second; + (*lifecycles)[TargetToStr(target_type)][var_name].second = + std::max(max_lifecycle_, cur_life); + } + } + ++max_lifecycle_; + } + } + LOG(INFO) << "There are " << (*lifecycles).size() << " types device var."; +} + +void MemoryOptimizePass::MakeReusePlan( + const lifecycle_map_t& lifecycles, + std::unordered_map* node2cluster) { + std::vector mem_nodes; + std::vector cluster; + for (auto& data : lifecycles) { + MemNode temp_node; + temp_node.name = data.first; + temp_node.cluster = -1; + temp_node.lifetime = data.second; + mem_nodes.push_back(temp_node); + } + auto overlap = [](std::pair a, std::pair b) -> bool { + return b.second >= a.first && a.second >= b.first; + }; + // If the lifetime of two nodes is overwritten, we set them as adjacent nodes. + for (size_t i = 0; i < mem_nodes.size(); i++) { + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (overlap(mem_nodes[i].lifetime, mem_nodes[j].lifetime)) { + mem_nodes[i].adj.insert(mem_nodes[j].name); + mem_nodes[j].adj.insert(mem_nodes[i].name); + } + } + } + + // Generating Memory Reuse Strategy Based on Greedy Way + // The vars can be reused if there is no overlap between them. + for (size_t i = 0; i < mem_nodes.size(); i++) { + if (mem_nodes[i].cluster >= 0) continue; + int cluster_index = cluster.size(); + mem_nodes[i].cluster = cluster_index; + (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; + cluster.push_back(mem_nodes[i].name); + std::unordered_set cluster_adj = mem_nodes[i].adj; + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (mem_nodes[j].cluster < 0 && + (cluster_adj.find(mem_nodes[j].name) == cluster_adj.end())) { + (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; + mem_nodes[j].cluster = cluster_index; + for (auto& n : mem_nodes[j].adj) { + cluster_adj.insert(n); + } + } + } + } + for (auto& name : cluster) { + LOG(INFO) << "cluster: " << name; + } +} + +void MemoryOptimizePass::PerformReusePlan( + SSAGraph* graph, + const std::unordered_map& reuse_table) { + int node_append_idx = 0; + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + auto& stmt = op_node->AsStmt(); + auto* op_info = stmt.mutable_op_info(); + std::unordered_map> in_args, out_args; + // replace the op's input according the reuse table. + for (auto argument : op_info->inputs()) { + for (const auto& x : argument.second) { + auto name = x; + if (reuse_table.count(x) && reuse_table.at(x) != x) { + name = reuse_table.at(x); + } + in_args[argument.first].push_back(name); + VLOG(4) << op_info->Type() << " input " << x << " -> " << name; + } + } + + // modify the graph + for (Node* input_node : op_node->inlinks) { + CHECK(input_node->IsArg()) << "The op node's inputs should be var node."; + std::string name = input_node->AsArg().name; + if (reuse_table.count(name) && reuse_table.at(name) != name) { + auto replace_name = reuse_table.at(name); + input_node->AsArg().name = + replace_name + "(" + std::to_string(node_append_idx) + ")"; + node_append_idx++; + } + } + + // replace the op's output according the reuse table. + for (auto argument : op_info->outputs()) { + for (const auto& x : argument.second) { + auto name = x; + if (reuse_table.count(x) && reuse_table.at(x) != x) { + name = reuse_table.at(x); + } + out_args[argument.first].push_back(name); + VLOG(4) << op_info->Type() << " output " << x << " -> " << name; + } + } + + // modify the graph + for (Node* out_node : op_node->outlinks) { + CHECK(out_node->IsArg()) << "The op node's outputs should be var node."; + std::string name = out_node->AsArg().name; + if (reuse_table.count(name) && reuse_table.at(name) != name) { + auto replace_name = reuse_table.at(name); + out_node->AsArg().name = + replace_name + "(" + std::to_string(node_append_idx) + ")"; + node_append_idx++; + } + } + + for (auto& arg : in_args) { + op_info->SetInput(arg.first, arg.second); + } + for (auto& arg : out_args) { + op_info->SetOutput(arg.first, arg.second); + } + + auto original_selected_kernel = std::move(stmt.kernels().front()); + auto updated_op_info = *stmt.mutable_op_info(); + stmt.ResetOp(updated_op_info, graph->valid_places()); + stmt.kernels().clear(); + stmt.kernels().emplace_back(std::move(original_selected_kernel)); + for (auto& kernel : stmt.kernels()) { + VLOG(4) << "kernel info: " << kernel->name(); + stmt.op()->AttachKernel(kernel.get()); + } + graph->CheckValid(); + } +} + +void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { + // Memory optimization. + // We will perform the following operation: + // 1. Collect all var's lifetime, then classify them according to the device. + // Only the vars on the same device can be reused. + // 2. Make reuse plan: the vars can be reused if there is no overlap between + // them. + // The final plan is a mapping table in which the key represents the original + // name of var and the value in the table represents the current name of var. + // 3. Perform reuse plan: Replace all var's name in the model according to the + // mapping table. + std::unordered_map lifecycles; + CollectLifeCycleByDevice(&lifecycles, graph.get()); + for (auto& ele : lifecycles) { + std::unordered_map node2cluster; + MakeReusePlan(ele.second, &node2cluster); + PerformReusePlan(graph.get(), node2cluster); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) + .BindTargets({TARGET(kARM)}); diff --git a/lite/core/mir/memory_optimize_pass.h b/lite/core/mir/memory_optimize_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..874fb648cd05931175159bad43e7be38a7aee928 --- /dev/null +++ b/lite/core/mir/memory_optimize_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lite/core/kernel.h" +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * MemoryOptimizePass will + */ +class MemoryOptimizePass : public ProgramPass { + public: + using lifecycle_t = std::pair; + using lifecycle_map_t = std::unordered_map; + void Apply(const std::unique_ptr& graph) override; + + private: + void CollectLifeCycleByDevice( + std::unordered_map* lifecycles, SSAGraph*); + void MakeReusePlan( + const lifecycle_map_t& lifecycles, + std::unordered_map* node2cluster); + void PerformReusePlan( + SSAGraph* graph, + const std::unordered_map& reuse_table); + + private: + int max_lifecycle_{-1}; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/node.cc b/lite/core/mir/node.cc index 61d3d317e7b7bbbfc4064cfbe0f2503f8fbe7a31..4a90e530a46c4d42d2ba032da1828973dfc1bcef 100644 --- a/lite/core/mir/node.cc +++ b/lite/core/mir/node.cc @@ -54,11 +54,6 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, valid_kernels_ = op_->CreateKernels(valid_places); } -std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) { - os << "Statement " << other.op_type() << " " << other.place().DebugString(); - return os; -} - mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { auto &x = AsArg(); x.name = name; diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h index 9c7d441ca3811d39b8ba9f5b49746c9a31c1d449..60fa1fb1ebe49e1be38a7d84cb82545389ea4aac 100644 --- a/lite/core/mir/node.h +++ b/lite/core/mir/node.h @@ -74,7 +74,11 @@ class Node { KernelBase& picked_kernel(); - friend std::ostream& operator<<(std::ostream& os, const Stmt& other); + friend std::ostream& operator<<(std::ostream& os, const Stmt& other) { + os << "Statement " << other.op_type() << " " + << other.place().DebugString(); + return os; + } // Description. std::string desc; diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index bd1ce1412ae3dee1236677e851da322dc45c05ff..4de0fdbf357160348a403d3c8527fe62891237f0 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -14,7 +14,10 @@ #pragma once #include +#include #include +#include + #include "lite/core/mir/node.h" #include "lite/core/mir/ssa_graph.h" @@ -44,6 +47,63 @@ class Pass { void set_doc(const std::string& doc) { doc_ = doc; } const std::string& doc() const { return doc_; } + // Some passes only apply to qualified targets, which need to be explicitly + // declared. + + // Bind targets. At runtime, there must be one device in the bound targets. + void BindTargets(const std::set& targets) { + std::set res; + for (const auto& target : targets) { + const std::set& universe = ExpandValidTargets(target); + std::set_union(bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(res, res.begin())); + } + bound_targets_ = res; + } + + // Exclude targets. At runtime, there must be one device in the bound targets. + void ExcludeTargets(const std::set& targets) { + std::set res; + for (const auto& target : targets) { + const std::set& universe = ExpandValidTargets(target); + std::set_difference(bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(res, res.begin())); + } + bound_targets_ = res; + } + + // Get all bound targets. + const std::set& Targets() const { return bound_targets_; } + + // Some passes are only available on qualified kernels and need to be + // explicitly declared. + // Bind kernels. All kernels bound at runtime must be registered. + void BindKernels( + const std::unordered_map>& + kernels) { + bound_kernels_ = kernels; + } + // Get all bound kernels. + const std::unordered_map>& + GetBoundKernels() const { + return bound_kernels_; + } + // Add one kernel to the bound kernels. + void BindKernel(const std::string& kernel_name, + const lite_api::Place& place) { + if (!bound_kernels_.count(kernel_name)) { + bound_kernels_.insert({kernel_name, {place}}); + } else { + bound_kernels_.at(kernel_name).insert(place); + } + } + Kind kind() const { return kind_; } bool is_debug_pass() const { return kind_ == Kind::kDebug; } bool is_program_pass() const { return kind_ == Kind::kProgramWise; } @@ -55,6 +115,8 @@ class Pass { const Kind kind_; std::string name_; std::string doc_; + std::set bound_targets_; + std::unordered_map> bound_kernels_; }; // Different kinds. diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index 6144ea2c24a22dfff11d6055ceb3771db0c30af5..849f80aea2191b72ac423c7125a4e69cb6927be5 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -14,8 +14,10 @@ #pragma once +#include #include #include "lite/api/paddle_lite_factory_helper.h" +#include "lite/api/paddle_place.h" #include "lite/core/mir/pass_manager.h" namespace paddle { @@ -24,12 +26,33 @@ namespace mir { class PassRegistry { public: - PassRegistry(const std::string& name, mir::Pass* pass) { - VLOG(2) << "Registry add MIR pass " << name; - PassManager::Global().AddNewPass(name, pass); + PassRegistry(const std::string& name, mir::Pass* pass) + : name_(name), pass_(pass) { + PassManager::Global().AddNewPass(name_, pass_); + } + PassRegistry& BindTargets(const std::set& targets) { + pass_->BindTargets(targets); + return *this; + } + PassRegistry& ExcludeTargets(const std::set& targets) { + pass_->ExcludeTargets(targets); + return *this; + } + PassRegistry& BindKernel(const std::string& name, + const lite_api::Place& place) { + pass_->BindKernel(name, place); + return *this; + } + PassRegistry& BindKernel(const std::string& name) { + pass_->BindKernel(name, + Place(TARGET(kAny), PRECISION(kAny), DATALAYOUT(kAny))); + return *this; } - bool Touch() const { return true; } + + private: + std::string name_; + mir::Pass* pass_; }; } // namespace mir @@ -41,4 +64,6 @@ class PassRegistry { new class__); \ bool mir_pass_registry##name__##_fake() { \ return mir_pass_registry##name__.Touch(); \ - } + } \ + static paddle::lite::mir::PassRegistry mir_pass_registry_func_##name__ \ + __attribute__((unused)) = mir_pass_registry##name__ diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f6be2c186d2d940a799201812cce397a9e94eb4 --- /dev/null +++ b/lite/core/mir/pass_utils.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass_utils.h" +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +using lite_api::Place; + +void ExpandPlaces(std::set* places, const Place& place) { + for (const auto& target : lite_api::ExpandValidTargets(place.target)) { + for (const auto& precision : + lite_api::ExpandValidPrecisions(place.precision)) { + for (const auto& layout : lite_api::ExpandValidLayouts(place.layout)) { + places->insert(Place(target, precision, layout)); + } + } + } +} + +bool KernelRegistered(const std::string name, const Place& place) { + std::set places; + ExpandPlaces(&places, place); + for (const auto& p : places) { + if (!KernelRegistry::Global() + .Create(name, p.target, p.precision, p.layout) + .empty()) { + return true; + } + } + return false; +} + +bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { + const auto& targets = pass.Targets(); + if (targets.find(TARGET(kAny)) != targets.end()) return true; + return (targets.find(target) != targets.end()); +} + +bool PassMatchesKernels(const mir::Pass& pass) { + const auto& kernels = pass.GetBoundKernels(); + for (const auto& kernel : kernels) { + for (const auto& place : kernel.second) { + if (!KernelRegistered(kernel.first, place)) { + return false; + } + } + } + return true; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_utils.h b/lite/core/mir/pass_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..942f64bf3190be1f399ac6f014be0881b1450d9b --- /dev/null +++ b/lite/core/mir/pass_utils.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { + +// Query if the specified kernel has been registered. +bool KernelRegistered(const std::string name, const Place& place); + +// Check if the pass hits the hardware target. +bool PassMatchesTarget(const mir::Pass& pass, TargetType target); + +// Check if the pass hits all necessary operators. +bool PassMatchesKernels(const mir::Pass& pass); + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index 8ec85a4ef124ae461cb1e16cf56717a0227b06e6..8e0fc55be2389244ae065b4c2809bbdd74be370c 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -415,7 +415,8 @@ bool IsNthOutput(const Node *var, CHECK(var->IsArg()); CHECK(op->IsStmt()); auto op_info = op->stmt()->op_info(); - if (op_info->Output(argument).size() <= nth) return false; + if (!op_info->HasOutput(argument) || op_info->Output(argument).size() <= nth) + return false; return var->arg()->name == op_info->Output(argument)[nth]; } @@ -426,7 +427,8 @@ bool IsNthInput(const Node *var, CHECK(var->IsArg()); CHECK(op->IsStmt()); auto op_info = op->stmt()->op_info(); - if (op_info->Input(argument).size() <= nth) return false; + if (!op_info->HasInput(argument) || op_info->Input(argument).size() <= nth) + return false; return var->arg()->name == op_info->Input(argument)[nth]; } diff --git a/lite/core/mir/runtime_context_assign_pass.cc b/lite/core/mir/runtime_context_assign_pass.cc index 7a063b0bfd487374ce31e5ebfe1d00369448eb9d..97c4819eaf6734ba9b374444166d17cb15e8ae65 100644 --- a/lite/core/mir/runtime_context_assign_pass.cc +++ b/lite/core/mir/runtime_context_assign_pass.cc @@ -38,4 +38,5 @@ class RuntimeContextAssignPass : public StmtPass { } // namespace paddle REGISTER_MIR_PASS(runtime_context_assign_pass, - paddle::lite::mir::RuntimeContextAssignPass); + paddle::lite::mir::RuntimeContextAssignPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc index 5193d9c899badba3571f26697c822e34bdf45f47..8f22022789046900c3c09cfb122c914968d8d87f 100644 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -26,8 +26,8 @@ namespace mir { bool SSAGraph::CheckBidirectionalConnection() { VLOG(4) << "node count " << node_storage_.size(); for (auto &node : node_storage_) { - if (node.IsStmt()) VLOG(4) << node.AsStmt().op_info()->Type(); - if (node.IsArg()) VLOG(4) << node.AsArg().name << " " << node.AsArg().id; + if (node.IsStmt()) VLOG(6) << node.AsStmt().op_info()->Type(); + if (node.IsArg()) VLOG(6) << node.AsArg().name << " " << node.AsArg().id; for (auto *in : node.inlinks) { CHECK(in->outlinks.end() != std::find(in->outlinks.begin(), in->outlinks.end(), &node)); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 729ad4c9ae42c8cea919326bec899a9560f44e8c..90aca56aec426f6b7ca0d300ded979ae7b10f6df 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -17,33 +17,41 @@ #include #include #include +#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" namespace paddle { namespace lite { namespace mir { -bool KernelScoreCmp(const std::pair>& a, - const std::pair>& b) { +bool KernelScoreCmp(const std::pair>& a, + const std::pair>& b) { return a.first > b.first; } void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { + kernel_pick_factors_.ConsiderTarget(); + kernel_pick_factors_.ConsiderPrecision(); + kernel_pick_factors_.ConsiderDataLayout(); CHECK(kernel_pick_factors_.any_factor_considered()) << "kernel_pick_factors should be specified first"; CHECK(graph) << "graph not valid"; - // sort kernels by the factors. + // sort kernels by the factors. + VLOG(4) << "graph->mutable_nodes().size():" << graph->mutable_nodes().size(); for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); // Get candidate kernels - std::vector>> scored; + std::vector>> scored; CHECK(!instruct.kernels().empty()) << "No kernels found for " << instruct.op_type(); + VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); for (auto&& kernel : instruct.kernels()) { - size_t score = KernelGrade(*kernel); + float score = KernelGrade(*kernel, graph->valid_places()); + VLOG(4) << "kernel->summary():" << kernel->summary() + << " score:" << score; scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); @@ -54,7 +62,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { // Just keep a single best kernel. // TODO(Superjomn) reconsider this. instruct.kernels().emplace_back(std::move(scored.front().second)); - VLOG(2) << "pick " << instruct.kernels().front()->name(); + VLOG(2) << "pick " << instruct.kernels().front()->name() << "\n\n"; } else { bool out_type_int8 = true; @@ -91,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { instruct.ResetOp(update_desc, graph->valid_places()); scored.clear(); for (auto&& kernel : instruct.kernels()) { - size_t score = KernelGrade(*kernel); + float score = KernelGrade(*kernel, graph->valid_places()); scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); @@ -117,7 +125,8 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (all_output_type_match) { instruct.kernels().emplace_back(std::move(candidate.second)); - VLOG(2) << "pick " << instruct.kernels().front()->name(); + VLOG(2) << "instruct.kernels.emplace_back " + << instruct.kernels().front()->name(); break; } } @@ -132,4 +141,5 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(static_kernel_pick_pass, - paddle::lite::mir::StaticKernelPickPass); + paddle::lite::mir::StaticKernelPickPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h index 34122782292eaafd33b5f99ce5bfeb32faf30c7f..7187ddcef6626888eaaf372f7b027aa5d9bd2a3a 100644 --- a/lite/core/mir/static_kernel_pick_pass.h +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -16,6 +16,7 @@ #include #include +#include #include "lite/core/mir/pass.h" #include "lite/core/types.h" @@ -38,8 +39,6 @@ class StaticKernelPickPass : public mir::StmtPass { public: void Apply(const std::unique_ptr& graph) override; - void SetPreferPlace(const Place& place) { place_ = place; } - const Place& place() const { return place_; } const core::KernelPickFactor& kernel_pick_factors() const { return kernel_pick_factors_; } @@ -49,47 +48,82 @@ class StaticKernelPickPass : public mir::StmtPass { private: // Score the kernel. - size_t KernelGrade(const lite::KernelBase& kernel) { - size_t score{}; + size_t KernelGrade(const lite::KernelBase& kernel, + const std::vector& places) { + CHECK_GT(places.size(), 0) << "valid_places is empty."; + float final_score{-1.}; + Place winner_place{places[0]}; const int kMax = std::numeric_limits::max(); + size_t place_size = places.size(); - // The more important factor comes first - if (kernel_pick_factors_.IsTargetConsidered() && - (place().target == kernel.target() || kernel.target() == TARGET(kAny) || - place().target == TARGET(kAny))) { - score += - kMax / static_cast(core::KernelPickFactor::Factor::TargetFirst); - } - if (kernel_pick_factors_.IsPrecisionConsidered() && - (place().precision == kernel.precision() || - kernel.precision() == PRECISION(kAny) || - place().precision == PRECISION(kAny))) { - score += kMax / - static_cast(core::KernelPickFactor::Factor::PrecisionFirst); - } - if (kernel_pick_factors_.IsDataLayoutConsidered() && - (place().layout == kernel.layout() || - kernel.layout() == DATALAYOUT(kAny) || - place().layout == DATALAYOUT(kAny))) { - score += kMax / static_cast( - core::KernelPickFactor::Factor::DataLayoutFirst); + // NOTE: We compare kernel's place with place in valid_places to select the + // best match place + // The place's order in valid_places array decide the user's + // preference + // final_score = weight * socre + // weight: The weight is compute with (valid_places.size() - i) / + // valid_places.size() as default. + // where i is the place's index in valid_places array. + // score: score is the weighted sum of target、percision and layout + for (int i = 0; i < place_size; ++i) { + const auto& place = places[i]; + float weight = static_cast(place_size - i) / place_size; + size_t score{}; + // The more important factor comes first + if (kernel_pick_factors_.IsTargetConsidered() && + (place.target == kernel.target() || kernel.target() == TARGET(kAny) || + place.target == TARGET(kAny))) { + score += kMax / + static_cast(core::KernelPickFactor::Factor::TargetFirst); + } + VLOG(4) << "[score s1]:" << score; + if (kernel_pick_factors_.IsPrecisionConsidered() && + (place.precision == kernel.precision() || + kernel.precision() == PRECISION(kAny) || + place.precision == PRECISION(kAny))) { + score += kMax / static_cast( + core::KernelPickFactor::Factor::PrecisionFirst); + } + VLOG(4) << "[score s2]:" << score; + if (kernel_pick_factors_.IsDataLayoutConsidered() && + (place.layout == kernel.layout() || + kernel.layout() == DATALAYOUT(kAny) || + place.layout == DATALAYOUT(kAny))) { + score += kMax / static_cast( + core::KernelPickFactor::Factor::DataLayoutFirst); + } + VLOG(4) << "[score s3]:" << score; + if (weight * score > final_score) { + final_score = weight * score; + winner_place = place; + } } + + VLOG(4) << "[score(final)]:" << final_score; + VLOG(4) << "-------- pick summary --------"; + VLOG(4) << " ===> place():" << PrecisionToStr(winner_place.precision) << " " + << DataLayoutToStr(winner_place.layout) << " " + << TargetToStr(winner_place.target); + VLOG(4) << " ===> kernel.place():" + << PrecisionToStr(kernel.place().precision) << " " + << DataLayoutToStr(kernel.place().layout) << " " + << TargetToStr(kernel.place().target); + VLOG(4) << "kernel.op_type():" << kernel.op_type(); VLOG(4) << "picker tactic " << kernel_pick_factors_; VLOG(4) << "kernel place " << kernel.place().DebugString(); - VLOG(4) << "picker place " << place().DebugString(); - VLOG(4) << "score " << score; + VLOG(4) << "picker place " << winner_place.DebugString(); + VLOG(4) << "------------------------------"; // The data layout is not considered, for the input and output arguments // might have different data layout. // TODO(Superjomn) reconsider the idea of taking the data layout as a kernel // specification. - return score; + return final_score; } private: core::KernelPickFactor kernel_pick_factors_; - Place place_; }; } // namespace mir diff --git a/lite/core/mir/subgraph/CMakeLists.txt b/lite/core/mir/subgraph/CMakeLists.txt index 9984e202db08d95bbcd691888d91f06a7a3b1d2f..95b5fe5ae13e03940bda8d83fcfc252b4ca490ab 100644 --- a/lite/core/mir/subgraph/CMakeLists.txt +++ b/lite/core/mir/subgraph/CMakeLists.txt @@ -16,10 +16,10 @@ set(subgraph_passes subgraph_pass) if(LITE_WITH_NPU) lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc - DEPS mir_pass types context ${mir_fusers} ${npu_bridges} npu_helper ${npu_ddk_libs} graph_op subgraph_pass) + DEPS mir_pass types context ${mir_fusers} ${npu_bridges} graph_op subgraph_pass) list(APPEND subgraph_passes npu_pass) lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc - DEPS npu_pass cxx_api mir_passes gflags + DEPS npu_pass mir_passes paddle_api_full paddle_api_light gflags ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) if (WITH_TESTING) @@ -30,5 +30,21 @@ if(LITE_WITH_NPU) endif() endif() +if(LITE_WITH_XPU) + lite_cc_library(xpu_pass SRCS generate_xpu_program_pass.cc + DEPS mir_pass types context ${mir_fusers} ${xpu_bridges} ${xpu_builder_libs} graph_op subgraph_pass) + list(APPEND subgraph_passes xpu_pass) + lite_cc_test(test_xpu_pass SRCS generate_xpu_program_pass_test.cc + DEPS xpu_pass mir_passes paddle_api_full gflags + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 + --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) + if (WITH_TESTING) + add_dependencies(test_xpu_pass extern_lite_download_mobilenet_v1_tar_gz) + add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) + set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set_target_properties(test_xpu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + endif() +endif() + set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes") message(STATUS "----> subgraph_passes: ${subgraph_passes}") diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index e7322087afae9098abe61d3769a7312737e7d7eb..c5465a5edaa28d3cc2cfb4a7ffe0cca2e3c1bc79 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -22,15 +22,9 @@ #include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pattern_matcher.h" -#include "ai_ddk_lib/include/HiAiModelManagerService.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" // for ge::op::Data -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/paddle_use_npu_bridges.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" -#include "lite/backends/npu/npu_helper.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/paddle_use_npu_bridges.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { @@ -52,7 +46,7 @@ std::shared_ptr GenerateNPUProgramPass::CvtVarNode( auto wgt = std::make_shared(arg.name); LOG(INFO) << "in convert const:" << arg.name; VLOG(4) << dims; - wgt->set_attr_value(lite::npu::bridge::CvtFromLiteTensor(tensor)); + wgt->set_attr_value(lite::npu::CvtFromLiteTensor(tensor)); return wgt; } else { CHECK_EQ(dims.size(), 4); @@ -75,13 +69,13 @@ std::shared_ptr GenerateNPUProgramPass::CvtVarNode( void GenerateNPUProgramPass::CvtAllOpNodes( const std::vector& nodes2cvt, - lite::npu::bridge::node_map_type* converted_vars) { - const auto& bridges = lite::npu::bridge::Factory::Instance(); + lite::kernels::npu::bridges::node_map_type* converted_vars) { + const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& cvtfunc_map = bridges.AllFunctions(); // return record all converted vars // op node's inputs must be found in converted_vars for (auto& node : nodes2cvt) { - lite::npu::bridge::node_map_type node_inputs; + lite::kernels::npu::bridges::node_map_type node_inputs; auto& stmt = node->AsStmt(); for (auto& var_node : node->inlinks) { auto& arg = var_node->AsArg(); @@ -107,7 +101,7 @@ std::string GenerateNPUProgramPass::BuildNPUGraph( const std::unordered_set& out_data_vars, int sub_id) { auto ordered_nodes = GetTopologicalOrder(op_nodes); - lite::npu::bridge::node_map_type converted_vars; + lite::kernels::npu::bridges::node_map_type converted_vars; CvtAllOpNodes(ordered_nodes, &converted_vars); std::vector in_var_names; @@ -125,13 +119,20 @@ std::string GenerateNPUProgramPass::BuildNPUGraph( outputs.push_back(*converted_vars.at(argname)); } - std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om"); - if (!npu::BuildNPUClient(inputs, outputs, model_name)) { + std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights"; + auto any_op = (*op_nodes.begin())->AsStmt().op(); + auto weight = any_op->scope()->Var(weight_var_name)->GetMutable(); + weight->set_persistable(true); + weight->set_precision(PRECISION(kInt8)); + // Compiling IR graph to NPU model and store mode data into weight tensor with + // persistable=true, Sothat the model parser can recognize it and save it to + // param files + if (!lite::npu::BuildModel(inputs, outputs, weight)) { LOG(WARNING) << "Build NPU failed subgraph " << sub_id; throw std::runtime_error("Build NPU failed subgraph."); } LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id; - return model_name; + return weight_var_name; } void GenerateNPUProgramPass::GenNPUSubgraph( @@ -145,12 +146,12 @@ void GenerateNPUProgramPass::GenNPUSubgraph( FindInputOutputVars( op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); - auto model_name = + auto weight_var_name = BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); auto any_op = (*op_nodes.begin())->AsStmt().op(); InsertNewNode(graph, - model_name, + weight_var_name, any_op->scope(), any_op->valid_places(), in_data_vars, @@ -166,7 +167,7 @@ void GenerateNPUProgramPass::GenNPUSubgraph( void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); - const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& op_map = bridges.AllFunctions(); std::vector supported_op_types; for (auto& i : op_map) { @@ -214,4 +215,5 @@ std::unique_ptr GenerateNPUProgramPass::GenProgram() { } // namespace paddle REGISTER_MIR_PASS(generate_npu_program_pass, - paddle::lite::mir::subgraph::GenerateNPUProgramPass); + paddle::lite::mir::subgraph::GenerateNPUProgramPass) + .BindTargets({TARGET(kNPU)}); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h index 9e030287cb7d91a06bc930f9c1daefb06b3d6965..823ca5f1f624a9e920a5f395a9d5098c5ea52929 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.h +++ b/lite/core/mir/subgraph/generate_npu_program_pass.h @@ -20,10 +20,10 @@ #include #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/npu_helper.h" +#include "lite/backends/npu/builder.h" #include "lite/core/mir/pass.h" #include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { @@ -41,7 +41,7 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { // nodes2cvt: op nodes to convert // return cvted_vars: converted var nodes void CvtAllOpNodes(const std::vector& nodes2cvt, - lite::npu::bridge::node_map_type* cvted_vars); + lite::kernels::npu::bridges::node_map_type* cvted_vars); std::shared_ptr CvtVarNode(lite::mir::Node* var_node, const Scope* scope); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc index 8bfdb7381baa03fb56955baa0a372b40cb8d1c25..95339d6175c98f22d542db24f02d6d714ccbe2a8 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc @@ -12,101 +12,160 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include "lite/core/op_registry.h" -#include "lite/core/program.h" -#include "lite/core/tensor.h" - -#include "lite/api/cxx_api.h" +#include +#include "lite/api/paddle_api.h" #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" -#include "lite/model_parser/pb/program_desc.h" - -DEFINE_string(optimized_model, "", "optimized_model"); -DEFINE_int32(batch_size, 1, "batch size"); -DEFINE_int32(im_channel, 3, "im_channel"); +DEFINE_string(model_file, "", "model file path of combined protobuf model"); +DEFINE_string(params_file, "", "params file path of combined protobuf model"); +DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); +DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors"); +DEFINE_int32(output_tensor_num, 1, "number of output tensors"); namespace paddle { namespace lite { -void TestModel(lite::Predictor* predictor, - const std::vector& valid_places, - const std::string& model_dir) { - predictor->Build(model_dir, - model_dir + "/model", - model_dir + "/params", - Place{TARGET(kARM), PRECISION(kFloat)}, - valid_places); - - auto* input_tensor = predictor->GetInput(0); - input_tensor->Resize(DDim(std::vector( - {FLAGS_batch_size, FLAGS_im_channel, FLAGS_im_height, FLAGS_im_width}))); - auto* data = input_tensor->mutable_data(); - auto item_size = input_tensor->dims().production(); - for (int i = 0; i < item_size; i++) { - data[i] = 1; +std::vector> ParseShape(std::string txt) { + std::vector> shape; + while (!txt.empty()) { + size_t idx = txt.find_first_of(":"); + std::string dims = txt.substr(0, idx); + std::vector s; + while (!dims.empty()) { + size_t idx = dims.find_first_of(","); + int d = atoi(dims.substr(0, idx).c_str()); + VLOG(3) << d; + s.push_back(d); + if (idx == std::string::npos) { + break; + } else { + dims = dims.substr(idx + 1); + } + } + shape.push_back(s); + if (idx == std::string::npos) { + break; + } else { + txt = txt.substr(idx + 1); + } } + return shape; +} - predictor->Run(); - if (model_dir != FLAGS_optimized_model && - std::find(valid_places.begin(), - valid_places.end(), - Place{TARGET(kNPU), PRECISION(kFloat)}) != valid_places.end()) { - predictor->SaveModel(FLAGS_optimized_model); +int64_t ShapeProduction(std::vector shape) { + int64_t s = 1; + for (int64_t dim : shape) { + s *= dim; } + return s; } -void CompareOutData(const lite::Predictor& tgt, const lite::Predictor& ref) { - auto* tgt_otensor = tgt.GetOutput(0); - auto* ref_otensor = ref.GetOutput(0); - const auto* tgt_pdata = tgt_otensor->data(); - const auto* ref_pdata = ref_otensor->data(); - EXPECT_EQ(tgt_otensor->dims().production(), ref_otensor->dims().production()); - for (size_t i = 0; i < tgt_otensor->dims().production(); ++i) { - auto diff = std::fabs((tgt_pdata[i] - ref_pdata[i]) / ref_pdata[i]); - VLOG(3) << diff; - EXPECT_LT(diff, 0.1); +void FillInputTensor( + const std::shared_ptr& predictor, + const std::vector>& input_tensor_shape, + const float value) { + for (int i = 0; i < input_tensor_shape.size(); i++) { + auto input_tensor = predictor->GetInput(i); + input_tensor->Resize(input_tensor_shape[i]); + auto input_tensor_data = input_tensor->mutable_data(); + auto input_tensor_size = ShapeProduction(input_tensor->shape()); + for (int j = 0; j < input_tensor_size; j++) { + input_tensor_data[i] = value; + } } } -TEST(NPUSubgraph, compare) { - DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, 1); - - lite::Predictor predictor_arm, predictor_npu, predictor_npu_savedmodel; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}}); - - TestModel(&predictor_arm, valid_places, FLAGS_model_dir); - - valid_places.push_back(Place{TARGET(kNPU), PRECISION(kFloat)}); - TestModel(&predictor_npu, valid_places, FLAGS_model_dir); - - CompareOutData(predictor_npu, predictor_arm); - LOG(INFO) << " ================ NPU speed ================== "; - for (int i = 0; i < FLAGS_repeats; ++i) { - auto start = GetCurrentUS(); - predictor_npu.Run(); - LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; +void CompareOutputTensor( + const std::shared_ptr& tar_predictor, + const std::shared_ptr& ref_predictor, + const int output_tensor_num) { + for (int i = 0; i < output_tensor_num; i++) { + auto tar_output_tensor = tar_predictor->GetOutput(i); + auto ref_output_tensor = ref_predictor->GetOutput(i); + auto tar_output_tensor_data = tar_output_tensor->data(); + auto ref_output_tensor_data = ref_output_tensor->data(); + auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape()); + auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape()); + EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size); + for (size_t j = 0; j < ref_output_tensor_size; j++) { + auto abs_diff = + std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]); + auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6); + VLOG(3) << "val: " << tar_output_tensor_data[j] + << " ref: " << ref_output_tensor_data[j] + << " abs_diff: " << abs_diff << " rel_diff: " << rel_diff; + EXPECT_LT(rel_diff, 0.1); + } } +} - LOG(INFO) << " =================== ARM CPU speed =================== "; - for (int i = 0; i < FLAGS_repeats; ++i) { +std::shared_ptr TestModel( + const std::string& model_dir, + const std::string& model_file, + const std::string& params_file, + const std::vector& valid_places, + const std::vector>& input_tensor_shape, + const std::string& optimized_model_dir) { + // generate optimized model + lite_api::CxxConfig cxx_config; + cxx_config.set_model_dir(model_dir); + cxx_config.set_model_file(model_file); + cxx_config.set_param_file(params_file); + cxx_config.set_valid_places(valid_places); + auto predictor = lite_api::CreatePaddlePredictor(cxx_config); + FillInputTensor(predictor, input_tensor_shape, 1); + predictor->SaveOptimizedModel(optimized_model_dir, + lite_api::LiteModelType::kNaiveBuffer); + // load optimized model + lite_api::MobileConfig mobile_config; + mobile_config.set_model_dir(optimized_model_dir); + mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); + mobile_config.set_threads(1); + predictor = lite_api::CreatePaddlePredictor(mobile_config); + FillInputTensor(predictor, input_tensor_shape, 1); + // run optimized model + for (int i = 0; i < FLAGS_warmup; i++) { + predictor->Run(); + } + for (int i = 0; i < FLAGS_repeats; i++) { auto start = GetCurrentUS(); - predictor_arm.Run(); + predictor->Run(); LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; } + return predictor; +} - TestModel(&predictor_npu_savedmodel, valid_places, FLAGS_optimized_model); - - CompareOutData(predictor_npu_savedmodel, predictor_arm); +TEST(NPUSubgraph, compare) { + // parsing input tensor shape, supported formats: "1,3,224,224" + // "1,3,224,224:1,80" + std::vector> input_tensor_shape = + ParseShape(FLAGS_input_tensor_shape); + // generate and run optimized CPU model + LOG(INFO) << " ================ CPU ================== "; + auto cpu_predictor = + TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + {lite_api::Place{TARGET(kARM), PRECISION(kFloat)}}, + input_tensor_shape, + FLAGS_optimized_model_dir + "/CPU"); + // generate and run optimized NPU model + LOG(INFO) << " ================ NPU ================== "; + auto npu_predictor = + TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + {lite_api::Place{TARGET(kARM), PRECISION(kFloat)}, + lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}}, + input_tensor_shape, + FLAGS_optimized_model_dir + "/NPU"); + // verify results + CompareOutputTensor(npu_predictor, cpu_predictor, FLAGS_output_tensor_num); } } // namespace lite diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass.cc b/lite/core/mir/subgraph/generate_xpu_program_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..319e1e51feb917b803753807ddbb1f72c2cb7084 --- /dev/null +++ b/lite/core/mir/subgraph/generate_xpu_program_pass.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/generate_xpu_program_pass.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher.h" + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace subgraph { + +std::shared_ptr GenerateXPUProgramPass::CvtVarNode( + lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, + lite::mir::Node* var_node, + const Scope* scope) { + CHECK(var_node->IsArg()); + const auto& arg = var_node->AsArg(); + auto var_name = arg.name; + VLOG(4) << "[XPU] Convert var node " << var_name; + + auto* var = scope->FindVar(var_name); + CHECK(var); + auto* tensor = var->GetMutable(); + CHECK(tensor); + auto dims = tensor->dims(); + auto cvted_var_node = + std::make_shared(graph_ctx->builder->CreateTensor( + var_name, lite::xpu::CvtShape(dims), ::xtcl::Float(32))); + if (arg.is_weight) { + auto cvted_var_tensor = lite::xpu::CvtTensor(tensor); + graph_ctx->params->emplace(std::make_pair(var_name, *cvted_var_tensor)); + } + return cvted_var_node; +} + +void GenerateXPUProgramPass::CvtAllOpNodes( + const std::vector& op_nodes, + lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, + lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes) { + const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); + const auto& supported_lists = bridges.AllFunctions(); + // return record all converted vars + // op node's inputs must be found in converted_vars + for (auto& node : op_nodes) { + lite::kernels::xpu::bridges::node_map_type input_nodes; + auto& stmt = node->AsStmt(); + for (auto& var_node : node->inlinks) { + auto& arg = var_node->AsArg(); + // weight should be handled in the converter, so skip here + if (arg.is_weight) { + continue; + } + auto var_name = arg.name; + if (!cvted_var_nodes->count(var_name)) { + cvted_var_nodes->insert(std::make_pair( + var_name, CvtVarNode(graph_ctx, var_node, stmt.op()->scope()))); + } + input_nodes.insert(*cvted_var_nodes->find(var_name)); + } + auto output_nodes = + supported_lists.at(stmt.op_type())(stmt.op(), graph_ctx, input_nodes); + cvted_var_nodes->insert(output_nodes.begin(), output_nodes.end()); + } +} + +std::string GenerateXPUProgramPass::BuildXPUGraph( + const std::unordered_set& op_nodes, + const std::unordered_set& in_data_vars, + const std::unordered_set& out_data_vars, + int sub_id) { + auto ordered_op_nodes = GetTopologicalOrder(op_nodes); + lite::kernels::xpu::bridges::graph_ctx_type graph_ctx; + graph_ctx.builder = std::make_shared(); + graph_ctx.params = + std::make_shared(); + lite::kernels::xpu::bridges::node_map_type cvted_var_nodes; + CvtAllOpNodes(ordered_op_nodes, &graph_ctx, &cvted_var_nodes); + + std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights"; + auto any_op = (*op_nodes.begin())->AsStmt().op(); + auto weight = any_op->scope()->Var(weight_var_name)->GetMutable(); + weight->set_persistable(true); + weight->set_precision(PRECISION(kInt8)); + // Compiling graph to XPU model and store mode data into weight tensor with + // persistable=true, Sothat the model parser can recognize it and save it to + // param files + std::vector> ordered_cvted_var_nodes; + for (auto out_data_var : out_data_vars) { + auto var_name = out_data_var->AsArg().name; + ordered_cvted_var_nodes.push_back(cvted_var_nodes[var_name]); + } + if (!lite::xpu::BuildModel(graph_ctx.builder, + graph_ctx.params, + &ordered_cvted_var_nodes, + weight)) { + LOG(WARNING) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")"; + throw std::runtime_error("[XPU] Build XPU graph failed."); + } + LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")"; + return weight_var_name; +} + +void GenerateXPUProgramPass::GenXPUSubgraph( + const std::unique_ptr& graph, + const std::unordered_set& op_nodes, + int sub_id) { + std::unordered_set in_data_vars; + std::unordered_set in_wgt_vars; + std::unordered_set out_data_vars; + std::unordered_set out_unused_vars; + FindInputOutputVars( + op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); + + auto weight_var_name = + BuildXPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); + + auto any_op = (*op_nodes.begin())->AsStmt().op(); + InsertNewNode(graph, + weight_var_name, + any_op->scope(), + any_op->valid_places(), + in_data_vars, + in_wgt_vars, + out_data_vars, + out_unused_vars); + + auto nodes2rm = GetNode2rm( + op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars}); + + GraphSafeRemoveNodes(graph.get(), nodes2rm); +} + +void GenerateXPUProgramPass::Apply(const std::unique_ptr& graph) { + LOG(INFO) << "[XPU] Before XPU Pass \n" << Visualize(graph.get()); + const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); + const auto& op_map = bridges.AllFunctions(); + std::vector supported_op_types; + for (auto& i : op_map) { + LOG(INFO) << "[XPU] Supported type: " << i.first; + supported_op_types.push_back(i.first); + } + + try { + int num_subgraph = FuseSubgraph(graph, supported_op_types); + InferOnce(graph); + auto op_nodes_all = ClassifySubgraph(graph); + CHECK_EQ(op_nodes_all.size(), num_subgraph); + int id = 1; + for (auto& op_nodes : op_nodes_all) { + LOG(INFO) << "[XPU] Converting Subgraph " << id; + GenXPUSubgraph(graph, op_nodes.second, id); + LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n" + << Visualize(graph.get()); + id++; + } + } catch (...) { + LOG(WARNING) << "[XPU] Build XPU graph failed."; + throw std::runtime_error("[XPU] Build XPU graph failed."); + } + + for (auto& item : graph->StmtTopologicalOrder()) { + if (item->IsStmt()) { + auto& stmt = item->AsStmt(); + LOG(INFO) << stmt; + insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); + } + } +} + +std::unique_ptr GenerateXPUProgramPass::GenProgram() { + LOG(INFO) << "[XPU] program insts.size=" << insts_.size(); + std::unique_ptr program( + new RuntimeProgram(std::move(insts_))); + return program; +} + +} // namespace subgraph +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(generate_xpu_program_pass, + paddle::lite::mir::subgraph::GenerateXPUProgramPass) + .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass.h b/lite/core/mir/subgraph/generate_xpu_program_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..cf121ae9503201e8cf6be40fe9054ccaf6e4b172 --- /dev/null +++ b/lite/core/mir/subgraph/generate_xpu_program_pass.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/backends/xpu/builder.h" +#include "lite/core/mir/pass.h" +#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace subgraph { + +class GenerateXPUProgramPass : public SubgraphProgramPass { + public: + using key2nodes_t = std::map; + + void Apply(const std::unique_ptr& graph) override; + std::unique_ptr GenProgram(); + + protected: + // nodes2cvt: op nodes to convert + // return cvted_vars: converted var nodes + void CvtAllOpNodes( + const std::vector& op_nodes, + lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, + lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes); + + std::shared_ptr CvtVarNode( + lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, + lite::mir::Node* var_node, + const Scope* scope); + + std::string BuildXPUGraph(const std::unordered_set& op_nodes, + const std::unordered_set& in_data_vars, + const std::unordered_set& out_data_vars, + int sub_id); + + void GenXPUSubgraph(const std::unique_ptr& graph, + const std::unordered_set& op_nodes, + int sub_id); + + private: + std::vector insts_; +}; + +} // namespace subgraph +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc b/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..728ecbc6b77666accd432b1ad82a03860588ab40 --- /dev/null +++ b/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +DEFINE_string(model_file, "", "model file path of combined protobuf model"); +DEFINE_string(params_file, "", "params file path of combined protobuf model"); +DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); +DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors"); +DEFINE_int32(output_tensor_num, 1, "number of output tensors"); + +namespace paddle { +namespace lite { + +std::vector> ParseShape(std::string txt) { + std::vector> shape; + while (!txt.empty()) { + size_t idx = txt.find_first_of(":"); + std::string dims = txt.substr(0, idx); + std::vector s; + while (!dims.empty()) { + size_t idx = dims.find_first_of(","); + int d = atoi(dims.substr(0, idx).c_str()); + VLOG(3) << d; + s.push_back(d); + if (idx == std::string::npos) { + break; + } else { + dims = dims.substr(idx + 1); + } + } + shape.push_back(s); + if (idx == std::string::npos) { + break; + } else { + txt = txt.substr(idx + 1); + } + } + return shape; +} + +int64_t ShapeProduction(std::vector shape) { + int64_t s = 1; + for (int64_t dim : shape) { + s *= dim; + } + return s; +} + +void FillInputTensor( + const std::shared_ptr& predictor, + const std::vector>& input_tensor_shape, + const float value) { + for (int i = 0; i < input_tensor_shape.size(); i++) { + auto input_tensor = predictor->GetInput(i); + input_tensor->Resize(input_tensor_shape[i]); + auto input_tensor_data = input_tensor->mutable_data(); + auto input_tensor_size = ShapeProduction(input_tensor->shape()); + for (int j = 0; j < input_tensor_size; j++) { + input_tensor_data[j] = value; + } + } +} + +void CompareOutputTensor( + const std::shared_ptr& tar_predictor, + const std::shared_ptr& ref_predictor, + const int output_tensor_num) { + for (int i = 0; i < output_tensor_num; i++) { + auto tar_output_tensor = tar_predictor->GetOutput(i); + auto ref_output_tensor = ref_predictor->GetOutput(i); + auto tar_output_tensor_data = tar_output_tensor->data(); + auto ref_output_tensor_data = ref_output_tensor->data(); + auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape()); + auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape()); + EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size); + for (size_t j = 0; j < ref_output_tensor_size; j++) { + auto diff = + std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) / + (std::fabs(ref_output_tensor_data[j]) + 1e-6); + VLOG(3) << diff; + EXPECT_LT(diff, 0.1); + } + } +} + +std::shared_ptr TestModel( + const std::string& model_dir, + const std::string& model_file, + const std::string& params_file, + const std::vector& valid_places, + const std::vector>& input_tensor_shape, + const std::string& optimized_model_dir) { + // generate optimized model + lite_api::CxxConfig cxx_config; + cxx_config.set_model_dir(model_dir); + cxx_config.set_model_file(model_file); + cxx_config.set_param_file(params_file); + cxx_config.set_valid_places(valid_places); + auto predictor = lite_api::CreatePaddlePredictor(cxx_config); + FillInputTensor(predictor, input_tensor_shape, -1); + predictor->SaveOptimizedModel(optimized_model_dir, + lite_api::LiteModelType::kNaiveBuffer); +#if 0 // TODO(hong19860320) supports light api for XPU + // load optimized model + lite_api::MobileConfig mobile_config; + mobile_config.set_model_dir(optimized_model_dir); + mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); + mobile_config.set_threads(1); + predictor = lite_api::CreatePaddlePredictor(mobile_config); + FillInputTensor(predictor, input_tensor_shape, 1); +#endif + // run optimized model + for (int i = 0; i < FLAGS_warmup; i++) { + predictor->Run(); + } + for (int i = 0; i < FLAGS_repeats; i++) { + auto start = GetCurrentUS(); + predictor->Run(); + LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; + } + return predictor; +} + +TEST(XPUSubgraph, compare) { + // parsing input tensor shape, supported formats: "1,3,224,224" + // "1,3,224,224:1,80" + std::vector> input_tensor_shape = + ParseShape(FLAGS_input_tensor_shape); + // generate and run optimized CPU model + LOG(INFO) << " ================ CPU ================== "; + auto cpu_predictor = + TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + {lite_api::Place{TARGET(kX86), PRECISION(kFloat)}}, + input_tensor_shape, + FLAGS_optimized_model_dir + "/CPU"); + // generate and run optimized XPU model + LOG(INFO) << " ================ XPU ================== "; + auto xpu_predictor = + TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + {lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}}, + input_tensor_shape, + FLAGS_optimized_model_dir + "/XPU"); + // verify results + CompareOutputTensor(xpu_predictor, cpu_predictor, FLAGS_output_tensor_num); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc index dddcdad7efc0b518e8c6396b2724808186adc2c2..719a01dfd892f83da5e1d9b1efa6df758612acc7 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass.cc @@ -43,20 +43,20 @@ SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr& graph) { } cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc( - const std::string& model_name, + const std::string& weight_var_name, const std::vector& in_var_names, const std::vector& out_var_names) { cpp::OpDesc op_desc; op_desc.SetType("graph_op"); op_desc.SetInput("Inputs", in_var_names); + op_desc.SetInput("Weight", {weight_var_name}); op_desc.SetOutput("Outputs", out_var_names); - op_desc.SetAttr("model_name", model_name); return op_desc; } void SubgraphProgramPass::InsertNewNode( const std::unique_ptr& graph, - const std::string& model_name, + const std::string& weight_var_name, Scope* scope, const std::vector& valid_places, std::unordered_set in_data_vars, @@ -72,7 +72,7 @@ void SubgraphProgramPass::InsertNewNode( out_var_names.push_back(i->AsArg().name); } - auto op_desc = GenGraphOpDesc(model_name, in_var_names, out_var_names); + auto op_desc = GenGraphOpDesc(weight_var_name, in_var_names, out_var_names); auto graph_op = LiteOpRegistry::Global().Create("graph_op"); graph_op->Attach(op_desc, scope); @@ -91,6 +91,12 @@ void SubgraphProgramPass::InsertNewNode( IR_OP_VAR_LINK(new_op_node, out_var); } + // add weight node to store pre-compilied NPU model + auto new_weight_node = graph->NewArgumentNode(weight_var_name); + new_weight_node->AsArg().is_weight = true; + new_weight_node->AsArg().is_persist = true; + DirectedLink(new_weight_node, new_op_node); + // assign context auto& inst = new_op_node->AsStmt(); inst.picked_kernel().SetContext( @@ -201,8 +207,30 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr& graph) { if (!item->IsStmt()) continue; auto& stmt = item->AsStmt(); auto& op = stmt.op(); + auto scope = op->scope(); + std::string op_type = op->op_info()->Type(); + // check the dimension of input variables in the scope, must not be empty ! + if (op_type == "feed") { + auto input_var_names = op->op_info()->output_names(); + CHECK_GE(input_var_names.size(), 1); + for (auto input_var_name : input_var_names) { + auto input_var = scope->FindVar(input_var_name); + CHECK(input_var) << "No input variable '" << input_var_name + << "' found in scope " << scope; + auto input = input_var->GetMutable(); + CHECK(!input->dims().empty()) << "The dimension of input variable '" + << input_var_name + << "' can not be empty."; + } + continue; + } + if (op_type == "fetch") { + continue; + } op->CheckShape(); op->InferShape(); + +#ifndef LITH_WITH_XPU // TOOD(xxx): remove Launch() at last auto& kkks = stmt.kernels(); if (!kkks.empty()) { @@ -211,6 +239,7 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr& graph) { kk->Launch(); } } +#endif } } @@ -278,19 +307,21 @@ int SubgraphProgramPass::FuseSubgraphID( const std::unique_ptr& graph) { int sub_id = 1; // id start from 1 not 0 for (auto& item : graph->StmtTopologicalOrder()) { - bool inputvar = 0; + // bool inputvar = false; if (!item->IsStmt()) continue; auto& stmt = item->AsStmt(); + /* if (stmt.subgraph_id() == -1) { for (auto& i : item->outlinks) { for (auto& j : i->outlinks) { if (j->IsStmt()) { auto& jstmt = j->AsStmt(); - if (jstmt.subgraph_id() == 0) inputvar = 1; + if (jstmt.subgraph_id() == 0) inputvar = true; } } } } + */ if (stmt.subgraph_id() != 0) continue; ChangeAllOutConnectedID(item, sub_id); sub_id++; @@ -310,4 +341,5 @@ int SubgraphProgramPass::FuseSubgraph( } // namespace paddle REGISTER_MIR_PASS(subgraph_program_pass, - paddle::lite::mir::subgraph::SubgraphProgramPass); + paddle::lite::mir::subgraph::SubgraphProgramPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/subgraph_program_pass.h b/lite/core/mir/subgraph/subgraph_program_pass.h index 51e9367539caa6f0868138235bc7b0907c189df5..24c0233bbb428a71fa5645b23573494b5067d8b1 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.h +++ b/lite/core/mir/subgraph/subgraph_program_pass.h @@ -60,13 +60,13 @@ class SubgraphProgramPass : public ProgramPass { const std::unique_ptr& graph); // generate the graph op desc - cpp::OpDesc GenGraphOpDesc(const std::string& model_name, + cpp::OpDesc GenGraphOpDesc(const std::string& weight_var_name, const std::vector& in_var_names, const std::vector& out_var_names); // insert a new graph op node void InsertNewNode(const std::unique_ptr& graph, - const std::string& model_name, + const std::string& weight_var_name, Scope* scope, const std::vector& valid_places, std::unordered_set in_data_vars, diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_program_pass_test.cc index de4acec91d3eacd5f880d6495367a4826eb90cfa..22e20b81d831ff25df090a7565e671b9139122f7 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass_test.cc @@ -46,6 +46,9 @@ TEST(SubgraphTest, models) { #endif #ifdef LITE_WITH_NPU Place{TARGET(kNPU), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_XPU + Place{TARGET(kXPU), PRECISION(kFloat)}, #endif }); lite::Program program(program_desc, scope, valid_places); @@ -214,7 +217,6 @@ TEST(SubGraphTest, SimpleNet) { auto* pass = new mir::subgraph::SubgraphProgramPass; ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); - const int num_nodes = graph->nodes().size(); ASSERT_EQ(graph->nodes().size(), 9); // LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); } diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index 2b216ceec590f243aab48a9af4b14cf3f8a13bd4..9d63dcbb38b2354c567ca1e0d434ac1a4be424c1 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -28,19 +28,24 @@ namespace mir { void TypeLayoutTransformPass::Apply(const std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. + VLOG(4) << "\n" << Visualize(graph.get()); std::list nodes; - for (auto& node : graph->mutable_nodes()) { - nodes.push_back(&node); + for (auto& node : graph->StmtTopologicalOrder()) { + nodes.push_back(node); } + VLOG(4) << "nodes.size():" << nodes.size(); for (auto& node : nodes) { - if (!node->IsStmt()) continue; + VLOG(4) << "!node->IsStmt():" << !node->IsStmt(); + if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; auto inlinks = node->inlinks; + VLOG(4) << "node->AsStmt().desc:" << node->AsStmt().desc + << " inlinks.size():" << inlinks.size(); for (auto* in : inlinks) { ComplementInputs(graph.get(), node, in); } } - VLOG(3) << "\n" << Visualize(graph.get()); + VLOG(4) << "\n" << Visualize(graph.get()); } void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, @@ -53,13 +58,21 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, CHECK(inst_node->IsStmt()); auto& inst = inst_node->AsStmt(); + VLOG(4) << "found Target tensor: " << in->AsArg().name; CHECK(in->IsRoleSet()); CHECK(in->IsArg()); auto in_arg_name = in->AsArg().name; - std::string tmp; - CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); - auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); + std::string inst_in_tensor_name; + CHECK(inst.op_info()->GetInputArgname(in_arg_name, &inst_in_tensor_name)); + auto decl_arg_type = + inst.picked_kernel().GetInputDeclType(inst_in_tensor_name); CHECK(in->AsArg().type); + VLOG(5) << "\n inst_in_tensor_name:" << inst_in_tensor_name + << "\n in->AsArg().name:" << in->AsArg().name + << "\n *in->AsArg().type:" << *in->AsArg().type + << "\n *decl_arg_type:" << *decl_arg_type + << "\n inst.op()->DebugString():" << inst.op()->DebugString(); + if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) { VLOG(4) << "found Layout unmatched tensor: " << in->AsArg().name << " for kernel " << inst.op()->DebugString() << " " @@ -83,10 +96,13 @@ void TypeLayoutTransformPass::AddLayoutInst( CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(in->IsArg()); - auto node_id = [&] { return graph->nodes().size(); }; + // auto node_id = [&] { return graph->nodes().size(); }; auto layout_output_name = - string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); + string_format("%s/layout_trans", in->AsArg().name.c_str()); auto* layout_output_arg = graph->NewArgumentNode(layout_output_name); + layout_output_arg->AsArg().type = + LiteType::GetTensorTy(from.target(), from.precision(), to.layout()); + auto* layout_inst = graph->NewInstructNode(); bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; @@ -111,17 +127,31 @@ void TypeLayoutTransformPass::AddLayoutInst( for (auto& kernel : kernels) { const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - if (TypeCompatible(*in_arg_ty, from)) { +#ifdef LITE_WITH_OPENCL + // layout kernel choose + // must ignore [layout check] for layout of kernels's input and output + if (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from) && + out_arg_ty->layout() == to.layout()) { +#else + if (TypeCompatible(*in_arg_ty, from) && + out_arg_ty->layout() == to.layout()) { +#endif is_found = true; selected_kernels.emplace_back(std::move(kernel)); // we pick the kernel - layout_inst->AsStmt(layout_type, std::move(kernels), layout_op); + layout_inst->AsStmt(layout_type, std::move(selected_kernels), layout_op); break; } } - CHECK(is_found) << "Can't find a layout kernel for layout op: " << from - << ":" << in->AsArg().name << "->" << to << ":" + CHECK(is_found) << "Can't find a layout kernel for layout op: " << from << ":" + << in->AsArg().name << "->" << to << ":" << inst_node->AsStmt().op_info()->Type(); + VLOG(4) << "========= final picked layout kernel ========= "; + VLOG(4) << "[info]:" << layout_inst->AsStmt().picked_kernel().name(); + VLOG(4) << "[summary]:" << layout_inst->AsStmt().picked_kernel().summary() + << "\n"; // Remove the old link RemoveDirectedLink(in, inst_node); @@ -173,4 +203,7 @@ void TypeLayoutTransformPass::SetValidPlaces( } // namespace paddle REGISTER_MIR_PASS(type_layout_cast_pass, - paddle::lite::mir::TypeLayoutTransformPass); + paddle::lite::mir::TypeLayoutTransformPass) + .BindTargets({TARGET(kAny)}) + .BindKernel("layout_once") + .BindKernel("layout"); diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 517f9a9b70fe4aad155b38ef248e6f3c7dbea922..2f177383fc2b3a035313c0654c961c0b21a7f197 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -28,12 +28,12 @@ namespace mir { void PrecisionCastPass::Apply(const std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. std::list nodes; - for (auto& node : graph->mutable_nodes()) { - nodes.push_back(&node); + for (auto& node : graph->StmtTopologicalOrder()) { + nodes.push_back(node); } for (auto& node : nodes) { - if (!node->IsStmt()) continue; + if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; auto inlinks = node->inlinks; for (auto* in : inlinks) { ComplementInputs(graph.get(), node, in); @@ -86,10 +86,12 @@ void PrecisionCastPass::AddCastInst(const Type& from, // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new Cast Statement Node. CHECK(in->IsArg()); - auto node_id = [&] { return graph->nodes().size(); }; - auto cast_op_output_name = - in->AsArg().name + "/trans/" + std::to_string(node_id()); + // auto node_id = [&] { return graph->nodes().size(); }; + auto cast_op_output_name = in->AsArg().name + "/precision_trans"; + // in->AsArg().name + "/precision_trans/" + std::to_string(node_id()); auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); + cast_op_output_arg->AsArg().type = + LiteType::GetTensorTy(from.target(), to.precision(), from.layout()); auto* cast_inst = graph->NewInstructNode(); // create Op and kernels. @@ -118,13 +120,8 @@ void PrecisionCastPass::AddCastInst(const Type& from, for (auto& kernel : kernels) { const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); -// TODO(xg): to optimize this -#ifndef LITE_WITH_FPGA - if (in_arg_ty->precision() == from.precision() && + if (TypeCompatible(*in_arg_ty, from) && out_arg_ty->precision() == to.precision()) { -#else - if (TypeCompatible(*in_arg_ty, from)) { -#endif is_found = true; selected_kernels.emplace_back(std::move(kernel)); // we pick the kernel @@ -179,4 +176,7 @@ void PrecisionCastPass::SetValidPlaces(const std::vector& valid_places) { } // namespace paddle REGISTER_MIR_PASS(type_precision_cast_pass, - paddle::lite::mir::PrecisionCastPass); + paddle::lite::mir::PrecisionCastPass) + .BindTargets({TARGET(kAny)}) + .BindKernel("calib_once") + .BindKernel("calib"); diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index f653654e9675bde7229c599d2d5640ac2016c37f..7a3277786553d8a256c48e9e5c99530b8d5681b5 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -29,14 +29,14 @@ namespace mir { void TypeTargetTransformPass::Apply(const std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. std::list nodes; - for (auto& node : graph->mutable_nodes()) { - nodes.push_back(&node); + for (auto& node : graph->StmtTopologicalOrder()) { + nodes.push_back(node); } CHECK(!valid_places_.empty()); for (auto& node : nodes) { - if (!node->IsStmt()) continue; + if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; auto inlinks = node->inlinks; for (auto* in : inlinks) { ComplementInputs(graph.get(), node, in); @@ -54,7 +54,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, CHECK(inst_node->IsStmt()); auto& inst = inst_node->AsStmt(); - LOG(INFO) << "found Target tensor: " << in->AsArg().name; + VLOG(3) << "found Target tensor: " << in->AsArg().name; CHECK(in->IsRoleSet()); CHECK(in->IsArg()); auto in_arg_name = in->AsArg().name; @@ -63,9 +63,9 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); CHECK(in->AsArg().type); if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { - LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name - << " for kernel " << inst.op()->DebugString() << " " - << *in->AsArg().type << " -> " << *decl_arg_type; + VLOG(3) << "found Target unmatched tensor: " << in->AsArg().name + << " for kernel " << inst.op()->DebugString() << " " + << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. AddIoCopyInst( *in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_); @@ -84,11 +84,17 @@ void TypeTargetTransformPass::AddIoCopyInst( // So there will be a new Argument node and a new IoCopy Statement Node. CHECK(in->IsArg()); - auto node_id = [&] { return graph->nodes().size(); }; + // auto node_id = [&] { return graph->nodes().size(); }; auto io_copy_output_name = - string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); + string_format("%s/target_trans", in->AsArg().name.c_str()); + // string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); // TODO(MyPandaShaoxiang) should set same place with input? auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); + // Set the place for io_copy_output_arg node, the target should be equal to + // to.target() + // The precision and layout should be equal to from.precision(), from.layout() + io_copy_output_arg->AsArg().type = + LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); auto* io_copy_inst = graph->NewInstructNode(); bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; @@ -115,7 +121,36 @@ void TypeTargetTransformPass::AddIoCopyInst( for (auto& kernel : kernels) { const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - if (TypeCompatible(*in_arg_ty, from)) { + + VLOG(4) << "------ kernel info -------"; + VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; + VLOG(4) << "from(last kernel output):" << from; + VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; + VLOG(4) << "to:" << to << "\n"; + +// kernel choose branch for opencl backend +// judge inst's target whether is kOpenCL +// Note: to == *decl_arg_type == in of inst, not output of last inst +#ifdef LITE_WITH_OPENCL + // ignore [layout check] for layout between [to] and [from] + // Because all of origin opencl insts in model, are not default layout + // NCHW, + // so skip layout check. + // detailed node info see below: + // [*in->AsArg().type] -> [from]: out of inst's previous kernel + // [*decl_arg_type] -> [to]: input of inst, not output of last + // [in_arg_ty]: in of io_copy + // [out_arg_ty]: out of io_copy + if (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from) && + TargetCompatibleTo(*out_arg_ty, to)) { + VLOG(4) << "do nothing. opencl found"; +#else + if (TypeCompatible(*in_arg_ty, from) && + out_arg_ty->target() == to.target()) { +#endif + VLOG(4) << "picked"; is_found = true; selected_kernels.emplace_back(std::move(kernel)); // we pick the kernel @@ -123,20 +158,23 @@ void TypeTargetTransformPass::AddIoCopyInst( io_copy_type, std::move(selected_kernels), io_copy_op); break; } + VLOG(4) << "not picked"; } CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from - << ":" << in->AsArg().name << "->" << to << ":" + << ":" << in->AsArg().name << " -> " << to << ":" << inst_node->AsStmt().op_info()->Type(); - // Remove the old link RemoveDirectedLink(in, inst_node); // Update the original instruction OpDesc. // Update its input to the io_copy_output_name // Add new link, var -> new_inst, new_inst->newarg, newarg->inst - DirectedLink(in, io_copy_inst); - DirectedLink(io_copy_inst, io_copy_output_arg); - DirectedLink(io_copy_output_arg, inst_node); + DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel] + DirectedLink( + io_copy_inst, + io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output + DirectedLink(io_copy_output_arg, + inst_node); // [io_copy kernel]'s output -> [current kernel] // reset opdesc and update kernel information UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), @@ -179,4 +217,5 @@ void TypeTargetTransformPass::SetValidPlaces( } // namespace paddle REGISTER_MIR_PASS(type_target_cast_pass, - paddle::lite::mir::TypeTargetTransformPass); + paddle::lite::mir::TypeTargetTransformPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/variable_place_inference_pass.cc b/lite/core/mir/variable_place_inference_pass.cc index e3795ae642972e58fe698210f9093f6ebba1c3c8..f1b6381fc0010e08cffa4baee4dc7b33a678b387 100644 --- a/lite/core/mir/variable_place_inference_pass.cc +++ b/lite/core/mir/variable_place_inference_pass.cc @@ -31,4 +31,5 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr &graph) { } // namespace paddle REGISTER_MIR_PASS(variable_place_inference_pass, - paddle::lite::mir::VariablePlaceInferencePass); + paddle::lite::mir::VariablePlaceInferencePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index d5b0bb8aa67573cf9c3f9cce7f9deeea15712ded..fe6ecfd66df23bb704fafcbf94106f7ca973c4f1 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -57,12 +57,21 @@ class VariablePlaceInferencePass : public DebugPass { // Set the tye of the weight void SetWeightType(Node* w, const LiteType& type) { // TODO(xg) to optimize this -#ifndef LITE_WITH_FPGA - w->AsArg().type = - LiteType::GetTensorTy(TARGET(kHost), type.precision(), type.layout()); -#else +#ifdef LITE_WITH_FPGA + w->AsArg().type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); +#endif + +#ifdef LITE_WITH_OPENCL w->AsArg().type = LiteType::GetTensorTy( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); +#endif + +#ifndef LITE_WITH_FPGA +#ifndef LITE_WITH_OPENCL + w->AsArg().type = LiteType::GetTensorTy( + TARGET(kHost), type.precision(), DATALAYOUT(kNCHW)); +#endif #endif } @@ -74,7 +83,10 @@ class VariablePlaceInferencePass : public DebugPass { // in fpga, we has io_copy+cali+layout tool ops, so we need type inference for // tool operator #ifndef LITE_WITH_FPGA +#ifndef LITE_WITH_OPENCL + VLOG(3) << "inst.op_type() == 'io_copy', continue"; if (inst.op_type() == "io_copy") continue; +#endif #endif // deal with inputs VLOG(4) << "Infering op " << inst.op_info()->Repr(); @@ -97,8 +109,8 @@ class VariablePlaceInferencePass : public DebugPass { std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); CHECK(arg_name.size() > 0) << "can not found op arguments for node " << node_name; - VLOG(4) << "-- input arg_name " << arg_name - << "-- node name :" << node_name; + VLOG(4) << "-- input arg_name:" << arg_name << " " + << "-- node name:" << node_name; auto type = inst.picked_kernel().GetInputDeclType(arg_name); if (!x_in->AsArg().type) { VLOG(4) << "set type " << *type << " " << x_in->AsArg().name; diff --git a/lite/core/mir/variable_place_inference_pass_test.cc b/lite/core/mir/variable_place_inference_pass_test.cc index cf86afd590db8b05dcec720455284b3311551848..dec37078fa24e6c7974391d254f3847b7a90e8ba 100644 --- a/lite/core/mir/variable_place_inference_pass_test.cc +++ b/lite/core/mir/variable_place_inference_pass_test.cc @@ -63,18 +63,6 @@ TEST(variable_place_inference_pass, test) { "type_target_cast_pass", // }); - Place prefered_place{ -#ifdef PADDLE_WITH_CUDA - TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), -#else -#ifdef PADDLE_WITH_ARM - TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW), -#else // X86 - TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW), -#endif // ARM -#endif - }; - optimizer.KernelPickPreferPlace(prefered_place); optimizer.Run(std::move(program), places, factor, passes); } diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index 412b299339a3cc516a59df5533a5d1a4bbc9bc13..0936a44a66e4777633b84dadf0a1dc049213faab 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -63,7 +63,7 @@ std::vector> OpLite::CreateKernels( targets.insert(place.target); } - VLOG(4) << "op " << op_type_ << " get " << kernels.size() << " kernels"; + VLOG(5) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index f843ef6f2b35fb285d4ae0259463e723e3775cd4..5dec9ed7aace837e3eb085a55d7b9b5382f7dea3 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -57,7 +57,7 @@ class OpLite : public Registry { : valid_places_(valid_places) {} void SetValidPlaces(const std::vector &places) { - VLOG(3) << "valid places " << valid_places_.size(); + VLOG(5) << "valid places " << valid_places_.size(); valid_places_ = places; } const std::vector &valid_places() const { return valid_places_; } @@ -80,6 +80,8 @@ class OpLite : public Registry { // Human-readable information. virtual std::string DebugString() const = 0; + virtual std::string SerializedOpInfo() const { return "N/A"; } + const Place &kernel_place() const { return kernel_place_; } // Create all the kernels for the valid targets. diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 816837effc2a49100416bf600c5298e6c396c13f..3b8b350ad82f2cc1ce296b1ad74a6e322abec8ff 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -54,6 +54,8 @@ std::list> KernelRegistry::Create( CREATE_KERNEL1(target__, kFP16); \ case PRECISION(kAny): \ CREATE_KERNEL1(target__, kAny); \ + case PRECISION(kInt64): \ + CREATE_KERNEL1(target__, kInt64); \ default: \ CHECK(false) << "not supported kernel precision " \ << PrecisionToStr(precision); \ @@ -78,6 +80,9 @@ std::list> KernelRegistry::Create( case TARGET(kNPU): { CREATE_KERNEL(kNPU); } break; + case TARGET(kXPU): { + CREATE_KERNEL(kXPU); + } break; case TARGET(kFPGA): { CREATE_KERNEL(kFPGA); } break; @@ -105,8 +110,11 @@ KernelRegistry::KernelRegistry() DATALAYOUT(layout__)>::Global()); // Currently, just register 2 kernel targets. INIT_FOR(kCUDA, kFloat, kNCHW); + INIT_FOR(kCUDA, kFloat, kNHWC); + INIT_FOR(kCUDA, kInt8, kNCHW); INIT_FOR(kCUDA, kAny, kNCHW); INIT_FOR(kCUDA, kAny, kAny); + INIT_FOR(kCUDA, kInt8, kNHWC); INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kAny, kNCHW); @@ -120,6 +128,7 @@ KernelRegistry::KernelRegistry() INIT_FOR(kX86, kFloat, kNCHW); INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kAny); + INIT_FOR(kX86, kInt64, kNCHW); INIT_FOR(kARM, kFloat, kNCHW); INIT_FOR(kARM, kInt8, kNCHW); @@ -127,7 +136,11 @@ KernelRegistry::KernelRegistry() INIT_FOR(kARM, kAny, kAny); INIT_FOR(kOpenCL, kFloat, kNCHW); + INIT_FOR(kOpenCL, kFloat, kNHWC); INIT_FOR(kOpenCL, kAny, kNCHW); + INIT_FOR(kOpenCL, kAny, kNHWC); + INIT_FOR(kOpenCL, kFloat, kAny); + INIT_FOR(kOpenCL, kInt8, kNCHW); INIT_FOR(kOpenCL, kAny, kAny); INIT_FOR(kNPU, kFloat, kNCHW); @@ -135,6 +148,11 @@ KernelRegistry::KernelRegistry() INIT_FOR(kNPU, kAny, kNCHW); INIT_FOR(kNPU, kAny, kAny); + INIT_FOR(kXPU, kFloat, kNCHW); + INIT_FOR(kXPU, kInt8, kNCHW); + INIT_FOR(kXPU, kAny, kNCHW); + INIT_FOR(kXPU, kAny, kAny); + INIT_FOR(kFPGA, kFP16, kNHWC); INIT_FOR(kFPGA, kFP16, kAny); INIT_FOR(kFPGA, kFloat, kNHWC); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 3eaa0e033d41527aabe71596b8d27a52f1307da0..1c67ee8f3dcafe30d9bda587d62233d0e715071e 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -15,9 +15,11 @@ #pragma once #include +#include #include #include #include +#include #include #include #include @@ -26,9 +28,47 @@ #include "lite/core/op_lite.h" #include "lite/core/target_wrapper.h" #include "lite/utils/all.h" +#include "lite/utils/macros.h" using LiteType = paddle::lite::Type; +class OpKernelInfoCollector { + public: + static OpKernelInfoCollector &Global() { + static auto *x = new OpKernelInfoCollector; + return *x; + } + void AddOp2path(const std::string &op_name, const std::string &op_path) { + size_t index = op_path.find_last_of('/'); + if (index != std::string::npos) { + op2path_.insert(std::pair( + op_name, op_path.substr(index + 1))); + } + } + void AddKernel2path(const std::string &kernel_name, + const std::string &kernel_path) { + size_t index = kernel_path.find_last_of('/'); + if (index != std::string::npos) { + kernel2path_.insert(std::pair( + kernel_name, kernel_path.substr(index + 1))); + } + } + void SetKernel2path( + const std::map &kernel2path_map) { + kernel2path_ = kernel2path_map; + } + const std::map &GetOp2PathDict() { + return op2path_; + } + const std::map &GetKernel2PathDict() { + return kernel2path_; + } + + private: + std::map op2path_; + std::map kernel2path_; +}; + namespace paddle { namespace lite { @@ -56,7 +96,6 @@ class OpLiteRegistor : public Registor { }); }) {} }; - template using KernelRegistryForTarget = Factory, std::unique_ptr>; @@ -67,9 +106,15 @@ class KernelRegistry final { variant *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // @@ -100,12 +145,29 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // @@ -115,6 +177,17 @@ class KernelRegistry final { KernelRegistryForTarget *, // + + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // @@ -150,15 +223,16 @@ class KernelRegistry final { const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { - VLOG(3) << "register for " << TargetToStr(Target) << ":" - << PrecisionToStr(Precision) << "//" - << GetKernelOffset(); using kernel_registor_t = KernelRegistryForTarget; auto &varient = registries_[GetKernelOffset()]; auto *reg = varient.template get(); CHECK(reg) << "Can not be empty of " << name; reg->Register(name, std::move(creator)); +#ifdef LITE_ON_MODEL_OPTIMIZE_TOOL + kernel_info_map_[name].push_back( + std::make_tuple(Target, Precision, Layout)); +#endif // LITE_ON_MODEL_OPTIMIZE_TOOL } template > Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; - return registries_[GetKernelOffset()] - .template get() - ->Creates(op_type); + std::list> kernel_list; + if (registries_[GetKernelOffset()].valid()) { + kernel_list = registries_[GetKernelOffset()] + .template get() + ->Creates(op_type); + } + return kernel_list; } std::list> Create(const std::string &op_type, @@ -190,22 +268,42 @@ class KernelRegistry final { } std::string DebugString() const { +#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL + return "No more debug info"; +#else // LITE_ON_MODEL_OPTIMIZE_TOOL STL::stringstream ss; - ss << "KernelCreator:\n"; - constexpr TargetType tgt = TARGET(kHost); - constexpr PrecisionType dt = PRECISION(kFloat); - constexpr DataLayoutType lt = DATALAYOUT(kNCHW); - constexpr DataLayoutType kany = DATALAYOUT(kAny); - using kernel_registor_t = KernelRegistryForTarget; - auto *reg = registries_[GetKernelOffset()] - .template get(); - ss << reg->DebugString() << "\n"; + ss << "\n"; + ss << "Count of kernel kinds: "; + int count = 0; + for (auto &item : kernel_info_map_) { + count += item.second.size(); + } + ss << count << "\n"; + + ss << "Count of registered kernels: " << kernel_info_map_.size() << "\n"; + for (auto &item : kernel_info_map_) { + ss << "op: " << item.first << "\n"; + for (auto &kernel : item.second) { + ss << " - (" << TargetToStr(std::get<0>(kernel)) << ","; + ss << PrecisionToStr(std::get<1>(kernel)) << ","; + ss << DataLayoutToStr(std::get<2>(kernel)); + ss << ")"; + ss << "\n"; + } + } + return ss.str(); - return ""; +#endif // LITE_ON_MODEL_OPTIMIZE_TOOL } private: mutable std::vector registries_; +#ifndef LITE_ON_TINY_PUBLISH + mutable std::map< + std::string, + std::vector>> + kernel_info_map_; +#endif }; template { public: KernelRegistor(const std::string &op_type, const std::string &alias) : Registor([=] { - VLOG(3) << "Register kernel " << op_type << " for " - << TargetToStr(target) << " " << PrecisionToStr(precision) - << " " << DataLayoutToStr(layout) << " alias " << alias; KernelRegistry::Global().Register( op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); @@ -238,6 +333,7 @@ class KernelRegistor : public lite::Registor { static paddle::lite::OpLiteRegistor LITE_OP_REGISTER_INSTANCE( \ op_type__)(#op_type__); \ int touch_op_##op_type__() { \ + OpKernelInfoCollector::Global().AddOp2path(#op_type__, __FILE__); \ return LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); \ } @@ -246,7 +342,8 @@ class KernelRegistor : public lite::Registor { op_type__##__##target__##__##precision__##__registor__ #define LITE_KERNEL_REGISTER_INSTANCE( \ op_type__, target__, precision__, layout__, alias__) \ - op_type__##__##target__##__##precision__##__registor__instance__##alias__ + op_type__##__##target__##__##precision__##__##layout__##registor__instance__##alias__ // NOLINT + #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) @@ -262,6 +359,9 @@ class KernelRegistor : public lite::Registor { static KernelClass LITE_KERNEL_INSTANCE( \ op_type__, target__, precision__, layout__, alias__); \ int touch_##op_type__##target__##precision__##layout__##alias__() { \ + OpKernelInfoCollector::Global().AddKernel2path( \ + #op_type__ "," #target__ "," #precision__ "," #layout__ "," #alias__, \ + __FILE__); \ LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \ .Touch(); \ return 0; \ diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index f1b92e06106a9a5d45d3ed16081beb9f7807f253..739615e2763f509f2dec97f5ab3e536aca7acc4f 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -18,6 +18,7 @@ #include #include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/pass_manager.h" +#include "lite/core/mir/pass_utils.h" #include "lite/core/mir/ssa_graph.h" #include "lite/core/mir/static_kernel_pick_pass.h" #include "lite/core/mir/type_target_cast_pass.h" @@ -27,6 +28,9 @@ #ifdef LITE_WITH_NPU #include "lite/core/mir/subgraph/generate_npu_program_pass.h" #endif +#ifdef LITE_WITH_XPU +#include "lite/core/mir/subgraph/generate_xpu_program_pass.h" +#endif namespace paddle { namespace lite { @@ -67,19 +71,28 @@ class Optimizer { "lite_fc_fuse_pass", // "lite_shuffle_channel_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", // + "lite_interpolate_fuse_pass", // "identity_scale_eliminate_pass", // #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "lite_elementwise_add_activation_fuse_pass", // #endif - "static_kernel_pick_pass", // + "static_kernel_pick_pass", // pick original kernel from graph + "variable_place_inference_pass", // inference arg/var's + // info(target/precision/layout/device) + // using kernel info + "argument_type_display_pass", // debug pass: show arg-type-node's + // info + // (target/precision/layout/device) + + "type_target_cast_pass", // add io_copy/io_copy_once if meet + // different targets when last and next + // node "variable_place_inference_pass", // "argument_type_display_pass", // - "type_target_cast_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // @@ -87,38 +100,52 @@ class Optimizer { "variable_place_inference_pass", // "argument_type_display_pass", // - "type_layout_cast_pass", // + "type_layout_cast_pass", // add layout/layout_once op if meet + // different layout when last and next node + "argument_type_display_pass", // + "variable_place_inference_pass", // "argument_type_display_pass", // "runtime_context_assign_pass", - "graph_visualze"}}); + "argument_type_display_pass", // +#if !defined(LITE_WITH_OPENCL) && !defined(LITE_WITH_NPU) && \ + !defined(LITE_WITH_XPU) + // TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in kernel + "memory_optimize_pass", +#endif + "argument_type_display_pass"}}); } else { RunPasses(passes); } exec_scope_ = program.exec_scope(); } - void KernelPickPreferPlace(const Place& place) { - auto* pass = mir::PassManager::Global().LookUp( - "static_kernel_pick_pass"); - CHECK(pass); - pass->SetPreferPlace(place); - } - const lite::Scope* exec_scope() const { return exec_scope_; } // Generate a new program based on the mir graph. std::unique_ptr GenRuntimeProgram() { +#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU) + auto target_place = Place{ #ifdef LITE_WITH_NPU - if (std::find(valid_places_.begin(), - valid_places_.end(), - Place{TARGET(kNPU), PRECISION(kFloat)}) != + TARGET(kNPU), +#endif +#ifdef LITE_WITH_XPU + TARGET(kXPU), +#endif + PRECISION(kFloat)}; + if (std::find(valid_places_.begin(), valid_places_.end(), target_place) != valid_places_.end()) { - CheckInputDimsNotEmpty(exec_scope_); +#ifdef LITE_WITH_NPU auto pass = mir::PassManager::Global() .LookUp( "generate_npu_program_pass"); +#endif +#ifdef LITE_WITH_XPU + auto pass = mir::PassManager::Global() + .LookUp( + "generate_xpu_program_pass"); +#endif try { pass->Apply(graph_); auto program = pass->GenProgram(); @@ -126,7 +153,8 @@ class Optimizer { program->set_exec_scope(exec_scope_); return program; } catch (...) { - LOG(WARNING) << "Build NPU graph failed"; + LOG(WARNING) << "Build " << TargetToStr(target_place.target) + << " program failed!"; } } #endif @@ -139,19 +167,6 @@ class Optimizer { return program; } - // check the input dims in the scope, must not be empty - void CheckInputDimsNotEmpty(const lite::Scope* scope) { - CHECK(scope); - auto* feed_var = scope->FindVar("feed"); - CHECK(feed_var) << "no feed variable in exec_scope: " << scope; - auto* feed_tensor_list = feed_var->GetMutable>(); - CHECK_GE(feed_tensor_list->size(), 1); - for (size_t i = 0; i < feed_tensor_list->size(); ++i) { - CHECK(!feed_tensor_list->at(i).dims().empty()) - << "Input " << i << " dims can not be empty."; - } - } - void InitTargetTypeTransformPass() { auto* pass = mir::PassManager::Global().LookUp( @@ -182,11 +197,23 @@ class Optimizer { // Specify the passes and run them. void RunPasses(const std::vector& passes) { for (auto& x : passes) { - LOG(INFO) << "== Running pass " << x; - auto* pass = mir::PassManager::Global().LookUp(x); + LOG(INFO) << "== Running pass: " << x; + mir::Pass* pass = mir::PassManager::Global().LookUp(x); CHECK(pass) << "Can not find pass: " << x; - pass->Apply(graph_); - LOG(INFO) << "== Running pass Done." << x; + bool matched = false; + for (const auto& place : valid_places_) { + if (PassMatchesTarget(*pass, place.target)) { + matched = true; + } + } + matched = matched && PassMatchesKernels(*pass); + if (!matched) { + LOG(INFO) << " - Skip " << x + << " because the target or kernel does not match."; + } else { + pass->Apply(graph_); + LOG(INFO) << "== Finished running: " << x; + } } } diff --git a/lite/core/profile/CMakeLists.txt b/lite/core/profile/CMakeLists.txt index de8a60bdc27e30b47a10e85acab95a5fb418d095..54a239024413834cb30c6e135c378d10480863e7 100644 --- a/lite/core/profile/CMakeLists.txt +++ b/lite/core/profile/CMakeLists.txt @@ -2,7 +2,7 @@ if (NOT LITE_WITH_PROFILE) return() endif() -lite_cc_library(basic_profiler SRCS basic_profiler.cc) +lite_cc_library(basic_profiler SRCS basic_profiler.cc DEPS gflags) lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler) diff --git a/lite/core/profile/basic_profiler.cc b/lite/core/profile/basic_profiler.cc index 031b86beb6b2008bc7299e233b98c1a12ac7286b..a947bfa295658d720a448f2376dfe26c507c3da2 100644 --- a/lite/core/profile/basic_profiler.cc +++ b/lite/core/profile/basic_profiler.cc @@ -13,14 +13,226 @@ // limitations under the License. #include "lite/core/profile/basic_profiler.h" +#include +#include + +DEFINE_string(time_profile_file, + "time_profile.txt", + "Lite time profile information dump file"); + +DEFINE_string(time_profile_summary_file, + "time_profile_summary.txt", + "Lite time profile summary information dump file"); + +DEFINE_string(time_profile_unit, + "ms", + "Unit of time in profile infomation, ms or us"); namespace paddle { namespace lite { namespace profile { +static std::string GetTimeUnit() { + auto time_unit = FLAGS_time_profile_unit; + if (time_unit != "ms" && time_unit != "us") { + LOG(FATAL) << "Profile time unit only support ms or us now"; + } + return time_unit; +} + const int BasicTimer::data_w = 10; const int BasicTimer::name_w = 15; +void BasicTimer::Start(const std::string& timer_key) { + TimerInfo& timer_info = timer_infos_[timer_key]; + timer_info.timer_ = static_cast( + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +void BasicTimer::Stop(const std::string& timer_key) { + if (timer_infos_.find(timer_key) == timer_infos_.end()) { + LOG(FATAL) << "Error: Can't found timer key [" << timer_key << "] for " + << key_; + } + TimerInfo& timer_info = timer_infos_[timer_key]; + auto duration = static_cast< + uint64_t>( // timer unit: microsecond, 1second = 1e6 microsecond + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count() - + timer_info.timer_); + Log(&timer_info, duration); +} + +void BasicTimer::SetCustomInfo(const std::string& key, + const std::string& value) { + if (custom_infos_.find(key) != custom_infos_.end()) { + LOG(FATAL) << "Error: Custom Info for key [" << key + << "] can't be overwritten"; + } + custom_infos_[key] = value; +} + +std::string BasicTimer::GetCustomInfo(const std::string& key) const { + auto iter = custom_infos_.find(key); + if (iter == custom_infos_.end()) { + LOG(FATAL) << "Error: Custom Info for key [" << key << "] can't be found"; + } + return iter->second; +} + +const TimerInfo& BasicTimer::GetTimerInfo(const std::string& key) const { + auto iter = timer_infos_.find(key); + if (iter == timer_infos_.end()) { + LOG(FATAL) << "Error: Timer Info for key [" << key << "] can't be found"; + } + return iter->second; +} + +void BasicTimer::SetWarmup(int warmup_times) { + CHECK_GE(warmup_times, 0) << "warmup times must >= 0"; + warmup_ = warmup_times; +} + +void BasicTimer::Log(TimerInfo* timer_info, uint64_t timespan) { + if (warmup_ > 0) { + --warmup_; + return; + } + CHECK(timer_info); + timer_info->count_++; + timer_info->total_ += timespan; + timer_info->max_ = std::max(timer_info->max_, timespan); + timer_info->min_ = std::min(timer_info->min_, timespan); +} + +std::string BasicTimer::basic_repr_header() { + auto time_unit = GetTimeUnit(); + STL::stringstream ss; + // clang-format off + ss << "op" << "\t" + << "kernel" << "\t" + << "k_average(" << time_unit << ")\t" + << "k_min(" << time_unit << ")\t" + << "k_max(" << time_unit << ")\t" + << "i_average(" << time_unit << ")\t" + << "i_min(" << time_unit << ")\t" + << "i_max(" << time_unit << ")\t" + << "count" << "\t" + << "op_info"; + // clang-format on + return ss.str(); +} + +std::string BasicTimer::basic_repr() const { + auto& kernel_timer_info = GetTimerInfo("kernel"); + auto& inst_timer_info = GetTimerInfo("instruction"); + float time_unit_factor = 1.; + if (GetTimeUnit() == "ms") { + time_unit_factor = 1000.; + } + STL::stringstream ss; + // clang-format off + ss << GetCustomInfo("op_type") << "\t" + << key() << "\t" + << kernel_timer_info.ave() / time_unit_factor << "\t" + << kernel_timer_info.min() / time_unit_factor << "\t" + << kernel_timer_info.max() / time_unit_factor << "\t" + << inst_timer_info.ave() / time_unit_factor << "\t" + << inst_timer_info.min() / time_unit_factor << "\t" + << inst_timer_info.max() / time_unit_factor << "\t" + << inst_timer_info.count() << "\t" + << GetCustomInfo("op_info"); + // clang-format on + return ss.str(); +} + +template class BasicProfiler; + +template +std::string BasicProfiler::summary_repr_header() const { + auto time_unit = GetTimeUnit(); + STL::stringstream ss; + // clang-format off + ss << "op" << "\t" + << "average(" << time_unit << ")\t" + << "min(" << time_unit << ")\t" + << "max(" << time_unit << ")\t" + << "op_time(" << time_unit << ")\t" + << "total_time(" << time_unit << ")\t" + << "precent" << "\t" + << "count"; + // clang-format on + return ss.str(); +} + +template +std::string BasicProfiler::summary_repr() const { + std::map op_summary; + uint64_t total{0}; + + for (const auto& rcd : records_) { + // We use kernel run time here + auto kernel_timer = rcd.GetTimerInfo("kernel"); + auto op_type = rcd.GetCustomInfo("op_type"); + auto& op_timer = op_summary[op_type]; + + total += kernel_timer.total_; + op_timer.total_ += kernel_timer.total_; + op_timer.max_ = std::max(kernel_timer.max_, op_timer.max_); + op_timer.min_ = std::min(kernel_timer.min_, op_timer.min_); + op_timer.count_ += kernel_timer.count_; + } + + float time_unit_factor = 1.; + if (GetTimeUnit() == "ms") { + time_unit_factor = 1000.; + } + STL::stringstream ss; + for (auto& iter : op_summary) { + auto& op_timer = iter.second; + // clang-format off + ss << iter.first << "\t" + << op_timer.ave() / time_unit_factor << "\t" + << op_timer.min() / time_unit_factor << "\t" + << op_timer.max() / time_unit_factor << "\t" + << op_timer.total() / time_unit_factor << "\t" + << total / time_unit_factor << "\t" + << (op_timer.total() * 1. / total * 100) << "%\t" + << op_timer.count() << "\t" + << "\n"; + // clang-format on + } + return ss.str(); +} + +template +BasicProfiler::~BasicProfiler() { + LOG(INFO) << "Basic Profile dumps:"; + auto b_repr = TimerT::basic_repr_header() + "\n" + basic_repr(); + LOG(INFO) << "\n" + b_repr; + + // Dump to file + std::ofstream basic_ostream(FLAGS_time_profile_file); + CHECK(basic_ostream.is_open()) << "Open " << FLAGS_time_profile_file + << " failed"; + basic_ostream.write(b_repr.c_str(), b_repr.size()); + basic_ostream.close(); + + LOG(INFO) << "Summary Profile dumps:"; + auto s_repr = summary_repr_header() + "\n" + summary_repr(); + LOG(INFO) << "\n" + s_repr; + + // Dump to file + std::ofstream summary_ostream(FLAGS_time_profile_summary_file); + CHECK(summary_ostream.is_open()) << "Open " << FLAGS_time_profile_summary_file + << " failed"; + summary_ostream.write(s_repr.c_str(), s_repr.size()); + summary_ostream.close(); +} + } // namespace profile } // namespace lite } // namespace paddle diff --git a/lite/core/profile/basic_profiler.h b/lite/core/profile/basic_profiler.h index f55a5764a0cc25237f8258009dcd19d49e4e4e99..660650655e6fb5035e897f939aac621a784389b0 100644 --- a/lite/core/profile/basic_profiler.h +++ b/lite/core/profile/basic_profiler.h @@ -18,10 +18,13 @@ * of each kernel. */ #pragma once +#include #include #include #include // NOLINT +#include #include +#include #include #include #include @@ -33,35 +36,49 @@ namespace paddle { namespace lite { namespace profile { +struct TimerInfo { + uint64_t total_{0}; + uint64_t count_{0}; + uint64_t max_{std::numeric_limits::min()}; + uint64_t min_{std::numeric_limits::max()}; + uint64_t timer_{0}; + + double ave() const { return total_ * 1. / count_; } + double max() const { return max_; } + double min() const { return min_; } + uint64_t total() const { return total_; } + uint64_t count() const { return count_; } +}; + /* Base class of all the profile records */ template class TimerBase { public: - void Start() { self()->Start(); } - void Stop() { self()->Stop(); } - void Log(uint32_t x) { return self()->Log(x); } + void Start(const std::string& key) { self()->Start(key); } + void Stop(const std::string& key) { self()->Stop(key); } + void Log(TimerInfo* timer_info, uint64_t x) { + return self()->Log(timer_info, x); + } std::string basic_repr() const { return const_self()->basic_repr(); } void SetId(int id) { self()->SetId(id); } - void SetKey(const std::string &key) { self()->SetKey(key); } + void SetKey(const std::string& key) { self()->SetKey(key); } int id() const { return const_self()->id(); } protected: - ChildT *self() { return reinterpret_cast(this); } - const ChildT *const_self() const { - return reinterpret_cast(this); + ChildT* self() { return reinterpret_cast(this); } + const ChildT* const_self() const { + return reinterpret_cast(this); } }; class BasicTimer : TimerBase { - uint64_t total_{}; - uint64_t count_{}; - uint32_t max_{std::numeric_limits::min()}; - uint32_t min_{std::numeric_limits::max()}; int id_{-1}; + int warmup_{0}; std::string key_; - uint64_t timer_{}; + std::map timer_infos_; + std::map custom_infos_; // TODO(Superjomn) make static static const int name_w; @@ -69,68 +86,35 @@ class BasicTimer : TimerBase { public: BasicTimer() = default; - BasicTimer(int id, const std::string &key) : id_(id), key_(key) {} + BasicTimer(int id, const std::string& key) : id_(id), key_(key) {} void SetId(int id) { id_ = id; } - void SetKey(const std::string &key) { key_ = key; } - void Start() { - timer_ = static_cast( - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } - void Stop() { - auto duration = static_cast< - uint64_t>( // timer unit: microsecond, 1second = 1e6 microsecond - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count() - - timer_); - Log(duration); - } - int count() const { return count_; } - - void Log(uint32_t timespan) { - total_ += timespan; - max_ = std::max(max_, timespan); - min_ = std::min(min_, timespan); - count_++; + int id() const { + CHECK_GE(id_, 0) << "id is not inited"; + return id_; } - static std::string basic_repr_header() { - STL::stringstream ss; - ss << std::setw(name_w) << "kernel" // - << std::setw(data_w) << "average" // - << std::setw(data_w) << "min" // - << std::setw(data_w) << "max" // - << std::setw(data_w) << "count"; - return ss.str(); - } + void SetKey(const std::string& key) { key_ = key; } + const std::string& key() const { return key_; } - std::string basic_repr() const { - STL::stringstream ss; - ss << std::setw(name_w) << key() // - << std::setw(data_w) << ave() // - << std::setw(data_w) << min() // - << std::setw(data_w) << max() // - << std::setw(data_w) << count_; - return ss.str(); - } + void Start(const std::string& timer_key); + void Stop(const std::string& timer_key); - const std::string &key() const { return key_; } + void Log(TimerInfo* timer_info, uint64_t timespan); - int id() const { - CHECK_GE(id_, 0) << "id is not inited"; - return id_; - } + void SetCustomInfo(const std::string& key, const std::string& value); + std::string GetCustomInfo(const std::string& key) const; - double ave() const { return total_ * 1. / count_; } - double max() const { return max_; } - double min() const { return min_; } + const TimerInfo& GetTimerInfo(const std::string& key) const; + + static std::string basic_repr_header(); + std::string basic_repr() const; // BasicRecord(const BasicRecord &) = delete; - void operator=(const BasicTimer &) = delete; + void operator=(const BasicTimer&) = delete; + + void SetWarmup(int warmup_times); }; /* @@ -139,28 +123,29 @@ class BasicTimer : TimerBase { template class BasicProfiler { public: - explicit BasicProfiler(const std::string &name) : name_(name) {} + explicit BasicProfiler(const std::string& name) : name_(name) {} using record_t = TimerT; - static BasicProfiler &Global() { + static BasicProfiler& Global() { static std::unique_ptr x(new BasicProfiler("[global]")); return *x; } - record_t &NewRcd(const std::string &key) { + record_t& NewRcd(const std::string& key) { records_.emplace_back(); records_.back().SetId(records_.size() - 1); records_.back().SetKey(key); + records_.back().SetWarmup(warmup_); return records_.back(); } - const record_t &record(int id) { + const record_t& record(int id) { CHECK_LT(id, records_.size()); CHECK_GE(id, 0); return records_[id]; } - record_t *mutable_record(int id) { + record_t* mutable_record(int id) { CHECK_GE(id, 0); CHECK_LT(static_cast(id), records_.size()); return &records_[id]; @@ -168,42 +153,65 @@ class BasicProfiler { std::string basic_repr() const { STL::stringstream ss; - for (const auto &rcd : records_) { + for (const auto& rcd : records_) { ss << rcd.basic_repr() << "\n"; } return ss.str(); } - ~BasicProfiler() { - LOG(INFO) << "Profile dumps:"; - LOG(INFO) << "\n" + BasicTimer::basic_repr_header() + "\n" + basic_repr(); + std::string summary_repr_header() const; + std::string summary_repr() const; + + void SetWarmup(int warmup_times) { + CHECK_GE(warmup_times, 0) << "warmup times must >= 0"; + // Instruction and kernel share the common BasicTimer instance, so the + // warmup count + // will be decrease twice when instruction execute once + // TODO(sangoly): fix the ugly code. + warmup_ = warmup_times * 2; } + ~BasicProfiler(); + private: std::string name_; std::vector records_; + int warmup_{0}; }; struct ProfileBlock { - explicit ProfileBlock(int id) : id_(id) { - BasicProfiler::Global().mutable_record(id_)->Start(); + explicit ProfileBlock(int id, const std::string& key) : id_(id), key_(key) { + BasicProfiler::Global().mutable_record(id_)->Start(key_); + } + + void Record() { + if (has_recorded_) { + LOG(FATAL) << "You can only call Record() once"; + } + BasicProfiler::Global().mutable_record(id_)->Stop(key_); + has_recorded_ = true; } ~ProfileBlock() { - BasicProfiler::Global().mutable_record(id_)->Stop(); + if (!has_recorded_) { + BasicProfiler::Global().mutable_record(id_)->Stop(key_); + } } private: int id_{}; + bool has_recorded_{false}; + std::string key_{}; }; -#define LITE_PROFILE_ONE(key__) \ - static int key__##__profiler_id = \ - ::paddle::lite::profile::BasicProfiler< \ - ::paddle::lite::profile::BasicTimer>::Global() \ - .NewRcd(#key__) \ - .id(); \ - ::paddle::lite::profile::ProfileBlock key__##profiler__(key__##__profiler_id); +#define LITE_PROFILE_ONE(key__) \ + static int key__##__profiler_id = \ + ::paddle::lite::profile::BasicProfiler< \ + ::paddle::lite::profile::BasicTimer>::Global() \ + .NewRcd(#key__) \ + .id(); \ + ::paddle::lite::profile::ProfileBlock key__##profiler__( \ + key__##__profiler_id, #key__); } // namespace profile } // namespace lite diff --git a/lite/core/profile/basic_profiler_test.cc b/lite/core/profile/basic_profiler_test.cc index 928fdd61cb9ef9e688b5f9bfe34658fdbe26f255..e61d3383d07a69d4bd29659842d515d53ad8d37b 100644 --- a/lite/core/profile/basic_profiler_test.cc +++ b/lite/core/profile/basic_profiler_test.cc @@ -27,18 +27,21 @@ TEST(basic_record, init) { timer.SetKey("hello"); } -TEST(basic_profile, init) { - auto& rcd = BasicProfiler::Global().NewRcd("fc"); - for (int i = 11; i < 100; i++) { - rcd.Log(i); - } +TEST(basic_profile, real_latency) { + auto profile_id = profile::BasicProfiler::Global() + .NewRcd("test0") + .id(); + auto& profiler = + *BasicProfiler::Global().mutable_record(profile_id); + // Set op info + profiler.SetCustomInfo("op_type", "fc"); + profiler.SetCustomInfo("op_info", "size:5x6"); - LOG(INFO) << BasicProfiler::Global().basic_repr(); -} + profile::ProfileBlock x(profile_id, "instruction"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); -TEST(basic_profile, real_latency) { - LITE_PROFILE_ONE(test0); - std::this_thread::sleep_for(std::chrono::milliseconds(1200)); + profile::ProfileBlock y(profile_id, "kernel"); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); } } // namespace profile diff --git a/lite/core/program.cc b/lite/core/program.cc index 179cdf909a6ad488fc1b487e89b1b6808c8a4c5a..b60f279c0fc74904477a080579a799f601e359b0 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -113,9 +113,8 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { void RuntimeProgram::Run() { for (auto& inst : instructions_) { - VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr() - << " on Target " << TargetToStr(inst.kernel()->target()); - + std::string op_type = inst.op()->op_info()->Type(); + if (op_type == "feed" || op_type == "fetch") continue; inst.Run(); #ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE @@ -140,7 +139,7 @@ void Program::Build(const cpp::ProgramDesc& prog) { auto op = LiteOpRegistry::Global().Create(op_type); CHECK(op) << "no Op found for " << op_type; if (op_type == "while") { - auto sub_block_idx = op_desc.GetAttr("sub_block"); + auto sub_block_idx = op_desc.GetAttr("sub_block"); auto sub_block = const_cast(prog).GetBlock( sub_block_idx); @@ -182,19 +181,29 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { } void Instruction::Run() { -#ifdef LITE_WITH_PROFILE - profile::ProfileBlock x(profile_id_); -#endif // LITE_WITH_PROFILE CHECK(op_) << "op null"; CHECK(kernel_) << "kernel null"; +#ifdef LITE_WITH_PROFILE + if (profile_id_ >= 0) { + profile::ProfileBlock x(profile_id_, "instruction"); + } +#endif // LITE_WITH_PROFILE if (first_epoch_) { first_epoch_ = false; CHECK(op_->CheckShape()); } - if (op_->run_once() && has_run_) return; + if (op_->run_once() && has_run_) { + return; + } +#ifndef LITE_SHUTDOWN_LOG VLOG(4) << "kernel launch"; +#endif op_->InferShape(); +#ifndef LITE_SHUTDOWN_LOG + VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target " + << TargetToStr(kernel_->target()); +#endif kernel_->Launch(); has_run_ = true; } diff --git a/lite/core/program.h b/lite/core/program.h index 1b3c036db502e5dae00c0e6baed00047f2e2458d..7a6700da61f7ba9f35491613d7733b4b637b8ff0 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -89,9 +89,18 @@ struct Instruction { std::unique_ptr&& kernel) : op_(op), kernel_(std::move(kernel)) { #ifdef LITE_WITH_PROFILE - profile_id_ = profile::BasicProfiler::Global() - .NewRcd(kernel_->SerializedKernelType()) - .id(); + if (op_->Type() != "feed" && op_->Type() != "fetch") { + profile_id_ = profile::BasicProfiler::Global() + .NewRcd(kernel_->SerializedKernelType()) + .id(); + kernel_->SetProfileID(profile_id_); + // Set profile custom info + auto& profiler = + *profile::BasicProfiler::Global().mutable_record( + profile_id_); + profiler.SetCustomInfo("op_type", op_->Type()); + profiler.SetCustomInfo("op_info", op_->SerializedOpInfo()); + } #endif // LITE_WITH_PROFILE } diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index 4dd4f5319d6238e202fdfd93ef0c2de4b45de291..1c7db871c7b525d6e4944fd0d669e81bcaff7f2a 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -79,6 +79,14 @@ void TensorLite::ShareDataWith(const TensorLite &other) { memory_size_ = other.memory_size_; } +void TensorLite::CopyDataFrom(const TensorLite &other) { + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + memory_size_ = other.memory_size_; + buffer_->CopyDataFrom(*other.buffer_, memory_size_); +} + void *TensorLite::mutable_data(size_t memory_size) { memory_size_ = memory_size; buffer_->ResetLazy(target_, memory_size_); @@ -90,26 +98,15 @@ void *TensorLite::mutable_data(TargetType target, size_t memory_size) { return mutable_data(memory_size); } -void TensorLite::CopyDataFrom(const TensorLite &other) { - dims_ = other.dims_; - target_ = other.target_; - lod_ = other.lod_; - memory_size_ = other.memory_size_; - buffer_->CopyDataFrom(*other.buffer_, memory_size_); +#ifdef LITE_WITH_OPENCL +template <> +const cl::Image2D *TensorLite::data() const { + if (nullptr == buffer_->data()) return nullptr; + return static_cast(buffer_->data()); } - -// static LoD TensorLite::ToAbsOffset(const LoD &lod) { -// if (lod.empty() || lod.size() == 1) return lod; -// LoD ret = lod; -// for (int level = static_cast(lod.size()) - 2; level >= 0; --level) { -// for (size_t i = 0; i < lod[level].size(); ++i) { -// size_t index = lod[level][i]; -// result[level][i] = result[level + 1][index]; -// } -// } -//} +#endif } // namespace lite } // namespace paddle -#endif +#endif // #ifndef LITE_WITH_FPGA diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 205e586ab33545e685dbe8083dfda9f4dffdb13d..8c4fe1604a517332e52b243404828e81af26f419 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -138,14 +138,34 @@ class TensorLite { // and the data type can be float/int8_t. // For other devices, T and R may be the same type. template - R *mutable_data(); + R *mutable_data() { + memory_size_ = dims_.production() * sizeof(T); + buffer_->ResetLazy(target_, memory_size_); + return reinterpret_cast(static_cast(buffer_->data()) + + offset_); + } + +#ifdef LITE_WITH_OPENCL + template + R *mutable_data(const size_t img_w, const size_t img_h) { + target_ = TARGET(kOpenCL); + buffer_->ResetLazyImage2D(target_, img_w, img_h); + return static_cast(buffer_->data()); + } +#endif // T is the data type and R is the return type // For OpenCL, the return type can be cl::Buffer // and the data type can be float/int8_t. // For other devices, T and R may be the same type. template - R *mutable_data(TargetType target); + R *mutable_data(TargetType target) { + target_ = target; + memory_size_ = dims_.production() * sizeof(T); + buffer_->ResetLazy(target, memory_size()); + return reinterpret_cast(static_cast(buffer_->data()) + + offset_); + } void *mutable_data(size_t memory_size); void *mutable_data(TargetType target, size_t memory_size); @@ -201,33 +221,24 @@ class TensorLite { size_t offset_{0}; }; -template -R *TensorLite::mutable_data() { - memory_size_ = dims_.production() * sizeof(T); - buffer_->ResetLazy(target_, memory_size_); - return reinterpret_cast(static_cast(buffer_->data()) + offset_); -} - -template -R *TensorLite::mutable_data(TargetType target) { - target_ = target; - memory_size_ = dims_.production() * sizeof(T); - buffer_->ResetLazy(target, memory_size()); - return reinterpret_cast(static_cast(buffer_->data()) + offset_); -} - template TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { - int64_t base = numel() / dims_[0]; - - TensorLite dst; - dst.buffer_ = buffer_; - dst.target_ = target_; - auto dst_dims = dims_; - dst_dims[0] = end - begin; - dst.Resize(dst_dims); - dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); - return dst; + CHECK_GE(begin, 0); + CHECK_LE(end, dims_[0]); + CHECK_LT(begin, end); + if (dims_[0] == 1) { + return *this; + } else { + int64_t base = numel() / dims_[0]; + TensorLite dst; + dst.buffer_ = buffer_; + dst.target_ = target_; + auto dst_dims = dims_; + dst_dims[0] = end - begin; + dst.Resize(dst_dims); + dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); + return dst; + } } template @@ -237,7 +248,12 @@ bool TensorCompareWith(const TensorT &a, const TensorT &b) { return true; } +#ifdef LITE_WITH_OPENCL +template <> +const cl::Image2D *TensorLite::data() const; +#endif + } // namespace lite } // namespace paddle -#endif +#endif // #ifndef LITE_WITH_FPGA diff --git a/lite/core/type_system.h b/lite/core/type_system.h index 722cdca0eb1bdbe338647f23aa12eab61ce99c5a..aeddf965c3b999750c7cca3595cc9f669b32d50e 100644 --- a/lite/core/type_system.h +++ b/lite/core/type_system.h @@ -27,6 +27,7 @@ #include #include #include "lite/core/tensor.h" +#include "lite/core/version.h" #include "lite/utils/all.h" namespace paddle { @@ -280,7 +281,7 @@ struct ParamTypeRecorder { */ class ParamTypeRegistry { public: - enum class IO : int { kInput = 0, kOutput }; + enum class IO : int { kInvalid = 0, kInput, kOutput }; template types_; + std::map versions_; }; } // namespace lite diff --git a/lite/core/types.cc b/lite/core/types.cc index ec89e83e5808fb85803adea0555c76b7e424424c..4ea383333d519ac2c481dce459ca49124a64df32 100644 --- a/lite/core/types.cc +++ b/lite/core/types.cc @@ -82,6 +82,10 @@ Type StdTypeToRepr() { return Type::_float64; } template <> +Type StdTypeToRepr>() { + return Type::_char_list; +} +template <> Type StdTypeToRepr() { return Type::_string; } diff --git a/lite/core/types.h b/lite/core/types.h index 0664aba6b6fbe2be89cbb2b0a0ad46497c3a5f3c..8f154f9dd509d3627750ecbf301923a2296252d1 100644 --- a/lite/core/types.h +++ b/lite/core/types.h @@ -16,6 +16,7 @@ #include #include +#include #include "lite/api/paddle_place.h" #include "lite/utils/all.h" @@ -36,7 +37,9 @@ enum class Type { _float64, _bool, _string, - // primary list types + // primary list type + _char_list, + // list types _list, // enum type _enum, @@ -45,6 +48,37 @@ enum class Type { __num__, }; +enum class FluidType { + // Pod Types + BOOL = 0, + INT16 = 1, + INT32 = 2, + INT64 = 3, + FP16 = 4, + FP32 = 5, + FP64 = 6, + // Tensor is used in C++. + SIZE_T = 19, + UINT8 = 20, + INT8 = 21, + + // Other types that may need additional descriptions + LOD_TENSOR = 7, + SELECTED_ROWS = 8, + FEED_MINIBATCH = 9, + FETCH_LIST = 10, + STEP_SCOPES = 11, + LOD_RANK_TABLE = 12, + LOD_TENSOR_ARRAY = 13, + PLACE_LIST = 14, + READER = 15, + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW = 17, + TUPLE = 18, +}; + template Type StdTypeToRepr() { return Type::_unk; @@ -58,6 +92,8 @@ Type StdTypeToRepr(); template <> Type StdTypeToRepr(); template <> +Type StdTypeToRepr>(); +template <> Type StdTypeToRepr(); // Factors that impact the kernel picking strategy. Multiple factors can be diff --git a/lite/core/version.h.in b/lite/core/version.h.in new file mode 100644 index 0000000000000000000000000000000000000000..3082adc5abecb20f5ce19032177fc7cdb75299ff --- /dev/null +++ b/lite/core/version.h.in @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { + +static constexpr int MAJOR_COEFF = 1000000; +static constexpr int MINOR_COEFF = 1000; +static constexpr int PATCH_COEFF = 1; + +static std::string paddlelite_commit() { + return "@PADDLE_LITE_COMMIT@"; +} + +static std::string paddlelite_branch() { + return "@PADDLE_LITE_BRANCH@"; +} + +static std::string paddlelite_tag() { + return "@PADDLE_LITE_TAG@"; +} + +static std::string version() { + STL::stringstream ss; + + std::string tag = paddlelite_tag(); + if (tag.empty()) { + ss << paddlelite_branch() << "(" << paddlelite_commit() << ")"; + } else { + ss << tag; + } + + return ss.str(); +} + +static int64_t int_version(const std::string& version) { + const std::vector vec = Split(version, "."); + if (vec.size() == 3) { + return std::stoi(vec[0]) * MAJOR_COEFF + + std::stoi(vec[1]) * MINOR_COEFF + + std::stoi(vec[2]) * PATCH_COEFF; + } + return -1; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/demo/cxx/Makefile.def b/lite/demo/cxx/Makefile.def index f0a0ec1dcb13ff509b718223b1cfd4f937c94ad7..1b5da970e8fa9b2793f7a4982d5ed22ed21e79fd 100644 --- a/lite/demo/cxx/Makefile.def +++ b/lite/demo/cxx/Makefile.def @@ -15,7 +15,7 @@ THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include ifeq ($(ARM_ABI), arm8) CC = /opt/android-ndk-r17c/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/bin/aarch64-linux-android-g++ - CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=22 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE + CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--gc-sections SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-24/arch-arm64 SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++_static.a \ @@ -24,9 +24,9 @@ ifeq ($(ARM_ABI), arm8) else CC = /opt/android-ndk-r17c/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin/arm-linux-androideabi-g++ CXX_FLAGS = -march=armv7-a -mthumb -mfpu=neon -mfloat-abi=softfp -funwind-tables -no-canonical-prefixes \ - -D__ANDROID_API__=22 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE + -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--fix-cortex-a8 -Wl,--gc-sections -Wl,-z,nocopyreloc - SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-22/arch-arm + SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-23/arch-arm SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++_static.a \ /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++abi.a \ /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libandroid_support.a \ diff --git a/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc b/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc index 18167e3ca115cbe994882951be909c1d30482e74..5ac041b2cc53e8f17ad86a2b71e6b02058b7e249 100644 --- a/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc +++ b/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc @@ -38,10 +38,8 @@ void RunModel() { config.set_model_dir(FLAGS_model_dir); std::vector valid_places{Place{TARGET(kARM), PRECISION(kFloat)}}; if (FLAGS_prefer_int8_kernel) { - valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)}); - config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)}); - } else { - config.set_preferred_place(Place{TARGET(kARM), PRECISION(kFloat)}); + valid_places.insert(valid_places.begin(), + Place{TARGET(kARM), PRECISION(kInt8)}); } config.set_valid_places(valid_places); diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java b/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java index e8eb01bd5574508d22b8b47fe014b315ecbe9b2c..84bebe6f2a209257c9056ce79bd65f0a3317a034 100644 --- a/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java @@ -20,9 +20,13 @@ public class MainActivity extends AppCompatActivity { setContentView(R.layout.activity_main); String textOutput = ""; + + String version = getVersionInfo("lite_naive_model_opt.nb", this); + textOutput += "Version: " + version + "\n"; + Tensor output; output = setInputAndRunNaiveModel("lite_naive_model_opt.nb", this); - textOutput += "lite_naive_model output: " + output.getFloatData()[0] + ", " + textOutput += "\nlite_naive_model output: " + output.getFloatData()[0] + ", " + output.getFloatData()[1] + "\n"; textOutput += "expected: 50.2132, -28.8729\n"; @@ -54,6 +58,14 @@ public class MainActivity extends AppCompatActivity { textView.setText(textOutput); } + public static String getVersionInfo(String modelName, Context context) { + String modelPath = copyFromAssetsToCache(modelName, context); + MobileConfig config = new MobileConfig(); + config.setModelDir(modelPath); + PaddlePredictor predictor = PaddlePredictor.createPaddlePredictor(config); + return predictor.getVersion(); + } + public static String copyFromAssetsToCache(String modelPath, Context context) { String newPath = context.getCacheDir() + "/" + modelPath; // String newPath = "/sdcard/" + modelPath; diff --git a/lite/demo/python/mobilenetv1_full_api.py b/lite/demo/python/mobilenetv1_full_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a31469e3e8da81f3753dc5d241d4ef39ac03832f --- /dev/null +++ b/lite/demo/python/mobilenetv1_full_api.py @@ -0,0 +1,67 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Paddle-Lite full python api demo +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +sys.path.append('../../python/lib') + +from lite_core import * + +# Command arguments +parser = argparse.ArgumentParser() +parser.add_argument( + "--model_dir", default="", type=str, help="Non-combined Model dir path") +parser.add_argument( + "--model_file", default="", type=str, help="Model file") +parser.add_argument( + "--param_file", default="", type=str, help="Combined model param file") + +def RunModel(args): + # 1. Set config information + config = CxxConfig() + if args.model_file != '' and args.param_file != '': + config.set_model_file(args.model_file) + config.set_param_file(args.param_file) + else: + config.set_model_dir(args.model_dir) + # For x86, you can set places = [Place(TargetType.X86, PrecisionType.FP32)] + places = [Place(TargetType.ARM, PrecisionType.FP32)] + config.set_valid_places(places) + + # 2. Create paddle predictor + predictor = create_paddle_predictor(config) + + # 3. Set input data + input_tensor = predictor.get_input(0) + input_tensor.resize([1, 3, 224, 224]) + input_tensor.set_float_data([1.] * 3 * 224 * 224) + + # 4. Run model + predictor.run() + + # 5. Get output data + output_tensor = predictor.get_output(0) + print(output_tensor.shape()) + print(output_tensor.float_data()[:10]) + +if __name__ == '__main__': + args = parser.parse_args() + RunModel(args) diff --git a/lite/demo/python/mobilenetv1_light_api.py b/lite/demo/python/mobilenetv1_light_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a44427092bae88aa41b3b1d0684cfcf36835b3d2 --- /dev/null +++ b/lite/demo/python/mobilenetv1_light_api.py @@ -0,0 +1,56 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Paddle-Lite light python api demo +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +sys.path.append('../../python/lib') + +from lite_core import * + +# Command arguments +parser = argparse.ArgumentParser() +parser.add_argument( + "--model_dir", default="", type=str, help="Non-combined Model dir path") + +def RunModel(args): + # 1. Set config information + config = MobileConfig() + config.set_model_dir(args.model_dir) + + # 2. Create paddle predictor + predictor = create_paddle_predictor(config) + + # 3. Set input data + input_tensor = predictor.get_input(0) + input_tensor.resize([1, 3, 224, 224]) + input_tensor.set_float_data([1.] * 3 * 224 * 224) + + # 4. Run model + predictor.run() + + # 5. Get output data + output_tensor = predictor.get_output(0) + print(output_tensor.shape()) + print(output_tensor.float_data()[:10]) + +if __name__ == '__main__': + args = parser.parse_args() + RunModel(args) diff --git a/lite/fluid/CMakeLists.txt b/lite/fluid/CMakeLists.txt index 308dcb2c3052a338c52cf888cec789e66cb8e887..ceb1f7d982392cfbf130719fc04cbd2337fad28c 100644 --- a/lite/fluid/CMakeLists.txt +++ b/lite/fluid/CMakeLists.txt @@ -1,4 +1,4 @@ if (LITE_WITH_X86) lite_cc_library(fluid_data_type SRCS data_type.cc DEPS framework_proto eigen3) -# lite_cc_library(selected_rows SRCS selected_rows.cc) +lite_cc_library(selected_rows SRCS selected_rows.cc DEPS tensor model_parser) endif() diff --git a/lite/fluid/data_type.cc b/lite/fluid/data_type.cc index aa8971499fb0f55d523d2a95bdbe64777b689c63..d33a77c4bfcefbc349d453de05dcbb7c27707a19 100644 --- a/lite/fluid/data_type.cc +++ b/lite/fluid/data_type.cc @@ -68,6 +68,7 @@ framework::proto::VarType::Type ToDataType(std::type_index type) { return it->second; } PADDLE_THROW("Not support %s as tensor type", type.name()); + return static_cast(-1); } std::type_index ToTypeIndex(framework::proto::VarType::Type type) { @@ -77,6 +78,7 @@ std::type_index ToTypeIndex(framework::proto::VarType::Type type) { } PADDLE_THROW("Not support framework::proto::VarType::Type(%d) as tensor type", static_cast(type)); + return std::type_index(typeid(void)); } std::string DataTypeToString(const framework::proto::VarType::Type type) { @@ -86,6 +88,7 @@ std::string DataTypeToString(const framework::proto::VarType::Type type) { } PADDLE_THROW("Not support framework::proto::VarType::Type(%d) as tensor type", static_cast(type)); + return std::string(); } size_t SizeOfType(framework::proto::VarType::Type type) { @@ -93,7 +96,8 @@ size_t SizeOfType(framework::proto::VarType::Type type) { if (it != gDataTypeMap().proto_to_size_.end()) { return it->second; } - PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type)); + PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type).c_str()); + return 0; } } // namespace fluid diff --git a/lite/fluid/eigen.h b/lite/fluid/eigen.h index f5d5e4b5e516315b16369be2e2dd9c46281fc3d0..eac5332b53c857b05aacbfa95ee2e4b9fcd98a93 100644 --- a/lite/fluid/eigen.h +++ b/lite/fluid/eigen.h @@ -32,7 +32,7 @@ struct EigenDim { static Type From(const lite::DDim& dims) { PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size"); Type ret; - for (int64_t d = 0; d < dims.size(); d++) { + for (size_t d = 0; d < dims.size(); d++) { ret[d] = dims[d]; } return ret; @@ -118,7 +118,9 @@ struct EigenScalar { using ConstType = Eigen::TensorMap< Eigen::TensorFixedSize, MajorType, IndexType>>; - static Type From(Tensor& tensor) { return Type(tensor.data()); } // NOLINT + static Type From(Tensor* tensor) { + return Type(const_cast(tensor->data())); + } // NOLINT static ConstType From(const Tensor& tensor) { return ConstType(tensor.data()); diff --git a/lite/fluid/for_range.h b/lite/fluid/for_range.h new file mode 100644 index 0000000000000000000000000000000000000000..a51d6c1b18f10a5029db5b0748b7a8978cf23dbc --- /dev/null +++ b/lite/fluid/for_range.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace fluid { + +template +struct ForRange { + ForRange(const lite::Context& dev_ctx, size_t limit); + + template + void operator()(Function func) const; +}; + +template <> +struct ForRange { + ForRange(lite::X86Context& dev_ctx, size_t limit) : limit_(limit) {} + + template + void operator()(Function func) const { + for (size_t i = 0; i < limit_; ++i) { + func(i); + } + } + + size_t limit_; +}; + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/hostdevice.h b/lite/fluid/hostdevice.h new file mode 100644 index 0000000000000000000000000000000000000000..c297d19e93a5b00e8fc5389ca18b4d5264829025 --- /dev/null +++ b/lite/fluid/hostdevice.h @@ -0,0 +1,18 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#define HOSTDEVICE +#define DEVICE +#define HOST diff --git a/lite/fluid/lod.h b/lite/fluid/lod.h index 68068ba1d018ffd4fbd1b52ec3cd382326b7a69f..36386f7eb967f31ec258681fe17222a928aa7b4b 100644 --- a/lite/fluid/lod.h +++ b/lite/fluid/lod.h @@ -21,7 +21,7 @@ namespace lite { namespace fluid { using LoD = std::vector>; -LoD ToAbsOffset(const LoD &in) { +static LoD ToAbsOffset(const LoD &in) { // the lowest level stores relative offsets if (in.empty() || in.size() == 1) return in; LoD result = in; diff --git a/lite/fluid/rw_lock.h b/lite/fluid/rw_lock.h new file mode 100644 index 0000000000000000000000000000000000000000..eb9829425eca9d8bd363a45961302a7f3818e513 --- /dev/null +++ b/lite/fluid/rw_lock.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if !defined(_WIN32) +#include +#else +#include // NOLINT +#endif // !_WIN32 + +#include "lite/utils/paddle_enforce.h" + +namespace paddle { +namespace lite { +namespace fluid { + +#if !defined(_WIN32) +struct RWLock { + RWLock() { pthread_rwlock_init(&lock_, nullptr); } + + ~RWLock() { pthread_rwlock_destroy(&lock_); } + + inline void RDLock() { + PADDLE_ENFORCE_EQ( + pthread_rwlock_rdlock(&lock_), 0, "acquire read lock failed"); + } + + inline void WRLock() { + PADDLE_ENFORCE_EQ( + pthread_rwlock_wrlock(&lock_), 0, "acquire write lock failed"); + } + + inline void UNLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed"); + } + + private: + pthread_rwlock_t lock_; +}; +// TODO(paddle-dev): Support RWLock for WIN32 for correctness. +#else +// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive +// In windows, rw_lock seems like a hack. Use empty object and do nothing. +struct RWLock { + // FIXME(minqiyang): use mutex here to do fake lock + inline void RDLock() { mutex_.lock(); } + + inline void WRLock() { mutex_.lock(); } + + inline void UNLock() { mutex_.unlock(); } + + private: + std::mutex mutex_; +}; +#endif + +class AutoWRLock { + public: + explicit AutoWRLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); } + + ~AutoWRLock() { UnLock(); } + + private: + inline void Lock() { lock_->WRLock(); } + + inline void UnLock() { lock_->UNLock(); } + + private: + RWLock* lock_; +}; + +class AutoRDLock { + public: + explicit AutoRDLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); } + + ~AutoRDLock() { UnLock(); } + + private: + inline void Lock() { lock_->RDLock(); } + + inline void UnLock() { lock_->UNLock(); } + + private: + RWLock* lock_; +}; + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/selected_rows.cc b/lite/fluid/selected_rows.cc new file mode 100644 index 0000000000000000000000000000000000000000..98e9325ca2f8fab3f8aa77a0bb074ae5d1be7670 --- /dev/null +++ b/lite/fluid/selected_rows.cc @@ -0,0 +1,247 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/fluid/selected_rows.h" +namespace paddle { +namespace lite { +namespace fluid { + +struct ReAllocateVisitor { + ReAllocateVisitor(const lite::DDim& dims, lite::Tensor* tensor) + : dims_(dims), tensor_(tensor) {} + + template + void operator()() const { + lite::Tensor cpu_tensor; + T* ptr = cpu_tensor.mutable_data(lite::TargetType::kX86, dims_); + const T* old_ptr = + tensor_->memory_size() == 0 ? nullptr : tensor_->mutable_data(); + if (old_ptr != nullptr) { + std::copy(old_ptr, old_ptr + tensor_->numel(), ptr); + } + tensor_->ShareDataWith(cpu_tensor); + } + + lite::DDim dims_; + lite::Tensor* tensor_; +}; + +struct TensorCopyVisitor { + TensorCopyVisitor(lite::Tensor* dst, + int64_t dst_offset, + const lite::Tensor src, + int64_t src_offset, + int64_t size) + : dst_(dst), + dst_offset_(dst_offset), + src_(src), + src_offset_(src_offset), + size_(size) {} + + template + void apply() const { + // TODO(Yancey1989): support other place + std::copy_n(src_.data() + src_offset_, + size_, + dst_->mutable_data(lite::TargetType::kX86) + dst_offset_); + } + + lite::Tensor* dst_; + int64_t dst_offset_; + lite::Tensor src_; + int64_t src_offset_; + int64_t size_; +}; + +struct TensorFillVisitor { + TensorFillVisitor(lite::Tensor* dst, + int64_t dst_offset, + int64_t size, + float value) + : dst_(dst), dst_offset_(dst_offset), size_(size) {} + + template + void apply() const { + // TODO(qiao): support other place + // paddle::platform::CPUPlace cpu; + auto* tensor_data = dst_->mutable_data(lite::TargetType::kX86); + auto* start = tensor_data + dst_offset_; + auto* end = start + size_; + std::fill(start, end, static_cast(0.0)); + } + + lite::Tensor* dst_; + int64_t dst_offset_; + int64_t size_; +}; + +void SerializeToStream(std::ostream& os, + const SelectedRows& selected_rows, + const lite::Context& dev_ctx) { + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + os.write(reinterpret_cast(&version), sizeof(version)); + } + { + // the 2st field, rows information + auto& rows = selected_rows.rows(); + uint64_t size = rows.size(); + os.write(reinterpret_cast(&size), sizeof(size)); + for (uint64_t i = 0; i < size; ++i) { + os.write(reinterpret_cast(&rows[i]), sizeof(rows[i])); + } + } + { + // the 3st field, the height of SelectedRows + int64_t height = selected_rows.height(); + os.write(reinterpret_cast(&height), sizeof(height)); + } + // the 4st field, Tensor data + TensorToStream(os, selected_rows.value()); +} + +void DeserializeFromStream( + std::istream& is, + SelectedRows* selected_rows, + const lite::Context& dev_ctx) { + { + // the 1st field, unit32_t version for SelectedRows + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + } + { + // the 2st field, rows information + uint64_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + auto& rows = *selected_rows->mutable_rows(); + rows.resize(size); + for (uint64_t i = 0; i < size; ++i) { + is.read(reinterpret_cast(&rows[i]), sizeof(int64_t)); + } + } + { + // the 3st field, the height of the SelectedRows + int64_t height; + is.read(reinterpret_cast(&height), sizeof(int64_t)); + selected_rows->set_height(height); + } + // the 4st field, tensor which contains the data + TensorFromStream(is, selected_rows->mutable_value()); +} + +bool SelectedRows::HasKey(int64_t key) const { + return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false + : true; +} + +int64_t SelectedRows::AutoGrownIndex(int64_t key, + bool auto_grown, + bool is_test) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %ld not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %lu should have the same size with rows_ %lu", + map_size, + vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } +} + +void SelectedRows::SyncIndex() { + rwlock_->WRLock(); + id_to_index_.clear(); + for (size_t i = 0; i < rows_.size(); ++i) { + id_to_index_[rows_[i]] = i; + } + rwlock_->UNLock(); +} + +void SelectedRows::Get(const lite::Tensor& ids, + lite::Tensor* value, + bool auto_grown, + bool is_test) { + PADDLE_ENFORCE(value->IsInitialized(), + "The value tensor should be initialized."); + if (ids.numel() == 0) { + VLOG(3) << "keys is empty, please check data!"; + } else { + int64_t value_width = value_->numel() / value_->dims()[0]; + PADDLE_ENFORCE_EQ(value_width, + value->numel() / value->dims()[0], + "output tensor should have the same shape with table " + "except the dims[0]."); + for (int i = 0; i < ids.numel(); ++i) { + auto id = ids.data()[i]; + int64_t index = AutoGrownIndex(id, auto_grown, is_test); + if (index < 0) { + VLOG(5) << "id " << id << " not in the table, return 0"; + TensorFillVisitor(value, i * value_width, value_width, 0.0) + .apply(); + } else { + TensorCopyVisitor(value, + i * value_width, + *value_.get(), + index * value_width, + value_width) + .apply(); + } + } + } +} + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/selected_rows.h b/lite/fluid/selected_rows.h new file mode 100644 index 0000000000000000000000000000000000000000..0624ec2b8d85d1dd6b32a0f3765bdaba84aa20ea --- /dev/null +++ b/lite/fluid/selected_rows.h @@ -0,0 +1,173 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/fluid/rw_lock.h" +#include "lite/model_parser/model_parser.h" +namespace paddle { +namespace lite { +namespace fluid { + +class SelectedRows { + /* + * @brief We can use the SelectedRows structure to reproduce a sparse table. + * A sparse table is a key-value structure that the key is an `int64_t`, + * and the value is a Tensor which the first dimension is 0. + * You can use the following interface to operate the sparse table, and you + * can find + * some detail information from the comments of each interface: + * + * HasKey(key), whether the sparse table has the specified key. + * Set(key, value), set a key-value pair into the sparse table. + * Get(keys, value*), get value by given key list and apply it to the given + * value pointer + * with the specified offset. + * + */ + public: + SelectedRows(const std::vector& rows, const int64_t& height) + : rows_(rows), height_(height) { + value_.reset(new Tensor()); + rwlock_.reset(new RWLock); + } + + SelectedRows() { + height_ = 0; + value_.reset(new Tensor()); + rwlock_.reset(new RWLock); + } + + TargetType target() const { return value_->target(); } + + const Tensor& value() const { return *value_; } + + Tensor* mutable_value() { return value_.get(); } + + int64_t height() const { return height_; } + + void set_height(int64_t height) { height_ = height; } + + const std::vector& rows() const { return rows_; } + + std::vector* mutable_rows() { return &rows_; } + + void set_rows(const std::vector& rows) { rows_ = rows; } + + /* + * @brief Get the index of key in rows + * + * @return -1 if the key does not exists. + */ + int64_t Index(int64_t key) const { + auto it = std::find(rows_.begin(), rows_.end(), key); + if (it == rows_.end()) { + PADDLE_THROW("id %ld not in table", key); + } + return static_cast(std::distance(rows_.begin(), it)); + } + + /* + * @brief whether has the specified key in the table. + * + * @return true if the key is exists. + */ + bool HasKey(int64_t key) const; + + /* + * @brief Get value by the key list. + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. + * + * @return a list of pair which contains the non-exists key and the index in + * the value + */ + void Get(const lite::Tensor& ids, + lite::Tensor* value, + bool auto_grown = false, + bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. If the key not + * exist, + * add the key into id_to_index_. + * + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. + * + * @return index of the key. + */ + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. + */ + inline int64_t GetIndexFromId(int64_t key) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + void SyncIndex(); + /* + * @brief Get complete Dims before + */ + DDim GetCompleteDims() const { + DDim dims = value_->dims(); + dims[0] = height_; + return dims; + } + + private: + // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. + // SelectedRows are simply concated when adding together. Until a + // SelectedRows add a Tensor, will the duplicate rows be handled. + std::vector rows_; + std::unordered_map + id_to_index_; // should not be used when rows_ has duplicate member + std::unique_ptr value_{nullptr}; + int64_t height_; // height indicates the underline tensor's height + std::unique_ptr rwlock_{nullptr}; +}; + +/* + * Serialize/Desiralize SelectedRows to std::ostream + * You can pass ofstream or ostringstream to serilize to file + * or to a in memory string. GPU tensor will be copied to CPU. + */ +void SerializeToStream(std::ostream& os, + const SelectedRows& selected_rows, + const lite::Context& dev_ctx); +void DeserializeFromStream( + std::istream& is, + SelectedRows* selected_rows, + const lite::Context& dev_ctx); + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/transform.h b/lite/fluid/transform.h new file mode 100644 index 0000000000000000000000000000000000000000..4577e07e7430081cd21f34185ae5f0b00412a9ee --- /dev/null +++ b/lite/fluid/transform.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/core/op_lite.h" +#include "lite/fluid/hostdevice.h" + +namespace paddle { +namespace lite { +namespace fluid { + +// Transform applys a unary or a binary functor on each element in a +// range defined by a pair of iterators. +// +// - The specialization for CPU calls std::transform. +// - The specialization for CUDA calls thrust::tranform. +// +// NOTE: We need to define InputIter and OutputIter defined as +// different types, because the InputIter points op's inputs and +// OutputIter pints to op's outputs. +// +// NOTE: We don't assume that InputIter to be const InputType* and +// OutputIter to be OutputType*, because we might use a iterator +// class, paddle::fluid::operators::RowwiseTRansformIterator. +template +struct Transform { + // The unary version. + template + void operator()(const lite::Context& context, + InputIter first, + InputIter last, + OutputIter result, + UnaryOperation op); + + // The binary version. + template + void operator()(const lite::Context& context, + InputIter1 first1, + InputIter1 last1, + InputIter2 first2, + OutputIter result, + BinaryOperation op); +}; + +template <> +struct Transform { + template + void operator()(const lite::X86Context& context, + InputIter first, + InputIter last, + OutputIter result, + UnaryOperation op) { + std::transform(first, last, result, op); + } + + template + void operator()(const lite::X86Context& context, + InputIter1 first1, + InputIter1 last1, + InputIter2 first2, + OutputIter result, + BinaryOperation op) { + std::transform(first1, last1, first2, result, op); + } +}; + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/gen_code/CMakeLists.txt b/lite/gen_code/CMakeLists.txt index d83657ad3e24eb5661225a4a0684c141e40a6163..40c95415546d99a66abf2d6f3595ae8695c4df86 100644 --- a/lite/gen_code/CMakeLists.txt +++ b/lite/gen_code/CMakeLists.txt @@ -15,6 +15,7 @@ lite_cc_test(test_gen_code SRCS gen_code_test.cc X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} EXCLUDE_COMPILE_DEPS "ON" @@ -42,6 +43,7 @@ lite_cc_test(test_generated_code SRCS generated_code_test.cc DEPS __generated_co X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} EXCLUDE_COMPILE_DEPS "ON" diff --git a/lite/gen_code/gen_code.h b/lite/gen_code/gen_code.h index 7dea36636af6fa3682c6f9a66ab237573a54f0b6..58a7959f4eb34cb438bf0e25b49b36110435cc6b 100644 --- a/lite/gen_code/gen_code.h +++ b/lite/gen_code/gen_code.h @@ -102,7 +102,7 @@ class Module { void AddValidPlaceDecl() { // clang-format off - Line("std::vector valid_places({lite::Place({TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW)}), lite::Place({TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)})});"); // NOLINT + Line("std::vector valid_places({lite::Place({TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW)})});"); // NOLINT // clang-format on } diff --git a/lite/kernels/CMakeLists.txt b/lite/kernels/CMakeLists.txt index 1996f50133acc6f3bdf651e8c0daae5b68c96832..0bfd39ae9a0bdf6e8af606711fd4dcc6011994b5 100644 --- a/lite/kernels/CMakeLists.txt +++ b/lite/kernels/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(x86) add_subdirectory(opencl) add_subdirectory(fpga) add_subdirectory(npu) +add_subdirectory(xpu) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index dd5d985cc1aff47fe6effc7276355323a3792226..80aacbf7efe2f13a6cb2b04201e036561e682bf1 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -1,8 +1,10 @@ -if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) - return() -endif() - -message(STATUS "compile with lite ARM kernels") +# for conv op +add_kernel(conv_depthwise ARM basic SRCS conv_depthwise.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conv_direct ARM basic SRCS conv_direct.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conv_gemmlike ARM basic SRCS conv_gemmlike.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conv_winograd ARM basic SRCS conv_winograd.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conv_compute_arm ARM basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} + conv_depthwise conv_direct conv_gemmlike conv_winograd) add_kernel(fc_compute_arm ARM basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_compute_arm ARM basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -10,7 +12,6 @@ add_kernel(mul_compute_arm ARM basic SRCS mul_compute.cc DEPS ${lite_kernel_deps add_kernel(matmul_compute_arm ARM basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(scale_compute_arm ARM basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(softmax_compute_arm ARM basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(conv_compute_arm ARM basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -39,19 +40,24 @@ add_kernel(shape_compute_arm ARM basic SRCS shape_compute.cc DEPS ${lite_kernel_ add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(unsqueeze_compute_arm ARM extra SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(expand_compute_arm ARM basic SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(generate_proposals_compute_arm ARM basic SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(roi_align_compute_arm ARM basic SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(box_clip_compute_arm ARM basic SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_mean_compute_arm ARM extra SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(stack_compute_arm ARM extra SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(affine_channel_compute_arm ARM extra SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(range_compute_arm ARM extra SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) # for OCR specific add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -72,10 +78,16 @@ add_kernel(fill_constant_compute_arm ARM extra SRCS fill_constant_compute.cc DEP add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) +# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered +# to the model_optimize_tool. +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +message(STATUS "compile with lite ARM kernels") + lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) -lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) @@ -88,4 +100,5 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_ lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra) lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) +lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm) diff --git a/lite/kernels/arm/activation_compute.cc b/lite/kernels/arm/activation_compute.cc index add88b9294b6458449cb49fae10ddcbac6d65631..d50049d48748cf7ec43485a12fa7c65c0171a63d 100644 --- a/lite/kernels/arm/activation_compute.cc +++ b/lite/kernels/arm/activation_compute.cc @@ -147,6 +147,28 @@ void FloorCompute::Run() { x_data, output_data, x_dims.production(), ctx.threads()); } +void HardSigmoidCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + float slope = param.hard_sigmoid_slope; + float offset = param.hard_sigmoid_offset; + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_hard_sigmoid( + x_data, output_data, x_dims.production(), slope, offset, ctx.threads()); +} + +void RsqrtCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_rsqrt( + x_data, output_data, x_dims.production(), ctx.threads()); +} + } // namespace arm } // namespace kernels } // namespace lite @@ -224,3 +246,17 @@ REGISTER_LITE_KERNEL( .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +REGISTER_LITE_KERNEL(hard_sigmoid, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::HardSigmoidCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL( + rsqrt, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::RsqrtCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/activation_compute.h b/lite/kernels/arm/activation_compute.h index 0d7d34727bcc62c37c72a418eb14782452d6df91..ba1318ea36d01d1c3352679e7b5de12d013c0e84 100644 --- a/lite/kernels/arm/activation_compute.h +++ b/lite/kernels/arm/activation_compute.h @@ -121,6 +121,24 @@ class FloorCompute : public KernelLite { virtual ~FloorCompute() = default; }; +class HardSigmoidCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~HardSigmoidCompute() = default; +}; + +class RsqrtCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~RsqrtCompute() = default; +}; + } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/kernels/arm/argmax_compute.cc b/lite/kernels/arm/argmax_compute.cc index 5cb0e48c158286d9463bead60cc87a59168da1b4..ad279e8f8e1f80639c0b2512f89595d01ef062fd 100644 --- a/lite/kernels/arm/argmax_compute.cc +++ b/lite/kernels/arm/argmax_compute.cc @@ -40,8 +40,12 @@ void ArgmaxCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def) +REGISTER_LITE_KERNEL(arg_max, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ArgmaxCompute, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/argmax_compute_test.cc b/lite/kernels/arm/argmax_compute_test.cc index ee603efa86aa0390b568ddd5948cc811e49019b6..58bdf18474ae69b2bdb863b9818dab41e25bf17b 100644 --- a/lite/kernels/arm/argmax_compute_test.cc +++ b/lite/kernels/arm/argmax_compute_test.cc @@ -68,7 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) { TEST(argmax_arm, retrive_op) { auto argmax = KernelRegistry::Global().Create( - "argmax"); + "arg_max"); ASSERT_FALSE(argmax.empty()); ASSERT_TRUE(argmax.front()); } @@ -136,4 +136,4 @@ TEST(argmax_arm, compute) { } // namespace kernels } // namespace lite } // namespace paddle -USE_LITE_KERNEL(argmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(arg_max, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/assign_value_compute.cc b/lite/kernels/arm/assign_value_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..45f28ba36369cc79d70d683894c8a934b9308863 --- /dev/null +++ b/lite/kernels/arm/assign_value_compute.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/assign_value_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void TensorFromVector(const std::vector& src, lite::Tensor* dst) { + auto* src_ptr = static_cast(src.data()); + auto* dst_ptr = static_cast(dst->mutable_data()); + auto size = src.size() * sizeof(T); + std::memcpy(dst_ptr, src_ptr, size); +} + +void AssignValueCompute::Run() { + auto& param = Param(); + int dtype = param.dtype; + std::vector fp32_values = param.fp32_values; + std::vector int32_values = param.int32_values; + auto* out = param.Out; + + if (dtype == static_cast(lite::core::FluidType::INT32)) { + TensorFromVector(int32_values, out); + } else if (dtype == static_cast(lite::core::FluidType::FP32)) { + TensorFromVector(fp32_values, out); + } else { + LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype; + } + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(assign_value, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::AssignValueCompute, + def) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/assign_value_compute.h b/lite/kernels/arm/assign_value_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..f0c33f865bb770adc64a1727521fad10d0516ede --- /dev/null +++ b/lite/kernels/arm/assign_value_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/assign_value_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class AssignValueCompute : public KernelLite { + public: + using param_t = operators::AssignValueParam; + + void Run() override; + + virtual ~AssignValueCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/batch_norm_compute_test.cc b/lite/kernels/arm/batch_norm_compute_test.cc index c603a04d470263a09d3ab2674aa7ea70b8cf3b31..bf690f88a5e776709a3988cc843762db3bf684e6 100644 --- a/lite/kernels/arm/batch_norm_compute_test.cc +++ b/lite/kernels/arm/batch_norm_compute_test.cc @@ -14,6 +14,7 @@ #include "lite/kernels/arm/batch_norm_compute.h" #include +#include #include #include #include diff --git a/lite/kernels/arm/beam_search_decode_compute.cc b/lite/kernels/arm/beam_search_decode_compute.cc index a417baa6d7201e2abf067648b1d90ab37ac5ee21..49ca51bf697f272dacf55db655bc237aff2cc460 100644 --- a/lite/kernels/arm/beam_search_decode_compute.cc +++ b/lite/kernels/arm/beam_search_decode_compute.cc @@ -276,6 +276,10 @@ void BeamSearchDecodeCompute::Run() { param.end_id); func.apply(); + + // when decode finish, we clear ids and scores + param.ids->clear(); + param.scores->clear(); } } // namespace arm diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 8b6971ec138c0adeb7691b05917f403ab7031664..87afbae153ecd1c259ab4696d91b46b42b99d7e8 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -23,7 +23,7 @@ namespace arm { template out_type TransOp(in_type in) { - return static_cast(in); + return static_cast(in); } void CastCompute::PrepareForRun() {} @@ -45,6 +45,20 @@ void CastCompute::Run() { const char* x_data_end = x_data_begin + param.X->numel(); float* out_data = param.Out->mutable_data(); std::transform(x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 2 && param.out_dtype == 5) { // int32 -> float32 + const int32_t* x_data_begin = param.X->data(); + const int32_t* x_data_end = x_data_begin + param.X->numel(); + float* out_data = param.Out->mutable_data(); + // std::transform(x_data_begin, x_data_end, out_data, TransOp); + // todo: the input type actually is float. + memcpy(out_data, x_data_begin, sizeof(float) * param.X->numel()); + } else if (param.in_dtype == 20 && param.out_dtype == 5) { // uint8->float32 + const unsigned char* x_data_begin = param.X->data(); + const unsigned char* x_data_end = x_data_begin + param.X->numel(); + float* out_data = param.Out->mutable_data(); + std::transform( + x_data_begin, x_data_end, out_data, TransOp); } else { LOG(FATAL) << "other has not been implemented"; } diff --git a/lite/kernels/arm/compare_compute.cc b/lite/kernels/arm/compare_compute.cc index fe4b3d6587aa72d234344ef430ef43e7fd9057fe..95014b4ccd427e152dfe919643afa5ff5eb3011d 100644 --- a/lite/kernels/arm/compare_compute.cc +++ b/lite/kernels/arm/compare_compute.cc @@ -87,14 +87,13 @@ void CompareCompute::Run() { auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); bool *z = param.Out->template mutable_data(); - const auto *x = param.X->template data(); + const auto *x = param.X->template data(); const auto *y = param.Y->template data(); auto axis = param.axis; bool force_cpu = param.force_cpu; if (x_size == y_size) { for (int i = 0; i < x_size; ++i) { z[i] = CompareFunctor()(x[i], y[i]); - // z[i] = x[i] < y[i]; } } else { int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis); diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 9de1a85900db7808e74e3f71004ce33287c7d883..ebb96e21d5e856325b7abdb8342df2aea3d5b5c3 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -13,101 +13,87 @@ // limitations under the License. #include "lite/kernels/arm/conv_compute.h" +#include #include "lite/core/op_registry.h" #include "lite/core/type_system.h" +#include "lite/kernels/arm/conv_depthwise.h" +#include "lite/kernels/arm/conv_direct.h" +#include "lite/kernels/arm/conv_gemmlike.h" +#include "lite/kernels/arm/conv_winograd.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -void ConvCompute::PrepareForRun() { +template <> +void ConvCompute::PrepareForRun() { auto& param = this->Param(); - auto x_dims = param.x->dims(); auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - auto& ctx = this->ctx_->template As(); - int win = x_dims[3]; // nchw - int hin = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int ow = o_dims[3]; - int oh = o_dims[2]; - int oc = o_dims[1]; + int ic = w_dims[1] * param.groups; + int oc = w_dims[0]; int kh = w_dims[2]; // oihw int kw = w_dims[3]; int pad = param.paddings[0]; int stride = param.strides[0]; - const auto* i_data = param.x->data(); - const auto* w_data = param.filter->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); + int chin = param.x->dims()[1]; + int hin = param.x->dims()[2]; + int win = param.x->dims()[3]; + int chout = param.output->dims()[1]; + int hout = param.output->dims()[2]; + int wout = param.output->dims()[3]; bool kps_equal = (param.paddings[0] == param.paddings[1]) && (param.strides[0] == param.strides[1]) && (kw == kh); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); - bool flag_dw_3x3 = - (kw == 3 && (pad == 0 || pad == 1) && (stride == 1 || stride == 2)); + bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2)); bool flag_dw_5x5 = (kw == 5 && stride == 1) || (kw == 5 && stride == 2 && pad == 2); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; - // select conv impl + /// select conv impl if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { - // dw conv impl - impl_ = new lite::arm::math::DepthwiseConv; + /// dw conv impl + impl_ = new DepthwiseConv; VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && no_dilation) { - if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) { - // winograd conv impl - impl_ = new lite::arm::math::WinogradConv; + if (ic >= 32 && oc >= 32 && hout > 16 && wout > 16) { + /// winograd conv impl + impl_ = new WinogradConv; VLOG(3) << "invoking winograd conv"; } else { - // direct conv impl - impl_ = new lite::arm::math::DirectConv; + /// direct conv impl + impl_ = new DirectConv; VLOG(3) << "invoking direct conv"; } - } else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && - no_dilation) { - // direct conv impl - impl_ = new lite::arm::math::DirectConv; + } else if (param.groups == 1 && kw == 3 && stride == 2 && + chin * chout < 4 * hin * win && kps_equal && no_dilation) { + /// direct conv impl + impl_ = new DirectConv; VLOG(3) << "invoking direct conv"; } else { - impl_ = new lite::arm::math::GemmLikeConv; + impl_ = new GemmLikeConv; VLOG(3) << "invoking gemm like conv"; } - CHECK(this->impl_->create(param, &ctx)); -} - -void ConvCompute::Run() { - auto& param = this->Param(); - CHECK(impl_); - impl_->run(param); - // if (this->act_ != nullptr) { - // this->act_->run(outputs, outputs, param.activation_param); - // } + impl_->SetContext(std::move(this->ctx_)); + impl_->SetParam(param); + impl_->PrepareForRun(); + is_first_epoch_ = false; } -template -void ConvComputeInt8::PrepareForRun() { +template <> +void ConvCompute::PrepareForRun() { auto& param = this->Param(); - auto x_dims = param.x->dims(); auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); auto& ctx = this->ctx_->template As(); - int win = x_dims[3]; // nchw - int hin = x_dims[2]; - int ic = x_dims[1]; - int bs = x_dims[0]; - int ow = o_dims[3]; - int oh = o_dims[2]; - int oc = o_dims[1]; + int ic = param.groups * w_dims[1]; + int oc = w_dims[0]; int kh = w_dims[2]; // oihw int kw = w_dims[3]; int ph = param.paddings[1]; @@ -115,78 +101,98 @@ void ConvComputeInt8::PrepareForRun() { int sh = param.strides[1]; int sw = param.strides[0]; - bool with_bias = param.bias; bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); - bool flag_dw_3x3 = (kw == 3) && (ph == 1) && (sw == 1 || sw == 2); - bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); + bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); + bool flag_dw_5x5 = (kw == 5 && sw == 1); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { - impl_ = new lite::arm::math::DepthwiseConvInt8; + impl_ = new DepthwiseConv; VLOG(3) << "Run DepthwiseConv Int8"; } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && kps_equal && no_dilation) { + impl_ = new DirectConv; VLOG(3) << "Run DirectConv Int8"; - impl_ = new lite::arm::math::DirectConvInt8; } else { + impl_ = new GemmLikeConv; VLOG(3) << "Run GemmLikeConvInt8"; - impl_ = new lite::arm::math::GemmLikeConvInt8; } - // Convert fp32 bias to int32 bias. - if (with_bias) { - Tensor temp_tensor; - temp_tensor.CopyDataFrom(*param.bias); - lite::arm::math::trans_fp32_bias_to_int32_basic( - &temp_tensor, param.bias, param.input_scale, param.weight_scale); - } - // param.bias->data(); - CHECK(this->impl_->create(param, &ctx)); + impl_->SetContext(std::move(this->ctx_)); + impl_->SetParam(param); + impl_->PrepareForRun(); + is_first_epoch_ = false; } -template -void ConvComputeInt8::Run() { +template <> +void ConvCompute::PrepareForRun() { auto& param = this->Param(); - CHECK(impl_); - impl_->run(param); -} + auto w_dims = param.filter->dims(); -template class ConvComputeInt8; -template class ConvComputeInt8; -template class ConvComputeInt8; + auto& ctx = this->ctx_->template As(); + + int ic = w_dims[1] * param.groups; + int oc = w_dims[0]; + int kh = w_dims[2]; // oihw + int kw = w_dims[3]; + int ph = param.paddings[1]; + int pw = param.paddings[0]; + int sh = param.strides[1]; + int sw = param.strides[0]; + + bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); + bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); + bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); + bool flag_dw_5x5 = (kw == 5 && sw == 1); + bool flag_dw = flag_dw_3x3 || flag_dw_5x5; + + if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { + impl_ = new DepthwiseConv; + VLOG(3) << "Run DepthwiseConv Int8"; + } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && + kps_equal && no_dilation) { + impl_ = new DirectConv; + VLOG(3) << "Run DirectConv Int8"; + } else { + impl_ = new GemmLikeConv; + VLOG(3) << "Run GemmLikeConvInt8"; + } + impl_->SetContext(std::move(this->ctx_)); + impl_->SetParam(param); + impl_->PrepareForRun(); + is_first_epoch_ = false; +} } // namespace arm } // namespace kernels } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) +typedef paddle::lite::kernels::arm::ConvCompute + ConvFp32; +typedef paddle::lite::kernels::arm::ConvCompute + ConvInt8_Fp32; +typedef paddle::lite::kernels::arm::ConvCompute + ConvInt8_Int8; + +REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, ConvFp32, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL(depthwise_conv2d, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ConvCompute, - def) +REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, ConvFp32, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL( - conv2d, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::ConvComputeInt8, - int8_out) +REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Filter", @@ -195,13 +201,7 @@ REGISTER_LITE_KERNEL( {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .Finalize(); -REGISTER_LITE_KERNEL( - conv2d, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::ConvComputeInt8, - fp32_out) +REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Filter", @@ -211,12 +211,7 @@ REGISTER_LITE_KERNEL( .Finalize(); REGISTER_LITE_KERNEL( - depthwise_conv2d, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::ConvComputeInt8, - int8_out) + depthwise_conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Filter", @@ -226,12 +221,7 @@ REGISTER_LITE_KERNEL( .Finalize(); REGISTER_LITE_KERNEL( - depthwise_conv2d, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::ConvComputeInt8, - fp32_out) + depthwise_conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Filter", diff --git a/lite/kernels/arm/conv_compute.h b/lite/kernels/arm/conv_compute.h index 28b8d40677cd97c257d135c6c491e98efb948881..267b4746a35b431c4b4e36b26604a8654e0e58bd 100644 --- a/lite/kernels/arm/conv_compute.h +++ b/lite/kernels/arm/conv_compute.h @@ -15,20 +15,26 @@ #pragma once #include "lite/backends/arm/math/funcs.h" #include "lite/core/kernel.h" -#include "lite/operators/conv_op.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -class ConvCompute : public KernelLite { +template +class ConvCompute : public KernelLite { public: - using param_t = operators::ConvParam; + virtual void PrepareForRun(); - void PrepareForRun() override; + virtual void ReInitWhenNeeded() { + CHECK(impl_); + impl_->ReInitWhenNeeded(); + } - void Run() override; + virtual void Run() { + CHECK(impl_); + impl_->Run(); + } ~ConvCompute() { if (impl_ != nullptr) { @@ -37,28 +43,8 @@ class ConvCompute : public KernelLite { } private: - lite::arm::math::ImplBase* impl_{ - nullptr}; -}; - -template -class ConvComputeInt8 : public KernelLite { - public: using param_t = operators::ConvParam; - - void PrepareForRun() override; - - void Run() override; - - ~ConvComputeInt8() { - if (impl_ != nullptr) { - delete impl_; - } - } - - private: - lite::arm::math::ImplBase* impl_{ - nullptr}; + KernelLite* impl_{nullptr}; }; } // namespace arm diff --git a/lite/kernels/arm/conv_compute_test.cc b/lite/kernels/arm/conv_compute_test.cc deleted file mode 100644 index 40f678164e015f82a599f4759e434eed9e69905e..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/conv_compute_test.cc +++ /dev/null @@ -1,1045 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/arm/conv_compute.h" -#include -#include -#include -#include -#include -#include "lite/backends/arm/math/type_trans.h" -#include "lite/core/op_registry.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -static int get_rand(int start, int end) { - int i = rand(); // NOLINT - i = (i % (end - start)) + start; - return i; -} - -template -static void conv_basic(const Dtype1* din, - Dtype2* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const Dtype1* weights, - const Dtype2* bias, - int group, - int kernel_w, - int kernel_h, - int stride_w, - int stride_h, - int dila_w, - int dila_h, - int pad_w, - int pad_h, - bool flag_bias, - bool flag_relu) { - Dtype2 beta = 0; - auto src_data = din; - auto dst_data_ref = dout; - auto weights_data = weights; - auto with_bias = flag_bias; - auto bias_data = bias; - - int in_num = num; - int out_channels = chout; - int out_h = hout; - int out_w = wout; - - int in_channel = chin; - int in_h = hin; - int in_w = win; - int out_c_group = out_channels / group; - int in_c_group = in_channel / group; - - for (int n = 0; n < in_num; ++n) { - for (int g = 0; g < group; ++g) { - for (int oc = 0; oc < out_c_group; ++oc) { - for (int oh = 0; oh < out_h; ++oh) { - for (int ow = 0; ow < out_w; ++ow) { - int out_idx = n * group * out_c_group * out_h * out_w + - g * out_c_group * out_h * out_w + oc * out_h * out_w + - oh * out_w + ow; - Dtype2 bias_d = - with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0; - dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta; - for (int ic = 0; ic < in_c_group; ++ic) { - for (int kh = 0; kh < kernel_h; ++kh) { - for (int kw = 0; kw < kernel_w; ++kw) { - int iw = ow * stride_w - pad_w + kw * (dila_w); - int ih = oh * stride_h - pad_h + kh * (dila_h); - if (iw < 0 || iw >= in_w) continue; - if (ih < 0 || ih >= in_h) continue; - - int iidx = n * in_channel * in_h * in_w + - g * in_c_group * in_h * in_w + ic * in_h * in_w + - ih * in_w + iw; - int widx = - g * out_c_group * in_c_group * kernel_h * kernel_w + - oc * in_c_group * kernel_h * kernel_w + - ic * kernel_h * kernel_w + kh * kernel_w + kw; - - dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx]; - } - } - } - if (flag_relu) { - dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 - ? dst_data_ref[out_idx] - : (Dtype2)0; - } - } - } - } - } - } -} - -template -void conv_compute_ref(const operators::ConvParam& param) { - const Dtype1* din = param.x->data(); - Dtype2* dout = param.output->mutable_data(); - - int num = param.x->dims()[0]; - int chout = param.output->dims()[1]; - int hout = param.output->dims()[2]; - int wout = param.output->dims()[3]; - - int chin = param.x->dims()[1]; - int hin = param.x->dims()[2]; - int win = param.x->dims()[3]; - - const Dtype1* weights = param.filter->mutable_data(); - Dtype2* bias = nullptr; - if (param.bias != nullptr) { - bias = param.bias->mutable_data(); - } - - int group = param.groups; - int kernel_w = param.filter->dims()[2]; - int kernel_h = param.filter->dims()[3]; - int stride_w = param.strides[0]; - int stride_h = param.strides[1]; - int dila_w = param.dilations[0]; - int dila_h = param.dilations[1]; - - int pad_w = param.paddings[0]; - int pad_h = param.paddings[1]; - bool flag_bias = (param.bias != nullptr); - bool flag_relu = param.fuse_relu; - - conv_basic(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - group, - kernel_w, - kernel_h, - stride_w, - stride_h, - dila_w, - dila_h, - pad_w, - pad_h, - flag_bias, - flag_relu); -} - -TEST(conv_arm, retrive_op) { - auto conv = KernelRegistry::Global().Create( - "conv2d"); - ASSERT_FALSE(conv.empty()); - ASSERT_TRUE(conv.front()); -} - -TEST(conv_arm_int8, retrive_op) { - auto conv = - KernelRegistry::Global().Create("conv2d"); - ASSERT_FALSE(conv.empty()); - ASSERT_TRUE(conv.front()); -} - -TEST(conv_arm, init) { - ConvCompute conv; - ASSERT_EQ(conv.precision(), PRECISION(kFloat)); - ASSERT_EQ(conv.target(), TARGET(kARM)); -} - -TEST(conv_arm_int8, init) { - ConvComputeInt8 float_out; - ASSERT_EQ(float_out.precision(), PRECISION(kInt8)); - ASSERT_EQ(float_out.target(), TARGET(kARM)); - ConvComputeInt8 int8_out; - ASSERT_EQ(float_out.precision(), PRECISION(kInt8)); - ASSERT_EQ(float_out.target(), TARGET(kARM)); -} - -TEST(conv_arm_int8, int8_int32) { - DeviceInfo::Init(); - for (auto n : {2}) { - for (auto ic : {6}) { - for (auto oc : {6}) { - for (auto ih : {9}) { - for (auto iw : {9}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto depthwise : {false, /*true*/}) { - for (auto dilation : {1}) { - for (auto stride : {1}) { - for (auto padding : {0}) { - for (auto ks : {1}) { - int group = 1; - if (depthwise) { // depthwise convolution ? - group = oc = ic; - } - - const int dks = dilation * (ks - 1) + 1; - int oh = (ih + 2 * padding - dks) / stride + 1; - int ow = (iw + 2 * padding - dks) / stride + 1; - std::vector input_shape = {n, ic, ih, iw}; - std::vector filter_shape = { - oc, ic / group, ks, ks}; - std::vector output_shape({n, oc, oh, ow}); - - Tensor input_int8; - Tensor filter_int8; - Tensor output_int32, output_int32_ref; - - input_int8.Resize(input_shape); - filter_int8.Resize(filter_shape); - output_int32.Resize(output_shape); - output_int32_ref.Resize(output_shape); - - int8_t* input_int8_data = - input_int8.mutable_data(); - int8_t* filter_int8_data = - filter_int8.mutable_data(); - for (int i = 0; i < input_int8.dims().production(); - i++) { - input_int8_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < filter_int8.dims().production(); - i++) { - filter_int8_data[i] = i % 10 * (i % 3 - 1); - } - - operators::ConvParam param; - param.x = &input_int8; - param.filter = &filter_int8; - param.bias = nullptr; - param.fuse_relu = false; - param.paddings = std::vector({padding, padding}); - param.strides = std::vector({stride, stride}); - param.dilations = - std::vector({dilation, dilation}); - param.groups = group; - param.output = &output_int32_ref; - conv_compute_ref(param); - - param.output = &output_int32; - std::unique_ptr ctx(new KernelContext); - lite::arm::math::GemmLikeConvInt8 - int8gemm_int32; - int8gemm_int32.init(param, &ctx->As()); - int8gemm_int32.create(param, &ctx->As()); - int8gemm_int32.run(param); - - int* output_int32_data = - output_int32.mutable_data(); - int* output_int32_ref_data = - output_int32_ref.mutable_data(); - - for (int i = 0; i < output_int32.dims().production(); - i++) { - EXPECT_NEAR(output_int32_data[i], - output_int32_ref_data[i], - 1e-3); - } - } - } - } - } - } - } - } - } - } - } - } - } -} - -TEST(conv_arm_int8, int8_fp32) { - DeviceInfo::Init(); - for (auto n : {2}) { - for (auto ic : {6}) { - for (auto oc : {6}) { - for (auto ih : {9}) { - for (auto iw : {9}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto depthwise : {false, /*true*/}) { - for (auto dilation : {1}) { - for (auto stride : {1}) { - for (auto padding : {0}) { - for (auto ks : {1}) { - int group = 1; - if (depthwise) { // depthwise convolution ? - group = oc = ic; - } - - LOG(INFO) << "flag_bias: " << flag_bias; - - const int dks = dilation * (ks - 1) + 1; - int oh = (ih + 2 * padding - dks) / stride + 1; - int ow = (iw + 2 * padding - dks) / stride + 1; - std::vector input_shape = {n, ic, ih, iw}; - std::vector filter_shape = { - oc, ic / group, ks, ks}; - std::vector bias_shape({1, oc, 1, 1}); - std::vector output_shape({n, oc, oh, ow}); - - Tensor input_fp32, input_int8; - Tensor filter_fp32, filter_int8; - Tensor bias_fp32, bias_int32; - Tensor output_int32_ref, output_int32; - Tensor output_fp32_ref, output_fp32; - Tensor output_int8_ref, output_int8; - - input_fp32.Resize(input_shape); - input_int8.Resize(input_shape); - filter_fp32.Resize(filter_shape); - filter_int8.Resize(filter_shape); - bias_fp32.Resize(bias_shape); - bias_int32.Resize(bias_shape); - output_int32.Resize(output_shape); - output_int32_ref.Resize(output_shape); - output_fp32_ref.Resize(output_shape); - output_fp32.Resize(output_shape); - output_int8_ref.Resize(output_shape); - output_int8.Resize(output_shape); - - float* input_fp32_data = - input_fp32.mutable_data(); - int8_t* input_int8_data = - input_int8.mutable_data(); - - float* filter_fp32_data = - filter_fp32.mutable_data(); - int8_t* filter_int8_data = - filter_int8.mutable_data(); - - float* bias_fp32_data = - bias_fp32.mutable_data(); - int* bias_int32_data = bias_int32.mutable_data(); - - for (int i = 0; i < input_fp32.dims().production(); - i++) { - input_fp32_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < filter_fp32.dims().production(); - i++) { - filter_fp32_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < bias_fp32.dims().production(); - i++) { - bias_fp32_data[i] = i % 10 * (i % 3 - 1); - } - - std::vector in_scale; - lite::arm::math::get_tensor_scale( - input_fp32, &in_scale, -1, 127.f); - lite::arm::math::trans_tensor_fp32_to_int8( - &input_fp32, &input_int8, in_scale[0]); - - std::vector w_scale; - lite::arm::math::get_tensor_scale( - filter_fp32, &w_scale, -1, 127.f); - int axis_size = oc; - int inner_size = ic / group * ks * ks; - w_scale = lite::arm::math::get_tensor_scale_n( - filter_fp32_data, axis_size, inner_size, 127.f); - lite::arm::math::fp32_to_int8(filter_fp32_data, - filter_int8_data, - w_scale.data(), - axis_size, - 1, - inner_size); - - // lite::arm::math::trans_fp32_bias_to_int32_basic(&bias_fp32, - // &bias_int32, in_scale[0], w_scale); - for (int i = 0; i < bias_int32.dims().production(); - i++) { - bias_int32_data[i] = 1; - } - - operators::ConvParam param; - param.x = &input_int8; - param.filter = &filter_int8; - if (flag_bias) { - param.bias = &bias_int32; - } else { - param.bias = nullptr; - } - param.fuse_relu = false; - param.paddings = std::vector({padding, padding}); - param.strides = std::vector({stride, stride}); - param.dilations = - std::vector({dilation, dilation}); - param.groups = group; - param.output = &output_int32_ref; - conv_compute_ref(param); - - int* output_int32_ref_data = - output_int32_ref.mutable_data(); - - // ============ int8gemm_int32 ============ - /* - param.output = &output_int32; - std::unique_ptr ctx_int32( - new KernelContext); - lite::arm::math::GemmLikeConvInt8 - int8gemm_int32; - int8gemm_int32.init(param, - &ctx_int32->As()); - int8gemm_int32.create(param, - &ctx_int32->As()); - int8gemm_int32.run(param); - int* output_int32_data = - output_int32.mutable_data(); - for (int i = 0; i < output_int32.dims().production(); - i++) { - EXPECT_NEAR(output_int32_data[i], - output_int32_ref_data[i], 1e-3); - } - */ - // ============ int8gemm_int8 ============ - int8_t* output_int8_ref_data = - output_int8_ref.mutable_data(); - lite::arm::math::trans_tensor_int32_to_int8( - &output_int32_ref, - &output_int8_ref, - in_scale[0], - 1, - w_scale); - param.output = &output_int8; - param.input_scale = in_scale[0]; - param.output_scale = 1; - param.weight_scale = w_scale; - std::unique_ptr ctx_int8( - new KernelContext); - lite::arm::math::GemmLikeConvInt8 - int8gemm_int8; - int8gemm_int8.init(param, - &ctx_int8->As()); - int8gemm_int8.create(param, - &ctx_int8->As()); - int8gemm_int8.run(param); - int8_t* output_int8_data = - output_int8.mutable_data(); - for (int i = 0; i < output_int8.dims().production(); - i++) { - EXPECT_NEAR(output_int8_data[i], - output_int8_ref_data[i], - 1e-3); - } - - // ============ int8gemm_float32 ============ - float* output_fp32_ref_data = - output_fp32_ref.mutable_data(); - lite::arm::math::trans_tensor_int32_to_fp32( - &output_int32_ref, - &output_fp32_ref, - in_scale[0], - w_scale); - param.output = &output_fp32; - param.input_scale = in_scale[0]; - param.output_scale = 1; - param.weight_scale = w_scale; - std::unique_ptr ctx_fp32( - new KernelContext); - lite::arm::math::GemmLikeConvInt8 - int8gemm_fp32; - int8gemm_fp32.init(param, - &ctx_fp32->As()); - int8gemm_fp32.create(param, - &ctx_fp32->As()); - int8gemm_fp32.run(param); - float* output_fp32_data = - output_fp32.mutable_data(); - for (int i = 0; i < output_fp32.dims().production(); - i++) { - EXPECT_NEAR(output_fp32_data[i], - output_fp32_ref_data[i], - 1e-3); - } - } - } - } - } - } - } - } - } - } - } - } - } -} - -TEST(conv_direct_int8, compute) { - DeviceInfo::Init(); - for (auto n : {1, 2}) { - for (auto ic : {1, 3, 8}) { - for (auto oc : {1, 3, 8}) { - for (auto ih : {5, 15, 28}) { - for (auto iw : {5, 15, 28}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto depthwise : {false, /*true*/}) { - for (auto dilation : {1}) { - for (auto stride : {1, 2}) { - for (auto padding : {1}) { - for (auto ks : {3}) { - int group = 1; - if (depthwise) { // depthwise convolution ? - group = oc = ic; - } - - const int dks = dilation * (ks - 1) + 1; - int oh = (ih + 2 * padding - dks) / stride + 1; - int ow = (iw + 2 * padding - dks) / stride + 1; - std::vector input_shape = {n, ic, ih, iw}; - std::vector filter_shape = { - oc, ic / group, ks, ks}; - std::vector bias_shape({1, oc, 1, 1}); - std::vector output_shape({n, oc, oh, ow}); - - Tensor input_fp32, input_int8; - Tensor filter_fp32, filter_int8; - Tensor bias_int32; - Tensor output_int32_ref, output_int32; - Tensor output_fp32_ref, output_fp32; - Tensor output_int8_ref, output_int8; - - input_fp32.Resize(input_shape); - input_int8.Resize(input_shape); - filter_fp32.Resize(filter_shape); - filter_int8.Resize(filter_shape); - bias_int32.Resize(bias_shape); - output_int32.Resize(output_shape); - output_int32_ref.Resize(output_shape); - output_fp32_ref.Resize(output_shape); - output_fp32.Resize(output_shape); - output_int8_ref.Resize(output_shape); - output_int8.Resize(output_shape); - - float* input_fp32_data = - input_fp32.mutable_data(); - int8_t* input_int8_data = - input_int8.mutable_data(); - - float* filter_fp32_data = - filter_fp32.mutable_data(); - int8_t* filter_int8_data = - filter_int8.mutable_data(); - - int* bias_int32_data = - bias_int32.mutable_data(); - - for (int i = 0; i < input_fp32.dims().production(); - i++) { - input_fp32_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < filter_fp32.dims().production(); - i++) { - filter_fp32_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < bias_int32.dims().production(); - i++) { - bias_int32_data[i] = i % 10 * (i % 3 - 1); - } - - std::vector in_scale; - lite::arm::math::get_tensor_scale( - input_fp32, &in_scale, -1, 127.f); - lite::arm::math::trans_tensor_fp32_to_int8( - &input_fp32, &input_int8, in_scale[0]); - - std::vector w_scale; - lite::arm::math::get_tensor_scale( - filter_fp32, &w_scale, -1, 127.f); - int axis_size = oc; - int inner_size = ic / group * ks * ks; - w_scale = lite::arm::math::get_tensor_scale_n( - filter_fp32_data, axis_size, inner_size, 127.f); - lite::arm::math::fp32_to_int8(filter_fp32_data, - filter_int8_data, - w_scale.data(), - axis_size, - 1, - inner_size); - - operators::ConvParam param; - param.x = &input_int8; - param.filter = &filter_int8; - if (flag_bias) { - param.bias = &bias_int32; - } - param.fuse_relu = false; - param.paddings = std::vector({padding, padding}); - param.strides = std::vector({stride, stride}); - param.dilations = - std::vector({dilation, dilation}); - param.groups = group; - param.output = &output_int32_ref; - conv_compute_ref(param); - - int* output_int32_ref_data = - output_int32_ref.mutable_data(); - - // ============ int8direct_int32 ============ - param.output = &output_int32; - std::unique_ptr ctx_int32( - new KernelContext); - lite::arm::math::DirectConvInt8 - int8direct_int32; - int8direct_int32.init(param, - &ctx_int32->As()); - int8direct_int32.create(param, - &ctx_int32->As()); - int8direct_int32.run(param); - int* output_int32_data = - output_int32.mutable_data(); - for (int i = 0; i < output_int32.dims().production(); - i++) { - EXPECT_NEAR(output_int32_data[i], - output_int32_ref_data[i], - 1e-3); - } - - // ============ int8direct_int8 ============ - int8_t* output_int8_ref_data = - output_int8_ref.mutable_data(); - lite::arm::math::trans_tensor_int32_to_int8( - &output_int32_ref, - &output_int8_ref, - in_scale[0], - 1, - w_scale); - param.output = &output_int8; - param.input_scale = in_scale[0]; - param.output_scale = 1; - param.weight_scale = w_scale; - std::unique_ptr ctx_int8( - new KernelContext); - lite::arm::math::DirectConvInt8 - int8direct_int8; - int8direct_int8.init(param, - &ctx_int8->As()); - int8direct_int8.create(param, - &ctx_int8->As()); - int8direct_int8.run(param); - int8_t* output_int8_data = - output_int8.mutable_data(); - for (int i = 0; i < output_int8.dims().production(); - i++) { - EXPECT_NEAR(output_int8_data[i], - output_int8_ref_data[i], - 1e-3); - } - - // ============ int8direct_float32 ============ - float* output_fp32_ref_data = - output_fp32_ref.mutable_data(); - lite::arm::math::trans_tensor_int32_to_fp32( - &output_int32_ref, - &output_fp32_ref, - in_scale[0], - w_scale); - param.output = &output_fp32; - param.input_scale = in_scale[0]; - param.output_scale = 1; - param.weight_scale = w_scale; - std::unique_ptr ctx_fp32( - new KernelContext); - lite::arm::math::DirectConvInt8 - int8direct_fp32; - int8direct_fp32.init(param, - &ctx_fp32->As()); - int8direct_fp32.create(param, - &ctx_fp32->As()); - int8direct_fp32.run(param); - float* output_fp32_data = - output_fp32.mutable_data(); - for (int i = 0; i < output_fp32.dims().production(); - i++) { - EXPECT_NEAR(output_fp32_data[i], - output_fp32_ref_data[i], - 1e-3); - } - } - } - } - } - } - } - } - } - } - } - } - } -} - -TEST(conv_depthwise_int8, compute) { - DeviceInfo::Init(); - for (auto n : {1, 2}) { - for (auto ic : {1, 3, 8}) { - for (auto ih : {5, 15, 28}) { - for (auto iw : {5, 15, 28}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto dilation : {1}) { - for (auto stride : {1, 2}) { - for (auto padding : {1, 2}) { - for (auto ks : {3, /*5 */}) { - int group = ic; - int oc = ic; - - bool flag_dw_3x3 = (ks == 3) && (padding == 1) && - (stride == 1 || stride == 2); - bool flag_dw_5x5 = - (ks == 5 && stride == 1 && padding == 2); - bool flag_dw = flag_dw_3x3 || flag_dw_5x5; - if (!flag_dw) continue; - - const int dks = dilation * (ks - 1) + 1; - int oh = (ih + 2 * padding - dks) / stride + 1; - int ow = (iw + 2 * padding - dks) / stride + 1; - std::vector input_shape = {n, ic, ih, iw}; - std::vector filter_shape = { - oc, ic / group, ks, ks}; - std::vector bias_shape({1, oc, 1, 1}); - std::vector output_shape({n, oc, oh, ow}); - - Tensor input_fp32, input_int8; - Tensor filter_fp32, filter_int8; - Tensor bias_int32; - Tensor output_int32_ref, output_int32; - Tensor output_fp32_ref, output_fp32; - Tensor output_int8_ref, output_int8; - - input_fp32.Resize(input_shape); - input_int8.Resize(input_shape); - filter_fp32.Resize(filter_shape); - filter_int8.Resize(filter_shape); - bias_int32.Resize(bias_shape); - - output_int32.Resize(output_shape); - output_int32_ref.Resize(output_shape); - output_fp32_ref.Resize(output_shape); - output_fp32.Resize(output_shape); - output_int8_ref.Resize(output_shape); - output_int8.Resize(output_shape); - - float* input_fp32_data = input_fp32.mutable_data(); - int8_t* input_int8_data = - input_int8.mutable_data(); - float* filter_fp32_data = - filter_fp32.mutable_data(); - int8_t* filter_int8_data = - filter_int8.mutable_data(); - - int* bias_int32_data = bias_int32.mutable_data(); - - for (int i = 0; i < input_fp32.dims().production(); i++) { - input_fp32_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < filter_fp32.dims().production(); - i++) { - filter_fp32_data[i] = i % 10 * (i % 3 - 1); - } - for (int i = 0; i < bias_int32.dims().production(); i++) { - bias_int32_data[i] = i % 10 * (i % 3 - 1); - } - - std::vector in_scale; - lite::arm::math::get_tensor_scale( - input_fp32, &in_scale, -1, 127.f); - lite::arm::math::trans_tensor_fp32_to_int8( - &input_fp32, &input_int8, in_scale[0]); - - std::vector w_scale; - lite::arm::math::get_tensor_scale( - filter_fp32, &w_scale, -1, 127.f); - int axis_size = oc; - int inner_size = ic / group * ks * ks; - w_scale = lite::arm::math::get_tensor_scale_n( - filter_fp32_data, axis_size, inner_size, 127.f); - lite::arm::math::fp32_to_int8(filter_fp32_data, - filter_int8_data, - w_scale.data(), - axis_size, - 1, - inner_size); - - operators::ConvParam param; - param.x = &input_int8; - param.filter = &filter_int8; - if (flag_bias) { - param.bias = &bias_int32; - } - param.fuse_relu = false; - param.paddings = std::vector({padding, padding}); - param.strides = std::vector({stride, stride}); - param.dilations = std::vector({dilation, dilation}); - param.groups = group; - param.output = &output_int32_ref; - conv_compute_ref(param); - - int* output_int32_ref_data = - output_int32_ref.mutable_data(); - - // ============ int8depthwise_int32 ============ - param.output = &output_int32; - std::unique_ptr ctx_int32( - new KernelContext); - lite::arm::math::DepthwiseConvInt8 - int8depthwise_int32; - int8depthwise_int32.init(param, - &ctx_int32->As()); - int8depthwise_int32.create(param, - &ctx_int32->As()); - int8depthwise_int32.run(param); - int* output_int32_data = output_int32.mutable_data(); - for (int i = 0; i < output_int32.dims().production(); - i++) { - EXPECT_NEAR(output_int32_data[i], - output_int32_ref_data[i], - 1e-3); - } - - // ============ int8depthwise_int8============ - int8_t* output_int8_ref_data = - output_int8_ref.mutable_data(); - lite::arm::math::trans_tensor_int32_to_int8( - &output_int32_ref, - &output_int8_ref, - in_scale[0], - 1, - w_scale); - param.output = &output_int8; - param.input_scale = in_scale[0]; - param.output_scale = 1; - param.weight_scale = w_scale; - std::unique_ptr ctx_int8( - new KernelContext); - lite::arm::math::DepthwiseConvInt8 - int8depthwise_int8; - int8depthwise_int8.init(param, - &ctx_int8->As()); - int8depthwise_int8.create(param, - &ctx_int8->As()); - int8depthwise_int8.run(param); - int8_t* output_int8_data = - output_int8.mutable_data(); - for (int i = 0; i < output_int8.dims().production(); - i++) { - EXPECT_NEAR( - output_int8_data[i], output_int8_ref_data[i], 1e-3); - } - - // ============int8depthwise_float32 ============ - float* output_fp32_ref_data = - output_fp32_ref.mutable_data(); - lite::arm::math::trans_tensor_int32_to_fp32( - &output_int32_ref, - &output_fp32_ref, - in_scale[0], - w_scale); - param.output = &output_fp32; - param.input_scale = in_scale[0]; - param.output_scale = 1; - param.weight_scale = w_scale; - std::unique_ptr ctx_fp32( - new KernelContext); - lite::arm::math::DepthwiseConvInt8 - int8depthwise_fp32; - int8depthwise_fp32.init(param, - &ctx_fp32->As()); - int8depthwise_fp32.create(param, - &ctx_fp32->As()); - int8depthwise_fp32.run(param); - float* output_fp32_data = - output_fp32.mutable_data(); - for (int i = 0; i < output_fp32.dims().production(); - i++) { - EXPECT_NEAR( - output_fp32_data[i], output_fp32_ref_data[i], 1e-3); - } - } - } - } - } - } - } - } - } - } - } -} - -TEST(conv_arm, compute) { - DeviceInfo::Init(); -#if 1 - for (auto n : {2}) { - for (auto ic : {6}) { - for (auto oc : {6}) { - for (auto ih : {9}) { - for (auto iw : {9}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto depthwise : {false, true}) { - for (auto dilation : {1}) { - for (auto stride : {1, 2}) { - for (auto padding : {0, 1, 2}) { - for (auto ks : {1, 3, 5}) { -#else - for (auto n : {1, 2}) { - for (auto ic : {6, 32 /*, 128*/}) { - for (auto oc : {6, 32 /*, 128*/}) { - for (auto ih : {9, 18 /*, 56 , 112, 224, 512*/}) { - for (auto iw : {9, 18 /*, 56, 112, 224, 512*/}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto depthwise : {false, true}) { - for (auto dilation : {1, 2}) { - for (auto stride : {1, 2}) { - for (auto padding : {0, 1, 2}) { - for (auto ks : {1, 3, 5}) { -#endif - int group = 1; - if (depthwise) { // depthwise convolution ? - group = oc = ic; - } - // get input, filter and output shape - std::vector input_shape = {n, ic, ih, iw}; - std::vector filter_shape = { - oc, ic / group, ks, ks}; - const int dks = dilation * (ks - 1) + 1; - int oh = (ih + 2 * padding - dks) / stride + 1; - int ow = (iw + 2 * padding - dks) / stride + 1; - std::vector output_shape({n, oc, oh, ow}); - // resize input, filter and output - Tensor input; - Tensor filter; - Tensor bias; - Tensor output; - Tensor output_ref; - input.Resize(input_shape); - filter.Resize(filter_shape); - output.Resize(output_shape); - output_ref.Resize(output_shape); - VLOG(3) << "input: " << input.dims(); - VLOG(3) << "filter: " << filter.dims() - << " padding:" << padding - << " stride:" << stride - << " dilation:" << dilation; - VLOG(3) << "output: " << output.dims(); - auto* input_data = input.mutable_data(); - auto* filter_data = filter.mutable_data(); - auto* output_data = output.mutable_data(); - for (int i = 0; i < input.dims().production(); i++) { - float sign = i % 3 == 0 ? -1.0f : 1.0f; - input_data[i] = sign * static_cast(i % 128); - } - for (int i = 0; i < filter.dims().production(); i++) { - filter_data[i] = - i * 0.001f / - static_cast(filter.dims().production()); - } - // prepare kernel params and run - ConvCompute conv; - std::unique_ptr ctx(new KernelContext); - ctx->As(); - conv.SetContext(std::move(ctx)); - operators::ConvParam param; - param.x = &input; - param.filter = &filter; - param.output = &output; - param.bias = nullptr; - if (flag_bias) { - bias.Resize({oc}); - auto* bias_data = bias.mutable_data(); - for (int i = 0; i < bias.dims().production(); i++) { - bias_data[i] = static_cast(i); - } - param.bias = &bias; - } - param.fuse_relu = flag_relu; - param.paddings = std::vector({padding, padding}); - param.strides = std::vector({stride, stride}); - param.dilations = - std::vector({dilation, dilation}); - param.groups = group; - conv.SetParam(param); - conv.Launch(); - // invoking ref implementation and compare results - param.output = &output_ref; - conv_compute_ref(param); - auto* output_ref_data = - output_ref.mutable_data(); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR( - output_data[i], output_ref_data[i], 1e-3); - } - } - } - } - } - } - } - } - } - } - } - } - } -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a20d607e3a594c8eff83e1f872433f1c6025fd2 --- /dev/null +++ b/lite/kernels/arm/conv_depthwise.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/conv_depthwise.h" +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template <> +void DepthwiseConv::PrepareForRun() { + auto& param = this->Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + auto w_dims = param.filter->dims(); + auto kw = w_dims[3]; + // select dw conv kernel + if (kw == 3) { + VLOG(5) << "invoke 3x3 dw conv fp32"; + // trans weights + constexpr int cblock = 4; + auto oc = w_dims[0]; + auto kh = w_dims[2]; + auto cround = ROUNDUP(oc, cblock); + weights_.Resize({cround, 1, kh, kw}); + // auto w_data = weights_.mutable_data(); + // auto w_data_in = param.filter->data(); + // lite::arm::math::conv_trans_weights_numc( + // w_data_in, w_data, oc, 1, cblock, kh * kw); + impl_ = lite::arm::math::conv_depthwise_3x3_fp32; + flag_trans_weights_ = false; + // flag_trans_weights_ = true; + } else if (kw == 5) { + VLOG(5) << "invoke 5x5 dw conv fp32"; + impl_ = lite::arm::math::conv_depthwise_5x5_fp32; + } else { + LOG(FATAL) << "this type dw conv not impl"; + } +} + +template <> +void DepthwiseConv::PrepareForRun() { + auto& param = this->Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + auto w_dims = param.filter->dims(); + int kh = w_dims[2]; + int kw = w_dims[3]; + int oc = w_dims[0]; + /// update scale + float in_scale = param.input_scale; + auto& scale = param.weight_scale; + CHECK(scale.size() == 1 || scale.size() == oc) + << "weights scale size must = filter size or = 1"; + w_scale_.resize(oc); + for (int i = 0; i < oc; ++i) { + if (scale.size() == 1) { + w_scale_[i] = scale[0] * in_scale; + } else { + w_scale_[i] = scale[i] * in_scale; + } + } + /// select dw conv kernel + if (kw == 3) { + // trans weights + VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out"; + impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; + int cround = ROUNDUP(w_dims[0], 8); + weights_.Resize({cround / 8, 1, kh * kw, 8}); + auto wptr = param.filter->data(); + auto wptr_new = weights_.mutable_data(); + lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); + flag_trans_weights_ = true; + } else if (kw == 5) { + // trans weights + VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out"; + impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32; + int cround = ROUNDUP(w_dims[0], 8); + weights_.Resize({cround / 8, 1, kh * kw, 8}); + auto wptr = param.filter->data(); + auto wptr_new = weights_.mutable_data(); + lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25); + flag_trans_weights_ = true; + } else { + LOG(FATAL) << "this type dw conv not impl"; + } +} + +template <> +void DepthwiseConv::PrepareForRun() { + auto& param = this->Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + auto w_dims = param.filter->dims(); + int kw = w_dims[3]; + int kh = w_dims[2]; + int oc = w_dims[0]; + /// update scale + float in_scale = param.input_scale; + float out_scale = param.output_scale; + auto& scale = param.weight_scale; + CHECK(scale.size() == 1 || scale.size() == oc) + << "weights scale size must = filter size or = 1"; + w_scale_.resize(oc); + for (int i = 0; i < oc; ++i) { + if (scale.size() == 1) { + w_scale_[i] = scale[0] * in_scale / out_scale; + } else { + w_scale_[i] = scale[i] * in_scale / out_scale; + } + } + /// update bias + if (param.bias) { + bias_.Resize(param.bias->dims()); + auto ptr = bias_.mutable_data(); + auto ptr_in = param.bias->data(); + for (int i = 0; i < bias_.numel(); ++i) { + ptr[i] = ptr_in[i] / out_scale; + } + flag_trans_bias_ = true; + } + /// select dw conv kernel + if (kw == 3) { + // trans weights + VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out"; + impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; + int cround = ROUNDUP(w_dims[0], 8); + weights_.Resize({cround / 8, 1, kh * kw, 8}); + auto wptr = param.filter->data(); + auto wptr_new = weights_.mutable_data(); + lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); + flag_trans_weights_ = true; + } else if (kw == 5) { + // trans weights + VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out"; + impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8; + int cround = ROUNDUP(w_dims[0], 8); + weights_.Resize({cround / 8, 1, kh * kw, 8}); + auto wptr = param.filter->data(); + auto wptr_new = weights_.mutable_data(); + lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25); + flag_trans_weights_ = true; + } else { + LOG(FATAL) << "this type dw conv not impl"; + } +} + +template <> +void DepthwiseConv::Run() { + auto& param = this->Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = flag_trans_weights_ ? weights_.data() + : param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); +} + +template <> +void DepthwiseConv::Run() { + auto& param = this->Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = flag_trans_weights_ ? weights_.data() + : param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); +} + +template <> +void DepthwiseConv::Run() { + auto& param = this->Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = flag_trans_weights_ ? weights_.data() + : param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_depthwise.h b/lite/kernels/arm/conv_depthwise.h new file mode 100644 index 0000000000000000000000000000000000000000..e1e70355f621d043ec196bf68735acef8e918e69 --- /dev/null +++ b/lite/kernels/arm/conv_depthwise.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +class DepthwiseConv : public KernelLite { + public: + typedef void (*conv_dw_impl)(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale); + DepthwiseConv() = default; + ~DepthwiseConv() {} + virtual void PrepareForRun(); + virtual void Run(); + + private: + using param_t = operators::ConvParam; + Tensor weights_; + Tensor bias_; + bool flag_trans_weights_{false}; + bool flag_trans_bias_{false}; + conv_dw_impl impl_{nullptr}; + std::vector w_scale_; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_direct.cc b/lite/kernels/arm/conv_direct.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae8c1d1b9aa4e1e3e79c68116d91a0d0c1e9b1ab --- /dev/null +++ b/lite/kernels/arm/conv_direct.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/conv_direct.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template <> +void DirectConv::ReInitWhenNeeded() { + auto& param = this->template Param(); + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + if (last_shape_ == x_dims) { + return; + } + auto& ctx = this->ctx_->template As(); + if (param.strides[0] == 2) { + ctx.ExtendWorkspace( + lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx)); + } else { + ctx.ExtendWorkspace( + lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx)); + } +} + +template <> +void DirectConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = weights_.data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + if (param.strides[0] == 1) { + lite::arm::math::conv_3x3s1_direct_fp32(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx); + } else { + lite::arm::math::conv_3x3s2_direct_fp32(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx); + } +} + +template <> +void DirectConv::ReInitWhenNeeded() {} + +template <> +void DirectConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = weights_.data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + if (param.strides[0] == 1) { + lite::arm::math::conv_3x3s1_direct_int8(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); + } else { + lite::arm::math::conv_3x3s2_direct_int8(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); + } +} + +template <> +void DirectConv::ReInitWhenNeeded() {} + +template <> +void DirectConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = weights_.data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + if (param.strides[0] == 1) { + lite::arm::math::conv_3x3s1_direct_int8(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); + } else { + lite::arm::math::conv_3x3s2_direct_int8(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx, + w_scale_.data()); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_direct.h b/lite/kernels/arm/conv_direct.h new file mode 100644 index 0000000000000000000000000000000000000000..24c934e14b5ae5e7de4d089da1611deb0e77fefb --- /dev/null +++ b/lite/kernels/arm/conv_direct.h @@ -0,0 +1,201 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +inline bool direct_conv_trans_weights( + const Tensor* win, + Tensor* wout, + const Tensor* bin, + Tensor* bout, + int stride, + const std::vector& w_scale, + float in_scale, + float out_scale, + std::vector& merge_scale) { // NOLINT + constexpr int cblock = 4; + int oc = win->dims()[0]; + int ic = win->dims()[1]; + int kh = win->dims()[2]; + int kw = win->dims()[3]; + int cround = ROUNDUP(oc, cblock); + wout->Resize({cround, ic, kh, kw}); + auto w_in_data = win->data(); + auto transed_w_data = wout->mutable_data(); + lite::arm::math::conv_trans_weights_numc( + w_in_data, transed_w_data, oc, ic, cblock, kh * kw); + return false; +} + +template <> +inline bool direct_conv_trans_weights( + const Tensor* win, + Tensor* wout, + const Tensor* bin, + Tensor* bout, + int stride, + const std::vector& w_scale, + float in_scale, + float out_scale, + std::vector& merge_scale) { // NOLINT + int cblock = 4; + if (stride == 2) { + cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num(); + } + int oc = win->dims()[0]; + int ic = win->dims()[1]; + int kh = win->dims()[2]; + int kw = win->dims()[3]; + int cround = ROUNDUP(oc, cblock); + wout->Resize({cround, ic, kh, kw}); + auto w_in_data = win->data(); + auto transed_w_data = wout->mutable_data(); + lite::arm::math::conv_trans_weights_numc( + w_in_data, transed_w_data, oc, ic, cblock, kh * kw); + /// update scale + CHECK(w_scale.size() == 1 || w_scale.size() == oc) + << "weights scale size must = filter size or = 1"; + merge_scale.resize(oc); + for (int i = 0; i < oc; ++i) { + if (w_scale.size() == 1) { + merge_scale[i] = w_scale[0] * in_scale; + } else { + merge_scale[i] = w_scale[i] * in_scale; + } + } + return false; +} + +template <> +inline bool direct_conv_trans_weights( + const Tensor* win, + Tensor* wout, + const Tensor* bin, + Tensor* bout, + int stride, + const std::vector& w_scale, + float in_scale, + float out_scale, + std::vector& merge_scale) { // NOLINT + int cblock = 4; + if (stride == 2) { + cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num(); + } + int oc = win->dims()[0]; + int ic = win->dims()[1]; + int kh = win->dims()[2]; + int kw = win->dims()[3]; + int cround = ROUNDUP(oc, cblock); + wout->Resize({cround, ic, kh, kw}); + auto w_in_data = win->data(); + auto transed_w_data = wout->mutable_data(); + lite::arm::math::conv_trans_weights_numc( + w_in_data, transed_w_data, oc, ic, cblock, kh * kw); + /// update scale + CHECK(w_scale.size() == 1 || w_scale.size() == oc) + << "weights scale size must = filter size or = 1"; + merge_scale.resize(oc); + float scale = in_scale / out_scale; + for (int i = 0; i < oc; ++i) { + if (w_scale.size() == 1) { + merge_scale[i] = w_scale[0] * scale; + } else { + merge_scale[i] = w_scale[i] * scale; + } + } + /// update bias + if (bin) { + bout->Resize(bin->dims()); + auto ptr = bout->mutable_data(); + auto ptr_in = bin->data(); + for (int i = 0; i < bin->numel(); ++i) { + ptr[i] = ptr_in[i] / out_scale; + } + return true; + } + return false; +} + +/// only support 3x3s1 and 3x3s2 +template +class DirectConv : public KernelLite { + public: + DirectConv() = default; + ~DirectConv() {} + + virtual void PrepareForRun() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + last_shape_ = x_dims; + + int ic = x_dims[1]; + int oc = o_dims[1]; + int sw = param.strides[1]; + int kw = w_dims[3]; + int kh = w_dims[2]; + CHECK(sw == 1 || sw == 2) + << "direct conv only support conv3x3s1 and conv3x3s2"; + CHECK(kw == 3 && kh == 3) + << "direct conv only support conv3x3s1 and conv3x3s2"; + flag_trans_bias_ = + direct_conv_trans_weights(param.filter, + &weights_, + param.bias, + &bias_, + sw, + param.weight_scale, + param.input_scale, + param.output_scale, + w_scale_); + } + + virtual void ReInitWhenNeeded(); + virtual void Run(); + + /// todo, support inplace weights transform + protected: + DDim last_shape_; + Tensor weights_; + Tensor bias_; + bool flag_trans_weights_{false}; + bool flag_trans_bias_{false}; + std::vector w_scale_; + + private: + using param_t = operators::ConvParam; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_gemmlike.cc b/lite/kernels/arm/conv_gemmlike.cc new file mode 100644 index 0000000000000000000000000000000000000000..56dc72f2d6bbc331ccc14305d502f11cf4f27609 --- /dev/null +++ b/lite/kernels/arm/conv_gemmlike.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/conv_gemmlike.h" +#include +#include "lite/backends/arm/math/gemm_prepacked_int8.h" +#include "lite/backends/arm/math/packed_sgemm.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template <> +void GemmLikeConv::PrepareForRun() { + ReInitWhenNeeded(); +} + +template <> +void GemmLikeConv::PrepareForRun() { + ReInitWhenNeeded(); + auto& param = this->Param(); + /// update scale + w_scale_ = param.weight_scale; + if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) { + LOG(FATAL) << "weights scale size must equal to filter size"; + return; + } + if (w_scale_.size() == 1) { + for (int i = 0; i < param.filter->dims()[0] - 1; ++i) { + w_scale_.push_back(w_scale_[0]); + } + } + float input_scale = param.input_scale; + for (auto& ws : w_scale_) { + ws *= input_scale; + } +} + +template <> +void GemmLikeConv::PrepareForRun() { + ReInitWhenNeeded(); + auto& param = this->Param(); + /// update scale + /// update scale + w_scale_ = param.weight_scale; + if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) { + LOG(FATAL) << "weights scale size must equal to filter size"; + return; + } + if (w_scale_.size() == 1) { + for (int i = 0; i < param.filter->dims()[0] - 1; ++i) { + w_scale_.push_back(w_scale_[0]); + } + } + float input_scale = param.input_scale; + float output_scale = param.output_scale; + for (auto& ws : w_scale_) { + ws = ws * input_scale / output_scale; + } + //! update bias + if (param.bias) { + bias_.Resize(param.bias->dims()); + auto ptr = bias_.mutable_data(); + auto ptr_in = param.bias->data(); + for (int i = 0; i < bias_.numel(); ++i) { + ptr[i] = ptr_in[i] / param.output_scale; + } + flag_trans_bias_ = true; + } +} + +template <> +void GemmLikeConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto weights = param.filter->data(); + if (flag_trans_weights_) { + weights = weights_.data(); + } + const float* bias = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + bias = bias_.data(); + } + auto din = param.x->data(); + auto dout = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + if (flag_1x1gemm_) { + lite::arm::math::conv1x1s1_gemm( + din, dout, bs, oc, oh, ow, ic, ih, iw, weights, bias, param, &ctx); + } else { + lite::arm::math::conv_im2col_gemm( + din, dout, bs, oc, oh, ow, ic, ih, iw, weights, bias, param, &ctx); + } +} + +template <> +void GemmLikeConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto weights = param.filter->data(); + if (flag_trans_weights_) { + weights = weights_.data(); + } + auto bias = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + bias = bias_.data(); + } + auto din = param.x->data(); + auto dout = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + if (flag_1x1gemm_) { + lite::arm::math::conv1x1s1_gemm_int8(din, + dout, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + weights, + bias, + param, + &ctx, + w_scale_.data()); + } else { + lite::arm::math::conv_im2col_gemm_int8(din, + dout, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + weights, + bias, + param, + &ctx, + w_scale_.data()); + } +} + +template <> +void GemmLikeConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto weights = param.filter->data(); + if (flag_trans_weights_) { + weights = weights_.data(); + } + auto bias = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + bias = bias_.data(); + } + auto din = param.x->data(); + auto dout = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + if (flag_1x1gemm_) { + lite::arm::math::conv1x1s1_gemm_int8(din, + dout, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + weights, + bias, + param, + &ctx, + w_scale_.data()); + } else { + lite::arm::math::conv_im2col_gemm_int8(din, + dout, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + weights, + bias, + param, + &ctx, + w_scale_.data()); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_gemmlike.h b/lite/kernels/arm/conv_gemmlike.h new file mode 100644 index 0000000000000000000000000000000000000000..0f1213390b1febcc721cefc8a4005184dd00d3ec --- /dev/null +++ b/lite/kernels/arm/conv_gemmlike.h @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +class GemmLikeConv : public KernelLite { + public: + GemmLikeConv() = default; + ~GemmLikeConv() {} + + virtual void ReInitWhenNeeded() { + auto& param = this->template Param(); + CHECK(this->ctx_); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + if (last_shape_ == x_dims) { + return; + } + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int kh = w_dims[2]; + int sw = param.strides[1]; + int sh = param.strides[0]; + int pw = param.paddings[1]; + int ph = param.paddings[0]; + int dw = param.dilations[1]; + int dh = param.dilations[0]; + + int m = oc / param.groups; + int k = ic * kh * kw / param.groups; + int n = oh * ow; + + bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); + bool ks_equal = (sw == sh) && (kw == kh); + //! select conv gemmlike kernel + if (kw == 1 && sw == 1 && pw == 0 && kps_equal) { + //! 1x1s1p0 gemmlike conv + flag_1x1gemm_ = true; + } else { + //! im2col gemmlike conv + flag_1x1gemm_ = false; + ctx.ExtendWorkspace(k * n * sizeof(float)); + } + if (!flag_trans_weights_ && n > 1) { + lite::arm::math::trans_gemm_weights( + *(param.filter), weights_, param.groups, &ctx); + flag_trans_weights_ = true; + } else if (n == 1) { + flag_trans_weights_ = false; + } + last_shape_ = x_dims; + } + + virtual void PrepareForRun(); + virtual void Run(); + + /// todo, support inplace weights transform + protected: + using param_t = operators::ConvParam; + DDim last_shape_; + std::vector w_scale_; + bool flag_1x1gemm_{true}; + bool flag_trans_weights_{false}; + bool flag_trans_bias_{false}; + Tensor weights_; + Tensor bias_; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_transpose_compute.cc b/lite/kernels/arm/conv_transpose_compute.cc index fdadbabfc189814870bce96db121233eecf88c24..9fca00ad6b01b0c420d9a0d3ad0f712604a4a441 100644 --- a/lite/kernels/arm/conv_transpose_compute.cc +++ b/lite/kernels/arm/conv_transpose_compute.cc @@ -49,10 +49,11 @@ void Conv2DTransposeCompute::PrepareForRun() { lite::Tensor tmp_weights; lite::arm::math::prepackA( - &tmp_weights, *(param.filter), 1., m, k, group, true, &ctx); + &tmp_weights, *(param.filter), 1.f, m, k, group, true, &ctx); param.filter->Resize(tmp_weights.dims()); param.filter->CopyDataFrom(tmp_weights); param.filter->Resize(w_dims); + is_first_epoch_ = false; } void Conv2DTransposeCompute::Run() { @@ -80,7 +81,7 @@ void Conv2DTransposeCompute::Run() { int group_size_out = wout * hout * chout / group; int group_size_coldata = m * n; auto& ctx = this->ctx_->template As(); - int hblock = lite::arm::math::get_hblock(ctx.arch()); + int hblock = lite::arm::math::get_hblock(&ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int group_size_weights = ((m_roundup * k + 15) / 16) * 16; bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) && @@ -96,7 +97,7 @@ void Conv2DTransposeCompute::Run() { const float* din_batch = din + i * chin * hin * win; float* dout_batch = dout + i * chout * hout * wout; float* col_data = static_cast(ctx.workspace_data()) + - ctx.l2_cache_size() / sizeof(float); + ctx.llc_size() / sizeof(float); if (flag_1x1s1p1) { col_data = dout_batch; } @@ -112,7 +113,7 @@ void Conv2DTransposeCompute::Run() { weights_group, din_group, n, - 0., + 0.f, coldata_group, n, nullptr, diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6e73a0a59f8dbf9f0549a4732daaa53b89b9666 --- /dev/null +++ b/lite/kernels/arm/conv_winograd.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/conv_winograd.h" +#include +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/packed_sgemm.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template <> +void WinogradConv::ReInitWhenNeeded() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + if (last_shape_ == x_dims) { + return; + } + + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int tile_w = (ow + 5) / 6; + int tile_h = (oh + 5) / 6; + int size_tile = tile_h * tile_w; + int size_trans_channel = 8 * 8 * size_tile; + int max_ch = ic > oc ? ic : oc; + + const int n_wino = size_tile; + ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * + sizeof(float)); + last_shape_ = x_dims; +} + +template <> +void WinogradConv::PrepareForRun() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + last_shape_ = x_dims; + + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int tile_w = (ow + 5) / 6; + int tile_h = (oh + 5) / 6; + int size_tile = tile_h * tile_w; + int size_trans_channel = 8 * 8 * size_tile; + int max_ch = ic > oc ? ic : oc; + + const int m_wino = oc; + const int n_wino = size_tile; + int hblock = lite::arm::math::get_hblock(&ctx); + int m_round = hblock * ((m_wino + hblock - 1) / hblock); + weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); + ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * + sizeof(float)); + auto weights_wino = + static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); + void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); + lite::arm::math::winograd_transform_weights( + weights_wino, param.filter->data(), oc, ic, trans_tmp_ptr); + auto weights_trans = weights_.mutable_data(); + for (int i = 0; i < 64; ++i) { + float* packed_weights = weights_trans + i * m_round * ic; + const float* weights_wino_ptr = weights_wino + i * oc * ic; + lite::arm::math::prepackA(packed_weights, + weights_wino_ptr, + 1.f, + ic, + 0, + m_wino, + 0, + ic, + false, + &ctx); + } + free(trans_tmp_ptr); + free(weights_wino); +} + +template <> +void WinogradConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + const auto* i_data = param.x->data(); + const auto* w_data = weights_.data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + lite::arm::math::conv_winograd3x3( + i_data, o_data, bs, oc, oh, ow, ic, ih, iw, w_data, b_data, param, &ctx); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/conv_winograd.h b/lite/kernels/arm/conv_winograd.h new file mode 100644 index 0000000000000000000000000000000000000000..8b6de0af5ed359015e15515b559bfaf754d4c3f9 --- /dev/null +++ b/lite/kernels/arm/conv_winograd.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +/// only support 3x3s1 and 3x3s2 +template +class WinogradConv : public KernelLite { + public: + WinogradConv() = default; + ~WinogradConv() {} + virtual void PrepareForRun(); + virtual void ReInitWhenNeeded(); + virtual void Run(); + + protected: + using param_t = operators::ConvParam; + Tensor weights_; + DDim last_shape_; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/decode_bboxes_compute_test.cc b/lite/kernels/arm/decode_bboxes_compute_test.cc index 262fd7b19a2eb9e501ac2119d3fad6db608dc68c..271a99c29b61063877b7d1c0d2e50bc65d135d72 100644 --- a/lite/kernels/arm/decode_bboxes_compute_test.cc +++ b/lite/kernels/arm/decode_bboxes_compute_test.cc @@ -14,6 +14,7 @@ #include "lite/kernels/arm/decode_bboxes_compute.h" #include +#include #include #include #include "lite/core/op_registry.h" diff --git a/lite/kernels/arm/density_prior_box_compute.cc b/lite/kernels/arm/density_prior_box_compute.cc index 3a9fa85411464081e7b006afffe4e80d87ef90f6..e45fd5ba4d3c120140de13d00f074fd3526ac3f5 100644 --- a/lite/kernels/arm/density_prior_box_compute.cc +++ b/lite/kernels/arm/density_prior_box_compute.cc @@ -100,7 +100,8 @@ void DensityPriorBoxCompute::Run() { prior_num, is_flip, is_clip, - order); + order, + false); } } // namespace arm diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index a0a87628fff494807c82cfb9891ce717fbfd0cab..2e57b6a3b37c91845d75444333fb205683cfd81c 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -116,6 +116,51 @@ void ElementwiseAddActivationCompute::Run() { } } +void ElementwiseSubCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_sub_broadcast( + x_data, y_data, out_data, pre, n, post); + } else { + lite::arm::math::elementwise_sub( + x_data, y_data, out_data, x_dims.production()); + } +} + +void ElementwiseSubActivationCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int axis = param.axis; + std::string act_type = param.act_type; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + if (act_type == "relu") { + lite::arm::math::elementwise_sub_relu_broadcast( + x_data, y_data, out_data, pre, n, post); + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type; + } + } else { + if (act_type == "relu") { + lite::arm::math::elementwise_sub_relu( + x_data, y_data, out_data, x_dims.production()); + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type; + } + } +} + void ElementwiseMulCompute::Run() { auto& param = Param(); const float* x_data = param.X->data(); @@ -249,10 +294,6 @@ void ElementwiseDivActivationCompute::Run() { LOG(FATAL) << "unsupported Activation type: " << act_type; } } - for (int i = 0; i < x_dims.production(); i++) { - LOG(INFO) << "x:" << x_data[i] << " y:" << y_data[i] - << " out:" << out_data[i]; - } } } // namespace arm @@ -283,6 +324,29 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +REGISTER_LITE_KERNEL(elementwise_sub, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseSubCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + fusion_elementwise_sub_activation, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseSubActivationCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kARM, kFloat, diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 003f4d542fde5d588abd67cb85f540107b2fa417..e76449aebcfa16317df99771f2b686d9a179ec25 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -38,6 +38,22 @@ class ElementwiseAddActivationCompute virtual ~ElementwiseAddActivationCompute() = default; }; +class ElementwiseSubCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseSubCompute() = default; +}; + +class ElementwiseSubActivationCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseSubActivationCompute() = default; +}; + class ElementwiseMulCompute : public KernelLite { public: diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index 83d40362e7bbd53737b005ccba99b89373cca215..1983c733180143dc0c715d6c8e3c4fddac6f8418 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -26,50 +26,74 @@ namespace lite { namespace kernels { namespace arm { -void FcCompute::PrepareForRun() { - auto& param = this->Param(); - auto x_dims = param.input->dims(); - auto w_dims = param.w->dims(); - - auto& ctx = this->ctx_->template As(); - - CHECK_GE(x_dims.size(), 2UL); - CHECK_EQ(w_dims.size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); - - m_ = x_dims.Slice(0, param.in_num_col_dims).production(); - k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); - CHECK_EQ(k_, w_dims[0]); - n_ = w_dims[1]; - CHECK_EQ(k_, static_cast(w_dims[0])); +/// for fp32 kernel +template <> +void FcCompute::PrepareForRun() { + ReInitWhenNeeded(); +} - if (m_ == 1) { - if (!transed_weight_) { - transed_weight_ = new Tensor; +/// for int8 kernel with fp32 output +template <> +void FcCompute::PrepareForRun() { + ReInitWhenNeeded(); + auto& param = this->template Param(); + /// update scale + float input_scale = param.input_scale; + int extend_size = flag_gemm_ ? m_ : n_; + scale_.resize(extend_size); + for (int i = 0; i < extend_size; ++i) { + if (flag_gemm_) { + scale_[i] = param.weight_scale[0] * input_scale; + } else { + scale_[i] = param.weight_scale[i] * input_scale; } - transed_weight_->Resize({n_, k_}); - const auto* w_data = param.w->data(); - auto* t_data = transed_weight_->mutable_data(); - int i = 0; + } +} - for (int nn = 0; nn < n_; ++nn) { - for (int kk = 0; kk < k_; ++kk) { - t_data[i++] = w_data[kk * n_ + nn]; - } +/// for int8 kernel with int8 output +template <> +void FcCompute::PrepareForRun() { + ReInitWhenNeeded(); + auto& param = this->template Param(); + /// update scale + scale_ = param.weight_scale; + float input_scale = param.input_scale; + float output_scale = param.output_scale; + int extend_size = flag_gemm_ ? m_ : n_; + scale_.resize(extend_size); + for (int i = 0; i < extend_size; ++i) { + if (flag_gemm_) { + scale_[i] = param.weight_scale[0] * input_scale / output_scale; + } else { + scale_[i] = param.weight_scale[i] * input_scale / output_scale; } } + /// update bias + if (param.bias) { + bias_.Resize(param.bias->dims()); + auto ptr = bias_.mutable_data(); + auto ptr_in = bias_.data(); + float out_scale = param.output_scale; + for (int i = 0; i < bias_.numel(); ++i) { + ptr[i] = ptr_in[i] / out_scale; + } + flag_trans_bias_ = true; + } } -void FcCompute::Run() { +template <> +void FcCompute::Run() { auto& param = this->Param(); - - const auto* i_data = param.input->data(); - const auto* w_data = param.w->data(); - const auto* b_data = param.bias ? param.bias->data() : nullptr; - auto* o_data = param.output->mutable_data(); - auto& ctx = this->ctx_->template As(); - if (m_ > 1) { + + auto i_data = param.input->data(); + auto o_data = param.output->mutable_data(); + auto w_data = flag_gemm_ ? param.w->data() : weights_.data(); + const float* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + if (flag_gemm_) { lite::arm::math::sgemm(false, false, m_, @@ -83,7 +107,7 @@ void FcCompute::Run() { 0.f, o_data, n_, - b_data, + nullptr, false, false, &ctx); @@ -92,134 +116,117 @@ void FcCompute::Run() { lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); } } else { - CHECK(transed_weight_); - const auto* t_data = transed_weight_->data(); - - lite::arm::math::sgemv(t_data, - i_data, - o_data, - false, - n_, - k_, - b_data != nullptr, - b_data, - false); + for (int i = 0; i < m_; ++i) { + auto i_data_batch = i_data + i * k_; + auto o_data_batch = o_data + i * n_; + lite::arm::math::sgemv(w_data, + i_data_batch, + o_data_batch, + false, + n_, + k_, + param.bias != nullptr, + b_data, + false); + } } } -template -void FcComputeInt8::PrepareForRun() { +template <> +void FcCompute::Run() { auto& param = this->Param(); - auto x_dims = param.input->dims(); - auto w_dims = param.w->dims(); - auto& ctx = this->ctx_->template As(); - if (!tmp_int32_out_) { - tmp_int32_out_ = new Tensor; - tmp_int32_out_->Resize(param.output->dims()); - } - - CHECK_GE(x_dims.size(), 2UL); - CHECK_EQ(w_dims.size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); - - this->m_ = x_dims.Slice(0, param.in_num_col_dims).production(); - this->k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); - this->n_ = w_dims[1]; - CHECK_EQ(k_, static_cast(w_dims[0])); - if (this->m_ == 1) { - if (!this->transed_weight_) { - this->transed_weight_ = new Tensor; + auto i_data = param.input->data(); + auto o_data = param.output->mutable_data(); + auto w_data = + flag_trans_weights_ ? weights_.data() : param.w->data(); + const float* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); + } + if (flag_gemm_) { + lite::arm::math::gemm_s8(false, + false, + m_, + n_, + k_, + i_data, + w_data, + o_data, + nullptr, + false, + false, + scale_.data(), + &ctx); + if (param.bias) { + CHECK_EQ(param.bias->numel(), n_); + lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); } - this->transed_weight_->Resize({this->n_, this->k_}); - const auto* w_data = param.w->template data(); - auto* t_data = this->transed_weight_->template mutable_data(); - int i = 0; - - for (int nn = 0; nn < this->n_; ++nn) { - for (int kk = 0; kk < this->k_; ++kk) { - t_data[i++] = w_data[kk * this->n_ + nn]; - } + } else { + for (int i = 0; i < m_; ++i) { + auto i_data_batch = i_data + i * k_; + auto o_data_batch = o_data + i * n_; + lite::arm::math::gemv_int8(w_data, + i_data_batch, + o_data_batch, + false, + n_, + k_, + scale_.data(), + param.bias != nullptr, + b_data, + false, + &ctx); } } - - if (this->m_ > 1) { - int hblock = lite::arm::math::get_hblock(ctx.arch()); - int m_round = hblock * ((this->m_ + hblock - 1) / hblock); - ctx.ExtendWorkspace(m_round * this->k_); - } - bool with_bias = param.bias; - if (with_bias) { - Tensor temp_tensor; - temp_tensor.CopyDataFrom(*param.bias); - lite::arm::math::trans_fp32_bias_to_int32_basic( - &temp_tensor, param.bias, param.input_scale, param.weight_scale); - } } -template -void FcComputeInt8::Run() { +template <> +void FcCompute::Run() { auto& param = this->Param(); - - const auto* i_data = param.input->template data(); - const auto* w_data = param.w->template data(); - const auto* b_data = param.bias ? param.bias->template data() : nullptr; - int* o_data = nullptr; - auto& ctx = this->ctx_->template As(); - o_data = this->tmp_int32_out_->template mutable_data(); - if (m_ > 1) { - int8_t* packed_in = - static_cast(ctx.template workspace_data()) + - ctx.llc_size() / sizeof(int8_t); - lite::arm::math::prepackA_int8( - packed_in, i_data, k_, 0, m_, 0, k_, false, &ctx); - lite::arm::math::gemm_prepack_int8(packed_in, - w_data, - b_data, - o_data, - m_, - n_, - k_, - false, - false, - false, - nullptr, - &ctx); - if (param.bias) { - CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); - } - } else { - CHECK(transed_weight_); - const auto* t_data = transed_weight_->template data(); - lite::arm::math::gemv_int8(t_data, - i_data, - o_data, - false, - n_, - k_, - nullptr, - b_data != nullptr, - b_data, - false); + auto i_data = param.input->data(); + auto o_data = param.output->mutable_data(); + auto w_data = + flag_trans_weights_ ? weights_.data() : param.w->data(); + const float* b_data = param.bias ? param.bias->data() : nullptr; + if (flag_trans_bias_) { + b_data = bias_.data(); } - - float i_scale = param.input_scale; - std::vector weight_scale = param.weight_scale; - if (Ptype_out == PRECISION(kInt8)) { - float o_scale = param.output_scale; - param.output->template mutable_data(); - lite::arm::math::trans_tensor_dtype( - tmp_int32_out_, param.output, i_scale, o_scale, weight_scale); - } else if (Ptype_out == PRECISION(kFloat)) { - param.output->template mutable_data(); - lite::arm::math::trans_tensor_dtype( - tmp_int32_out_, param.output, i_scale, 1.f, weight_scale); + if (flag_gemm_) { + CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " + "must not have bias"; + lite::arm::math::gemm_s8(false, + false, + m_, + n_, + k_, + i_data, + w_data, + o_data, + nullptr, + false, + false, + scale_.data(), + &ctx); } else { - LOG(ERROR) << "unsupported precision type!!"; + for (int i = 0; i < m_; ++i) { + auto i_data_batch = i_data + i * k_; + auto o_data_batch = o_data + i * n_; + lite::arm::math::gemv_int8(w_data, + i_data_batch, + o_data_batch, + false, + n_, + k_, + scale_.data(), + param.bias != nullptr, + b_data, + false, + &ctx); + } } } @@ -228,36 +235,33 @@ void FcComputeInt8::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - fc, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::FcCompute, def) +typedef paddle::lite::kernels::arm::FcCompute + FcCompute_FP32; +typedef paddle::lite::kernels::arm::FcCompute + FcCompute_int8_fp32; +typedef paddle::lite::kernels::arm::FcCompute + FcCompute_int8_int8; + +REGISTER_LITE_KERNEL(fc, kARM, kFloat, kNCHW, FcCompute_FP32, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL( - fc, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::FcComputeInt8, - int8out) +REGISTER_LITE_KERNEL(fc, kARM, kInt8, kNCHW, FcCompute_int8_int8, int8out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .Finalize(); -REGISTER_LITE_KERNEL( - fc, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::FcComputeInt8, - fp32out) +REGISTER_LITE_KERNEL(fc, kARM, kInt8, kNCHW, FcCompute_int8_fp32, fp32out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .Finalize(); diff --git a/lite/kernels/arm/fc_compute.h b/lite/kernels/arm/fc_compute.h index 2af3845ebbbb76c1a2ae1681787abe2852c442ab..2e5f2345e824b13d78a1575d3374652b8474c7fd 100644 --- a/lite/kernels/arm/fc_compute.h +++ b/lite/kernels/arm/fc_compute.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/type_trans.h" #include "lite/core/kernel.h" @@ -22,44 +24,108 @@ namespace lite { namespace kernels { namespace arm { -class FcCompute : public KernelLite { - public: - using param_t = operators::FcParam; +template +void naive_transpose(const Dtype* din, Dtype* dout, int m, int n) { + int k = 0; + for (int i = 0; i < n; ++i) { + for (int j = 0; j < m; ++j) { + dout[k++] = din[j * n + i]; + } + } +} - void PrepareForRun() override; +template +void fc_trans_weights(const Tensor& tin, Tensor* tout); - void Run() override; +template <> +void fc_trans_weights(const Tensor& tin, Tensor* tout) { + CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2"; + int m = tin.dims()[0]; + int n = tin.dims()[1]; + tout->Resize({n, m}); + auto ptr_in = tin.data(); + auto ptr_out = tout->mutable_data(); + naive_transpose(ptr_in, ptr_out, m, n); +} - ~FcCompute() override { - if (transed_weight_) { - delete transed_weight_; - } - }; +template <> +void fc_trans_weights(const Tensor& tin, Tensor* tout) { + CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2"; + int m = tin.dims()[0]; + int n = tin.dims()[1]; + tout->Resize({n, m}); + auto ptr_in = tin.data(); + auto ptr_out = tout->mutable_data(); + naive_transpose(ptr_in, ptr_out, m, n); +} - private: - lite::Tensor* transed_weight_{nullptr}; - int m_, n_, k_; -}; +template +bool check_fc_use_gemm(int m, const std::vector& scale, bool has_bias) { + return m > 1; +} -template -class FcComputeInt8 : public KernelLite { +template <> +bool check_fc_use_gemm( + int m, const std::vector& scale, bool has_bias) { + CHECK(scale.size() > 0) << "Int8 FC param must has weight_scale"; + return m > 1 && scale.size() == 1; +} + +template <> +bool check_fc_use_gemm( + int m, const std::vector& scale, bool has_bias) { + CHECK(scale.size() > 0) << "Int8 FC param must has weight_scale"; + return m > 1 && scale.size() == 1 && !has_bias; +} + +template +class FcCompute : public KernelLite { public: using param_t = operators::FcParam; - void PrepareForRun() override; + virtual void ReInitWhenNeeded() { + auto& param = this->template Param(); + auto x_dims = param.input->dims(); + if (last_shape_ == x_dims) { + return; + } + last_shape_ = x_dims; + auto w_dims = param.w->dims(); + auto& ctx = this->ctx_->template As(); - void Run() override; + CHECK_GE(x_dims.size(), 2UL); + CHECK_EQ(w_dims.size(), 2UL); + CHECK_EQ(param.output->dims().size(), 2UL); - ~FcComputeInt8() override { - if (transed_weight_) { - delete transed_weight_; + m_ = x_dims.Slice(0, param.in_num_col_dims).production(); + k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); + CHECK_EQ(k_, w_dims[0]); + n_ = w_dims[1]; + CHECK_EQ(k_, static_cast(w_dims[0])); + flag_gemm_ = check_fc_use_gemm( + m_, param.weight_scale, param.bias != nullptr); + if (!flag_trans_weights_ && !flag_gemm_) { + flag_trans_weights_ = true; + fc_trans_weights(*param.w, &weights_); } - }; + } + + virtual void PrepareForRun(); + virtual void Run(); + + ~FcCompute() = default; private: - lite::Tensor* transed_weight_{nullptr}; - Tensor* tmp_int32_out_{nullptr}; - int m_, n_, k_; + DDim last_shape_; + Tensor weights_; + Tensor bias_; + bool flag_trans_weights_{false}; + bool flag_trans_bias_{false}; + bool flag_gemm_{true}; + int m_; + int n_; + int k_; + std::vector scale_; }; } // namespace arm diff --git a/lite/kernels/arm/fc_compute_test.cc b/lite/kernels/arm/fc_compute_test.cc deleted file mode 100644 index acda9016673fde7227a7b0c5278e73c3c92cf8aa..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/fc_compute_test.cc +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/arm/fc_compute.h" -#include -#include -#include -#include -#include -#include -#include -#include "lite/backends/arm/math/funcs.h" -#include "lite/core/op_registry.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -#define A(i, j) a[i * lda + j] -#define B(i, j) b[i * ldb + j] -#define C(i, j) c[i * ldc + j] - -template -void gemm_bias(const T* a, - const int M, - const int K, - const T* b, - const int K_, - const int N, - T* biases, - T* c) { - EXPECT_TRUE(K_ == K && M > 0 && N > 0 && K > 0); - EXPECT_TRUE(a && b && c); - const int lda = K; - const int ldb = N; - const int ldc = N; - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - C(m, n) = 0.0f; - for (int k = 0; k < K; ++k) { - C(m, n) += A(m, k) * B(k, n); - } - } - } - if (biases) { - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - C(m, n) += biases[n]; - } - } - } -} - -template -void FillData(T* a, - const int n, - const T lower = static_cast(-2.f), - const T upper = static_cast(2.f)) { - static unsigned int seed = 100; - std::mt19937 rng(seed++); - std::uniform_real_distribution uniform_dist(0, 1); - for (int i = 0; i < n; ++i) { - a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); - } -} - -TEST(fc_arm, retrive_op) { - auto fc = - KernelRegistry::Global().Create("fc"); - ASSERT_FALSE(fc.empty()); - ASSERT_TRUE(fc.front()); -} - -TEST(fc_arm, init) { - FcCompute fc; - ASSERT_EQ(fc.precision(), PRECISION(kFloat)); - ASSERT_EQ(fc.target(), TARGET(kARM)); -} - -TEST(fc_arm, compare_test) { - using T = float; - - for (int m : {1, 2, 3, 4}) { - for (int n : {1, 2, 3, 4}) { - for (int k : {1, 2, 3, 4}) { - for (bool with_bias : {true, false}) { - VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k - << (with_bias ? ", with bias" : ""); - lite::Tensor x, w, b, out, ref; - - x.Resize({m, k}); - w.Resize({k, n}); - b.Resize({1, n}); - out.Resize({m, n}); - ref.Resize({m, n}); - - auto* x_data = x.mutable_data(); - auto* w_data = w.mutable_data(); - auto* b_data = with_bias ? b.mutable_data() : nullptr; - - auto* out_data = out.mutable_data(); - auto* ref_data = ref.mutable_data(); - - FillData(x_data, x.dims().production()); - FillData(w_data, w.dims().production()); - FillData(out_data, out.dims().production(), 0, 0); - FillData(ref_data, ref.dims().production(), 0, 0); - - if (with_bias) { - FillData(b_data, b.dims().production()); - } - - FcCompute fc; - operators::FcParam param; - - param.input = &x; - param.w = &w; - param.bias = with_bias ? &b : nullptr; - param.output = &out; - param.in_num_col_dims = 1; - param.in_mat_dims = x.dims(); - - DeviceInfo::Init(); - std::unique_ptr ctx(new KernelContext); - ctx->As(); - fc.SetParam(param); - fc.SetContext(std::move(ctx)); - fc.PrepareForRun(); - fc.Run(); - - gemm_bias(x_data, m, k, w_data, k, n, b_data, ref_data); - - for (int i = 0; i < out.dims().production(); i++) { - EXPECT_NEAR(out_data[i], ref_data[i], 1e-3); - } - } - } - } - } -} - -TEST(fc_arm, num_col_dims) { - using T = float; - - for (bool with_bias : {true, false}) { - lite::Tensor x, w, b, out, ref; - - x.Resize({1, 2, 3}); - w.Resize({3, 4}); - b.Resize({1, 4}); - out.Resize({2, 4}); - ref.Resize({2, 4}); - - auto* x_data = x.mutable_data(); - auto* w_data = w.mutable_data(); - auto* b_data = with_bias ? b.mutable_data() : nullptr; - - auto* out_data = out.mutable_data(); - auto* ref_data = ref.mutable_data(); - - FillData(x_data, x.dims().production()); - FillData(w_data, w.dims().production()); - FillData(out_data, out.dims().production(), 0, 0); - FillData(ref_data, ref.dims().production(), 0, 0); - if (with_bias) { - FillData(b_data, b.dims().production()); - } - FcCompute fc; - operators::FcParam param; - param.input = &x; - param.w = &w; - param.bias = with_bias ? &b : nullptr; - param.output = &out; - param.in_num_col_dims = 2; - param.in_mat_dims = x.dims(); - - std::unique_ptr ctx(new KernelContext); - ctx->As(); - DeviceInfo::Init(); - - fc.SetParam(param); - fc.SetContext(std::move(ctx)); - fc.PrepareForRun(); - fc.Run(); - - gemm_bias(x_data, 2, 3, w_data, 3, 4, b_data, ref_data); - - for (int i = 0; i < out.dims().production(); i++) { - EXPECT_NEAR(out_data[i], ref_data[i], 1e-3); - } - } -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/fill_constant_compute.cc b/lite/kernels/arm/fill_constant_compute.cc index 1e4a58fc970cfe99d318e810ba07301d313d1814..ca7629f84f0200332d8ed0864792ae7bde46f7be 100644 --- a/lite/kernels/arm/fill_constant_compute.cc +++ b/lite/kernels/arm/fill_constant_compute.cc @@ -38,6 +38,31 @@ class FillConstantCompute : public KernelLite { virtual ~FillConstantCompute() = default; }; +template +class FillConstantBatchLikeCompute + : public KernelLite { + public: + using param_t = operators::FillConstantBatchLikeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + + if (param.input->lod().size() && param.input_dim_idx == 0) { + auto odims = param.out->dims(); + odims[param.output_dim_idx] = param.input->lod().back().size() - 1; + param.out->Resize(odims); + } + + auto data = param.out->template mutable_data(); + for (int i = 0; i < param.out->numel(); i++) { + data[i] = param.value; + } + } + + virtual ~FillConstantBatchLikeCompute() = default; +}; + } // namespace arm } // namespace kernels } // namespace lite @@ -52,3 +77,13 @@ REGISTER_LITE_KERNEL(fill_constant, def) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +REGISTER_LITE_KERNEL( + fill_constant_batch_size_like, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::FillConstantBatchLikeCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/gather_compute.cc b/lite/kernels/arm/gather_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..a46a6f9d6ab4850506c681ac3ca80e23d18b97d4 --- /dev/null +++ b/lite/kernels/arm/gather_compute.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lite/kernels/arm/gather_compute.h" +#include +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void GatherCompute::PrepareForRun() {} + +void GatherCompute::Run() { + auto& param = this->Param(); + + auto* p_output = param.Out->mutable_data(); + auto index_size = param.Index->dims()[0]; + auto src_dims = param.X->dims(); + const float* p_src = param.X->data(); + const float* p_index = param.Index->data(); + + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) { + slice_size *= src_dims[i]; + } + for (int i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + i * slice_size, + p_src + index_ * slice_size, + slice_size * sizeof(float)); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/gather_compute.h b/lite/kernels/arm/gather_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..eb667f132b7975de4f74a43ae24475153aca058e --- /dev/null +++ b/lite/kernels/arm/gather_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { +class GatherCompute : public KernelLite { + public: + using param_t = operators::GatherParam; + + void PrepareForRun() override; + + void Run() override; + + ~GatherCompute() {} +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/increment_compute.cc b/lite/kernels/arm/increment_compute.cc index fd548f91f9537cd2b168558f04c27d7f83c1ea28..2cf66805263ca5ee82174421ca037f72f4527b87 100644 --- a/lite/kernels/arm/increment_compute.cc +++ b/lite/kernels/arm/increment_compute.cc @@ -28,8 +28,8 @@ void IncrementCompute::Run() { int total_num = param.X->dims().production(); - const auto* x_data = param.X->data(); - auto* o_data = param.Out->mutable_data(); + const auto* x_data = param.X->data(); + auto* o_data = param.Out->mutable_data(); lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx); } diff --git a/lite/kernels/arm/layer_norm_compute.cc b/lite/kernels/arm/layer_norm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..71823ed707c8c74388615a80780f893a6c551d61 --- /dev/null +++ b/lite/kernels/arm/layer_norm_compute.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/layer_norm_compute.h" +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void LayerNormCompute::PrepareForRun() {} + +void LayerNormCompute::Run() { + auto& param = this->Param(); + + auto input_dims = param.X->dims(); + + const auto* x_data = param.X->data(); + const auto* scale = param.Scale ? param.Scale->data() : nullptr; + const auto* bias = param.Bias ? param.Bias->data() : nullptr; + auto* o_data = param.Y->mutable_data(); + auto* mean = param.Mean->mutable_data(); + auto* var = param.Variance->mutable_data(); + + int axis = param.begin_norm_axis; + auto matrix_dim = param.X->dims().Flatten2D(axis); + int left = matrix_dim[0]; + int right = matrix_dim[1]; + + lite::arm::math::matrix_norm_row( + x_data, scale, bias, o_data, mean, var, param.epsilon, left, right); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(layer_norm, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::LayerNormCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/layer_norm_compute.h b/lite/kernels/arm/layer_norm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..186234dbcbdce71a3aa20c7c28fadfe4f2625cb6 --- /dev/null +++ b/lite/kernels/arm/layer_norm_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class LayerNormCompute : public KernelLite { + public: + using param_t = operators::LayerNormParam; + + void PrepareForRun() override; + + void Run() override; + + ~LayerNormCompute() {} +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/layer_norm_compute_test.cc b/lite/kernels/arm/layer_norm_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..22fe3d06569fac424ab797712142b4d088dc7d3a --- /dev/null +++ b/lite/kernels/arm/layer_norm_compute_test.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/layer_norm_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void LayerNormComputeRef(const operators::LayerNormParam& param) { + auto* x = param.X; + auto* y = param.Y; + auto* scale_tensor = param.Scale; + auto* bias_tensor = param.Bias; + auto* mean_tensor = param.Mean; + auto* var_tensor = param.Variance; + + int begin_norm_axis = param.begin_norm_axis; + float epsilon = param.epsilon; + + auto* x_data = x->data(); + auto* scale_data = + (scale_tensor == nullptr ? nullptr : scale_tensor->data()); + auto* bias_data = + (bias_tensor == nullptr ? nullptr : bias_tensor->data()); + auto* out_data = y->mutable_data(); + auto* mean_data = mean_tensor->mutable_data(); + auto* var_data = var_tensor->mutable_data(); + + auto matrix_dim = x->dims().Flatten2D(begin_norm_axis); + int batch_size = matrix_dim[0]; + int feature_size = matrix_dim[1]; + for (int i = 0; i < batch_size; ++i) { + int start = i * feature_size; + int end = start + feature_size; + + float mean = 0; + float var = 0; + for (int j = start; j < end; ++j) { + mean += x_data[j]; + var += x_data[j] * x_data[j]; + } + mean /= feature_size; + var = var / feature_size - mean * mean; + mean_data[i] = mean; + var_data[i] = var; + var = sqrt(var + epsilon); + for (int j = start; j < end; ++j) { + out_data[j] = (x_data[j] - mean) / var; + if (scale_data) { + out_data[j] *= scale_data[j - start]; + } + if (bias_data) { + out_data[j] += bias_data[j - start]; + } + } + } +} + +TEST(layer_norm_arm, init) { + LayerNormCompute layer_norm; + ASSERT_EQ(layer_norm.precision(), PRECISION(kFloat)); + ASSERT_EQ(layer_norm.target(), TARGET(kARM)); +} + +TEST(layer_norm_arm, compute) { + LayerNormCompute layer_norm; + operators::LayerNormParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_mean; + lite::Tensor output_var; + lite::Tensor output_ref; + lite::Tensor output_mean_ref; + lite::Tensor output_var_ref; + lite::Tensor bias; + lite::Tensor scale; + lite::Tensor* bias_ptr; + lite::Tensor* scale_ptr; + + for (auto n : {1, 3}) { + for (auto c : {1, 3, 5}) { + for (auto h : {3, 16, 20, 32}) { + for (auto w : {3, 16, 20, 32}) { + for (auto axis : {0, 1, 2}) { + for (auto has_bias : {true, false}) { + for (auto has_scale : {true, false}) { + auto dims = DDim(std::vector({n, c, h, w})); + auto out_size = dims.Flatten2D(axis)[0]; + auto inner_size = dims.Flatten2D(axis)[1]; + bias_ptr = nullptr; + scale_ptr = nullptr; + if (has_bias) { + bias.Resize(std::vector({inner_size, 1, 1, 1})); + float* bias_data = bias.mutable_data(); + for (int i = 0; i < inner_size; ++i) { + bias_data[i] = 0.01; + } + bias_ptr = &bias; + } + if (has_scale) { + scale.Resize(std::vector({inner_size, 1, 1, 1})); + float* scale_data = scale.mutable_data(); + for (int i = 0; i < inner_size; ++i) { + scale_data[i] = 0.2; + } + scale_ptr = &scale; + } + + x.Resize(dims); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + output_mean.Resize(std::vector({out_size, 1, 1, 1})); + output_mean_ref.Resize( + std::vector({out_size, 1, 1, 1})); + output_var.Resize(std::vector({out_size, 1, 1, 1})); + output_var_ref.Resize( + std::vector({out_size, 1, 1, 1})); + + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_mean_data = output_mean.mutable_data(); + auto* output_var_data = output_var.mutable_data(); + auto* output_data_ref = output_ref.mutable_data(); + auto* output_mean_data_ref = + output_mean_ref.mutable_data(); + auto* output_var_data_ref = + output_var_ref.mutable_data(); + + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i % 255 * 0.001; + } + param.X = &x; + param.Y = &output; + param.begin_norm_axis = axis; + param.Bias = bias_ptr; + param.Scale = scale_ptr; + param.Mean = &output_mean; + param.Variance = &output_var; + param.epsilon = 0.00001; + layer_norm.SetParam(param); + layer_norm.Run(); + + param.Y = &output_ref; + param.Mean = &output_mean_ref; + param.Variance = &output_var_ref; + LayerNormComputeRef(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_data_ref[i], 1e-4); + } + for (int i = 0; i < output_mean.dims().production(); ++i) { + EXPECT_NEAR( + output_mean_data[i], output_mean_data_ref[i], 1e-5); + EXPECT_NEAR(output_var_data[i], output_var_data_ref[i], 1e-5); + } + } + } + } + } + } + } + } +} + +TEST(layer_norm, retrive_op) { + auto layer_norm = + KernelRegistry::Global().Create( + "layer_norm"); + ASSERT_FALSE(layer_norm.empty()); + ASSERT_TRUE(layer_norm.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(layer_norm, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/logical_compute.cc b/lite/kernels/arm/logical_compute.cc index c1cef5c2ced4fca239bf000fc1617509d73f01f5..1e47329d8ff65f3d036fd4a8a653cfe5cdc80a3a 100644 --- a/lite/kernels/arm/logical_compute.cc +++ b/lite/kernels/arm/logical_compute.cc @@ -82,28 +82,25 @@ void UnaryLogicalCompute::Run() { } // namespace kernels } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - logical_xor, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::BinaryLogicalCompute< - paddle::lite::kernels::arm::_LogicalXorFunctor>, - // paddle::lite::kernels::arm::BinaryLogicalCompute>, - def) + +REGISTER_LITE_KERNEL(logical_xor, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::BinaryLogicalCompute< + paddle::lite::kernels::arm::_LogicalXorFunctor>, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .Finalize(); -REGISTER_LITE_KERNEL( - logical_and, - kARM, - kFloat, - kNCHW, - // paddle::lite::kernels::arm::BinaryLogicalCompute>, - paddle::lite::kernels::arm::BinaryLogicalCompute< - paddle::lite::kernels::arm::_LogicalAndFunctor>, - def) +REGISTER_LITE_KERNEL(logical_and, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::BinaryLogicalCompute< + paddle::lite::kernels::arm::_LogicalAndFunctor>, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) diff --git a/lite/kernels/arm/lookup_table_compute.cc b/lite/kernels/arm/lookup_table_compute.cc index d39d7ccb60d0e69ecc2a8f3278bdf032b2d8fb16..fa7e2c0c3ae4580f5d19e82f7c48c74db3058847 100644 --- a/lite/kernels/arm/lookup_table_compute.cc +++ b/lite/kernels/arm/lookup_table_compute.cc @@ -38,13 +38,14 @@ void LookupTableCompute::Run() { auto table_dim = w->dims(); int64_t ids_numel = ids->numel(); auto ids_data = ids->data(); - int ids_int = ids_data[0]; + int64_t row_number = table_dim[0]; int64_t row_width = table_dim[1]; auto table_data = w->data(); auto dout = out->mutable_data(); for (int64_t i = 0; i < ids_numel; ++i) { + int ids_int = ids_data[i]; if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) { memset(dout + i * row_width, 0, row_width * sizeof(float)); } else { diff --git a/lite/kernels/arm/lrn_compute_test.cc b/lite/kernels/arm/lrn_compute_test.cc index 03683aa21282b7cf5aff2db1b3d705df0d4f354d..8e030006151c5834a68037800192ec7d9bc5d94d 100644 --- a/lite/kernels/arm/lrn_compute_test.cc +++ b/lite/kernels/arm/lrn_compute_test.cc @@ -14,6 +14,7 @@ #include "lite/kernels/arm/lrn_compute.h" #include +#include #include #include #include "lite/core/op_registry.h" diff --git a/lite/kernels/arm/mul_compute.cc b/lite/kernels/arm/mul_compute.cc index d0ae2d0df165070e39b5193a61bfedfae69ec6c8..fa43b6cf8e5d7418583d44d2ed9b6e49d128d2d6 100644 --- a/lite/kernels/arm/mul_compute.cc +++ b/lite/kernels/arm/mul_compute.cc @@ -56,7 +56,7 @@ void MulCompute::Run() { } else { constexpr bool is_tranposed_y = false; auto& ctx = this->ctx_->template As(); - int hblock = lite::arm::math::get_hblock(ctx.arch()); + int hblock = lite::arm::math::get_hblock(&ctx); int m_round = hblock * ((m_ + hblock - 1) / hblock); ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); diff --git a/lite/kernels/arm/norm_compute.cc b/lite/kernels/arm/norm_compute.cc index 3cc1645fc6823c4c3276cd1f22f4be8a584d2073..fb8b4bbe0773b808a0f6942d1120ebc7d4e844d2 100644 --- a/lite/kernels/arm/norm_compute.cc +++ b/lite/kernels/arm/norm_compute.cc @@ -47,4 +47,5 @@ REGISTER_LITE_KERNEL( norm, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::NormCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Norm", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/prior_box_compute.cc b/lite/kernels/arm/prior_box_compute.cc index 203f483351a27650b8ccb27c930dcdc5d2a5ceb9..48ae1e94dd74453c9160604a94ecfabd0e516034 100644 --- a/lite/kernels/arm/prior_box_compute.cc +++ b/lite/kernels/arm/prior_box_compute.cc @@ -65,6 +65,7 @@ void PriorBoxCompute::Run() { size_t prior_num = aspect_ratios_vec.size() * min_size.size(); prior_num += max_size.size(); std::vector order = param.order; + bool min_max_aspect_ratios_order = param.min_max_aspect_ratios_order; lite::arm::math::prior_box(param.input, param.image, @@ -82,7 +83,8 @@ void PriorBoxCompute::Run() { prior_num, is_flip, is_clip, - order); + order, + min_max_aspect_ratios_order); } } // namespace arm diff --git a/lite/kernels/arm/range_compute.cc b/lite/kernels/arm/range_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4629ac2de7af2965f35c3778e29c076fc515f87 --- /dev/null +++ b/lite/kernels/arm/range_compute.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/range_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void RangeCompute::Run() { + auto& param = Param(); + // int start = static_cast(param.Start->data()[0]); + // int end = static_cast(param.End->data()[0]); + // int step = static_cast(param.Step->data()[0]); + int start = (param.Start->data()[0]); + int end = (param.End->data()[0]); + int step = (param.Step->data()[0]); + + float* out_data = param.Out->mutable_data(); + float value = start; + for (int i = 0; i < param.Out->dims().production(); ++i) { + out_data[i] = value; + value += step; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + range, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::RangeCompute, def) + .BindInput("Start", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("End", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Step", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/range_compute.h b/lite/kernels/arm/range_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..3713fadca1a35bd4b473066cf5dfd903571152c6 --- /dev/null +++ b/lite/kernels/arm/range_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class RangeCompute : public KernelLite { + public: + void Run() override; + + virtual ~RangeCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/read_from_array_compute.cc b/lite/kernels/arm/read_from_array_compute.cc index 945ada8c65abf53bd590247bdfca2ccc23eb1304..43fcca4221bff188bf37caed33bbc9dba2e2b965 100644 --- a/lite/kernels/arm/read_from_array_compute.cc +++ b/lite/kernels/arm/read_from_array_compute.cc @@ -28,14 +28,13 @@ void ReadFromArrayCompute::Run() { int in_num = param.X->size(); CHECK_EQ(param.I->numel(), 1) << "I should have only one element"; - int id = param.I->data()[0]; + int id = param.I->data()[0]; CHECK_LE(id, in_num) << "id is not valid"; int input_size = (*param.X)[id].numel(); param.Out->Resize((*param.X)[id].dims()); - auto* o_data = param.Out->mutable_data(); - const auto* x_data = (*param.X)[id].data(); - memcpy(o_data, x_data, sizeof(float) * input_size); + param.Out->CopyDataFrom((*param.X)[id]); + auto out_lod = param.Out->mutable_lod(); *out_lod = (*param.X)[id].lod(); } diff --git a/lite/kernels/arm/softmax_compute_test.cc b/lite/kernels/arm/softmax_compute_test.cc index 5a883e4ebe6aaf1f7a8eb6f25815725fd7ea8e87..459112d8c0169375584baf0cb983037682e47a3d 100644 --- a/lite/kernels/arm/softmax_compute_test.cc +++ b/lite/kernels/arm/softmax_compute_test.cc @@ -14,6 +14,7 @@ #include "lite/kernels/arm/softmax_compute.h" #include +#include #include #include #include "lite/core/op_registry.h" diff --git a/lite/kernels/arm/topk_compute.cc b/lite/kernels/arm/topk_compute.cc index 994ef3f8dd00c0bf0c9c5f64025b23195462ce5f..c1abd42b41e7d15effd0d7c62f00c2460e54a793 100644 --- a/lite/kernels/arm/topk_compute.cc +++ b/lite/kernels/arm/topk_compute.cc @@ -43,5 +43,6 @@ REGISTER_LITE_KERNEL( top_k, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::TopkCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Indices", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Indices", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .Finalize(); diff --git a/lite/kernels/arm/unsqueeze_compute.cc b/lite/kernels/arm/unsqueeze_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..3dc7a274df609b7a96fdcc8978d5cd2e98ac5c93 --- /dev/null +++ b/lite/kernels/arm/unsqueeze_compute.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/unsqueeze_compute.h" +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void UnsqueezeCompute::Run() { + auto& param = Param(); + auto x = param.X; + auto output = param.Out; + auto x_dims = x->dims(); + auto* x_data = x->data(); + auto* out_data = output->mutable_data(); + memcpy(out_data, x_data, x_dims.production() * sizeof(float)); +} + +void Unsqueeze2Compute::Run() { + auto& param = Param(); + auto x = param.X; + auto output = param.Out; + auto xshape = param.XShape; + auto x_dims = x->dims(); + auto* x_data = x->data(); + auto* out_data = output->mutable_data(); + auto* xshape_data = xshape->mutable_data(); + memcpy(out_data, x_data, x_dims.production() * sizeof(float)); + memcpy(xshape_data, x_data, x_dims.production() * sizeof(float)); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(unsqueeze, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::host::UnsqueezeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL(unsqueeze2, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::host::Unsqueeze2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/unsqueeze_compute.h b/lite/kernels/arm/unsqueeze_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..57d4c657f682e130f8eab830222d9b0eeec8a367 --- /dev/null +++ b/lite/kernels/arm/unsqueeze_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class UnsqueezeCompute : public KernelLite { + public: + void Run() override; + + virtual ~UnsqueezeCompute() = default; +}; + +class Unsqueeze2Compute : public KernelLite { + public: + void Run() override; + + virtual ~Unsqueeze2Compute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/while_compute.cc b/lite/kernels/arm/while_compute.cc index ab3da93acc4d6a29f1a7b41c7dc43e4c05f59b88..00b37b2db9512adfe0d465dcbb9c76af78d32486 100644 --- a/lite/kernels/arm/while_compute.cc +++ b/lite/kernels/arm/while_compute.cc @@ -46,7 +46,7 @@ void WhileCompute::Run() { REGISTER_LITE_KERNEL( while, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::WhileCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindInput("Condition", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) diff --git a/lite/kernels/arm/write_to_array_compute.cc b/lite/kernels/arm/write_to_array_compute.cc index 42498e77f2c1ed8c1caf7f8640dfcbef5f004c08..ee68442ffcd0a5c12f3659e0739715c2128ece28 100644 --- a/lite/kernels/arm/write_to_array_compute.cc +++ b/lite/kernels/arm/write_to_array_compute.cc @@ -28,7 +28,7 @@ void WriteToArrayCompute::Run() { CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; const auto* x_data = param.X->data(); - int id = param.I->data()[0]; + int id = param.I->data()[0]; int id_test = param.I->data()[0]; if (id >= param.Out->size()) { for (int i = param.Out->size(); i < id + 1; i++) { @@ -57,5 +57,5 @@ REGISTER_LITE_KERNEL(write_to_array, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 6623894ec7d796ee21baf0eb1a0c922b395da737..8aa084c243eb4f6b4ae3fc8b1aab408a14624c52 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -4,23 +4,36 @@ endif() message(STATUS "compile with lite CUDA kernels") -nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) -lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) -nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) +add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose) +add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) +add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise) +add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) +add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale) +add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale) +add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) -nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) -lite_cc_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) - -lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) -nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) -lite_cc_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) - -set(cuda_kernels -mul_compute_cuda -io_copy_compute_cuda -leaky_relu_compute_cuda -nearest_interp_compute_cuda -yolo_box_compute_cuda -) - -set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels") +lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) +nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) +nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) +nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) +nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) +nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) +nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) +nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) +nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) +nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda) +#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) +nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) +nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) +nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) diff --git a/lite/kernels/cuda/bilinear_interp_compute.cu b/lite/kernels/cuda/bilinear_interp_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..7e1dbaf228c31d8123e48832e93e0180c4920359 --- /dev/null +++ b/lite/kernels/cuda/bilinear_interp_compute.cu @@ -0,0 +1,195 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/bilinear_interp_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +template +__global__ void BilinearInterp(const T* in, + const size_t in_img_h, + const size_t in_img_w, + const size_t input_h, + const size_t input_w, + T* out, + const size_t out_img_h, + const size_t out_img_w, + const size_t output_h, + const size_t output_w, + const size_t num_channels, + const float ratio_h, + const float ratio_w, + const bool align_corners, + const int align_mode) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id = out_id_w / out_img_size; + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int out_img_idx = tid % out_img_w; + + int in_img_idy = align_flag + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h : 0; + T h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } +} + +void BilinearInterpCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + Tensor* input = param.X; + Tensor* output = param.Out; + Tensor* out_size = param.OutSize; + + auto* input_data = input->data(); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + int out_h = param.out_h; + int out_w = param.out_w; + float scale = param.scale; + bool align_corners = param.align_corners; + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + if (out_size != nullptr) { + Tensor sizes; + float* size_data = sizes.mutable_data(); + float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); + cudaMemcpy( + size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); + out_h = static_cast(size_data[0]); + out_w = static_cast(size_data[1]); + } + + auto output_data = output->mutable_data(TARGET(kCUDA)); + + if (in_h == out_h && in_w == out_w) { + cudaMemcpy(output_data, + input_data, + sizeof(float) * n * c * in_h * in_w, + cudaMemcpyHostToDevice); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + int pixel_num = n * out_chw; + int threads = 512; + int blocks = (pixel_num + threads - 1) / threads; + blocks = blocks > 8 ? 8 : blocks; + int align_mode = param.align_mode; + + BilinearInterp<<>>(input_data, + in_h, + in_w, + n, + in_chw, + output_data, + out_h, + out_w, + n, + out_chw, + c, + ratio_h, + ratio_w, + align_corners, + align_mode); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(bilinear_interp, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::BilinearInterpCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/bilinear_interp_compute.h b/lite/kernels/cuda/bilinear_interp_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..333e67f8fff373a84ac9f3a19fc57214376bd34f --- /dev/null +++ b/lite/kernels/cuda/bilinear_interp_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class BilinearInterpCompute + : public KernelLite { + public: + using param_t = operators::InterpolateParam; + + void Run() override; + virtual ~BilinearInterpCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/bilinear_interp_compute_test.cc b/lite/kernels/cuda/bilinear_interp_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7e8143150d2963fb4cb74c3530cfd6e125a454c --- /dev/null +++ b/lite/kernels/cuda/bilinear_interp_compute_test.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/bilinear_interp_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; + +TEST(bilinear_interp, normal) { + BilinearInterpCompute bilinear_interp_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::InterpolateParam param; + + Tensor x, osz, out; + Tensor x_cpu, osz_cpu, out_cpu; + Tensor x_ref, osz_ref, out_ref; + + int n = 1, c = 1, in_h = 3, in_w = 3; + int out_h = 6, out_w = 6; + float scale = 2.0; + + param.out_h = out_h; + param.out_w = out_w; + param.scale = scale; + param.align_corners = false; + param.align_mode = 0; + + x.Resize({n, c, in_h, in_w}); + osz.Resize({2}); + out.Resize({n, c, out_h, out_w}); + + x_cpu.Resize({n, c, in_h, in_w}); + osz_cpu.Resize({2}); + out_cpu.Resize({n, c, out_h, out_w}); + + x_ref.Resize({n, c, in_h, in_w}); + osz_ref.Resize({2}); + out_ref.Resize({n, c, out_h, out_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* osz_cpu_data = osz_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + float* x_ref_data = x_ref.mutable_data(); + float* osz_ref_data = osz_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + osz_cpu_data[0] = out_h; + osz_cpu_data[1] = out_w; + osz_ref_data[0] = out_h; + osz_ref_data[1] = out_w; + + x.Assign(x_cpu_data, x_cpu.dims()); + osz.Assign(osz_cpu_data, osz_cpu.dims()); + + param.X = &x; + param.OutSize = &osz; + param.Out = &out; + bilinear_interp_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + bilinear_interp_kernel.SetContext(std::move(ctx)); + bilinear_interp_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/calib_compute.cu b/lite/kernels/cuda/calib_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..77f233e00ed1b2bf5a7a61e8ca6fcd83c2f36f3f --- /dev/null +++ b/lite/kernels/cuda/calib_compute.cu @@ -0,0 +1,144 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/cuda/math/utils.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/kernels/cuda/calib_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +__global__ void Fp32ToInt8Kernel(const int num, + const float scale, + const float* input, + int8_t* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = lite::cuda::math::from_float(input[index] / scale); + } +} + +__global__ void Int8ToFp32Kernel(const int num, + const float scale, + const int8_t* input, + float* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = input[index] * scale; + } +} + +void CalibComputeFp32ToInt8::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + auto stream = ctx.exec_stream(); + + auto scale = param.scale; + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(TARGET(kCUDA)); + int num = static_cast(param.input->numel()); + int threads = 1024; + int blocks = (num + threads - 1) / threads; + Fp32ToInt8Kernel<<>>(num, scale, din, dout); + cudaError_t error = cudaGetLastError(); + CHECK(error == cudaSuccess) << cudaGetErrorString(error); +} + +void CalibComputeInt8ToFp32::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + auto stream = ctx.exec_stream(); + + auto scale = param.scale; + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(TARGET(kCUDA)); + int num = static_cast(param.input->numel()); + int threads = 1024; + int blocks = (num + threads - 1) / threads; + Int8ToFp32Kernel<<>>(num, scale, din, dout); + cudaError_t error = cudaGetLastError(); + CHECK(error == cudaSuccess) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(calib, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kAny))}) + .Finalize(); + +REGISTER_LITE_KERNEL(calib, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kAny))}) + .Finalize(); + +REGISTER_LITE_KERNEL(calib_once, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kAny))}) + .Finalize(); +REGISTER_LITE_KERNEL(calib_once, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kAny))}) + .Finalize(); diff --git a/lite/kernels/cuda/calib_compute.h b/lite/kernels/cuda/calib_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..ab5a03e90c52ec88be4809909a6588f1da20be0f --- /dev/null +++ b/lite/kernels/cuda/calib_compute.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/calib_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class CalibComputeFp32ToInt8 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + virtual ~CalibComputeFp32ToInt8() = default; + + std::string doc() const override { return "Fp32 --> Int8"; } +}; + +class CalibComputeInt8ToFp32 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + virtual ~CalibComputeInt8ToFp32() = default; + + std::string doc() const override { return "Int8 --> Fp32"; } +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/calib_compute_cuda_test.cc b/lite/kernels/cuda/calib_compute_cuda_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8703d8730a1880b5b93502e5095b1a17d03bee6c --- /dev/null +++ b/lite/kernels/cuda/calib_compute_cuda_test.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +static void int8_to_fp32_basic(const int8_t* din, + float* dout, + const float scale, + int num) { + for (int j = 0; j < num; ++j) { + dout[j] = din[j] * scale; + } +} + +static void fp32_to_int8_basic(const float* din, + int8_t* dout, + const float scale, + int num) { + for (int j = 0; j < num; ++j) { + auto v = din[j] / scale; + v = std::max(v, static_cast(INT8_MIN)); + v = std::min(v, static_cast(INT8_MAX)); + v = roundf(v); + dout[j] = static_cast(v); + } +} + +void calib_ref(const operators::CalibParam& param, bool to_float = true) { + auto scale = param.scale; + if (to_float) { + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + int8_to_fp32_basic(din, dout, scale, param.input->numel()); + } else { + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + fp32_to_int8_basic(din, dout, scale, param.input->numel()); + } +} + +TEST(calib_cuda, int8_to_fp32) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto calib = std::move(*std::next(kernels.begin(), 1)); + LOG(INFO) << "get kernel: " << calib->doc(); + const int n = 64, c = 32, h = 18, w = 18; + Tensor x; + Tensor x_cpu; + Tensor output; + Tensor output_cpu; + // set the dims of input, output tensors + x.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_cpu.Resize({n, c, h, w}); + // initialize the data of input tensors + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_cpu_data[i] = static_cast(sign * (i % 127)); + } + x.Assign(x_cpu_data, x_cpu.dims()); + // prepare kernel params and run + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + calib->SetContext(std::move(ctx)); + + operators::CalibParam param; + param.scale = 0.013f; + param.input = &x; + param.output = &output; + calib->SetParam(param); + calib->Launch(); + cudaDeviceSynchronize(); + // invoking ref implementation and compare results + param.input = &x_cpu; + param.output = &output_cpu; + calib_ref(param); + auto* output_data = output.mutable_data(); + std::unique_ptr output_gpu_copy(new float[output.numel()]); + CopySync(output_gpu_copy.get(), + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + const auto* output_cpu_data = output_cpu.data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_gpu_copy[i], output_cpu_data[i], 1e-5); + } +} + +TEST(calib_cuda, fp32_to_int8) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto calib = std::move(kernels.front()); + LOG(INFO) << "get kernel: " << calib->doc(); + const int n = 64, c = 32, h = 18, w = 18; + Tensor x; + Tensor x_cpu; + Tensor output; + Tensor output_cpu; + // set the dims of input, output tensors + x.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_cpu.Resize({n, c, h, w}); + // initialize the data of input tensors + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_cpu_data[i] = sign * (i % 127) * 0.013f; + } + x.Assign(x_cpu_data, x_cpu.dims()); + // prepare kernel params and run + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + calib->SetContext(std::move(ctx)); + + operators::CalibParam param; + param.scale = 0.013f; + param.input = &x; + param.output = &output; + calib->SetParam(param); + calib->Launch(); + cudaDeviceSynchronize(); + // invoking ref implementation and compare results + param.input = &x_cpu; + param.output = &output_cpu; + calib_ref(param, false); + auto* output_data = output.mutable_data(); + std::unique_ptr output_gpu_copy(new int8_t[output.numel()]); + CopySync(output_gpu_copy.get(), + output_data, + sizeof(int8_t) * output.numel(), + IoDirection::DtoH); + const auto* output_cpu_data = output_cpu.data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_EQ(output_gpu_copy[i], output_cpu_data[i]); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/concat_compute.cu b/lite/kernels/cuda/concat_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..89a5be142a931eeb5226130d499525f694548667 --- /dev/null +++ b/lite/kernels/cuda/concat_compute.cu @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/concat_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +template +__global__ void Concat(const int num, + const Dtype* in_data, + const int num_concats, + const int concat_size, + const int top_concat_axis, + const int bottom_concat_axis, + const int offset_concat_axis, + Dtype* out_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + const int total_concat_size = concat_size * bottom_concat_axis; + const int concat_num = index / total_concat_size; + const int concat_index = index % total_concat_size; + const int top_index = + concat_index + + (concat_num * top_concat_axis + offset_concat_axis) * concat_size; + out_data[top_index] = in_data[index]; + } +} + +template +void ConcatCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + std::vector input = param.x; + Tensor* output = param.output; + auto* output_data = output->mutable_data(TARGET(kCUDA)); + int axis = param.axis; + int inner_size = 1; + int outer_size = 1; + auto input_dims = input[0]->dims(); + for (int i = 0; i < axis; i++) { + outer_size *= input_dims[i]; + } + + for (int i = axis + 1; i < input_dims.size(); i++) { + inner_size *= input_dims[i]; + } + + int all_concat_axis = param.output->dims()[axis]; + int in_num = input.size(); + int offset_concat_axis = 0; + + for (int i = 0; i < in_num; i++) { + auto* input_data = input[i]->data(); + int input_concat_axis = input[i]->dims()[axis]; + int input_concat_size = input_concat_axis * inner_size; + int num = input_concat_size * outer_size; + int threads = 1024; + int blocks = (num + threads - 1) / threads; + Concat<<>>(num, + input_data, + outer_size, + inner_size, + all_concat_axis, + input_concat_axis, + offset_concat_axis, + output_data); + offset_concat_axis += input_concat_axis; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(concat, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ConcatCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/concat_compute.h b/lite/kernels/cuda/concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..9952cc4c89ad6141879fec4ef440d319f237067c --- /dev/null +++ b/lite/kernels/cuda/concat_compute.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class ConcatCompute + : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void Run() override; + virtual ~ConcatCompute() = default; +}; + +template +class ConcatComputeNHWC + : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void Run() override {} + virtual ~ConcatComputeNHWC() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/concat_compute_test.cc b/lite/kernels/cuda/concat_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc12fcd289d36c38f02663c6a7aaa0ec7c70653a --- /dev/null +++ b/lite/kernels/cuda/concat_compute_test.cc @@ -0,0 +1,230 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/concat_compute.h" +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +bool infer_shape(const operators::ConcatParam& param) { + std::vector input_dims; + for (auto p : param.x) { + input_dims.push_back(p->dims()); + } + size_t axis = static_cast(param.axis); + const size_t n = input_dims.size(); + CHECK_GT_OR_FALSE(n, 0); + auto& out_dims = input_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += input_dims[i][j]; + } else { + CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + } + } + } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } + // Set output dims + param.output->Resize(lite::DDim(out_dims)); + return true; +} + +void concat_compute_ref(const operators::ConcatParam& param) { + std::vector input = param.x; + int axis = param.axis; + infer_shape(param); + + lite::Tensor* output = param.output; + int num = input.size(); + int rows = 1; + auto dim_0 = input[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1; + for (int didx = 0; didx < input[i]->dims().size(); ++didx) { + input_i_numel *= input[i]->dims()[didx]; + } + int t_cols = input_i_numel / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + auto output_data = output->mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = input[j]->data(); + for (int k = 0; k < out_rows; ++k) { + memcpy(output_data + k * out_cols + col_idx, + input_data + k * col_len, + sizeof(float) * col_len); + } + col_idx += col_len; + } +} + +TEST(concat, init) { + ConcatCompute concat; + ASSERT_EQ(concat.precision(), PRECISION(kFloat)); + ASSERT_EQ(concat.target(), TARGET(kCUDA)); +} + +TEST(concat, compute_input_multi) { + ConcatCompute concat_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ConcatParam param; + operators::ConcatParam param_ref; + + LOG(INFO) << "test concat start"; + // init param + std::vector x; + std::vector x_cpu; + std::vector x_ref; + lite::Tensor out; + lite::Tensor out_cpu; + lite::Tensor out_ref; + lite::Tensor tensorA; + lite::Tensor tensorB; + lite::Tensor tensorC; + lite::Tensor tensorD; + lite::Tensor tensorA_cpu; + lite::Tensor tensorB_cpu; + lite::Tensor tensorC_cpu; + lite::Tensor tensorD_cpu; + lite::Tensor tensorA_ref; + lite::Tensor tensorB_ref; + lite::Tensor tensorC_ref; + lite::Tensor tensorD_ref; + + DDimLite ddimA({1, 3, 38, 38}); + DDimLite ddimB({1, 4, 38, 38}); + DDimLite ddimC({1, 5, 38, 38}); + DDimLite ddimD({1, 6, 38, 38}); + + tensorA.Resize(ddimA); + tensorB.Resize(ddimB); + tensorC.Resize(ddimC); + tensorD.Resize(ddimD); + tensorA_cpu.Resize(ddimA); + tensorB_cpu.Resize(ddimB); + tensorC_cpu.Resize(ddimC); + tensorD_cpu.Resize(ddimD); + tensorA_ref.Resize(ddimA); + tensorB_ref.Resize(ddimB); + tensorC_ref.Resize(ddimC); + tensorD_ref.Resize(ddimD); + + out.Resize({1, 18, 38, 38}); + out_cpu.Resize({1, 18, 38, 38}); + out_ref.Resize({1, 18, 38, 38}); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + auto* out_cpu_data = out_cpu.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < tensorA_cpu.numel(); i++) { + tensorA_cpu.mutable_data()[i] = i; + tensorA_ref.mutable_data()[i] = i; + } + for (int i = 0; i < tensorB_cpu.numel(); i++) { + tensorB_cpu.mutable_data()[i] = i + 3; + tensorB_ref.mutable_data()[i] = i + 3; + } + for (int i = 0; i < tensorC_cpu.numel(); i++) { + tensorC_cpu.mutable_data()[i] = i + 6; + tensorC_ref.mutable_data()[i] = i + 6; + } + for (int i = 0; i < tensorD_cpu.numel(); i++) { + tensorD_cpu.mutable_data()[i] = i + 9; + tensorD_ref.mutable_data()[i] = i + 9; + } + tensorA.Assign( + tensorA_cpu.mutable_data(), tensorA_cpu.dims()); + tensorB.Assign( + tensorB_cpu.mutable_data(), tensorB_cpu.dims()); + tensorC.Assign( + tensorC_cpu.mutable_data(), tensorC_cpu.dims()); + tensorD.Assign( + tensorD_cpu.mutable_data(), tensorD_cpu.dims()); + + x.push_back(&tensorA); + x.push_back(&tensorB); + x.push_back(&tensorC); + x.push_back(&tensorD); + x_cpu.push_back(&tensorA_cpu); + x_cpu.push_back(&tensorB_cpu); + x_cpu.push_back(&tensorC_cpu); + x_cpu.push_back(&tensorD_cpu); + x_ref.push_back(&tensorA_ref); + x_ref.push_back(&tensorB_ref); + x_ref.push_back(&tensorC_ref); + x_ref.push_back(&tensorD_ref); + + for (int cur_axis : {1}) { + param.x = x; + param.axis = cur_axis; + param.output = &out; + + concat_kernel.SetParam(param); + LOG(INFO) << "test concat start cur_axis:" << cur_axis; + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + concat_kernel.SetContext(std::move(ctx)); + concat_kernel.Launch(); + cudaDeviceSynchronize(); + LOG(INFO) << "sync end"; + CHECK(cudaSuccess == cudaMemcpy(out_cpu_data, + out_data, + sizeof(float) * out.numel(), + cudaMemcpyDeviceToHost)); + LOG(INFO) << "concat.Run end"; + + param_ref.x = x_ref; + param_ref.axis = cur_axis; + param_ref.output = &out_ref; + + LOG(INFO) << "concat_compute_ref start"; + concat_compute_ref(param_ref); + LOG(INFO) << "concat_compute_ref end"; + + for (int i = 0; i < out_ref.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/conv_compute.cc b/lite/kernels/cuda/conv_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..eea81602ddf94158250aecf01fe5e95193bf58c1 --- /dev/null +++ b/lite/kernels/cuda/conv_compute.cc @@ -0,0 +1,180 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/conv_compute.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +inline int ConvOutputSize( + int input_size, int filter_size, int dilation, int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + CHECK_GT_OR_FALSE(output_size, 0); + + return output_size; +} + +void ConvCompute::PrepareForRun() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + conv_impl_.reset(new lite::cuda::math::CudnnConv2D); + conv_impl_->init(param, &ctx); +} + +void ConvCompute::Run() { + auto& param = this->Param(); + conv_impl_->run(param); +} + +template +void ConvComputeInt8::PrepareForRun() { + auto& param = this->Param(); + + const auto in_dims = param.x->dims(); + const auto filter_dims = param.filter->dims(); + std::vector output_shape({in_dims[0]}); + + for (size_t i = 0; i < param.strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 1], + filter_dims[i + 1], + param.dilations[i], + param.paddings[i], + param.strides[i])); + } + output_shape.push_back(filter_dims[0]); + param.output->Resize(lite::DDim(output_shape)); + + auto& ctx = this->ctx_->template As(); + conv_impl_.reset(new lite::cuda::math::CudnnConv2DInt8); + conv_impl_->init(param, &ctx); +} + +template +void ConvComputeInt8::Run() { + auto& param = this->Param(); + const auto in_dims = param.x->dims(); + const auto filter_dims = param.filter->dims(); + std::vector output_shape({in_dims[0]}); + + for (size_t i = 0; i < param.strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 1], + filter_dims[i + 1], + param.dilations[i], + param.paddings[i], + param.strides[i])); + } + output_shape.push_back(filter_dims[0]); + param.output->Resize(lite::DDim(output_shape)); + + conv_impl_->run(param); +} + +template class ConvComputeInt8; +template class ConvComputeInt8; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ConvCompute, + def) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + conv2d, + kCUDA, + kInt8, + kNHWC, + paddle::lite::kernels::cuda::ConvComputeInt8, + fp32_out) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + conv2d, + kCUDA, + kInt8, + kNHWC, + paddle::lite::kernels::cuda::ConvComputeInt8, + int8_out) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/cuda/conv_compute.h b/lite/kernels/cuda/conv_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..71cf4b6331f302467d6c60aae20cc84dc3b0261b --- /dev/null +++ b/lite/kernels/cuda/conv_compute.h @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/cuda/math/cudnn_conv.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ConvCompute : public KernelLite { + public: + using param_t = operators::ConvParam; + + void PrepareForRun() override; + void Run() override; + virtual ~ConvCompute() = default; + + private: + std::unique_ptr> conv_impl_; +}; + +template +class ConvComputeInt8 + : public KernelLite { + public: + using param_t = operators::ConvParam; + + void PrepareForRun() override; + void Run() override; + virtual ~ConvComputeInt8() = default; + + private: + std::unique_ptr> conv_impl_; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/conv_compute_test.cc b/lite/kernels/cuda/conv_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..05175a0debcd687a2e5e06fa799839ad52c50adb --- /dev/null +++ b/lite/kernels/cuda/conv_compute_test.cc @@ -0,0 +1,252 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/conv_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +float random(float low, float high) { + static std::mt19937 mt(100); + std::uniform_real_distribution dist(low, high); + return dist(mt); +} + +TEST(conv_compute, fp32) { + ConvCompute conv_fp32; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam act_param; + act_param.has_active = true; + // act_param.active_type = core::ActiveType::Active_relu; + act_param.active_type = lite_api::ActivationType::kLeakyRelu; + act_param.Leaky_relu_alpha = 0.1; + operators::ConvParam param; + param.activation_param = act_param; + param.paddings = {1, 1}; + param.groups = 1; + + Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu; + int n = 1, c = 1, h = 3, w = 3; + int c_o = 1, h_o = 3, w_o = 3; + y.Resize({n, c_o, h_o, w_o}); + x_cpu.Resize({n, c, h, w}); + filter_cpu.Resize({c_o, c / param.groups, 3, 3}); + y_cpu.Resize({n, c_o, h_o, w_o}); + bias_cpu.Resize({c_o}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* filter_cpu_data = filter_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + float* bias_cpu_data = bias_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i; + } + std::vector weight = {-0.2209115, + -0.17199445, + -0.2059412, + 0.6763207, + -0.12260777, + -0.43123743, + -0.49696392, + -0.27471393, + -0.81017196}; + for (int i = 0; i < filter_cpu.numel(); i++) { + filter_cpu_data[i] = weight[i]; + } + for (int i = 0; i < bias_cpu.numel(); i++) { + bias_cpu_data[i] = 0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + filter.Assign(filter_cpu_data, + filter_cpu.dims()); + bias.Assign(bias_cpu_data, bias_cpu.dims()); + + param.x = &x; + param.filter = &filter; + param.output = &y; + // param.bias = &bias; + + conv_fp32.SetParam(param); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + conv_fp32.SetContext(std::move(ctx)); + conv_fp32.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + + std::vector real_results = {-0.8, -0.7}; + for (int i = 0; i < y.numel(); i++) { + LOG(INFO) << y_cpu_data[i]; + } +} + +TEST(conv_compute, int8) { + ConvComputeInt8 int8_conv_fp32out; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam act_param; + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; + operators::ConvParam param; + // param.activation_param = act_param; + param.groups = 1; + + Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu; + int n = 1, c = 4, h = 3, w = 3; + y.Resize({1, 1, 1, c}); + x_cpu.Resize({n, h, w, c}); + filter_cpu.Resize({c, 3, 3, c / param.groups}); + y_cpu.Resize({1, 1, 1, c}); + bias_cpu.Resize({c}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + auto* x_cpu_data = x_cpu.mutable_data(); + auto* filter_cpu_data = filter_cpu.mutable_data(); + auto* y_cpu_data = x_cpu.mutable_data(); + auto* bias_cpu_data = bias_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = static_cast(1); + } + for (int i = 0; i < filter_cpu.numel(); i++) { + filter_cpu_data[i] = static_cast(1); + } + for (int i = 0; i < bias_cpu.numel(); i++) { + bias_cpu_data[i] = i + 1.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + filter.Assign(filter_cpu_data, + filter_cpu.dims()); + bias.Assign(bias_cpu_data, + filter_cpu.dims()); + + param.x = &x; + param.filter = &filter; + param.output = &y; + param.weight_scale = {1, 2, 3, 4}; + + int8_conv_fp32out.SetParam(param); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + int8_conv_fp32out.SetContext(std::move(ctx)); + int8_conv_fp32out.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + std::vector real_results = {36, 72, 108, 144}; + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5); + } +} + +TEST(conv_compute, int8_int8_out) { + ConvComputeInt8 int8_conv_fp32out; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam act_param; + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; + // act_param.active_type = lite_api::ActivationType::kLeakyRelu; + act_param.Leaky_relu_alpha = 0.1; + operators::ConvParam param; + param.activation_param = act_param; + param.groups = 1; + + Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu; + int c_i = 3, h_i = 3, w_i = 3; + int n = 1, c = 4; + y.Resize({1, 1, 1, c}); + x_cpu.Resize({n, h_i, w_i, c_i}); + filter_cpu.Resize({c, 3, 3, c_i / param.groups}); + y_cpu.Resize({1, 1, 1, c}); + bias_cpu.Resize({c}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + auto* x_cpu_data = x_cpu.mutable_data(); + auto* filter_cpu_data = filter_cpu.mutable_data(); + auto* y_cpu_data = x_cpu.mutable_data(); + auto* bias_cpu_data = bias_cpu.mutable_data(); + + std::cout << "input" << std::endl; + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = static_cast(random(-36, 36)); + std::cout << float(x_cpu_data[i]) << std::endl; + } + std::cout << "filter" << std::endl; + for (int i = 0; i < filter_cpu.numel(); i++) { + filter_cpu_data[i] = static_cast(random(-10, 10)); + std::cout << float(filter_cpu_data[i]) << std::endl; + } + for (int i = 0; i < bias_cpu.numel(); i++) { + bias_cpu_data[i] = i + 1.0; + // bias_cpu_data[i] = 0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + filter.Assign(filter_cpu_data, + filter_cpu.dims()); + bias.Assign(bias_cpu_data, + filter_cpu.dims()); + + param.x = &x; + param.filter = &filter; + param.output = &y; + param.weight_scale = {0.01, 0.02, 0.03, 0.04}; + param.output_scale = 2; + param.bias = &bias; + + int8_conv_fp32out.SetParam(param); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + int8_conv_fp32out.SetContext(std::move(ctx)); + int8_conv_fp32out.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH); + + std::vector real_results = {0, 7, 8, 1}; + for (int i = 0; i < y.numel(); i++) { + // EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5); + LOG(INFO) << float(y_cpu_data[i]); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/dropout_compute.cc b/lite/kernels/cuda/dropout_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e3a3a62432f3bc5f2e62112b2b220abc17ee2bd --- /dev/null +++ b/lite/kernels/cuda/dropout_compute.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/dropout_compute.h" +#include +#include "lite/backends/cuda/math/scale.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void DropoutCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* out_data = param.output->mutable_data(TARGET(kCUDA)); + int num = param.x->dims().production(); + const float prob_data = param.dropout_prob; + float scale = 1.0f; + if (param.dropout_implementation == "downgrade_in_infer") { + scale = 1.0f - prob_data; + } + lite::cuda::math::scale(num, x_data, out_data, scale, 0); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(dropout, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::DropoutCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/dropout_compute.h b/lite/kernels/cuda/dropout_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..aec0d1966bc8b368484b5c810f133a8e9a6fb410 --- /dev/null +++ b/lite/kernels/cuda/dropout_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class DropoutCompute : public KernelLite { + public: + void Run() override; + + virtual ~DropoutCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/dropout_compute_test.cc b/lite/kernels/cuda/dropout_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e6ed54330c0a109091934ebe48ed341afcae96f9 --- /dev/null +++ b/lite/kernels/cuda/dropout_compute_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/dropout_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void dropout_compute_ref(const operators::DropoutParam& param) { + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + int num = param.x->dims().production(); + const float prob_data = param.dropout_prob; + if (param.dropout_implementation.compare( + std::string({"downgrade_in_infer"})) == 0) { + float scale = 1.0 - prob_data; + for (int i = 0; i < num; i++) { + output_data[i] = x_data[i] * scale; + } + } else { + for (int i = 0; i < num; i++) { + output_data[i] = x_data[i]; + } + } +} + +TEST(dropout_cuda, normal) { + DropoutCompute dropout_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::DropoutParam param; + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor x_ref; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4}) { + for (auto c : {1, 3, 4, 256}) { + for (auto h : {1, 3, 4, 6}) { + for (auto w : {1, 3, 4, 6}) { + for (auto prob : {0.2f, 0.8f}) + for (auto impl : {std::string({"downgrade_in_infer"})}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + x_ref.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_cpu.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + auto* output_data = output.mutable_data(TARGET(kCUDA)); + auto* output_cpu_data = output_cpu.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + + for (int i = 0; i < x.dims().production(); i++) { + x_cpu_data[i] = i; + x_ref_data[i] = i; + } + + x.Assign(x_cpu_data, + x_cpu.dims()); + + param.x = &x; + param.output = &output; + param.dropout_prob = prob; + param.dropout_implementation = impl; + dropout_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + dropout_kernel.SetContext(std::move(ctx)); + dropout_kernel.Launch(); + + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + + param.x = &x_ref; + param.output = &output_ref; + dropout_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_add_compute.cu b/lite/kernels/cuda/elementwise_add_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..4bacf532a2b67168679449200b1af721b7a282c8 --- /dev/null +++ b/lite/kernels/cuda/elementwise_add_compute.cu @@ -0,0 +1,139 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/backends/cuda/math/elementwise.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/elementwise_add_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void ElementwiseAddCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + const lite::Tensor* x = param.X; + const lite::Tensor* y = param.Y; + lite::Tensor* out = param.Out; + + CHECK(x->dims().production() == y->dims().production()); + + auto* x_data = x->data(); + auto* y_data = y->data(); + auto out_data = out->mutable_data(TARGET(kCUDA)); + + int pixel_num = x->numel(); + lite::cuda::math::elementwise_add( + pixel_num, x_data, y_data, out_data, stream); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddComputeNHWC::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + const lite::Tensor* x = param.X; + const lite::Tensor* y = param.Y; + lite::Tensor* out = param.Out; + + CHECK(x->dims().production() == y->dims().production()); + + auto* x_data = x->data(); + auto* y_data = y->data(); + auto out_data = out->mutable_data(TARGET(kCUDA)); + + int pixel_num = x->numel(); + lite::cuda::math::elementwise_add( + pixel_num, x_data, y_data, out_data, stream); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddComputeInt8::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + const lite::Tensor* x = param.X; + const lite::Tensor* y = param.Y; + lite::Tensor* out = param.Out; + + CHECK(x->dims().production() == y->dims().production()); + + const int c = x->dims()[3]; + + auto* x_data = x->data(); + auto* y_data = y->data(); + auto out_data = out->mutable_data(TARGET(kCUDA)); + + int pixel_num = x->numel(); + float output_scale = param.output_scale; + if (c % 4 == 0) { + lite::cuda::math::elementwise_add_nhwc4_int8( + pixel_num / 4, + static_cast(x_data), + static_cast(y_data), + 1. / output_scale, + static_cast(out_data), + stream); + } else { + lite::cuda::math::elementwise_add_int8( + pixel_num, x_data, y_data, 1. / output_scale, out_data, stream); + } + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(elementwise_add, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseAddCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_add, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/cuda/elementwise_add_compute.h b/lite/kernels/cuda/elementwise_add_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..5c3fecc5d894aeea2bc5260b1815bbfa718eb5c6 --- /dev/null +++ b/lite/kernels/cuda/elementwise_add_compute.h @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ElementwiseAddCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseAddCompute() = default; +}; + +class ElementwiseAddComputeNHWC + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseAddComputeNHWC() = default; +}; + +class ElementwiseAddComputeInt8 + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseAddComputeInt8() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_add_compute_test.cc b/lite/kernels/cuda/elementwise_add_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc63f1470b65de37eb73c71701a83146e12778ae --- /dev/null +++ b/lite/kernels/cuda/elementwise_add_compute_test.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/elementwise_add_compute.h" +#include +#include +#include +#include "lite/api/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; + +static void ElementwiseAddRef(float* x, float* y, float* out, int num) { + for (int i = 0; i < num; ++i) { + out[i] = x[i] + y[i]; + } +} + +TEST(elementwise_add, normal) { + ElementwiseAddCompute elementwise_add_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ElementwiseParam param; + Tensor x, y, out; + Tensor x_cpu, y_cpu, out_cpu; + Tensor x_ref, y_ref, out_ref; + + const int n = 1; + const int c = 3; + const int h = 2000; + const int w = 2000; + + x.Resize({n, c, h, w}); + y.Resize({n, c, h, w}); + out.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + y_cpu.Resize({n, c, h, w}); + out_cpu.Resize({n, c, h, w}); + x_ref.Resize({n, c, h, w}); + y_ref.Resize({n, c, h, w}); + out_ref.Resize({n, c, h, w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* y_cpu_data = y_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + + auto* x_ref_data = x_ref.mutable_data(); + auto* y_ref_data = y_ref.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + for (int i = 0; i < y_cpu.numel(); ++i) { + y_cpu_data[i] = i - 5.0; + y_ref_data[i] = i - 5.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + param.X = &x; + param.Y = &y; + param.Out = &out; + elementwise_add_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + elementwise_add_kernel.SetContext(std::move(ctx)); + elementwise_add_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + ElementwiseAddRef(x_ref_data, y_ref_data, out_ref_data, out.numel()); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(elementwise_add, int8_out) { + ElementwiseAddComputeInt8 elementwise_add_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ElementwiseParam param; + Tensor x, y, out; + Tensor x_cpu, y_cpu, out_cpu; + + const int n = 1; + const int h = 36; + const int w = 36; + const int c = 125; + + x.Resize({n, h, w, c}); + y.Resize({n, h, w, c}); + out.Resize({n, h, w, c}); + x_cpu.Resize({n, h, w, c}); + y_cpu.Resize({n, h, w, c}); + out_cpu.Resize({n, h, w, c}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* y_cpu_data = y_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + } + for (int i = 0; i < y_cpu.numel(); ++i) { + y_cpu_data[i] = i; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + param.X = &x; + param.Y = &y; + param.Out = &out; + param.output_scale = 50 / 127.; + elementwise_add_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + elementwise_add_kernel.SetContext(std::move(ctx)); + auto start = GetCurrentUS(); + for (int i = 0; i < 1000000; i++) { + elementwise_add_kernel.Launch(); + } + LOG(INFO) << "time: " << (GetCurrentUS() - start) / 1000000.; + + CopySync( + out_cpu_data, out_data, sizeof(int8_t) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out.numel(); i++) { + // LOG(INFO) << float(out_cpu_data[i]); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/feed_compute.cc b/lite/kernels/cuda/feed_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..cffa8a573d9b12b52ae1448632a56e40cea35b95 --- /dev/null +++ b/lite/kernels/cuda/feed_compute.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/feed_compute.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void FeedCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + VLOG(4) << "feed_list.size: " << param.feed_list->size(); + const lite::Tensor& feed_item = (*param.feed_list)[param.col]; + + int num = static_cast(feed_item.numel()); + auto input = feed_item.data(); + param.out->Resize(feed_item.dims()); + auto output = param.out->mutable_data(TARGET(kCUDA)); + VLOG(4) << "col: " << param.col << " num:" << num; + + TargetW::MemcpyAsync( + output, input, num * sizeof(float), IoDirection::HtoD, stream); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + feed, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::FeedCompute, nchw) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + feed, kCUDA, kFloat, kNHWC, paddle::lite::kernels::cuda::FeedCompute, nhwc) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/cuda/feed_compute.h b/lite/kernels/cuda/feed_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..0510404b2b6ad6c50f69c847bf833afbcfe59b99 --- /dev/null +++ b/lite/kernels/cuda/feed_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class FeedCompute : public KernelLite { + public: + using param_t = operators::FeedParam; + using TargetW = TargetWrapper; + + void Run() override; + virtual ~FeedCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/io_copy_compute.cc b/lite/kernels/cuda/io_copy_compute.cc index 011b7c802ac77ae5365c50f5025846e90779934b..9d9aa97999b70e017954f00967861c04d2f5a3a3 100644 --- a/lite/kernels/cuda/io_copy_compute.cc +++ b/lite/kernels/cuda/io_copy_compute.cc @@ -51,7 +51,7 @@ class IoCopyHostToCudaCompute CHECK(param.x->target() == TARGET(kHost) || param.x->target() == TARGET(kX86)); auto mem_size = param.x->memory_size(); - LOG(INFO) << "copy size " << mem_size; + VLOG(4) << "copy size " << mem_size; auto* data = param.y->mutable_data(TARGET(kCUDA), mem_size); CopyFromHostSync(data, param.x->raw_data(), mem_size); } @@ -89,6 +89,7 @@ class IoCopyCudaToHostCompute auto& param = Param(); CHECK(param.x->target() == TARGET(kCUDA)); auto mem_size = param.x->memory_size(); + VLOG(4) << "io copy cuda to host " << mem_size; auto* data = param.y->mutable_data(TARGET(kHost), mem_size); CopyToHostSync(data, param.x->raw_data(), mem_size); } @@ -107,8 +108,14 @@ REGISTER_LITE_KERNEL(io_copy, kAny, paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, host_to_device) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy, @@ -117,8 +124,14 @@ REGISTER_LITE_KERNEL(io_copy, kAny, paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, device_to_host) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy_once, @@ -127,8 +140,14 @@ REGISTER_LITE_KERNEL(io_copy_once, kAny, paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, host_to_device) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy_once, @@ -137,6 +156,12 @@ REGISTER_LITE_KERNEL(io_copy_once, kAny, paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, device_to_host) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); diff --git a/lite/kernels/cuda/layout_compute.cc b/lite/kernels/cuda/layout_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..e2d0ae4f2ef10b29247a2f823988e8098aa33795 --- /dev/null +++ b/lite/kernels/cuda/layout_compute.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/layout_compute.h" +#include "lite/backends/cuda/math/transpose.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +#define NCHWTONHWC(type) \ + auto& param = this->template Param(); \ + auto& ctx = this->ctx_->template As(); \ + auto input = param.x->template data(); \ + auto input_dim = param.x->dims(); \ + CHECK(input_dim.size() == 4) \ + << "NCHW to NHWC should guarantee that the input dims should be 4"; \ + int n = input_dim[0]; \ + int c = input_dim[1]; \ + int h = input_dim[2]; \ + int w = input_dim[3]; \ + param.y->Resize({n, h, w, c}); \ + auto output = param.y->template mutable_data(TARGET(kCUDA)); \ + lite::cuda::math::NCHW2NHWC(n, c, h * w, input, output, &ctx); + +#define NHWCTONCHW(type) \ + auto& param = this->template Param(); \ + auto& ctx = this->ctx_->template As(); \ + auto input = param.x->template data(); \ + auto input_dim = param.x->dims(); \ + CHECK(input_dim.size() == 4) \ + << "NHWC to NCHW should guarantee that the input dims should be 4"; \ + int n = input_dim[0]; \ + int h = input_dim[1]; \ + int w = input_dim[2]; \ + int c = input_dim[3]; \ + param.y->Resize({n, c, h, w}); \ + auto output = param.y->template mutable_data(TARGET(kCUDA)); \ + lite::cuda::math::NHWC2NCHW(n, c, h * w, input, output, &ctx); + +void NCHWToNHWCCompute::Run() { NCHWTONHWC(float) } + +void NCHWToNHWCComputeInt8::Run() { NCHWTONHWC(int8_t) } + +void NHWCToNCHWCompute::Run() { NHWCTONCHW(float) } + +void NHWCToNCHWComputeInt8::Run() { NHWCTONCHW(int8_t) } + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(layout, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::NCHWToNHWCCompute, + nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::NHWCToNCHWCompute, + nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::NCHWToNHWCComputeInt8, + int8_nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::NHWCToNCHWComputeInt8, + int8_nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::NCHWToNHWCCompute, + nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::NHWCToNCHWCompute, + nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::NCHWToNHWCComputeInt8, + int8_nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::NHWCToNCHWComputeInt8, + int8_nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/layout_compute.h b/lite/kernels/cuda/layout_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..10a0961212dde34a35dcc43b07bc0207ed2c93a3 --- /dev/null +++ b/lite/kernels/cuda/layout_compute.h @@ -0,0 +1,56 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class NCHWToNHWCCompute : public KernelLite { + public: + using param_t = operators::LayoutParam; + void Run() override; + virtual ~NCHWToNHWCCompute() = default; +}; + +class NCHWToNHWCComputeInt8 + : public KernelLite { + public: + using param_t = operators::LayoutParam; + void Run() override; + virtual ~NCHWToNHWCComputeInt8() = default; +}; + +class NHWCToNCHWCompute : public KernelLite { + public: + using param_t = operators::LayoutParam; + void Run() override; + virtual ~NHWCToNCHWCompute() = default; +}; + +class NHWCToNCHWComputeInt8 + : public KernelLite { + public: + using param_t = operators::LayoutParam; + void Run() override; + virtual ~NHWCToNCHWComputeInt8() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/layout_compute_test.cc b/lite/kernels/cuda/layout_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a781eb7b9aa3fd6b79fce59f6914b2532dca5ee --- /dev/null +++ b/lite/kernels/cuda/layout_compute_test.cc @@ -0,0 +1,184 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/layout_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +#define IN(n, c, h, w) \ + input_data[w + h * input_w + c * input_h * input_w + \ + n * input_c * input_h * input_w] +#define OUT(n, c, h, w) \ + output_data[w + h * output_w + c * output_h * output_w + \ + n * output_c * output_h * output_w] + +template +void nchw2nhwc_ref(lite::Tensor* input, lite::Tensor* output) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_c = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, h, w, c) = IN(n, c, h, w); + } + } + } + } +} +#undef IN +#undef OUT + +#define IN(n, h, w, c) \ + input_data[c + w * input_c + h * input_w * input_c + \ + n * input_h * input_w * input_c] +#define OUT(n, h, w, c) \ + output_data[c + w * output_c + h * output_w * output_c + \ + n * output_h * output_w * output_c] +template +void nhwc2nchw_ref(lite::Tensor* input, lite::Tensor* output) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_h = input->dims()[1]; + int input_w = input->dims()[2]; + int input_c = input->dims()[3]; + int output_h = output->dims()[1]; + int output_w = output->dims()[2]; + int output_c = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, c, h, w) = IN(n, h, w, c); + } + } + } + } +} + +template +void test_reformat(LayOutCompute* layout_kernel, bool nchw2nhwc) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + operators::LayoutParam param; + + lite::Tensor x, x_cpu, x_ref; + lite::Tensor out, out_cpu, out_ref; + int N = 5, C = 6, H = 7, W = 8; + if (nchw2nhwc) { + x.Resize({N, C, H, W}); + out.Resize({N, H, W, C}); + + x_cpu.Resize({N, C, H, W}); + out_cpu.Resize({N, H, W, C}); + + x_ref.Resize({N, C, H, W}); + out_ref.Resize({N, H, W, C}); + } else { + x.Resize({N, H, W, C}); + out.Resize({N, C, H, W}); + + x_cpu.Resize({N, H, W, C}); + out_cpu.Resize({N, C, H, W}); + + x_ref.Resize({N, H, W, C}); + out_ref.Resize({N, C, H, W}); + } + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = static_cast((i + 1) % 127); + x_ref_data[i] = static_cast((i + 1) % 127); + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.x = &x; + param.y = &out; + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + layout_kernel->SetParam(param); + layout_kernel->SetContext(std::move(ctx)); + layout_kernel->Launch(); + cudaDeviceSynchronize(); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + CopySync( + out_cpu_data, out_data, sizeof(Dtype) * out.numel(), IoDirection::DtoH); + if (nchw2nhwc) { + nchw2nhwc_ref(&x_ref, &out_ref); + } else { + nhwc2nchw_ref(&x_ref, &out_ref); + } + + auto* out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(static_cast(out_cpu_data[i]), + static_cast(out_ref_data[i]), + 1e-5); + } +} + +TEST(normal, nchw2nhwc) { + LayOutCompute* layout_k = new NCHWToNHWCCompute(); + test_reformat(layout_k, true); + delete layout_k; +} + +/* +TEST(normal, nhwc2nchw) { + LayOutCompute * layout_k = new NHWCToNCHWCompute(); + test_reformat(layout_k, false); + delete layout_k; +} + +TEST(normal, nchw2nhwcint8) { + LayOutCompute * layout_k = new NCHWToNHWCCompute(); + test_reformat(layout_k, true); + delete layout_k; +} + +TEST(normal, nhwc2nchwint8) { + LayOutCompute * layout_k = new NHWCToNCHWCompute(); + test_reformat(layout_k, false); + delete layout_k; +} +*/ + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/leaky_relu_compute.cu b/lite/kernels/cuda/leaky_relu_compute.cu index eb6561c9edebb7f61a512ba303adc6f9032fd317..0dc281c493ec840a9e0df4868dcf2df76a771ac5 100644 --- a/lite/kernels/cuda/leaky_relu_compute.cu +++ b/lite/kernels/cuda/leaky_relu_compute.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once #include "lite/core/op_registry.h" #include "lite/kernels/cuda/leaky_relu_compute.h" @@ -67,4 +66,5 @@ REGISTER_LITE_KERNEL(leaky_relu, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .SetVersion("1.5.0") .Finalize(); diff --git a/lite/kernels/cuda/leaky_relu_compute_test.cc b/lite/kernels/cuda/leaky_relu_compute_test.cc index 9fb5a5eddf84807b18ae5f87b4fb65cd3b7355aa..8ced10ce7df8ff874fb03d958961a147031d063c 100644 --- a/lite/kernels/cuda/leaky_relu_compute_test.cc +++ b/lite/kernels/cuda/leaky_relu_compute_test.cc @@ -35,7 +35,6 @@ TEST(leaky_relu, normal) { x_cpu.Resize({h, w}); y_cpu.Resize({h, w}); - auto* x_data = x.mutable_data(TARGET(kCUDA)); auto* y_data = y.mutable_data(TARGET(kCUDA)); float* x_cpu_data = x_cpu.mutable_data(); float* y_cpu_data = x_cpu.mutable_data(); diff --git a/lite/kernels/cuda/mul_compute.h b/lite/kernels/cuda/mul_compute.h index 4a542104d6743b52758cbecfb11c025628e46333..c2fc4364ef77742858b143734d2ecf4d13e201e9 100644 --- a/lite/kernels/cuda/mul_compute.h +++ b/lite/kernels/cuda/mul_compute.h @@ -33,19 +33,36 @@ void mul_compute(const lite::cuda::Blas& blas, int y_h, int y_w, T* out) { + float alpha = 1.0; + float beta = 0.0; + /* blas.sgemm(CUBLAS_OP_N, CUBLAS_OP_N, x_h, y_w, x_w, - nullptr, + &alpha, x, x_w, y, y_w, - nullptr, + &beta, out, x_h); + */ + blas.sgemm(CUBLAS_OP_N, + CUBLAS_OP_N, + y_w, + x_h, + y_h, + &alpha, + y, + y_w, + x, + x_w, + &beta, + out, + y_w); } class MulCompute : public KernelLite { @@ -56,23 +73,29 @@ class MulCompute : public KernelLite { CHECK(ctx_) << "running context should be set first"; auto& context = this->ctx_->template As(); CHECK(context.cublas_fp32()) << "blas should init first"; - /* auto& blas = *context.cublas_fp32(); - CHECK(param.x->target() == TARGET(kCUDA)); - auto* x = param.x->data(); - int x_h = param.x->dims()[0]; - int x_w = param.x->dims()[1]; - auto* y = param.y->data(); - int y_h = param.y->dims()[0]; - int y_w = param.y->dims()[1]; - */ + auto& param = this->Param(); + const auto* x_data = param.x->data(); + const auto* y_data = param.y->data(); + auto* out_data = param.output->mutable_data(TARGET(kCUDA)); - const auto& param = Param(); - param.output->mutable_data(TARGET(kCUDA)); - LOG(INFO) << "mul output memory size " << param.output->data_size(); + int x_h = static_cast( + param.x->dims().Slice(0, param.x_num_col_dims).production()); + int x_w = static_cast( + param.x->dims() + .Slice(param.x_num_col_dims, param.x->dims().size()) + .production()); + int y_h = static_cast( + param.y->dims().Slice(0, param.y_num_col_dims).production()); + int y_w = static_cast( + param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production()); + CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; + LOG(INFO) << x_h << " " << x_w << " " << y_h << " " << y_w; - // mul_compute(blas, x, x_h, x_w, y, y_h, y_w, out); + mul_compute(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data); } virtual ~MulCompute() = default; diff --git a/lite/kernels/cuda/mul_compute_test.cc b/lite/kernels/cuda/mul_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1c1d63e7dcd46f84cd128fc5b855da2098e179d --- /dev/null +++ b/lite/kernels/cuda/mul_compute_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/mul_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(mul_compute, normal) { + MulCompute mul_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + Tensor x, y, out, x_cpu, y_cpu, out_cpu; + int x_h = 2, x_w_y_h = 3, y_w = 4; + out.Resize({x_h, y_w}); + x_cpu.Resize({x_h, x_w_y_h}); + y_cpu.Resize({x_w_y_h, y_w}); + out_cpu.Resize({x_h, y_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i + 1.0; + } + for (int i = 0; i < y_cpu.numel(); i++) { + y_cpu_data[i] = i + 1.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + operators::MulParam param; + param.x = &x; + param.y = &y; + param.output = &out; + mul_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + mul_kernel.SetContext(std::move(ctx)); + mul_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out_cpu.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/nearest_interp_compute.cu b/lite/kernels/cuda/nearest_interp_compute.cu index 8edeacfe5a9339433c922c77f548e69aca4d0bce..1a614e0656b417786deff8df6b7a827433b33f7b 100644 --- a/lite/kernels/cuda/nearest_interp_compute.cu +++ b/lite/kernels/cuda/nearest_interp_compute.cu @@ -120,9 +120,9 @@ void NearestInterpCompute::Run() { int in_chw = c * in_hw; int out_chw = c * out_hw; - int pixelNum = n * out_chw; + int pixel_num = n * out_chw; int threads = 512; - int blocks = (pixelNum + threads - 1) / threads; + int blocks = (pixel_num + threads - 1) / threads; blocks = blocks > 8 ? 8 : blocks; KeNearestNeighborInterp<<>>(input_data, @@ -154,7 +154,16 @@ REGISTER_LITE_KERNEL(nearest_interp, kNCHW, paddle::lite::kernels::cuda::NearestInterpCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .Finalize(); diff --git a/lite/kernels/cuda/nearest_interp_compute.h b/lite/kernels/cuda/nearest_interp_compute.h index d4fb0f43c65d68f5bcac01c952f4c2d392dd724d..7be9d14cf780ec2de142c3a35b7c558868e3a338 100644 --- a/lite/kernels/cuda/nearest_interp_compute.h +++ b/lite/kernels/cuda/nearest_interp_compute.h @@ -21,7 +21,7 @@ namespace kernels { namespace cuda { class NearestInterpCompute - : public KernelLite { + : public KernelLite { public: using param_t = operators::InterpolateParam; diff --git a/lite/kernels/cuda/nearest_interp_compute_test.cc b/lite/kernels/cuda/nearest_interp_compute_test.cc index 4aec6db1a21eba59439ff4d6601bd1d220c4e804..85032016d630f11bbfe150f750470e89e241c61b 100644 --- a/lite/kernels/cuda/nearest_interp_compute_test.cc +++ b/lite/kernels/cuda/nearest_interp_compute_test.cc @@ -16,91 +16,58 @@ #include #include #include -#include "lite/fluid/eigen.h" namespace paddle { namespace lite { namespace kernels { namespace cuda { -template -using EigenTensor = lite::fluid::EigenTensor; using Tensor = lite::Tensor; -static void NearestNeighborInterpolate(const Tensor& input, - Tensor* output, - const float ratio_h, - const float ratio_w, - const int n, - const int c, - const int out_h, - const int out_w, - const bool align_corners) { - auto input_t = EigenTensor::From(input); - auto output_t = EigenTensor::From(*output); - for (int k = 0; k < out_h; k++) { // loop for images - int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) - : static_cast(ratio_h * k); - for (int l = 0; l < out_w; l++) { - int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) - : static_cast(ratio_w * l); - for (int i = 0; i < n; i++) { // loop for batches - for (int j = 0; j < c; j++) { // loop for channels - output_t(i, j, k, l) = input_t(i, j, in_k, in_l); +void NearestInterpRef(Tensor* input, Tensor* output, bool with_align) { + int hin = input->dims()[2]; + int win = input->dims()[3]; + int channels = input->dims()[1]; + int num = input->dims()[0]; + int hout = output->dims()[2]; + int wout = output->dims()[3]; + float scale_w = (with_align) ? (static_cast(win - 1) / (wout - 1)) + : (static_cast(win) / (wout)); + float scale_h = (with_align) ? (static_cast(hin - 1) / (hout - 1)) + : (static_cast(hin) / (hout)); + const float* src = input->data(); + float* dst = output->mutable_data(); + int dst_stride_w = 1; + int dst_stride_h = wout; + int dst_stride_c = wout * hout; + int dst_stride_batch = wout * hout * channels; + int src_stride_w = 1; + int src_stride_h = win; + int src_stride_c = win * hin; + int src_stride_batch = win * hin * channels; + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + int src_index = n * src_stride_batch + c * src_stride_c; + for (int h = 0; h < hout; ++h) { + for (int w = 0; w < wout; ++w) { + int fw = (with_align) ? static_cast(scale_w * w + 0.5) + : static_cast(scale_w * w); + fw = (fw < 0) ? 0 : fw; + int fh = (with_align) ? static_cast(scale_h * h + 0.5) + : static_cast(scale_h * h); + fh = (fh < 0) ? 0 : fh; + int w_start = static_cast(fw); + int h_start = static_cast(fh); + int dst_index = n * dst_stride_batch + c * dst_stride_c + + h * dst_stride_h + w * dst_stride_w; + dst[dst_index] = + src[src_index + w_start * src_stride_w + h_start * src_stride_h]; } } } } } -static void NearestInterpRef(operators::InterpolateParam param, - Tensor* input, - const size_t scale, - const size_t n, - const size_t c, - const size_t in_h, - const size_t in_w, - Tensor* output_size, - Tensor* output, - size_t out_h, - size_t out_w) { - if (scale > 0) { - out_h = static_cast(in_h * scale); - out_w = static_cast(in_w * scale); - } - bool align_corners = param.align_corners; - if (output_size != nullptr) { - auto out_size_data = output_size->mutable_data(); - out_h = static_cast(out_size_data[0]); - out_w = static_cast(out_size_data[1]); - } - - float* input_data = input->mutable_data(); - LOG(INFO) << *(input_data + 2); - float* output_data = output->mutable_data(); - LOG(INFO) << *(output_data + 2); - if (in_h == out_h && in_w == out_w) { - std::memcpy(output_data, input_data, sizeof(float) * n * c * in_h * in_w); - LOG(INFO) << *(output_data + 2); - return; - } - float ratio_h = 0.f; - float ratio_w = 0.f; - if (out_h > 1) { - ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - } - if (out_w > 1) { - ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; - } - NearestNeighborInterpolate( - *input, output, ratio_h, ratio_w, n, c, out_h, out_w, align_corners); -} - TEST(nearest_interp, normal) { NearestInterpCompute nearest_interp_kernel; std::unique_ptr ctx(new KernelContext); @@ -112,9 +79,8 @@ TEST(nearest_interp, normal) { Tensor x_cpu, osz_cpu, out_cpu; Tensor x_ref, osz_ref, out_ref; - int n = 1, c = 3, in_h = 4, in_w = 4; - int in_chw = c * in_h * in_w; - int out_h = 4, out_w = 4; + int n = 1, c = 3, in_h = 40, in_w = 40; + int out_h = 80, out_w = 80; float scale = 2.0; param.out_h = out_h; @@ -134,8 +100,6 @@ TEST(nearest_interp, normal) { osz_ref.Resize({2}); out_ref.Resize({n, c, out_h, out_w}); - auto* x_data = x.mutable_data(TARGET(kCUDA)); - auto* osz_data = osz.mutable_data(TARGET(kCUDA)); auto* out_data = out.mutable_data(TARGET(kCUDA)); float* x_cpu_data = x_cpu.mutable_data(); @@ -173,8 +137,7 @@ TEST(nearest_interp, normal) { CopySync( out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); - NearestInterpRef( - param, &x_ref, scale, n, c, in_h, in_w, &osz_ref, &out_ref, out_h, out_w); + NearestInterpRef(&x_ref, &out_ref, false); for (int i = 0; i < out.numel(); i++) { EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); } diff --git a/lite/kernels/cuda/pool_compute.cu b/lite/kernels/cuda/pool_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..a2483a2c759e8acc5f5944fd316c83bb49530d36 --- /dev/null +++ b/lite/kernels/cuda/pool_compute.cu @@ -0,0 +1,375 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/pool_compute.h" +#include "lite/utils/macros.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +#define MAX_VAL(a, b) (((a) > (b)) ? (a) : (b)) +#define MIN_VAL(a, b) (((a) < (b)) ? (a) : (b)) + +__global__ void max_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int w_s = w_id * stride_w - pad_w; + const int iw_s = MAX_VAL(w_s, 0); + const int iw_e = MIN_VAL(w_s + win_w, in_w); + const int w_loop = iw_e - iw_s; + const int h_s = h_id * stride_h - pad_h; + const int ih_s = MAX_VAL(h_s, 0); + const int ih_e = MIN_VAL(h_s + win_h, in_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float max_val = -FLT_MAX; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + max_val = MAX_VAL(max_val, *(in_p + j)); + } + in_p += in_w; + } + max_val = max_val == -FLT_MAX ? 0.f : max_val; + output[nc_id * spatial_out + h_id * out_w + w_id] = max_val; + } +} + +__global__ void adaptive_max_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int iw_s = floor(static_cast(w_id * in_w) / out_w); + const int iw_e = ceil(static_cast((w_id + 1) * in_w) / out_w); + const int w_loop = iw_e - iw_s; + const int ih_s = floor(static_cast(h_id * in_h) / out_h); + const int ih_e = ceil(static_cast((h_id + 1) * in_h) / out_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float max_val = -FLT_MAX; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + max_val = MAX_VAL(max_val, *(in_p + j)); + } + in_p += in_w; + } + output[nc_id * spatial_out + h_id * out_w + w_id] = max_val; + } +} + +__global__ void avg_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + bool exclusive, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int w_s = w_id * stride_w - pad_w; + const int iw_s = MAX_VAL(w_s, 0); + const int iw_e = MIN_VAL(w_s + win_w, in_w); + const int w_loop = iw_e - iw_s; + const int h_s = h_id * stride_h - pad_h; + const int ih_s = MAX_VAL(h_s, 0); + const int ih_e = MIN_VAL(h_s + win_h, in_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float sum_val = 0.f; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + sum_val += *(in_p + j); + } + in_p += in_w; + } + int pool_size = exclusive ? h_loop * w_loop : win_w * win_h; + pool_size = pool_size == 0 ? 1 : pool_size; + output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size; + } +} + +__global__ void adaptive_avg_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int iw_s = floor(static_cast(w_id * in_w) / out_w); + const int iw_e = ceil(static_cast((w_id + 1) * in_w) / out_w); + const int w_loop = iw_e - iw_s; + const int ih_s = floor(static_cast(h_id * in_h) / out_h); + const int ih_e = ceil(static_cast((h_id + 1) * in_h) / out_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float sum_val = 0.f; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + sum_val += *(in_p + j); + } + in_p += in_w; + } + int pool_size = h_loop * w_loop; + pool_size = pool_size == 0 ? 1 : pool_size; + output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size; + } +} + +__global__ void global_max_pool_kernel(const float* input, + float* output, + const int in_h, + const int in_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int spatial_in = in_h * in_w; + const float* in_p = input + gid * spatial_in; + int i = 0; + float max_val = -0.f; + // unroll 8 + for (; i < spatial_in - 7; i += 8) { + max_val = MAX_VAL(max_val, *(in_p + 0)); + max_val = MAX_VAL(max_val, *(in_p + 1)); + max_val = MAX_VAL(max_val, *(in_p + 2)); + max_val = MAX_VAL(max_val, *(in_p + 3)); + max_val = MAX_VAL(max_val, *(in_p + 4)); + max_val = MAX_VAL(max_val, *(in_p + 5)); + max_val = MAX_VAL(max_val, *(in_p + 6)); + max_val = MAX_VAL(max_val, *(in_p + 7)); + in_p += 8; + } + for (; i < spatial_in; i++) { + max_val = MAX_VAL(max_val, *in_p); + in_p++; + } + output[gid] = max_val; + } +} + +__global__ void global_avg_pool_kernel(const float* input, + float* output, + const int in_h, + const int in_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int spatial_in = in_h * in_w; + const float* in_p = input + gid * spatial_in; + int i = 0; + float sum_val = 0.f; + // unroll 8 + for (; i < spatial_in - 7; i += 8) { + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + } + for (; i < spatial_in; i++) { + sum_val += *in_p++; + } + output[gid] = sum_val / spatial_in; + } +} + +void PoolCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + auto x_dims = param.x->dims(); + auto out_dims = param.output->dims(); + const int in_h = x_dims[2]; + const int in_w = x_dims[3]; + const int out_h = out_dims[2]; + const int out_w = out_dims[3]; + const int spatial_in = in_h * in_w; + const int spatial_out = out_h * out_w; + const int win_h = param.ksize[0]; + const int win_w = param.ksize[1]; + const int stride_h = param.strides[0]; + const int stride_w = param.strides[1]; + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int total_threads = out_dims.production(); + const int threads = 512; + const int blocks = (total_threads + threads - 1) / threads; + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(TARGET(kCUDA)); + if (param.global_pooling) { + if (param.pooling_type == "max") { + global_max_pool_kernel<<>>( + input_data, output_data, in_h, in_w, total_threads); + } else { + global_avg_pool_kernel<<>>( + input_data, output_data, in_h, in_w, total_threads); + } + } else { + if (!adaptive) { + if (param.pooling_type == "max") { + max_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } else { + avg_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + exclusive, + total_threads); + } + } else { + if (param.pooling_type == "max") { + adaptive_max_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } else { + adaptive_avg_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } + } + } + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + pool2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::PoolCompute, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/pool_compute.h b/lite/kernels/cuda/pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..55b346bfaf4ac139c8d22bff2ac64f0e78bc6023 --- /dev/null +++ b/lite/kernels/cuda/pool_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class PoolCompute + : public KernelLite { + public: + using param_t = operators::PoolParam; + + void Run() override; + virtual ~PoolCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/pool_compute_test.cc b/lite/kernels/cuda/pool_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe6ff92c0ce943cad36fbdd4f1408e344d9fd5fd --- /dev/null +++ b/lite/kernels/cuda/pool_compute_test.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/pool_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +static int PoolOutputSize( + int input_size, int filter_size, int padding, int stride, bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + return output_size; +} + +static std::vector compute_output_shape(operators::PoolParam* param_) { + const auto x_dims = param_->x->dims(); + std::vector& ksize = param_->ksize; + if (param_->global_pooling) { + ksize.resize(static_cast(x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) { + param_->paddings[i] = 0; + ksize[i] = static_cast(x_dims[i + 2]); + } + } + + std::vector output_shape({x_dims[0], x_dims[1]}); + if (param_->adaptive) { + output_shape.insert( + output_shape.end(), param_->ksize.begin(), param_->ksize.end()); + } else { + for (size_t i = 0; i < param_->ksize.size(); ++i) { + output_shape.push_back(PoolOutputSize(x_dims[i + 2], + param_->ksize[i], + param_->paddings[i], + param_->strides[i], + param_->ceil_mode)); + } + } + return output_shape; +} + +static void pool_compute_ref(const operators::PoolParam& param) { + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* src_ptr = param.x->data(); + float* dst_ptr = param.output->mutable_data(); + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = param.paddings; + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + std::string data_format = param.data_format; + + int in_n = in_dims[0]; + int in_c = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int size_in_n = in_c * in_h * in_w; + int size_in_c = in_h * in_w; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + int size_out_n = in_c * out_h * out_w; + int size_out_c = out_h * out_w; + + int window_h = ksize[0]; + int window_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + if (global_pooling == true) { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + const float* src = src_ptr + n * size_in_n + c * size_in_c; + float res = src[0]; + if (pooling_type == "max") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res = cur_val > res ? cur_val : res; + } + } else if (pooling_type == "avg") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res += cur_val; + } + res /= size_in_c; + } + dst_ptr[n * size_out_n + c] = res; + } + } + } else { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + for (int h = 0; h < out_h; ++h) { + int sh = h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + for (int w = 0; w < out_w; ++w) { + int sw = w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + int pooling_size = (ew - sw) * (eh - sh); + if (pooling_size == 0) { + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = 0.f; + continue; + } + float res = 0.f; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw; + if (kh == sh && kw == sw) { + res = src_ptr[src_idx]; + } else { + if (pooling_type == "max") { + res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx]; + } + if (pooling_type == "avg") { + res += src_ptr[src_idx]; + } + } + } + } + if (pooling_type == "avg") { + if (exclusive) { + res /= pooling_size; + } else { + res /= window_h * window_w; + } + } + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res; + } + } + } + } + } +} + +TEST(pool_cuda, compute) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + PoolCompute pool; + operators::PoolParam param; + pool.SetContext(std::move(ctx)); + + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + for (auto pooling_type : {"max", "avg"}) { + for (auto ceil_mode : {true, false}) { + for (auto global_pooling : {true, false}) { + for (auto exclusive : {true, false}) { + for (auto ksize : {2, 3}) { + for (auto stride : {1, 2}) { + for (auto pad : {0, 1}) { + for (auto n : {1, 2}) { + for (auto c : {1, 3}) { + for (auto h : {2, 3, 4, 11}) { + for (auto w : {2, 3, 4, 11}) { + VLOG(3) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " ksize:" << ksize + << " stride:" << stride << " pad:" << pad + << " exclusive:" << exclusive + << " global_pooling:" << global_pooling + << " ceil_mode: " << ceil_mode + << " pooling_type:" << pooling_type; + + // init x, output + x.Resize(DDim(std::vector({n, c, h, w}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x_cpu.dims().production(); ++i) { + float sign = i % 3 == 0 ? -0.03 : 0.05f; + x_cpu_data[i] = sign * (i % 128); + } + x.Assign(x_cpu_data, + x_cpu.dims()); + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + if (global_pooling) { + param.ksize = {h, w}; + } else { + param.ksize = {ksize, ksize}; + } + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + param.paddings = {pad, pad}; + param.exclusive = exclusive; + param.ceil_mode = ceil_mode; + param.adaptive = false; + param.use_quantizer = false; + + const std::vector& output_shape = + compute_output_shape(¶m); + if (output_shape[2] * output_shape[3] == 0) continue; + output.Resize(DDim(output_shape)); + output_ref.Resize(DDim(output_shape)); + output_cpu.Resize(DDim(output_shape)); + auto* output_data = + output.mutable_data(TARGET(kCUDA)); + auto* output_ref_data = + output_ref.mutable_data(); + auto* output_cpu_data = + output_cpu.mutable_data(); + + // compute + pool.SetParam(param); + pool.Launch(); + + // compute ref + param.x = &x_cpu; + param.output = &output_ref; + pool_compute_ref(param); + + cudaDeviceSynchronize(); + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + // compare + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR( + output_cpu_data[i], output_ref_data[i], 1e-4); + } + VLOG(3) << "compare pass"; + } + } + } + } + } + } + } + } + } + } + } +} +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/relu_compute.cu b/lite/kernels/cuda/relu_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..7c6623a4fe3bc68408a90c7ed2a2e9e35d7061fb --- /dev/null +++ b/lite/kernels/cuda/relu_compute.cu @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/relu_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void ReluKernel(const int num, const T* input, T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { +#if __CUDA_ARCH__ >= 350 + output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) : 0; +#else + output[index] = input[index] >= 0 ? input[index] : 0; +#endif + } +} + +void ReluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + int threads = 1024; + int blocks = (num + threads - 1) / threads; + ReluKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + relu, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ReluCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/relu_compute.h b/lite/kernels/cuda/relu_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b0fd500ff4369fc4a4ca256153aa5f0d21cf1e8e --- /dev/null +++ b/lite/kernels/cuda/relu_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~ReluCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/relu_compute_test.cc b/lite/kernels/cuda/relu_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..39144bfda13a9eac4ac7ad65d3d426d528fc2beb --- /dev/null +++ b/lite/kernels/cuda/relu_compute_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/relu_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(relu, normal) { + ReluCompute relu_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 256, w = 256; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = x_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 5.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + relu_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + relu_kernel.SetContext(std::move(ctx)); + relu_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + // for (int i = 0; i < y.numel(); i++) { + // LOG(INFO) << y_cpu_data[i]; + // } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/scale_compute.cc b/lite/kernels/cuda/scale_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bf7414d8c85383a834159678cdd5a09e0b434d9 --- /dev/null +++ b/lite/kernels/cuda/scale_compute.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/kernels/cuda/scale_compute.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void ScaleCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; + } + lite::cuda::math::scale( + x_dims.production(), x_data, output_data, scale, bias); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + scale, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ScaleCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/scale_compute.h b/lite/kernels/cuda/scale_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd082122a7e16762c790c8f360e2e0d7939496c --- /dev/null +++ b/lite/kernels/cuda/scale_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/backends/cuda/math/scale.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ScaleCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScaleCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8d2987524cd2e8f9c38aba4da3ff61a80bf53ce --- /dev/null +++ b/lite/kernels/cuda/softmax_compute.cu @@ -0,0 +1,246 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/softmax_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +extern __shared__ char tile[]; +template +__global__ void sharemem_softmax_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + int inner_num, + int outer_num, + int axis_size) { + dtype* data = reinterpret_cast(tile) + threadIdx.x; + //! compute thread index and real data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int blocksize = blockDim.x; + int real_index = idx_outer * inner_num + idx_inner; + int loop_idx = real_index; +//! read all data to sharemem in softmax channel +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + data[i * blocksize] = in_data[loop_idx]; + loop_idx += inner_num; + } + //! get maximum value in softmax channel + dtype max_data = data[0]; +#pragma unroll + for (int i = 1; i < axis_size; ++i) { + dtype dt = data[i * blocksize]; + if (max_data < dt) { + max_data = dt; + } + } + //! subtract then summarize + dtype sum = 0; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + dtype* dt = data + i * blocksize; + *dt = expf(*dt - max_data); + sum += *dt; + } + //! write back result + loop_idx = real_index; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + out_data[loop_idx] = data[i * blocksize] / sum; + loop_idx += inner_num; + } + } +} + +//! general kernel for softmax +template +__global__ void softmax_max_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + dtype min_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + //! get maximum data across softmax axis + dtype max_data = min_data; + for (int i = 0; i < axis_size; ++i) { + max_data = + in_data[real_index] > max_data ? in_data[real_index] : max_data; + real_index += inner_num; + } + out_data[idx] = max_data; + } +} + +template +__global__ void softmax_sub_exp_sum_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + const dtype* max_data, + dtype* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + + dtype max_data_cur = max_data[idx]; + dtype sum_data_cur = 0; + int real_index = idx_outer * inner_num + idx_inner; + //! compute exp and summarize across the softmax axis + for (int i = 0; i < axis_size; ++i) { + dtype sub_data = in_data[real_index] - max_data_cur; + sub_data = expf(sub_data); + sum_data_cur += sub_data; + out_data[real_index] = sub_data; + real_index += inner_num; + } + sum_data[idx] = sum_data_cur; + } +} + +template +__global__ void softmax_divid_output_kernel(int total_size, + dtype* io_data, + const dtype* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + dtype sum_data_cur = 1.f / sum_data[idx]; + int real_index = idx_outer * inner_num + idx_inner; + //! compute final result + for (int i = 0; i < axis_size; ++i) { + io_data[real_index] = io_data[real_index] * sum_data_cur; + real_index += inner_num; + } + } +} + +void SoftmaxCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + auto x_dims = param.x->dims(); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int total_threads = inner_num * outer_num; + int axis_size = x_dims[axis]; + + int device_id; + const int threads = 512; + const int blocks = (total_threads + threads - 1) / threads; + cudaGetDevice(&device_id); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, device_id); + size_t sharedmem_size = deviceProp.sharedMemPerBlock; + int max_dimsize = sharedmem_size / sizeof(float) / threads; + + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(TARGET(kCUDA)); + if (axis_size <= max_dimsize) { + int use_sharemem_size = axis_size * threads * sizeof(float); + sharemem_softmax_kernel<<>>( + total_threads, + input_data, + output_data, + inner_num, + outer_num, + axis_size); + } else { + //! re_alloc device memory + Tensor tmax_data; + Tensor tsum_data; + tmax_data.Resize({1, 1, 1, outer_num * inner_num}); + tsum_data.Resize({1, 1, 1, outer_num * inner_num}); + auto max_data = tmax_data.mutable_data(TARGET(kCUDA)); + auto sum_data = tsum_data.mutable_data(TARGET(kCUDA)); + //! firstly, get maximum data + float min_data = std::numeric_limits::min(); + softmax_max_kernel<<>>(total_threads, + input_data, + max_data, + min_data, + inner_num, + outer_num, + axis_size); + //! then, compute exp and sum data + softmax_sub_exp_sum_kernel<<>>( + total_threads, + input_data, + output_data, + max_data, + sum_data, + inner_num, + outer_num, + axis_size); + //! last, compute divided output + softmax_divid_output_kernel<<>>( + total_threads, output_data, sum_data, inner_num, outer_num, axis_size); + } + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SoftmaxCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("axis", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/softmax_compute.h b/lite/kernels/cuda/softmax_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..4acde4ab072390dd139c3e4e715f9ad288dc4ef8 --- /dev/null +++ b/lite/kernels/cuda/softmax_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SoftmaxCompute + : public KernelLite { + public: + using param_t = operators::SoftmaxParam; + + void Run() override; + virtual ~SoftmaxCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/softmax_compute_test.cc b/lite/kernels/cuda/softmax_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4d53520911a4868c73d7806fcc1bb5bf8bf33df --- /dev/null +++ b/lite/kernels/cuda/softmax_compute_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/softmax_compute.h" +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +template +static void softmax_compute_ref(const operators::SoftmaxParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + ASSERT_EQ(x_dims.data(), param.output->dims().data()); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + dtype max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + dtype sum_data = (dtype)0; + for (int j = 0; j < axis_size; j++) { + output_data[offset] = exp(x_data[offset] - max_data); + sum_data += output_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + output_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +TEST(softmax_cuda, compute) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + SoftmaxCompute softmax; + operators::SoftmaxParam param; + softmax.SetContext(std::move(ctx)); + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + for (auto n : {1, 3}) { + for (auto c : {1, 4}) { + for (auto h : {5, 1, 112}) { + for (auto w : {1, 6, 112}) { + for (auto axis : {-2, -1, 0, 1, 2}) { + x.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_cpu.Resize({n, c, h, w}); + output_ref.Resize({n, c, h, w}); + auto* x_cpu_data = x_cpu.mutable_data(); + auto* output_data = output.mutable_data(TARGET(kCUDA)); + auto* output_cpu_data = output_ref.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_cpu_data[i] = i; + } + x.Assign(x_cpu_data, + x_cpu.dims()); + param.x = &x; + param.axis = axis; + param.output = &output; + softmax.SetParam(param); + softmax.Launch(); + param.x = &x_cpu; + param.output = &output_ref; + softmax_compute_ref(param); + cudaDeviceSynchronize(); + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/transpose_compute.cu b/lite/kernels/cuda/transpose_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..0050e5e0f6d67f4eacaadc675b98417b9436b006 --- /dev/null +++ b/lite/kernels/cuda/transpose_compute.cu @@ -0,0 +1,99 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/transpose_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void TransposeCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + + const lite::Tensor* X = param.x; + lite::Tensor* Out = param.output; + std::vector axes = param.axis; + + const float* in = X->data(); + float* out = Out->mutable_data(TARGET(kCUDA)); + + int ndim = X->dims().size(); + std::vector dims = X->dims().data(); + + // NCHW -> NHWC + if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 && + axes[3] == 1) { + lite::cuda::math::NCHW2NHWC( + dims[0], dims[1], dims[2] * dims[3], in, out, &ctx); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + return; + } + + // NHWC -> NCHW + if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 && + axes[3] == 2) { + lite::cuda::math::NHWC2NCHW( + dims[0], dims[3], dims[1] * dims[2], in, out, &ctx); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + return; + } + + lite::cuda::math::Transpose(dims, axes, in, out, &ctx); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(transpose, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::TransposeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(transpose2, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::TransposeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +// REGISTER_LITE_KERNEL(transpose2, +// kCUDA, +// kFloat, +// kNCHW, +// paddle::lite::kernels::cuda::TransposeCompute, +// def) +// .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) +// .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))}) +// .Finalize(); diff --git a/lite/kernels/cuda/transpose_compute.h b/lite/kernels/cuda/transpose_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..f85f43993d60cc9dbe5e665a3b2b0fffcbcbc7c9 --- /dev/null +++ b/lite/kernels/cuda/transpose_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/backends/cuda/math/transpose.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class TransposeCompute : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void Run() override; + virtual ~TransposeCompute() = default; + + private: + lite::Tensor axes_, dims_; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/transpose_compute_test.cc b/lite/kernels/cuda/transpose_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..517f761b61268d2c664f74bdb338ffb79f8841f8 --- /dev/null +++ b/lite/kernels/cuda/transpose_compute_test.cc @@ -0,0 +1,285 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/transpose_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +namespace { + +#define IN(n, c, h, w) \ + input_data[w + h * input_w + c * input_h * input_w + \ + n * input_c * input_h * input_w] +#define OUT(n, c, h, w) \ + output_data[w + h * output_w + c * output_h * output_w + \ + n * output_c * output_h * output_w] +void nchw2nhwc_ref(lite::Tensor* input, + lite::Tensor* output, + const std::vector axies) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_c = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, h, w, c) = IN(n, c, h, w); + } + } + } + } +} +#undef IN +#undef OUT + +#define IN(n, h, w, c) \ + input_data[c + w * input_c + h * input_w * input_c + \ + n * input_h * input_w * input_c] +#define OUT(n, h, w, c) \ + output_data[c + w * output_c + h * output_w * output_c + \ + n * output_h * output_w * output_c] +void nhwc2nchw_ref(lite::Tensor* input, + lite::Tensor* output, + const std::vector axies) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_h = input->dims()[1]; + int input_w = input->dims()[2]; + int input_c = input->dims()[3]; + int output_h = output->dims()[1]; + int output_w = output->dims()[2]; + int output_c = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, c, h, w) = IN(n, h, w, c); + } + } + } + } +} + +void transpose_ref(lite::Tensor* input, + lite::Tensor* output, + const std::vector axes) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int ndim = input->dims().size(); + auto dims = input->dims(); + std::vector strides(ndim, 0); + std::vector buf(ndim, 0); + int cur_stride = 1; + for (int i = ndim - 1; i >= 0; --i) { + buf[i] = cur_stride; + cur_stride *= dims[i]; + } + for (int i = 0; i < ndim; ++i) { + strides[i] = buf[axes[i]]; + } + + auto y_dims = output->dims(); + int size = input->dims().production(); + for (int i = 0; i < size; ++i) { + int idx = 0; + int v = i; + for (int j = ndim - 1; j >= 0; --j) { + idx += v % y_dims[j] * strides[j]; + v /= y_dims[j]; + } + output_data[i] = input_data[idx]; + } +} +} // namespace + +TEST(transpose_nchw, normal) { + TransposeCompute transpose_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::TransposeParam param; + + lite::Tensor x, x_cpu, x_ref; + lite::Tensor out, out_cpu, out_ref; + + int N = 5, C = 6, H = 7, W = 8; + std::vector axes({0, 2, 3, 1}); + x.Resize({N, C, H, W}); + out.Resize({N, H, W, C}); + + x_cpu.Resize({N, C, H, W}); + out_cpu.Resize({N, H, W, C}); + + x_ref.Resize({N, C, H, W}); + out_ref.Resize({N, H, W, C}); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 1; + x_ref_data[i] = i + 1; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.x = &x; + param.output = &out; + param.axis = axes; + transpose_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + transpose_kernel.SetContext(std::move(ctx)); + transpose_kernel.Launch(); + cudaDeviceSynchronize(); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + nchw2nhwc_ref(&x_ref, &out_ref, axes); + auto* out_ref_data = out_ref.mutable_data(); + // transpose_ref(&x_ref, &out_ref, axes); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(transpose_nhwc, normal) { + TransposeCompute transpose_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::TransposeParam param; + + lite::Tensor x, x_cpu, x_ref; + lite::Tensor out, out_cpu, out_ref; + + int N = 5, C = 6, H = 7, W = 8; + std::vector axes({0, 3, 1, 2}); + x.Resize({N, H, W, C}); + out.Resize({N, C, H, W}); + + x_cpu.Resize({N, H, W, C}); + out_cpu.Resize({N, C, H, W}); + + x_ref.Resize({N, H, W, C}); + out_ref.Resize({N, C, H, W}); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 1; + x_ref_data[i] = i + 1; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + param.x = &x; + param.output = &out; + param.axis = axes; + transpose_kernel.SetParam(param); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + transpose_kernel.SetContext(std::move(ctx)); + transpose_kernel.Launch(); + cudaDeviceSynchronize(); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + nhwc2nchw_ref(&x_ref, &out_ref, axes); + // transpose_ref(&x_ref, &out_ref, axes); + auto* out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(transpose, normal) { + TransposeCompute transpose_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::TransposeParam param; + + lite::Tensor x, x_cpu, x_ref; + lite::Tensor out, out_cpu, out_ref; + + int C = 6, H = 7, W = 8; + std::vector axes({2, 0, 1}); + x.Resize({C, H, W}); + out.Resize({W, C, H}); + + x_cpu.Resize({C, H, W}); + out_cpu.Resize({W, C, H}); + + x_ref.Resize({C, H, W}); + out_ref.Resize({W, C, H}); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 1; + x_ref_data[i] = i + 1; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + param.x = &x; + param.output = &out; + param.axis = axes; + transpose_kernel.SetParam(param); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + transpose_kernel.SetContext(std::move(ctx)); + transpose_kernel.Launch(); + cudaDeviceSynchronize(); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + transpose_ref(&x_ref, &out_ref, axes); + auto* out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/yolo_box_compute.cu b/lite/kernels/cuda/yolo_box_compute.cu index 99fff9a709338b081cbeed484bebbd694e383617..0a00c06cbfb9200e45d48a59aa26f2350c2cf9ab 100644 --- a/lite/kernels/cuda/yolo_box_compute.cu +++ b/lite/kernels/cuda/yolo_box_compute.cu @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include #include "lite/core/op_registry.h" #include "lite/kernels/cuda/yolo_box_compute.h" +// #include "lite/core/target_wrapper.h" namespace paddle { namespace lite { @@ -95,7 +95,7 @@ __host__ __device__ inline void CalcLabelScore(T* scores, template __global__ void KeYoloBoxFw(const T* input, - const T* imgsize, + const int* imgsize, T* boxes, T* scores, const float conf_thresh, @@ -118,8 +118,8 @@ __global__ void KeYoloBoxFw(const T* input, int l = tid % w; int an_stride = (5 + class_num) * grid_num; - int img_height = static_cast(imgsize[2 * i]); - int img_width = static_cast(imgsize[2 * i + 1]); + int img_height = imgsize[2 * i]; + int img_width = imgsize[2 * i + 1]; int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); @@ -168,9 +168,13 @@ void YoloBoxCompute::Run() { int downsample_ratio = param.downsample_ratio; const float* input = X->data(); - const float* imgsize = ImgSize->data(); + const int* imgsize = ImgSize->data(); float* boxes = Boxes->mutable_data(TARGET(kCUDA)); float* scores = Scores->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemsetAsync( + boxes, 0, Boxes->numel() * sizeof(float), stream); + TargetWrapperCuda::MemsetAsync( + scores, 0, Scores->numel() * sizeof(float), stream); const int n = X->dims()[0]; const int h = X->dims()[2]; @@ -179,8 +183,13 @@ void YoloBoxCompute::Run() { const int an_num = anchors.size() / 2; int input_size = downsample_ratio * h; - anchors_.Resize(static_cast({anchors.size()})); + anchors_.Resize({static_cast(anchors.size())}); int* d_anchors = anchors_.mutable_data(TARGET(kCUDA)); + // TargetWrapperCuda::MemcpyAsync(d_anchors, + // anchors.data(), + // sizeof(int) * anchors.size(), + // IoDirection::HtoD, + // stream); CopySync(d_anchors, anchors.data(), sizeof(int) * anchors.size(), @@ -218,8 +227,20 @@ REGISTER_LITE_KERNEL(yolo_box, kNCHW, paddle::lite::kernels::cuda::YoloBoxCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("ImgSize", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Boxes", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Scores", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) .Finalize(); diff --git a/lite/kernels/cuda/yolo_box_compute_test.cc b/lite/kernels/cuda/yolo_box_compute_test.cc index 5cd957938319cff216014e3a97d7348d223884e7..994251b249e7dc6d8ae8870937c34cfa0323fd22 100644 --- a/lite/kernels/cuda/yolo_box_compute_test.cc +++ b/lite/kernels/cuda/yolo_box_compute_test.cc @@ -89,7 +89,7 @@ inline static void calc_label_score(float* scores, template static void YoloBoxRef(const T* input, - const T* imgsize, + const int* imgsize, T* boxes, T* scores, const float conf_thresh, @@ -106,8 +106,8 @@ static void YoloBoxRef(const T* input, float box[4]; for (int i = 0; i < n; i++) { - int img_height = static_cast(imgsize[2 * i]); - int img_width = static_cast(imgsize[2 * i + 1]); + int img_height = imgsize[2 * i]; + int img_width = imgsize[2 * i + 1]; for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { @@ -180,18 +180,16 @@ TEST(yolo_box, normal) { boxes_ref.Resize({n, m, 4}); scores_ref.Resize({n, cls, m}); - auto* x_data = x.mutable_data(TARGET(kCUDA)); - auto* sz_data = sz.mutable_data(TARGET(kCUDA)); auto* boxes_data = boxes.mutable_data(TARGET(kCUDA)); auto* scores_data = scores.mutable_data(TARGET(kCUDA)); float* x_cpu_data = x_cpu.mutable_data(); - float* sz_cpu_data = sz_cpu.mutable_data(); + int* sz_cpu_data = sz_cpu.mutable_data(); float* boxes_cpu_data = boxes_cpu.mutable_data(); float* scores_cpu_data = scores_cpu.mutable_data(); float* x_ref_data = x_ref.mutable_data(); - float* sz_ref_data = sz_ref.mutable_data(); + int* sz_ref_data = sz_ref.mutable_data(); float* boxes_ref_data = boxes_ref.mutable_data(); float* scores_ref_data = scores_ref.mutable_data(); @@ -205,7 +203,7 @@ TEST(yolo_box, normal) { sz_ref_data[1] = 32; x.Assign(x_cpu_data, x_cpu.dims()); - sz.Assign(sz_cpu_data, sz_cpu.dims()); + sz.Assign(sz_cpu_data, sz_cpu.dims()); param.X = &x; param.ImgSize = &sz; diff --git a/lite/kernels/fpga/conv_compute.cc b/lite/kernels/fpga/conv_compute.cc index fe662c58ee862cae337aaf93eabb499dc80358fc..3e06e103bba61937e48bb4d14eeedd493ab15bba 100644 --- a/lite/kernels/fpga/conv_compute.cc +++ b/lite/kernels/fpga/conv_compute.cc @@ -28,10 +28,9 @@ void ConvCompute::PrepareForRun() { // ==================================================== zynqmp::ConvParam& conv_param = pe_.param(); - param.output->mutable_data(); - filter_.setDataType(zynqmp::FP32); + // filter_.setDataType(zynqmp::FP32); conv_param.input = param.x->ZynqTensor(); conv_param.output = param.output->ZynqTensor(); conv_param.filter = param.filter->ZynqTensor(); @@ -40,11 +39,17 @@ void ConvCompute::PrepareForRun() { conv_param.paddings = param.paddings; conv_param.dilations = param.dilations; fill_scale_bias_const(&conv_param); + conv_param.bias()->copyFrom(param.bias->ZynqTensor()); + conv_param.relu.enabled = param.fuse_relu; pe_.init(); pe_.apply(); } -void ConvCompute::Run() { pe_.dispatch(); } +void ConvCompute::Run() { + auto& param = this->Param(); + zynqmp::ConvParam& conv_param = pe_.param(); + pe_.dispatch(); +} } // namespace fpga } // namespace kernels diff --git a/lite/kernels/fpga/conv_compute.h b/lite/kernels/fpga/conv_compute.h index 42909c0fa049772d0b837a3ec690397d58e19cb4..a023fb46fc8af0ad12d07929137f3eb058e92ef4 100644 --- a/lite/kernels/fpga/conv_compute.h +++ b/lite/kernels/fpga/conv_compute.h @@ -37,9 +37,6 @@ class ConvCompute private: zynqmp::ConvPE pe_; - zynqmp::Tensor input_; - zynqmp::Tensor output_; - zynqmp::Tensor filter_; }; } // namespace fpga diff --git a/lite/kernels/fpga/elementwise_compute.h b/lite/kernels/fpga/elementwise_compute.h index ef60b82f04adae3cb77b09ef19f747d9e19c4bee..7051dd7eeda02537be713ff042a0cf33ac1b618d 100644 --- a/lite/kernels/fpga/elementwise_compute.h +++ b/lite/kernels/fpga/elementwise_compute.h @@ -36,9 +36,6 @@ class ElementwiseAddCompute private: zynqmp::ElementwiseAddPE pe_; - zynqmp::Tensor input_x_; - zynqmp::Tensor input_y_; - zynqmp::Tensor output_; }; class ElementwiseAddActivationCompute @@ -51,9 +48,6 @@ class ElementwiseAddActivationCompute private: zynqmp::ElementwiseAddPE pe_; - zynqmp::Tensor input_x_; - zynqmp::Tensor input_y_; - zynqmp::Tensor output_; }; } // namespace fpga diff --git a/lite/kernels/fpga/pooling_compute.cc b/lite/kernels/fpga/pooling_compute.cc index 3a727798d88e1dbd18844c429108ce3c48274034..e4979f8e5762400f453e323f98a6b18ba17a0998 100644 --- a/lite/kernels/fpga/pooling_compute.cc +++ b/lite/kernels/fpga/pooling_compute.cc @@ -35,9 +35,6 @@ void PoolCompute::PrepareForRun() { pool_param.output = param.output->ZynqTensor(); pool_param.relu.enabled = false; - auto& in_dims = param.x->dims(); - auto& out_dims = param.output->dims(); - pool_param.type = param.pooling_type == "max" ? zynqmp::PoolingType::MAX : zynqmp::PoolingType::AVERAGE; pool_param.globalPooling = param.global_pooling; diff --git a/lite/kernels/fpga/pooling_compute.h b/lite/kernels/fpga/pooling_compute.h index 18eee5f21dbcc3f8db9ecaa771a2146990ca4351..0f5bf106dec81b95cc27f43bf3259748552eb0d4 100644 --- a/lite/kernels/fpga/pooling_compute.h +++ b/lite/kernels/fpga/pooling_compute.h @@ -36,8 +36,6 @@ class PoolCompute private: zynqmp::PoolingPE pe_; - zynqmp::Tensor input_; - zynqmp::Tensor output_; }; } // namespace fpga diff --git a/lite/kernels/fpga/softmax_compute.cc b/lite/kernels/fpga/softmax_compute.cc index 260f03c114da00b6336245b185e6c3e58ce468d4..63abc76e68ebf15a458ed380d7eabeaf89d5dd2f 100644 --- a/lite/kernels/fpga/softmax_compute.cc +++ b/lite/kernels/fpga/softmax_compute.cc @@ -22,7 +22,7 @@ namespace fpga { using float16 = zynqmp::float16; -void SoftmaxCompute::Run() { +void SoftmaxCompute::PrepareForRun() { zynqmp::SoftmaxParam& softmax_param = pe_.param(); auto& param = Param(); @@ -33,6 +33,8 @@ void SoftmaxCompute::Run() { pe_.apply(); } +void SoftmaxCompute::Run() { pe_.dispatch(); } + } // namespace fpga } // namespace kernels } // namespace lite diff --git a/lite/kernels/fpga/softmax_compute.h b/lite/kernels/fpga/softmax_compute.h index 5eb4af6223ed9166b72d47cf7e6c052c2a547e53..035c9a60ec369b77778332f789d8b5b2a7db2462 100644 --- a/lite/kernels/fpga/softmax_compute.h +++ b/lite/kernels/fpga/softmax_compute.h @@ -29,6 +29,7 @@ using float16 = zynqmp::float16; class SoftmaxCompute : public KernelLite { public: + void PrepareForRun() override; void Run() override; virtual ~SoftmaxCompute() = default; diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index ff950be06048a99a6f122655b52edd8fcf064400..65900fa76a955533984ceb426427274711b5929e 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -3,7 +3,7 @@ message(STATUS "compile with lite host kernels") add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) -add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(multiclass_nms_compute_host Host extra SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) -lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) +#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) diff --git a/lite/kernels/host/reshape_compute.cc b/lite/kernels/host/reshape_compute.cc index a5934999cdd9c88037936bbf73f7d810aaffc3e7..02f99787e60e73d91ca8f65cb42dcd4c56e7212b 100644 --- a/lite/kernels/host/reshape_compute.cc +++ b/lite/kernels/host/reshape_compute.cc @@ -24,27 +24,9 @@ namespace host { void ReshapeCompute::Run() { auto& param = Param(); auto x = param.x; - auto actual_shape = param.actual_shape; auto output = param.output; - bool inplace = param.inplace; - auto x_dims = x->dims(); auto output_dims = output->dims(); - if (actual_shape) { - auto actual_shape_dims = actual_shape->dims(); - auto* actual_shape_data = actual_shape->data(); -#ifdef LITE_WITH_CUDA - lite::Tensor cpu_actual_shape; - if (actual_shape->target() == TARGET(kCUDA)) { - cpu_actual_shape.CopyDataFrom(*actual_shape); - actual_shape_data = cpu_actual_shape.data(); - } -#endif - auto shape = std::vector( - actual_shape_data, actual_shape_data + actual_shape_dims.production()); - output_dims = lite::operators::ValidateShape(shape, x_dims); - output->Resize(output_dims); - } - if (inplace) { + if (param.inplace) { output->ShareDataWith(*x); } else { output->CopyDataFrom(*x); @@ -66,6 +48,9 @@ REGISTER_LITE_KERNEL(reshape, .BindInput("X", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .BindInput("Shape", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) @@ -86,6 +71,9 @@ REGISTER_LITE_KERNEL(reshape2, .BindInput("Shape", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .BindOutput("Out", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) diff --git a/lite/kernels/host/reshape_compute_test.cc b/lite/kernels/host/reshape_compute_test.cc index 23f418c3ceff3bc602f1fb949c935a067134a7c4..e09da816469eb3bd8d3505de5cb9dc3d451a527d 100644 --- a/lite/kernels/host/reshape_compute_test.cc +++ b/lite/kernels/host/reshape_compute_test.cc @@ -32,40 +32,57 @@ TEST(reshape_host, compute) { ReshapeCompute reshape; operators::ReshapeParam param; - Tensor x; - Tensor actual_shape; + Tensor input; Tensor output; - - x.Resize(DDim(std::vector({1, 2, 4, 6}))); - actual_shape.Resize(DDim(std::vector({2}))); - - auto* x_data = x.mutable_data(); - auto* actual_shape_data = actual_shape.mutable_data(); - for (int i = 0; i < x.dims().production(); i++) { - x_data[i] = i; + input.Resize({1, 2, 4, 6}); + auto* input_data = input.mutable_data(); + for (int i = 0; i < input.numel(); i++) { + input_data[i] = i; } - actual_shape_data[0] = 6; - actual_shape_data[1] = 8; + Tensor shape_tensor; + shape_tensor.Resize({2}); + auto* shape_tensor_data = shape_tensor.mutable_data(); + shape_tensor_data[0] = 6; + shape_tensor_data[1] = 8; - param.x = &x; - param.shape = {-1, 0, 3, 2, 1}; - param.output = &output; - param.actual_shape = &actual_shape; + // set param and run + param.x = &input; + param.shape_tensor = &shape_tensor; // use shape_tensor param.inplace = false; + param.output = &output; reshape.SetParam(param); reshape.Run(); // check output dims - CHECK_EQ(actual_shape.dims().production(), output.dims().size()); + CHECK_EQ(shape_tensor.numel(), output.numel()); for (int i = 0; i < output.dims().size(); i++) { - CHECK_EQ(output.dims()[i], actual_shape_data[i]); + CHECK_EQ(output.dims()[i], shape_tensor_data[i]); } // check output data auto* output_data = output.mutable_data(); - CHECK_NE(output_data, x_data); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_data[i], x_data[i], 1e-6); + CHECK_NE(output_data, input_data); + for (int i = 0; i < output.numel(); i++) { + EXPECT_NEAR(output_data[i], input_data[i], 1e-6); + } + + // use shape, set param and run + param.shape_tensor = nullptr; + param.shape_vct = {-1, 0, 3, 2, 1}; + reshape.SetParam(param); + reshape.Run(); + + // check output dims + CHECK_EQ(shape_tensor.numel(), output.numel()); + for (int i = 0; i < output.dims().size(); i++) { + CHECK_EQ(output.dims()[i], shape_tensor_data[i]); + } + + // check output data + output_data = output.mutable_data(); + CHECK_NE(output_data, input_data); + for (int i = 0; i < output.numel(); i++) { + EXPECT_NEAR(output_data[i], input_data[i], 1e-6); } // check output data if inplace = true; @@ -73,7 +90,7 @@ TEST(reshape_host, compute) { reshape.SetParam(param); reshape.Run(); output_data = output.mutable_data(); - CHECK_EQ(output_data, x_data); + CHECK_EQ(output_data, input_data); } TEST(reshape, retrive_op) { diff --git a/lite/kernels/npu/CMakeLists.txt b/lite/kernels/npu/CMakeLists.txt index 960dbff8dba4e391761d323cbfc24946853f9e3a..eb1824e1112beec57d93d63a2464fed94fab81c9 100644 --- a/lite/kernels/npu/CMakeLists.txt +++ b/lite/kernels/npu/CMakeLists.txt @@ -5,5 +5,9 @@ endif() message(STATUS "compile with lite NPU kernels") -add_kernel(graph_compute_npu NPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} ${npu_ddk_libs}) +add_kernel(graph_compute_npu NPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} npu_runtime) # lite_cc_test(test_graph_compute_npu SRCS graph_compute_test.cc DEPS graph_compute_npu) + +if(NOT LITE_ON_TINY_PUBLISH) + add_subdirectory(bridges) +endif() diff --git a/lite/backends/npu/bridge/CMakeLists.txt b/lite/kernels/npu/bridges/CMakeLists.txt similarity index 58% rename from lite/backends/npu/bridge/CMakeLists.txt rename to lite/kernels/npu/bridges/CMakeLists.txt index cf3ad9905588d2501952f3eba0f39336e199b54b..032de819743f4aba02e442dd71c26b950d1435b6 100644 --- a/lite/backends/npu/bridge/CMakeLists.txt +++ b/lite/kernels/npu/bridges/CMakeLists.txt @@ -1,8 +1,6 @@ +lite_cc_library(npu_bridge_registry SRCS registry.cc) -lite_cc_library(npu_bridge_registry SRCS registry.cc DEPS ${npu_ddk_libs}) -lite_cc_library(npu_bridge_utils SRCS utils.cc DEPS ${npu_ddk_libs} tensor op mir_node scope) - -set(npu_bridge_deps npu_bridge_registry npu_bridge_utils op) +set(npu_bridge_deps npu_bridge_registry npu_builder op) lite_cc_library(npu_bridge_fc_op SRCS fc_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_conv_op SRCS conv_op.cc DEPS ${npu_bridge_deps}) @@ -12,7 +10,7 @@ lite_cc_library(npu_bridge_scale_op SRCS scale_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_softmax_op SRCS softmax_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_pool_op SRCS pool_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${npu_bridge_deps}) -lite_cc_library(npu_bridge_elementwise_op SRCS elementwise_ops.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_elementwise_ops SRCS elementwise_ops.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_reshape_op SRCS reshape_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_conv_transpose_op SRCS conv_transpose_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_interpolate_op SRCS interpolate_op.cc DEPS ${npu_bridge_deps}) @@ -24,7 +22,6 @@ lite_cc_library(npu_bridge_pad2d_op SRCS pad2d_op.cc DEPS ${npu_bridge_deps}) set(npu_bridges npu_bridge_registry - npu_bridge_utils npu_bridge_fc_op npu_bridge_conv_op npu_bridge_mul_op @@ -33,7 +30,7 @@ set(npu_bridges npu_bridge_softmax_op npu_bridge_pool_op npu_bridge_batch_norm_op - npu_bridge_elementwise_op + npu_bridge_elementwise_ops npu_bridge_reshape_op npu_bridge_conv_transpose_op npu_bridge_interpolate_op @@ -44,24 +41,24 @@ set(npu_bridges npu_bridge_pad2d_op CACHE INTERNAL "npu_bridges") -lite_cc_library(npu_test_helper SRCS test_helper.cc DEPS npu_helper ${npu_ddk_libs} ${npu_bridges} ${npu_kernels} ${ops}) +set(npu_bridge_test_deps ${npu_bridges} ${npu_kernels} ${ops}) -lite_cc_test(test_npu_bridge_fc_op SRCS fc_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_conv_op SRCS conv_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_mul_op SRCS mul_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_act_op SRCS act_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_scale_op SRCS scale_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_elementwise_op SRCS elementwise_ops_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_interpolate_op SRCS interpolate_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc DEPS npu_test_helper) -lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_fc_op SRCS fc_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_mul_op SRCS mul_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_act_op SRCS act_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_scale_op SRCS scale_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_elementwise_ops SRCS elementwise_ops_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_interpolate_op SRCS interpolate_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) message(STATUS "+++++ npu_bridges: ${npu_bridges}") diff --git a/lite/backends/npu/bridge/act_op.cc b/lite/kernels/npu/bridges/act_op.cc similarity index 65% rename from lite/backends/npu/bridge/act_op.cc rename to lite/kernels/npu/bridges/act_op.cc index 9573f7d7e90035c4dc1a29d40120c88470c0def2..2b3a415ad72d5629d343678f65e2e0040fafda14 100644 --- a/lite/backends/npu/bridge/act_op.cc +++ b/lite/kernels/npu/bridges/act_op.cc @@ -12,27 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" -#include "lite/operators/relu_op.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ActConverter(const std::shared_ptr act_op, const node_map_type& inputs_map) { auto scope = act_op->scope(); auto op_info = act_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; // create act node and set input node from inputs_map @@ -40,8 +34,8 @@ node_map_type ActConverter(const std::shared_ptr act_op, auto act_node = std::make_shared(unique_op_type); CHECK(inputs_map.count(x_var_name)); act_node->set_input_x(*inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(act_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(act_node); // parse and set activation type int act_mode = 1; @@ -73,16 +67,20 @@ node_map_type ActConverter(const std::shared_ptr act_op, return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(sigmod, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(relu, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(tanh, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(elu, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(abs, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(softsign, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(softplus, paddle::lite::npu::bridge::ActConverter); -REGISTER_NPU_BRIDGE(hardsigmoid, paddle::lite::npu::bridge::ActConverter); +REGISTER_NPU_BRIDGE(sigmod, paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(relu, paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(tanh, paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(elu, paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(abs, paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(softsign, + paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(softplus, + paddle::lite::kernels::npu::bridges::ActConverter); +REGISTER_NPU_BRIDGE(hardsigmoid, + paddle::lite::kernels::npu::bridges::ActConverter); diff --git a/lite/backends/npu/bridge/act_op_test.cc b/lite/kernels/npu/bridges/act_op_test.cc similarity index 94% rename from lite/backends/npu/bridge/act_op_test.cc rename to lite/kernels/npu/bridges/act_op_test.cc index edbfbb416f1ae1798d885e22b9438e05f7e8f3d4..420de655dcdfb2069948399525bc4a8a561d0fd5 100644 --- a/lite/backends/npu/bridge/act_op_test.cc +++ b/lite/kernels/npu/bridges/act_op_test.cc @@ -14,15 +14,16 @@ #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" #include "lite/operators/relu_op.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void relu_ref(const std::shared_ptr op) { Scope* scope = op->scope(); @@ -91,8 +92,9 @@ TEST(NPUBridges, relu) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/batch_norm_op.cc b/lite/kernels/npu/bridges/batch_norm_op.cc similarity index 73% rename from lite/backends/npu/bridge/batch_norm_op.cc rename to lite/kernels/npu/bridges/batch_norm_op.cc index 76b4ac3d9b112701c2c606e3adbb75ff54c70a1b..5b3cbd52133b61f0c0e37e2ba9bf2f6775f7a2b4 100644 --- a/lite/backends/npu/bridge/batch_norm_op.cc +++ b/lite/kernels/npu/bridges/batch_norm_op.cc @@ -12,20 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/batch_norm_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type BatchNormConverter( const std::shared_ptr batch_norm_op, @@ -33,7 +27,7 @@ node_map_type BatchNormConverter( auto scope = batch_norm_op->scope(); auto op_info = batch_norm_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; std::shared_ptr batch_norm_node = @@ -43,27 +37,27 @@ node_map_type BatchNormConverter( auto scale_var_name = op_info->Input("Scale").front(); lite::Tensor* scale = scope->FindVar(scale_var_name)->GetMutable(); auto npu_scale = std::make_shared(scale_var_name); - npu_scale->set_attr_value(CvtFromLiteTensor(scale)); - OpList::Global().add(npu_scale); + npu_scale->set_attr_value(lite::npu::CvtFromLiteTensor(scale)); + lite::npu::OpList::Global().add(npu_scale); auto bias_var_name = op_info->Input("Bias").front(); lite::Tensor* bias = scope->FindVar(bias_var_name)->GetMutable(); auto npu_bias = std::make_shared(bias_var_name); - npu_bias->set_attr_value(CvtFromLiteTensor(bias)); - OpList::Global().add(npu_bias); + npu_bias->set_attr_value(lite::npu::CvtFromLiteTensor(bias)); + lite::npu::OpList::Global().add(npu_bias); auto mean_var_name = op_info->Input("Mean").front(); lite::Tensor* mean = scope->FindVar(mean_var_name)->GetMutable(); auto npu_mean = std::make_shared(mean_var_name); - npu_mean->set_attr_value(CvtFromLiteTensor(mean)); - OpList::Global().add(npu_mean); + npu_mean->set_attr_value(lite::npu::CvtFromLiteTensor(mean)); + lite::npu::OpList::Global().add(npu_mean); auto variance_var_name = op_info->Input("Variance").front(); lite::Tensor* variance = scope->FindVar(variance_var_name)->GetMutable(); auto npu_variance = std::make_shared(variance_var_name); - npu_variance->set_attr_value(CvtFromLiteTensor(variance)); - OpList::Global().add(npu_variance); + npu_variance->set_attr_value(lite::npu::CvtFromLiteTensor(variance)); + lite::npu::OpList::Global().add(npu_variance); float npu_momentum = op_info->GetAttr("momentum"); float npu_epsilon = op_info->GetAttr("epsilon"); @@ -80,17 +74,19 @@ node_map_type BatchNormConverter( batch_norm_node->set_attr_mode(npu_mode); batch_norm_node->set_attr_use_global_stats(npu_use_global_stats); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(batch_norm_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(batch_norm_node); node_map_type outputs_map; outputs_map[op_info->Output("Y").front()] = batch_norm_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(batch_norm, paddle::lite::npu::bridge::BatchNormConverter); +REGISTER_NPU_BRIDGE(batch_norm, + paddle::lite::kernels::npu::bridges::BatchNormConverter); diff --git a/lite/backends/npu/bridge/batch_norm_op_test.cc b/lite/kernels/npu/bridges/batch_norm_op_test.cc similarity index 96% rename from lite/backends/npu/bridge/batch_norm_op_test.cc rename to lite/kernels/npu/bridges/batch_norm_op_test.cc index ec5898f6c8299dc0068391431af60b9f075cc55c..38a876efb7c8ca6c38dee44e3c7a29a141d995d4 100644 --- a/lite/backends/npu/bridge/batch_norm_op_test.cc +++ b/lite/kernels/npu/bridges/batch_norm_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/batch_norm_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void batch_norm_ref(const std::shared_ptr op) { @@ -157,8 +158,9 @@ TEST(NPUBridges, batch_norm) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/concat_op.cc b/lite/kernels/npu/bridges/concat_op.cc similarity index 71% rename from lite/backends/npu/bridge/concat_op.cc rename to lite/kernels/npu/bridges/concat_op.cc index 85482251815a6bb94135c38331e4b0f3e4611e05..9be47339354c5602f98583b5163d11e037570321 100644 --- a/lite/backends/npu/bridge/concat_op.cc +++ b/lite/kernels/npu/bridges/concat_op.cc @@ -12,28 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/concat_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" -#include "lite/backends/npu/npu_helper.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ConcatConverter(const std::shared_ptr concat_op, const node_map_type& inputs_map) { lite::Scope* scope = concat_op->scope(); const lite::OpInfo* op_info = concat_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "converting " << op_type << " ... "; auto x_var_names = op_info->Input("X"); @@ -49,26 +42,28 @@ node_map_type ConcatConverter(const std::shared_ptr concat_op, for (auto x_var_name : x_var_names) { if (inputs_map.find(x_var_name) != inputs_map.end()) { output_node->set_dynamic_input_x(index + 1, *inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); } else { auto consty = std::make_shared(x_var_name); auto* x = scope->FindVar(x_var_name)->GetMutable(); - consty->set_attr_value(CvtFromLiteTensor(x)); + consty->set_attr_value(lite::npu::CvtFromLiteTensor(x)); output_node->set_dynamic_input_x(index + 1, *consty); - OpList::Global().add(consty); + lite::npu::OpList::Global().add(consty); } index++; } - OpList::Global().add(output_node); + lite::npu::OpList::Global().add(output_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = output_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(concat, paddle::lite::npu::bridge::ConcatConverter); +REGISTER_NPU_BRIDGE(concat, + paddle::lite::kernels::npu::bridges::ConcatConverter); diff --git a/lite/backends/npu/bridge/concat_op_test.cc b/lite/kernels/npu/bridges/concat_op_test.cc similarity index 95% rename from lite/backends/npu/bridge/concat_op_test.cc rename to lite/kernels/npu/bridges/concat_op_test.cc index f1bf3101b2dfd6ad363496cb442634ce63e2aa8e..f870bb0e7e2c0e7d854d152a0067bf657c19ada7 100644 --- a/lite/backends/npu/bridge/concat_op_test.cc +++ b/lite/kernels/npu/bridges/concat_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/concat_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { std::vector stride_numel(const DDim& ddim) { std::vector strides(ddim.size()); @@ -119,8 +120,9 @@ TEST(NPUBridges, concat) { test_concat({{3, 3, 5, 2}, {3, 3, 5, 6}}, 3); } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/conv_op.cc b/lite/kernels/npu/bridges/conv_op.cc similarity index 87% rename from lite/backends/npu/bridge/conv_op.cc rename to lite/kernels/npu/bridges/conv_op.cc index 1be3d17cb6430104c846feb14e76ab48fe43c544..2a4ae56a515b8119324d944e14d20f5ad4295fd3 100644 --- a/lite/backends/npu/bridge/conv_op.cc +++ b/lite/kernels/npu/bridges/conv_op.cc @@ -12,27 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/conv_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ConvConverter(const std::shared_ptr conv_op, const node_map_type& inputs_map) { auto scope = conv_op->scope(); auto op_info = conv_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " << op_type << "... "; // get input, filter and op attributes @@ -78,13 +72,13 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, // check input CHECK(inputs_map.count(input_var_name)); - OpList::Global().add(inputs_map.at(input_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(input_var_name)); // create filter node CHECK(!inputs_map.count(filter_var_name)); auto filter_const_node = std::make_shared(filter_var_name); - filter_const_node->set_attr_value(CvtFromLiteTensor(filter)); - OpList::Global().add(filter_const_node); + filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter)); + lite::npu::OpList::Global().add(filter_const_node); // create bias node if has bias // supports the bias nodes with the following dimensions @@ -93,7 +87,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, // 2: {n, oc, oh, ow} std::shared_ptr bias_node = nullptr; bool is_channel_bias = false; - if (HasInputArg(op_info, scope, "Bias")) { + if (lite::npu::HasInputArg(op_info, scope, "Bias")) { auto bias_var_name = op_info->Input("Bias").front(); auto* bias = scope->FindVar(bias_var_name)->GetMutable(); auto bias_dims = bias->dims(); @@ -121,10 +115,11 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, } else { // bias node with const data auto bias_const_node = std::make_shared(bias_var_name); - bias_const_node->set_attr_value(CvtFromLiteTensor(bias, bias_shape)); + bias_const_node->set_attr_value( + lite::npu::CvtFromLiteTensor(bias, bias_shape)); bias_node = bias_const_node; } - OpList::Global().add(bias_node); + lite::npu::OpList::Global().add(bias_node); } // create conv node and set input, filter, bias nodes and attributes @@ -147,7 +142,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, ge::AttrValue::LIST_INT({strides[0], strides[1]})); depthwise_conv_node->set_attr_kernel( ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]})); - OpList::Global().add(depthwise_conv_node); + lite::npu::OpList::Global().add(depthwise_conv_node); conv_node = depthwise_conv_node; // ConvolutionDepthwise Op doesn't support bias, so append Add node to // support bias @@ -155,7 +150,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto add_node = std::make_shared(unique_op_type + "/add"); add_node->set_input_x1(*depthwise_conv_node); add_node->set_input_x2(*bias_node); - OpList::Global().add(add_node); + lite::npu::OpList::Global().add(add_node); conv_node = add_node; } } else { @@ -174,7 +169,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, ge::AttrValue::LIST_INT({strides[0], strides[1]})); common_conv_node->set_attr_kernel( ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]})); - OpList::Global().add(common_conv_node); + lite::npu::OpList::Global().add(common_conv_node); conv_node = common_conv_node; // Convolution Op only support bias with dimension {1, oc, 1, 1}, // so append Add node if dimension is {1, oc, oh, ow} or (n, oc, oh, ow) @@ -185,7 +180,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto add_node = std::make_shared(unique_op_type + "/add"); add_node->set_input_x1(*common_conv_node); add_node->set_input_x2(*bias_node); - OpList::Global().add(add_node); + lite::npu::OpList::Global().add(add_node); conv_node = add_node; } } @@ -199,7 +194,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, std::make_shared(unique_op_type + "/relu"); relu_node->set_input_x(*conv_node); relu_node->set_attr_mode(1); - OpList::Global().add(relu_node); + lite::npu::OpList::Global().add(relu_node); outputs_map[op_info->Output("Output").front()] = relu_node; } else { outputs_map[op_info->Output("Output").front()] = conv_node; @@ -207,10 +202,12 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(conv2d, paddle::lite::npu::bridge::ConvConverter); -REGISTER_NPU_BRIDGE(depthwise_conv2d, paddle::lite::npu::bridge::ConvConverter); +REGISTER_NPU_BRIDGE(conv2d, paddle::lite::kernels::npu::bridges::ConvConverter); +REGISTER_NPU_BRIDGE(depthwise_conv2d, + paddle::lite::kernels::npu::bridges::ConvConverter); diff --git a/lite/backends/npu/bridge/conv_op_test.cc b/lite/kernels/npu/bridges/conv_op_test.cc similarity index 98% rename from lite/backends/npu/bridge/conv_op_test.cc rename to lite/kernels/npu/bridges/conv_op_test.cc index 27e1226eaf471ca8ec8c5a9100cbed09070aa83e..26309aa9e27a1f0a5f6093b44242434d9e29a173 100644 --- a/lite/backends/npu/bridge/conv_op_test.cc +++ b/lite/kernels/npu/bridges/conv_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/conv_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void conv_ref(const std::shared_ptr op) { Scope* scope = op->scope(); @@ -268,8 +269,9 @@ TEST(NPUBridges, conv) { #endif } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/conv_transpose_op.cc b/lite/kernels/npu/bridges/conv_transpose_op.cc similarity index 80% rename from lite/backends/npu/bridge/conv_transpose_op.cc rename to lite/kernels/npu/bridges/conv_transpose_op.cc index e27132c21658d31a857d8ca70fad698ba071a7d0..f8392ec8d9b08c86a571b47187715c5bb251570f 100644 --- a/lite/backends/npu/bridge/conv_transpose_op.cc +++ b/lite/kernels/npu/bridges/conv_transpose_op.cc @@ -12,20 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/conv_transpose_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ConvTransposeConverter( const std::shared_ptr conv_transpose_op, @@ -33,7 +27,7 @@ node_map_type ConvTransposeConverter( auto scope = conv_transpose_op->scope(); auto op_info = conv_transpose_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " << op_type << "... "; // get input, output and op attributes @@ -70,21 +64,22 @@ node_map_type ConvTransposeConverter( } auto input_sizes_const_node = std::make_shared(unique_op_type + "/input_size"); - input_sizes_const_node->set_attr_value(CreateTensorAndFillData(output_shape)); + input_sizes_const_node->set_attr_value( + lite::npu::CreateTensorAndFillData(output_shape)); conv_transpose_node->set_input_input_sizes(*input_sizes_const_node); - OpList::Global().add(input_sizes_const_node); + lite::npu::OpList::Global().add(input_sizes_const_node); // create filter node CHECK(!inputs_map.count(filter_var_name)); auto filter_const_node = std::make_shared(filter_var_name); - filter_const_node->set_attr_value(CvtFromLiteTensor(filter)); + filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter)); conv_transpose_node->set_input_filter(*filter_const_node); - OpList::Global().add(filter_const_node); + lite::npu::OpList::Global().add(filter_const_node); // set input node CHECK(inputs_map.count(input_var_name)); conv_transpose_node->set_input_x(*inputs_map.at(input_var_name)); - OpList::Global().add(inputs_map.at(input_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(input_var_name)); // set attributes conv_transpose_node->set_attr_mode(1); @@ -99,11 +94,11 @@ node_map_type ConvTransposeConverter( ge::AttrValue::LIST_INT({strides[0], strides[1]})); conv_transpose_node->set_attr_kernel( ge::AttrValue::LIST_INT({filter_shape[2], filter_shape[3]})); - OpList::Global().add(conv_transpose_node); + lite::npu::OpList::Global().add(conv_transpose_node); // append add node to add bias if has bias std::shared_ptr output_node = conv_transpose_node; - if (HasInputArg(op_info, scope, "Bias")) { + if (lite::npu::HasInputArg(op_info, scope, "Bias")) { // create bias node auto bias_var_name = op_info->Input("Bias").front(); CHECK(!inputs_map.count(bias_var_name)); @@ -112,13 +107,13 @@ node_map_type ConvTransposeConverter( CHECK_EQ(channel_size, filter_shape[1] * groups); auto bias_const_node = std::make_shared(bias_var_name); bias_const_node->set_attr_value( - CvtFromLiteTensor(bias, {1, channel_size, 1, 1})); - OpList::Global().add(bias_const_node); + lite::npu::CvtFromLiteTensor(bias, {1, channel_size, 1, 1})); + lite::npu::OpList::Global().add(bias_const_node); // append add node to add bias node auto add_node = std::make_shared(unique_op_type + "/add"); add_node->set_input_x1(*conv_transpose_node); add_node->set_input_x2(*bias_const_node); - OpList::Global().add(add_node); + lite::npu::OpList::Global().add(add_node); output_node = add_node; } @@ -129,7 +124,7 @@ node_map_type ConvTransposeConverter( std::make_shared(unique_op_type + "/relu"); relu_node->set_input_x(*output_node); relu_node->set_attr_mode(1); - OpList::Global().add(relu_node); + lite::npu::OpList::Global().add(relu_node); outputs_map[op_info->Output("Output").front()] = relu_node; } else { outputs_map[op_info->Output("Output").front()] = output_node; @@ -137,10 +132,12 @@ node_map_type ConvTransposeConverter( return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(conv2d_transpose, - paddle::lite::npu::bridge::ConvTransposeConverter); +REGISTER_NPU_BRIDGE( + conv2d_transpose, + paddle::lite::kernels::npu::bridges::ConvTransposeConverter); diff --git a/lite/backends/npu/bridge/conv_transpose_op_test.cc b/lite/kernels/npu/bridges/conv_transpose_op_test.cc similarity index 98% rename from lite/backends/npu/bridge/conv_transpose_op_test.cc rename to lite/kernels/npu/bridges/conv_transpose_op_test.cc index 02e3c7a1ce1a963474db1aa38ccf743f966cbab0..a009ef588e1ddf9561f895e977fbb08a98b2d51b 100644 --- a/lite/backends/npu/bridge/conv_transpose_op_test.cc +++ b/lite/kernels/npu/bridges/conv_transpose_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/conv_transpose_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void add_bias_with_relu(DType* data, @@ -360,8 +361,9 @@ TEST(NPUBridges, conv_transpose) { #endif } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/elementwise_ops.cc b/lite/kernels/npu/bridges/elementwise_ops.cc similarity index 71% rename from lite/backends/npu/bridge/elementwise_ops.cc rename to lite/kernels/npu/bridges/elementwise_ops.cc index 5459d819bbd2ca7bd8f3dad90bab6a6cf6faa4e8..6ba7acc254c0c352fe46aeee77ac3a5d25c4582f 100644 --- a/lite/backends/npu/bridge/elementwise_ops.cc +++ b/lite/kernels/npu/bridges/elementwise_ops.cc @@ -12,20 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/elementwise_ops.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ElementwiseConverter( const std::shared_ptr elementwise_op, @@ -33,7 +27,7 @@ node_map_type ElementwiseConverter( auto scope = elementwise_op->scope(); auto op_info = elementwise_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "converting elementwise..."; std::shared_ptr elementwise_node = @@ -47,20 +41,20 @@ node_map_type ElementwiseConverter( CHECK(inputs_map.find(x_var_name) != inputs_map.end()); elementwise_node->set_input_x1(*inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); if (inputs_map.find(y_var_name) != inputs_map.end()) { elementwise_node->set_input_x2(*inputs_map.at(y_var_name)); - OpList::Global().add(inputs_map.at(y_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(y_var_name)); } else { auto consty = std::make_shared(y_var_name); auto* y = scope->FindVar(y_var_name)->GetMutable(); - consty->set_attr_value(CvtFromLiteTensor(y)); + consty->set_attr_value(lite::npu::CvtFromLiteTensor(y)); elementwise_node->set_input_x2(*consty); - OpList::Global().add(consty); + lite::npu::OpList::Global().add(consty); } - OpList::Global().add(elementwise_node); + lite::npu::OpList::Global().add(elementwise_node); // paddlelite has sum only elementwise_node->set_attr_mode(1); @@ -70,10 +64,11 @@ node_map_type ElementwiseConverter( return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle REGISTER_NPU_BRIDGE(elementwise_add, - paddle::lite::npu::bridge::ElementwiseConverter); + paddle::lite::kernels::npu::bridges::ElementwiseConverter); diff --git a/lite/backends/npu/bridge/elementwise_ops_test.cc b/lite/kernels/npu/bridges/elementwise_ops_test.cc similarity index 96% rename from lite/backends/npu/bridge/elementwise_ops_test.cc rename to lite/kernels/npu/bridges/elementwise_ops_test.cc index ff82daec100278be6e577c8efcef3995095adb9f..0e2fc9f2622d839c8eda6f82aab2759053b3e23d 100644 --- a/lite/backends/npu/bridge/elementwise_ops_test.cc +++ b/lite/kernels/npu/bridges/elementwise_ops_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/elementwise_ops.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void elementwise_add_ref(const std::shared_ptr op) { @@ -173,8 +174,9 @@ TEST(NPUBridges, elementwise_add) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/npu/bridges/fc_op.cc b/lite/kernels/npu/bridges/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1233ccedd4086bfca36fa4f1ba996814cc68127d --- /dev/null +++ b/lite/kernels/npu/bridges/fc_op.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace npu { +namespace bridges { + +node_map_type FCConverter(const std::shared_ptr fc_op, + const node_map_type& inputs_map) { + auto scope = fc_op->scope(); + auto op_info = fc_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::npu::UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; + + auto fc_node = std::make_shared(unique_op_type); + + auto x_var_name = op_info->Input("Input").front(); + auto w_var_name = op_info->Input("W").front(); + + int in_num_col_dims = op_info->GetAttr("in_num_col_dims"); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto w = scope->FindVar(w_var_name)->GetMutable(); + auto x_dims = x->dims(); + auto w_dims = w->dims(); + + CHECK_GE(x_dims.size(), 2UL); + CHECK_EQ(w_dims.size(), 2UL); + + int m = x_dims.Slice(0, in_num_col_dims).production(); + int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production(); + int n = w_dims[1]; + CHECK_EQ(k * n, w_dims.production()); + VLOG(3) << "x dims: " << x_dims << " w dims: " << w_dims << " m: " << m + << " k: " << k << " n: " << n; + + CHECK(inputs_map.count(x_var_name)); + CHECK(!inputs_map.count(w_var_name)); + + // reshape x to (m, k, 1, 1) + auto reshaped_x_node = + std::make_shared(x_var_name + "_reshape"); + reshaped_x_node->set_input_tensor(*inputs_map.at(x_var_name)); + reshaped_x_node->set_attr_shape({m, k, 1, 1}); + reshaped_x_node->set_attr_axis(0); + fc_node->set_input_x(*reshaped_x_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(reshaped_x_node); + + // create w const node, set its shape to (k, n, 1, 1) and fill with + // the transposed w tensor + auto w_const_node = std::make_shared(w_var_name); + ge::TensorDesc w_const_desc( + ge::Shape({n, k, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); + ge::TensorPtr w_const_tensor = std::make_shared(); + w_const_tensor->SetTensorDesc(w_const_desc); + auto w_data = w->mutable_data(); + std::vector transposed_w_data(w_dims.production()); + for (int i = 0; i < k; i++) { + for (int j = 0; j < n; j++) { + transposed_w_data[j * k + i] = w_data[i * n + j]; + } + } + w_const_tensor->SetData(reinterpret_cast(transposed_w_data.data()), + transposed_w_data.size() * sizeof(float)); + w_const_node->set_attr_value(w_const_tensor); + fc_node->set_input_w(*w_const_node); + lite::npu::OpList::Global().add(w_const_node); + + // add bias node if bias tensor exists + if (lite::npu::HasInputArg(op_info, scope, "Bias")) { + auto bias_var_name = op_info->Input("Bias").front(); + auto bias = scope->FindVar(bias_var_name)->GetMutable(); + auto bias_dims = bias->dims(); + CHECK(!inputs_map.count(bias_var_name)); + CHECK_EQ(bias_dims.production(), n); + + auto bias_const_node = std::make_shared(bias_var_name); + bias_const_node->set_attr_value( + lite::npu::CvtFromLiteTensor(bias, {1, n, 1, 1})); + fc_node->set_input_b(*bias_const_node); + lite::npu::OpList::Global().add(bias_const_node); + } + lite::npu::OpList::Global().add(fc_node); + + // reshape output of fc_node from (m, n, 1, 1) to (m, n) + auto reshaped_fc_node = + std::make_shared(unique_op_type + "_reshape"); + reshaped_fc_node->set_input_tensor(*fc_node); + reshaped_fc_node->set_attr_shape({m, n}); + reshaped_fc_node->set_attr_axis(0); + lite::npu::OpList::Global().add(reshaped_fc_node); + + node_map_type outputs_map; + outputs_map[op_info->Output("Out").front()] = reshaped_fc_node; + return outputs_map; +} + +} // namespace bridges +} // namespace npu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(fc, paddle::lite::kernels::npu::bridges::FCConverter); diff --git a/lite/backends/npu/bridge/fc_op_test.cc b/lite/kernels/npu/bridges/fc_op_test.cc similarity index 84% rename from lite/backends/npu/bridge/fc_op_test.cc rename to lite/kernels/npu/bridges/fc_op_test.cc index 7bfee2034fd96cf12582b36cf766ec9170ad965b..77015236e2eed847d0ec0ea5c06e646e5893f29a 100644 --- a/lite/backends/npu/bridge/fc_op_test.cc +++ b/lite/kernels/npu/bridges/fc_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/fc_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void fc_ref(const std::shared_ptr op) { Scope* scope = op->scope(); @@ -67,41 +68,36 @@ void fc_ref(const std::shared_ptr op) { } } -void test_fc(const std::vector& x_shape, +void test_fc(const std::vector& input_shape, const std::vector& w_shape, int in_num_col_dims, bool has_bias) { CHECK_EQ(w_shape.size(), 2UL); - const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& supported_lists = bridges.AllFunctions(); CHECK(bridges.HasType("fc")); Scope scope; - std::string x_var_name("Input"); + std::string input_var_name("Input"); std::string w_var_name("W"); std::string bias_var_name("Bias"); std::string out_var_name("Out"); std::string out_ref_var_name("out_ref"); - auto* x = scope.Var(x_var_name)->GetMutable(); + auto* input = scope.Var(input_var_name)->GetMutable(); auto* w = scope.Var(w_var_name)->GetMutable(); auto* out = scope.Var(out_var_name)->GetMutable(); auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); - x->Resize(x_shape); - input->Resize({bs, ic, ih, iw}); - - // get w shape - auto in_mat_dims = input->dims().Flatten2D(in_num_col_dims); - std::vector w_shape = {in_mat_dims[1], out_num_classes}; + input->Resize(input_shape); w->Resize(w_shape); - FillTensor(x); + FillTensor(input); FillTensor(w); // create fc op cpp::OpDesc fc_op_desc; fc_op_desc.SetType("fc"); - fc_op_desc.SetInput("Input", {x_var_name}); + fc_op_desc.SetInput("Input", {input_var_name}); fc_op_desc.SetInput("W", {w_var_name}); fc_op_desc.SetOutput("Out", {out_var_name}); fc_op_desc.SetAttr("in_num_col_dims", static_cast(in_num_col_dims)); @@ -113,7 +109,7 @@ void test_fc(const std::vector& x_shape, } auto fc_op = CreateOp(fc_op_desc, &scope); - LauchOp(fc_op, {x_var_name}, {out_var_name}); + LauchOp(fc_op, {input_var_name}, {out_var_name}); out_ref->CopyDataFrom(*out); // compare results @@ -123,10 +119,6 @@ void test_fc(const std::vector& x_shape, for (int i = 0; i < out->dims().production(); i++) { EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); } - - // model release - npu::OpList::Global().clear(); - npu::DeviceInfo::Global().Clear(); } TEST(NPUBridges, fc) { @@ -134,11 +126,13 @@ TEST(NPUBridges, fc) { test_fc({1, 8, 8, 1}, {8, 4}, 2, use_bias); test_fc({1, 5, 5, 1}, {5, 7}, 2, use_bias); test_fc({1, 4, 1, 1}, {4, 8}, 1, use_bias); + test_fc({1, 1024, 1, 1}, {1024, 1000}, 1, use_bias); } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/interpolate_op.cc b/lite/kernels/npu/bridges/interpolate_op.cc similarity index 80% rename from lite/backends/npu/bridge/interpolate_op.cc rename to lite/kernels/npu/bridges/interpolate_op.cc index 83cae61e3f895e49638d5ee75e4c98a1503e626d..b0cfa1c28fae68ec936e8715fb25d59853d063bc 100644 --- a/lite/backends/npu/bridge/interpolate_op.cc +++ b/lite/kernels/npu/bridges/interpolate_op.cc @@ -12,19 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type InterpolateConverter( const std::shared_ptr interpolate_op, @@ -32,13 +27,13 @@ node_map_type InterpolateConverter( auto scope = interpolate_op->scope(); auto op_info = interpolate_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; // get input, output and attributes from lite op auto x_var_name = op_info->Input("X").front(); CHECK(inputs_map.count(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); auto x = scope->FindVar(x_var_name)->GetMutable(); auto x_dims = x->dims(); @@ -63,7 +58,7 @@ node_map_type InterpolateConverter( // update out_h and out_w if has OutSize bool inputs_map_has_w = false; - if (HasInputArg(op_info, scope, "OutSize")) { + if (lite::npu::HasInputArg(op_info, scope, "OutSize")) { auto out_size_var_name = op_info->Input("OutSize").front(); if (inputs_map.count(out_size_var_name)) { inputs_map_has_w = true; @@ -82,12 +77,12 @@ node_map_type InterpolateConverter( auto interp_method = op_info->GetAttr("interp_method"); if (interp_method == "bilinear") { auto interp_node = std::make_shared(unique_op_type); - OpList::Global().add(interp_node); + lite::npu::OpList::Global().add(interp_node); interp_node->set_input_x(*inputs_map.at(x_var_name)); if (inputs_map_has_w) { auto out_size_var_name = op_info->Input("OutSize").front(); interp_node->set_input_w(*inputs_map.at(out_size_var_name)); - OpList::Global().add(inputs_map.at(out_size_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(out_size_var_name)); } else { const float largest_multiple = 7.0f; float multiple = static_cast(x_h * x_w) / (out_h * out_w); @@ -98,9 +93,9 @@ node_map_type InterpolateConverter( auto w_const_node = std::make_shared(unique_op_type + "/w"); w_const_node->set_attr_value( - CreateTensorAndFillData(std::vector({out_h, out_w}))); + lite::npu::CreateTensorAndFillData(std::vector({out_h, out_w}))); interp_node->set_input_w(*w_const_node); - OpList::Global().add(w_const_node); + lite::npu::OpList::Global().add(w_const_node); } interp_node->set_attr_output_dim_mode( 2); // 0: zoom_factor, 1: shrink_factor, 2: height/width @@ -109,19 +104,19 @@ node_map_type InterpolateConverter( } else if (interp_method == "nearest") { auto interp_node = std::make_shared(unique_op_type); - OpList::Global().add(interp_node); + lite::npu::OpList::Global().add(interp_node); interp_node->set_input_image(*inputs_map.at(x_var_name)); if (inputs_map_has_w) { auto out_size_var_name = op_info->Input("OutSize").front(); interp_node->set_input_size(*inputs_map.at(out_size_var_name)); - OpList::Global().add(inputs_map.at(out_size_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(out_size_var_name)); } else { auto w_const_node = std::make_shared(unique_op_type + "/w"); w_const_node->set_attr_value( - CreateTensorAndFillData(std::vector({out_h, out_w}))); + lite::npu::CreateTensorAndFillData(std::vector({out_h, out_w}))); interp_node->set_input_size(*w_const_node); - OpList::Global().add(w_const_node); + lite::npu::OpList::Global().add(w_const_node); } interp_node->set_attr_align_corners(align_corners); outputs_map[op_info->Output("Out").front()] = interp_node; @@ -132,12 +127,13 @@ node_map_type InterpolateConverter( return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle REGISTER_NPU_BRIDGE(bilinear_interp, - paddle::lite::npu::bridge::InterpolateConverter); + paddle::lite::kernels::npu::bridges::InterpolateConverter); REGISTER_NPU_BRIDGE(nearest_interp, - paddle::lite::npu::bridge::InterpolateConverter); + paddle::lite::kernels::npu::bridges::InterpolateConverter); diff --git a/lite/backends/npu/bridge/interpolate_op_test.cc b/lite/kernels/npu/bridges/interpolate_op_test.cc similarity index 98% rename from lite/backends/npu/bridge/interpolate_op_test.cc rename to lite/kernels/npu/bridges/interpolate_op_test.cc index 79dd612c59c51287710ee239ed069bc27752c488..c061fbfe5ff2741bff0ca7427519a37e14606899 100644 --- a/lite/backends/npu/bridge/interpolate_op_test.cc +++ b/lite/kernels/npu/bridges/interpolate_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/interpolate_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void bilinear_interp_ref(const std::shared_ptr op) { @@ -393,8 +394,9 @@ TEST(NPUBridges, bilinear_interp) { #endif } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/mul_op.cc b/lite/kernels/npu/bridges/mul_op.cc similarity index 82% rename from lite/backends/npu/bridge/mul_op.cc rename to lite/kernels/npu/bridges/mul_op.cc index 290f3d88f874169f5bee629dd504791e172a4718..ce1662c71d62a6d73a7a3b9ce594b0dd80b6fec1 100644 --- a/lite/backends/npu/bridge/mul_op.cc +++ b/lite/kernels/npu/bridges/mul_op.cc @@ -12,21 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/mul_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" -#include "lite/backends/npu/npu_helper.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { // Note: inputs_map the var_name contains only the data, the weight should be // handle in this converter @@ -35,7 +28,8 @@ node_map_type MulConverter(const std::shared_ptr mul_op, LOG(INFO) << "converting mul..."; lite::Scope* scope = mul_op->scope(); const lite::OpInfo* op_info = mul_op->op_info(); - auto output_node = std::make_shared(UniqueName("mul")); + auto output_node = + std::make_shared(lite::npu::UniqueName("mul")); auto x_var_name = op_info->Input("X").front(); auto y_var_name = op_info->Input("Y").front(); @@ -67,8 +61,8 @@ node_map_type MulConverter(const std::shared_ptr mul_op, reshapex->set_input_tensor(*xsrc); reshapex->set_attr_shape({m, k}); reshapex->set_attr_axis(0); - OpList::Global().add(xsrc); - OpList::Global().add(reshapex); + lite::npu::OpList::Global().add(xsrc); + lite::npu::OpList::Global().add(reshapex); output_node->set_input_x(*reshapex); } else { auto constx = std::make_shared(x_var_name); @@ -80,7 +74,7 @@ node_map_type MulConverter(const std::shared_ptr mul_op, auto* pdata = reinterpret_cast(xtensor->mutable_data()); ptensor->SetData(pdata, size * sizeof(float)); constx->set_attr_value(ptensor); - OpList::Global().add(constx); + lite::npu::OpList::Global().add(constx); output_node->set_input_x(*constx); } @@ -90,8 +84,8 @@ node_map_type MulConverter(const std::shared_ptr mul_op, reshapey->set_input_tensor(*ysrc); reshapey->set_attr_shape({k, n}); reshapey->set_attr_axis(0); - OpList::Global().add(ysrc); - OpList::Global().add(reshapey); + lite::npu::OpList::Global().add(ysrc); + lite::npu::OpList::Global().add(reshapey); output_node->set_input_w(*reshapey); } else { auto consty = std::make_shared(y_var_name); @@ -103,20 +97,21 @@ node_map_type MulConverter(const std::shared_ptr mul_op, auto* pdata = reinterpret_cast(ytensor->mutable_data()); ptensor->SetData(pdata, size * sizeof(float)); consty->set_attr_value(ptensor); - OpList::Global().add(consty); + lite::npu::OpList::Global().add(consty); output_node->set_input_w(*consty); } - OpList::Global().add(output_node); + lite::npu::OpList::Global().add(output_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = output_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(mul, paddle::lite::npu::bridge::MulConverter); +REGISTER_NPU_BRIDGE(mul, paddle::lite::kernels::npu::bridges::MulConverter); diff --git a/lite/backends/npu/bridge/mul_op_test.cc b/lite/kernels/npu/bridges/mul_op_test.cc similarity index 87% rename from lite/backends/npu/bridge/mul_op_test.cc rename to lite/kernels/npu/bridges/mul_op_test.cc index c28d0487cc181c5a5af77fb61191bb20870ee0dd..9bcd72bb35b7bf5de19d880f4ad535fec8e480fa 100644 --- a/lite/backends/npu/bridge/mul_op_test.cc +++ b/lite/kernels/npu/bridges/mul_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/mul_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void mul_ref(const std::shared_ptr op) { Scope* scope = op->scope(); @@ -55,7 +56,7 @@ void test_mul(const std::vector& x_shape, const std::vector& y_shape, int x_num_col_dims, int y_num_col_dims) { - const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& supported_lists = bridges.AllFunctions(); CHECK(bridges.HasType("mul")); @@ -69,15 +70,6 @@ void test_mul(const std::vector& x_shape, auto* out = scope.Var(out_var_name)->GetMutable(); auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); x->Resize(x_shape); - - // get y shape - auto x_mat_dims = x->dims().Flatten2D(x_num_col_dims); - std::vector y_shape; - for (int i = 0; i < y_num_col_dims - 1; i++) { - y_shape.push_back(1); - } - y_shape.push_back(x_mat_dims[1]); - y_shape.push_back(o); y->Resize(y_shape); FillTensor(x); @@ -104,10 +96,6 @@ void test_mul(const std::vector& x_shape, for (int i = 0; i < out->dims().production(); i++) { EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); } - - // model release - npu::OpList::Global().clear(); - npu::DeviceInfo::Global().Clear(); } TEST(NPUBridges, mul) { @@ -116,8 +104,9 @@ TEST(NPUBridges, mul) { test_mul({1, 4, 1, 1}, {4, 8}, 1, 1); } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/pad2d_op.cc b/lite/kernels/npu/bridges/pad2d_op.cc similarity index 73% rename from lite/backends/npu/bridge/pad2d_op.cc rename to lite/kernels/npu/bridges/pad2d_op.cc index 2c67383c0c9733265202df2ef3f0a1432701cb1a..acc3b6adf9a89ffc4d984082d7330c30d46362ba 100644 --- a/lite/backends/npu/bridge/pad2d_op.cc +++ b/lite/kernels/npu/bridges/pad2d_op.cc @@ -12,34 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type Pad2dConverter(const std::shared_ptr pad2d_op, const node_map_type& inputs_map) { auto scope = pad2d_op->scope(); auto op_info = pad2d_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; std::shared_ptr pad2d_node = std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); pad2d_node->set_input_x(*inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(pad2d_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(pad2d_node); auto mode = op_info->GetAttr("mode"); if (mode == "constant") { @@ -58,17 +53,19 @@ node_map_type Pad2dConverter(const std::shared_ptr pad2d_op, padding.insert(padding.begin(), xds * 2 - 4, 0); auto npu_padding = std::make_shared(unique_op_type + "/padding"); - npu_padding->set_attr_value(CreateTensorAndFillData(padding, {xds, 2})); + npu_padding->set_attr_value( + lite::npu::CreateTensorAndFillData(padding, {xds, 2})); pad2d_node->set_input_padding(*npu_padding); - OpList::Global().add(npu_padding); + lite::npu::OpList::Global().add(npu_padding); if (mode == "constant") { auto pad_value = op_info->GetAttr("pad_value"); auto npu_pad_value = std::make_shared(unique_op_type + "/pad_value"); - npu_pad_value->set_attr_value(CreateTensorAndFillData({pad_value})); + npu_pad_value->set_attr_value( + lite::npu::CreateTensorAndFillData({pad_value})); pad2d_node->set_input_constant_values(*npu_pad_value); - OpList::Global().add(npu_pad_value); + lite::npu::OpList::Global().add(npu_pad_value); pad2d_node->set_attr_T(0); // type of pad_value: 0:float 3:int32 } @@ -78,9 +75,10 @@ node_map_type Pad2dConverter(const std::shared_ptr pad2d_op, return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(pad2d, paddle::lite::npu::bridge::Pad2dConverter); +REGISTER_NPU_BRIDGE(pad2d, paddle::lite::kernels::npu::bridges::Pad2dConverter); diff --git a/lite/backends/npu/bridge/pad2d_op_test.cc b/lite/kernels/npu/bridges/pad2d_op_test.cc similarity index 97% rename from lite/backends/npu/bridge/pad2d_op_test.cc rename to lite/kernels/npu/bridges/pad2d_op_test.cc index 7a10e0a5592997ba972f9a8d7d59a5de7287830b..db39deb2e98bfae2c220b8addc0c18f105fd2c9c 100644 --- a/lite/backends/npu/bridge/pad2d_op_test.cc +++ b/lite/kernels/npu/bridges/pad2d_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/pad2d_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void pad2d_ref(const std::shared_ptr op) { @@ -180,8 +181,9 @@ TEST(NPUBridges, pad2d) { #endif } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/paddle_use_npu_bridges.h b/lite/kernels/npu/bridges/paddle_use_npu_bridges.h similarity index 96% rename from lite/backends/npu/bridge/paddle_use_npu_bridges.h rename to lite/kernels/npu/bridges/paddle_use_npu_bridges.h index 404d0039540ae8c37012cb081f22777358d41080..631f6f1499538ae35662a58980b7418e25585669 100644 --- a/lite/backends/npu/bridge/paddle_use_npu_bridges.h +++ b/lite/kernels/npu/bridges/paddle_use_npu_bridges.h @@ -14,7 +14,7 @@ #pragma once -#include "lite/backends/npu/bridge/registry.h" +#include "lite/kernels/npu/bridges/registry.h" USE_NPU_BRIDGE(mul); USE_NPU_BRIDGE(fc); diff --git a/lite/backends/npu/bridge/pool_op.cc b/lite/kernels/npu/bridges/pool_op.cc similarity index 80% rename from lite/backends/npu/bridge/pool_op.cc rename to lite/kernels/npu/bridges/pool_op.cc index aebfd68856da6e5ad416e65861d845ba16d83214..66cb27d7c34be707129f78ff15eaf4848f6878c0 100644 --- a/lite/backends/npu/bridge/pool_op.cc +++ b/lite/kernels/npu/bridges/pool_op.cc @@ -12,27 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/pool_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type PoolConverter(const std::shared_ptr pool_op, const node_map_type& inputs_map) { auto scope = pool_op->scope(); auto op_info = pool_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; std::shared_ptr pool_node = @@ -73,17 +67,18 @@ node_map_type PoolConverter(const std::shared_ptr pool_op, pool_node->set_attr_ceil_mode(npu_ceil_mode); // output_node->set_attr_data_mode(npu_data_mode); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(pool_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(pool_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = pool_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(pool2d, paddle::lite::npu::bridge::PoolConverter); +REGISTER_NPU_BRIDGE(pool2d, paddle::lite::kernels::npu::bridges::PoolConverter); diff --git a/lite/backends/npu/bridge/pool_op_test.cc b/lite/kernels/npu/bridges/pool_op_test.cc similarity index 97% rename from lite/backends/npu/bridge/pool_op_test.cc rename to lite/kernels/npu/bridges/pool_op_test.cc index 86ad89308489fb1459f4e3c436753758cc612683..d4543a6ae128a0c534b216e42c6f3488a1dbfbf9 100644 --- a/lite/backends/npu/bridge/pool_op_test.cc +++ b/lite/kernels/npu/bridges/pool_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/pool_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void pool_ref(const std::shared_ptr op) { Scope* scope = op->scope(); @@ -240,8 +241,9 @@ TEST(NPUBridges, pool) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/registry.cc b/lite/kernels/npu/bridges/registry.cc similarity index 88% rename from lite/backends/npu/bridge/registry.cc rename to lite/kernels/npu/bridges/registry.cc index 180e0aa46eb55ab74498bc3e58990bc7f0767072..ead7567f41d5bb5e8c7e0f70cd9ec7f3542e196b 100644 --- a/lite/backends/npu/bridge/registry.cc +++ b/lite/kernels/npu/bridges/registry.cc @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/npu/bridge/registry.h" +#include "lite/kernels/npu/bridges/registry.h" #include namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { Factory& Factory::Instance() { static Factory g_npu_bridge; @@ -33,7 +34,8 @@ void Factory::Insert(const std::string& op_type, const func_type& func_name) { map_.insert(std::make_pair(op_type, func_name)); } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/registry.h b/lite/kernels/npu/bridges/registry.h similarity index 92% rename from lite/backends/npu/bridge/registry.h rename to lite/kernels/npu/bridges/registry.h index 979760c816e9ef24200eb9a1cde8e691d6ee12f7..efbf2461c0c7e9f79b7e053bdf082f243f5d3033 100644 --- a/lite/backends/npu/bridge/registry.h +++ b/lite/kernels/npu/bridges/registry.h @@ -25,8 +25,9 @@ namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { // var_name, npu node point using node_map_type = @@ -49,8 +50,9 @@ class Factory { DISALLOW_COPY_AND_ASSIGN(Factory); }; -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle @@ -73,8 +75,8 @@ class Factory { __reg_npu_bridge_##op_type##__, \ "REGISTER_NPU_BRIDGE must be called in global namespace only once!"); \ int __reg_npu_bridge_##op_type##_Insert() { \ - paddle::lite::npu::bridge::Factory::Instance().Insert(#op_type, \ - cvt_func_name); \ + paddle::lite::kernels::npu::bridges::Factory::Instance().Insert( \ + #op_type, cvt_func_name); \ return 0; \ } diff --git a/lite/backends/npu/bridge/reshape_op.cc b/lite/kernels/npu/bridges/reshape_op.cc similarity index 79% rename from lite/backends/npu/bridge/reshape_op.cc rename to lite/kernels/npu/bridges/reshape_op.cc index af160f9c72d68219979a54bac203b743de733786..50111222dd6e22ad13e675864fc4c8999ee474ff 100644 --- a/lite/backends/npu/bridge/reshape_op.cc +++ b/lite/kernels/npu/bridges/reshape_op.cc @@ -13,26 +13,21 @@ // limitations under the License. #include "lite/operators/reshape_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ReshapeConverter(const std::shared_ptr reshape_op, const node_map_type& inputs_map) { auto scope = reshape_op->scope(); auto op_info = reshape_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; // get input, output and op attributes @@ -44,10 +39,10 @@ node_map_type ReshapeConverter(const std::shared_ptr reshape_op, auto reshape_node = std::make_shared(unique_op_type); CHECK(inputs_map.count(x_var_name)); reshape_node->set_input_tensor(*inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); // read shape from actual shape tensor as input "w" if 'Shape' is found - if (HasInputArg(op_info, scope, "Shape")) { + if (lite::npu::HasInputArg(op_info, scope, "Shape")) { auto actual_shape_var_name = op_info->Input("Shape").front(); if (!inputs_map.count(actual_shape_var_name)) { auto actual_shape = @@ -66,13 +61,14 @@ node_map_type ReshapeConverter(const std::shared_ptr reshape_op, } auto actual_shape_const_node = std::make_shared(actual_shape_var_name); - actual_shape_const_node->set_attr_value(CreateTensorAndFillData( - std::vector(out_shape.begin(), out_shape.end()))); + actual_shape_const_node->set_attr_value( + lite::npu::CreateTensorAndFillData( + std::vector(out_shape.begin(), out_shape.end()))); reshape_node->set_input_w(*actual_shape_const_node); - OpList::Global().add(actual_shape_const_node); + lite::npu::OpList::Global().add(actual_shape_const_node); } else { reshape_node->set_input_w(*inputs_map.at(actual_shape_var_name)); - OpList::Global().add(inputs_map.at(actual_shape_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(actual_shape_var_name)); } } else { auto shape = op_info->GetAttr>("shape"); @@ -86,7 +82,7 @@ node_map_type ReshapeConverter(const std::shared_ptr reshape_op, reshape_node->set_attr_shape( ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); } - OpList::Global().add(reshape_node); + lite::npu::OpList::Global().add(reshape_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = reshape_node; @@ -106,16 +102,19 @@ node_map_type ReshapeConverter(const std::shared_ptr reshape_op, xshape_node->set_input_tensor(*inputs_map.at(x_var_name)); xshape_node->set_attr_shape( ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end())); - OpList::Global().add(xshape_node); + lite::npu::OpList::Global().add(xshape_node); outputs_map[op_info->Output("XShape").front()] = xshape_node; } return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(reshape, paddle::lite::npu::bridge::ReshapeConverter); -REGISTER_NPU_BRIDGE(reshape2, paddle::lite::npu::bridge::ReshapeConverter); +REGISTER_NPU_BRIDGE(reshape, + paddle::lite::kernels::npu::bridges::ReshapeConverter); +REGISTER_NPU_BRIDGE(reshape2, + paddle::lite::kernels::npu::bridges::ReshapeConverter); diff --git a/lite/backends/npu/bridge/reshape_op_test.cc b/lite/kernels/npu/bridges/reshape_op_test.cc similarity index 97% rename from lite/backends/npu/bridge/reshape_op_test.cc rename to lite/kernels/npu/bridges/reshape_op_test.cc index 4a75961fdf9c62192d4637f783ff07eff5783a30..d675b5cac2bc8975e6ed9f8521a700f579d0e2b7 100644 --- a/lite/backends/npu/bridge/reshape_op_test.cc +++ b/lite/kernels/npu/bridges/reshape_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/reshape_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void reshape_ref(const std::shared_ptr op) { auto scope = op->scope(); @@ -190,8 +191,9 @@ TEST(NPUBridges, reshape) { #endif } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/scale_op.cc b/lite/kernels/npu/bridges/scale_op.cc similarity index 73% rename from lite/backends/npu/bridge/scale_op.cc rename to lite/kernels/npu/bridges/scale_op.cc index a884b34856d336682036408c329efe3b0323909d..4e305b15f2f485317d5040be11cd92269d08baa8 100644 --- a/lite/backends/npu/bridge/scale_op.cc +++ b/lite/kernels/npu/bridges/scale_op.cc @@ -12,27 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/scale_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ScaleConverter(const std::shared_ptr scale_op, const node_map_type& inputs_map) { auto scope = scale_op->scope(); auto op_info = scale_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; // get input, output and op attributes @@ -52,26 +46,26 @@ node_map_type ScaleConverter(const std::shared_ptr scale_op, auto scale_node = std::make_shared(unique_op_type); CHECK(inputs_map.count(x_var_name)); scale_node->set_input_x(*inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(scale_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(scale_node); // add filter node(fill with scale) auto filter_const_node = std::make_shared(unique_op_type + "/filter"); filter_const_node->set_attr_value( - CreateTensorAndFillData(scale, scale_bias_shape)); + lite::npu::CreateTensorAndFillData(scale, scale_bias_shape)); scale_node->set_input_filter(*filter_const_node); - OpList::Global().add(filter_const_node); + lite::npu::OpList::Global().add(filter_const_node); // add bias node(fill with bias) if (fabs(bias) > 1e-6f) { auto bias_const_node = std::make_shared(unique_op_type + "/bias"); bias_const_node->set_attr_value( - CreateTensorAndFillData(bias, scale_bias_shape)); + lite::npu::CreateTensorAndFillData(bias, scale_bias_shape)); scale_node->set_input_bias(*bias_const_node); scale_node->set_attr_has_bias_value(true); - OpList::Global().add(bias_const_node); + lite::npu::OpList::Global().add(bias_const_node); } scale_node->set_attr_axis(1); @@ -81,9 +75,10 @@ node_map_type ScaleConverter(const std::shared_ptr scale_op, return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(scale, paddle::lite::npu::bridge::ScaleConverter); +REGISTER_NPU_BRIDGE(scale, paddle::lite::kernels::npu::bridges::ScaleConverter); diff --git a/lite/backends/npu/bridge/scale_op_test.cc b/lite/kernels/npu/bridges/scale_op_test.cc similarity index 95% rename from lite/backends/npu/bridge/scale_op_test.cc rename to lite/kernels/npu/bridges/scale_op_test.cc index f4a241c8d915d39abeaaf3a84ae53e89de162210..e3a75059030e27f547456c8a3ae85fbab40eb419 100644 --- a/lite/backends/npu/bridge/scale_op_test.cc +++ b/lite/kernels/npu/bridges/scale_op_test.cc @@ -15,14 +15,15 @@ #include "lite/operators/scale_op.h" #include #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void scale_ref(const std::shared_ptr op) { Scope* scope = op->scope(); @@ -114,8 +115,9 @@ TEST(NPUBridges, scale) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/shuffle_channel_op.cc b/lite/kernels/npu/bridges/shuffle_channel_op.cc similarity index 67% rename from lite/backends/npu/bridge/shuffle_channel_op.cc rename to lite/kernels/npu/bridges/shuffle_channel_op.cc index ac4ae58d34489155d8a359ef7d5ab663ef8b239a..d1e7bc83dd90f07fd1e0f2811a1492e9bfcc0660 100644 --- a/lite/backends/npu/bridge/shuffle_channel_op.cc +++ b/lite/kernels/npu/bridges/shuffle_channel_op.cc @@ -12,20 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/shuffle_channel_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type ShuffleChannelConverter( const std::shared_ptr shuffle_channel_op, @@ -33,7 +27,7 @@ node_map_type ShuffleChannelConverter( auto scope = shuffle_channel_op->scope(); auto op_info = shuffle_channel_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; std::shared_ptr shuffle_channel_node = @@ -43,18 +37,20 @@ node_map_type ShuffleChannelConverter( shuffle_channel_node->set_input_x(*inputs_map.at(x_var_name)); shuffle_channel_node->set_attr_group(op_info->GetAttr("group")); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(shuffle_channel_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(shuffle_channel_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = shuffle_channel_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(shuffle_channel, - paddle::lite::npu::bridge::ShuffleChannelConverter); +REGISTER_NPU_BRIDGE( + shuffle_channel, + paddle::lite::kernels::npu::bridges::ShuffleChannelConverter); diff --git a/lite/backends/npu/bridge/shuffle_channel_op_test.cc b/lite/kernels/npu/bridges/shuffle_channel_op_test.cc similarity index 95% rename from lite/backends/npu/bridge/shuffle_channel_op_test.cc rename to lite/kernels/npu/bridges/shuffle_channel_op_test.cc index c37c97a3b4b85b746ea16c475646da6560919c41..cbf2eac9f3d4805e1b5bc4573189194f962c2d03 100644 --- a/lite/backends/npu/bridge/shuffle_channel_op_test.cc +++ b/lite/kernels/npu/bridges/shuffle_channel_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/shuffle_channel_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void shuffle_channel_ref( const std::shared_ptr op) { @@ -106,8 +107,9 @@ TEST(NPUBridges, softmax) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/softmax_op.cc b/lite/kernels/npu/bridges/softmax_op.cc similarity index 72% rename from lite/backends/npu/bridge/softmax_op.cc rename to lite/kernels/npu/bridges/softmax_op.cc index 6c556e6ca776e05c4f34695e825a4426ec8ca5de..24712315646d8d83349c47d415ab41cdfcadad88 100644 --- a/lite/backends/npu/bridge/softmax_op.cc +++ b/lite/kernels/npu/bridges/softmax_op.cc @@ -12,27 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/softmax_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type SoftmaxConverter(const std::shared_ptr softmax_op, const node_map_type& inputs_map) { auto scope = softmax_op->scope(); auto op_info = softmax_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; std::shared_ptr softmax_node = @@ -51,17 +45,19 @@ node_map_type SoftmaxConverter(const std::shared_ptr softmax_op, softmax_node->set_input_x(*inputs_map.at(x_var_name)); softmax_node->set_attr_axis(axis); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(softmax_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(softmax_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = softmax_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(softmax, paddle::lite::npu::bridge::SoftmaxConverter); +REGISTER_NPU_BRIDGE(softmax, + paddle::lite::kernels::npu::bridges::SoftmaxConverter); diff --git a/lite/backends/npu/bridge/softmax_op_test.cc b/lite/kernels/npu/bridges/softmax_op_test.cc similarity index 95% rename from lite/backends/npu/bridge/softmax_op_test.cc rename to lite/kernels/npu/bridges/softmax_op_test.cc index c3114f5360fa24b7694c946323d3272eefc46a31..50415c4b965215c34ebd73f7ec6b11abc4dee2dd 100644 --- a/lite/backends/npu/bridge/softmax_op_test.cc +++ b/lite/kernels/npu/bridges/softmax_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/softmax_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void softmax_ref(const std::shared_ptr op) { @@ -125,8 +126,9 @@ TEST(NPUBridges, softmax) { } } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/split_op.cc b/lite/kernels/npu/bridges/split_op.cc similarity index 75% rename from lite/backends/npu/bridge/split_op.cc rename to lite/kernels/npu/bridges/split_op.cc index 86de45fedfbaaf9380857f26d507d20142a57676..0caa51c53035ef46b0f29be5a3047860c900a403 100644 --- a/lite/backends/npu/bridge/split_op.cc +++ b/lite/kernels/npu/bridges/split_op.cc @@ -12,27 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/split_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" -#include "lite/backends/npu/npu_helper.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { + node_map_type SplitConverter(const std::shared_ptr split_op, const node_map_type& inputs_map) { lite::Scope* scope = split_op->scope(); const lite::OpInfo* op_info = split_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " << op_type << " ... "; auto x_var_name = op_info->Input("X").front(); @@ -45,7 +39,7 @@ node_map_type SplitConverter(const std::shared_ptr split_op, std::make_shared(unique_op_type); CHECK(inputs_map.count(x_var_name)); output_node->set_input_x(*inputs_map.at(x_var_name)); - OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); output_node->set_attr_axis(static_cast(axis)); if (num > 0) { @@ -63,24 +57,25 @@ node_map_type SplitConverter(const std::shared_ptr split_op, for (auto out_var_name : out_var_names) { auto const_node = std::make_shared( unique_op_type + "/const_zero" + std::to_string(index)); - const_node->set_attr_value(CreateTensorAndFillData(0)); - OpList::Global().add(const_node); + const_node->set_attr_value(lite::npu::CreateTensorAndFillData(0)); + lite::npu::OpList::Global().add(const_node); auto add_node = std::make_shared(unique_op_type + "/add" + std::to_string(index)); add_node->set_input_x1(*output_node, "y" + std::to_string(index)); add_node->set_input_x2(*const_node); outputs_map[out_var_name] = add_node; - OpList::Global().add(add_node); + lite::npu::OpList::Global().add(add_node); index++; } - OpList::Global().add(output_node); + lite::npu::OpList::Global().add(output_node); return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(split, paddle::lite::npu::bridge::SplitConverter); +REGISTER_NPU_BRIDGE(split, paddle::lite::kernels::npu::bridges::SplitConverter); diff --git a/lite/backends/npu/bridge/split_op_test.cc b/lite/kernels/npu/bridges/split_op_test.cc similarity index 95% rename from lite/backends/npu/bridge/split_op_test.cc rename to lite/kernels/npu/bridges/split_op_test.cc index 91629a70fc47b5cd89f11943d44cd0c4cbd67af7..9bbac09a986ec81593b5a46ca3096ec7b192025a 100644 --- a/lite/backends/npu/bridge/split_op_test.cc +++ b/lite/kernels/npu/bridges/split_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/split_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template void split_ref(const std::shared_ptr op) { @@ -99,7 +100,7 @@ void test_split(int bs, int axis, int num, std::vector sections) { - const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& supported_lists = bridges.AllFunctions(); CHECK(bridges.HasType("split")); // prepare input&output variables @@ -161,8 +162,9 @@ TEST(NPUBridges, split) { test_split(4, 2, 3, 6, 3, 0, {5, 1}); } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/test_helper.cc b/lite/kernels/npu/bridges/test_helper.cc similarity index 80% rename from lite/backends/npu/bridge/test_helper.cc rename to lite/kernels/npu/bridges/test_helper.cc index 3d6dc034816a7d37c28829e6b84573f852d5c935..b410a4190d86f2ddf020e7f223787acc0108a398 100644 --- a/lite/backends/npu/bridge/test_helper.cc +++ b/lite/kernels/npu/bridges/test_helper.cc @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/npu/bridge/test_helper.h" +#include "lite/kernels/npu/bridges/test_helper.h" #include -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" #include "lite/operators/graph_op.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { void LauchOp(const std::shared_ptr op, const std::vector& input_var_names, @@ -32,7 +32,7 @@ void LauchOp(const std::shared_ptr op, auto op_type = op->op_info()->Type(); // convert op to IR graph - const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& supported_lists = bridges.AllFunctions(); CHECK(bridges.HasType(op_type)); @@ -43,7 +43,7 @@ void LauchOp(const std::shared_ptr op, ge::Shape(input->dims().Vectorize()), ge::FORMAT_NCHW, ge::DT_FLOAT); auto input_node = std::make_shared(input_var_name); input_node->update_input_desc_x(input_desc); - npu::OpList::Global().add(input_node); + lite::npu::OpList::Global().add(input_node); inputs_map[input_var_name] = input_node; } auto outputs_map = supported_lists.at(op_type)(op, inputs_map); @@ -58,15 +58,20 @@ void LauchOp(const std::shared_ptr op, for (auto output_var_name : output_var_names) { graph_outputs.push_back(*outputs_map[output_var_name]); } - std::string model_name(UniqueName("test_" + op_type) + ".om"); - CHECK(npu::BuildNPUClient(graph_inputs, graph_outputs, model_name)); + std::string weight_var_name = "weight"; + auto weight = scope->Var(weight_var_name)->GetMutable(); + weight->set_persistable(true); + weight->set_precision(PRECISION(kInt8)); + CHECK(lite::npu::BuildModel(graph_inputs, graph_outputs, weight)); + CHECK_GT(weight->numel(), 0); + CHECK_NE(weight->data(), 0); // create graph op and set inputs and outputs cpp::OpDesc graph_op_desc; graph_op_desc.SetType("graph_op"); graph_op_desc.SetInput("Inputs", input_var_names); + graph_op_desc.SetInput("Weight", {weight_var_name}); graph_op_desc.SetOutput("Outputs", output_var_names); - graph_op_desc.SetAttr("model_name", model_name); auto graph_op = std::make_shared(graph_op_desc.Type()); @@ -88,12 +93,12 @@ void LauchOp(const std::shared_ptr op, graph_kernel->Launch(); // release all of resources of generated model - npu::OpList::Global().clear(); - npu::DeviceInfo::Global().Clear(); + lite::npu::OpList::Global().clear(); } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/test_helper.h b/lite/kernels/npu/bridges/test_helper.h similarity index 95% rename from lite/backends/npu/bridge/test_helper.h rename to lite/kernels/npu/bridges/test_helper.h index 537f7376409b54b441e226a6824013cdff735000..4fe22ba28b8f4d7af32518c8a25739903f18c4d1 100644 --- a/lite/backends/npu/bridge/test_helper.h +++ b/lite/kernels/npu/bridges/test_helper.h @@ -22,8 +22,9 @@ namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { template std::shared_ptr CreateOp(const cpp::OpDesc& opdesc, lite::Scope* scope) { @@ -58,7 +59,8 @@ void LauchOp(const std::shared_ptr op, const std::vector& input_var_names, const std::vector& output_var_names); -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/bridge/transpose_op.cc b/lite/kernels/npu/bridges/transpose_op.cc similarity index 70% rename from lite/backends/npu/bridge/transpose_op.cc rename to lite/kernels/npu/bridges/transpose_op.cc index ad00e599ce77d7727c692db157db64d17cd13a5c..5e9a69837b9e253845e6a1df35a897cfe342a84e 100644 --- a/lite/backends/npu/bridge/transpose_op.cc +++ b/lite/kernels/npu/bridges/transpose_op.cc @@ -12,20 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/transpose_op.h" -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/utils.h" +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { node_map_type TransposeConverter( const std::shared_ptr transpose_op, @@ -33,7 +27,7 @@ node_map_type TransposeConverter( auto scope = transpose_op->scope(); auto op_info = transpose_op->op_info(); auto op_type = op_info->Type(); - auto unique_op_type = UniqueName(op_type); + auto unique_op_type = lite::npu::UniqueName(op_type); LOG(INFO) << "Converting " + op_type + "..."; std::shared_ptr transpose_node = @@ -50,8 +44,8 @@ node_map_type TransposeConverter( w_data[i] = 1.f; } auto npu_w = std::make_shared(w_var_name); - npu_w->set_attr_value(CvtFromLiteTensor(w)); - OpList::Global().add(npu_w); + npu_w->set_attr_value(lite::npu::CvtFromLiteTensor(w)); + lite::npu::OpList::Global().add(npu_w); auto axis = op_info->GetAttr>("axis"); auto npu_axis = ge::AttrValue::LIST_INT(axis.begin(), axis.end()); @@ -61,18 +55,21 @@ node_map_type TransposeConverter( transpose_node->set_input_w(*npu_w); transpose_node->set_attr_order(npu_axis); - OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(transpose_node); + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(transpose_node); node_map_type outputs_map; outputs_map[op_info->Output("Out").front()] = transpose_node; return outputs_map; } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle -REGISTER_NPU_BRIDGE(transpose, paddle::lite::npu::bridge::TransposeConverter); -REGISTER_NPU_BRIDGE(transpose2, paddle::lite::npu::bridge::TransposeConverter); +REGISTER_NPU_BRIDGE(transpose, + paddle::lite::kernels::npu::bridges::TransposeConverter); +REGISTER_NPU_BRIDGE(transpose2, + paddle::lite::kernels::npu::bridges::TransposeConverter); diff --git a/lite/backends/npu/bridge/transpose_op_test.cc b/lite/kernels/npu/bridges/transpose_op_test.cc similarity index 96% rename from lite/backends/npu/bridge/transpose_op_test.cc rename to lite/kernels/npu/bridges/transpose_op_test.cc index 9bbfb11123fc3148968049d3b35faa308d7efcc0..9ad2610caa4f1674c1a07afd62a4b85361ec6645 100644 --- a/lite/backends/npu/bridge/transpose_op_test.cc +++ b/lite/kernels/npu/bridges/transpose_op_test.cc @@ -14,14 +14,15 @@ #include "lite/operators/transpose_op.h" #include -#include "lite/backends/npu/bridge/registry.h" -#include "lite/backends/npu/bridge/test_helper.h" #include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" namespace paddle { namespace lite { +namespace kernels { namespace npu { -namespace bridge { +namespace bridges { int data_index(std::vector pos, DDimLite dims) { int d1 = dims[1]; @@ -139,8 +140,9 @@ TEST(NPUBridges, transpose) { // test_transpose(1, 1, 1, 2, std::vector{0,1,2,3}); } -} // namespace bridge +} // namespace bridges } // namespace npu +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/npu/graph_compute.cc b/lite/kernels/npu/graph_compute.cc index 9f0f557f5cd9038cdf3ee6029129fbe069ef9674..f2b42c658d11edfed65eea2af48a3c0202ba3114 100644 --- a/lite/kernels/npu/graph_compute.cc +++ b/lite/kernels/npu/graph_compute.cc @@ -30,10 +30,16 @@ void GraphCompute::PrepareForRun() { auto& ctx = this->ctx_->template As(); auto& param = this->Param(); - exec_ = ctx.client(param.model_name); - CHECK(exec_); + CHECK(param.weight); + CHECK(lite::npu::LoadModel(*param.weight, &model_client_, &model_name_)); + // TODO(hong19860320): find an good way to free the model data. + // No interface exists to free the data of tensor, so I resize the dim to 1 + // and change target to force it to realloc a small size memory. + param.weight->Resize({1}); + param.weight->mutable_data(TargetType::kARM); + CHECK(model_client_); int ret = - exec_->GetModelIOTensorDim(param.model_name, npu_idims_, npu_odims_); + model_client_->GetModelIOTensorDim(model_name_, npu_idims_, npu_odims_); CHECK_EQ(ret, hiai::AI_SUCCESS) << "[NPU] Get dims failed."; npu_itensors_.resize(npu_idims_.size()); @@ -43,8 +49,8 @@ void GraphCompute::PrepareForRun() { VLOG(3) << "npu_idims[" << i << "]: " << npu_idims_[i].GetNumber() << "," << npu_idims_[i].GetChannel() << "," << npu_idims_[i].GetHeight() << "," << npu_idims_[i].GetWidth(); - VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i]->dims(); - CHECK_EQ(param.inputs[i]->dims().production(), + VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i].second->dims(); + CHECK_EQ(param.inputs[i].second->dims().production(), npu_idims_[i].GetNumber() * npu_idims_[i].GetChannel() * npu_idims_[i].GetHeight() * npu_idims_[i].GetWidth()); npu_itensors_[i].reset(new hiai::AiTensor); @@ -55,16 +61,16 @@ void GraphCompute::PrepareForRun() { VLOG(3) << "npu_odims[" << i << "]: " << npu_odims_[i].GetNumber() << "," << npu_odims_[i].GetChannel() << "," << npu_odims_[i].GetHeight() << "," << npu_odims_[i].GetWidth(); - VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i]->dims(); + VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i].second->dims(); auto out_size = npu_odims_[i].GetNumber() * npu_odims_[i].GetChannel() * npu_odims_[i].GetHeight() * npu_odims_[i].GetWidth(); - if (param.outputs[i]->dims().production() != out_size) { - param.outputs[i]->Resize({npu_odims_[i].GetNumber(), - npu_odims_[i].GetChannel(), - npu_odims_[i].GetHeight(), - npu_odims_[i].GetWidth()}); + if (param.outputs[i].second->dims().production() != out_size) { + param.outputs[i].second->Resize({npu_odims_[i].GetNumber(), + npu_odims_[i].GetChannel(), + npu_odims_[i].GetHeight(), + npu_odims_[i].GetWidth()}); } - LOG(INFO) << param.outputs[i]->dims(); + LOG(INFO) << param.outputs[i].second->dims(); npu_otensors_[i].reset(new hiai::AiTensor); npu_otensors_[i]->Init(&(npu_odims_[i])); } @@ -74,7 +80,7 @@ bool GraphCompute::input_dims_changed() const { auto& param = this->Param(); CHECK_EQ(param.inputs.size(), npu_idims_.size()); for (size_t i = 0; i < param.inputs.size(); ++i) { - auto param_idims = param.inputs[i]->dims(); + auto param_idims = param.inputs[i].second->dims(); CHECK(!param_idims.empty()); CHECK_EQ(param_idims.size(), 4); std::vector idims{static_cast(npu_idims_[i].GetNumber()), @@ -99,7 +105,7 @@ void GraphCompute::Run() { CHECK_EQ(param.outputs.size(), npu_otensors_.size()); for (size_t i = 0; i < param.inputs.size(); ++i) { - auto* itensor = param.inputs[i]; + auto* itensor = param.inputs[i].second; CHECK(itensor); const auto* i_data = itensor->data(); std::memcpy( @@ -108,7 +114,7 @@ void GraphCompute::Run() { sizeof(float) * static_cast(itensor->dims().production())); } std::string key = "model_name"; // Note: key seems must be model_name - npu_context_.AddPara(key, param.model_name); + model_context_.AddPara(key, model_name_); auto GetCurrentUS = []() -> double { struct timeval time; @@ -117,13 +123,13 @@ void GraphCompute::Run() { }; int istamp; auto start_time = GetCurrentUS(); - CHECK_EQ( - hiai::AI_SUCCESS, - exec_->Process(npu_context_, npu_itensors_, npu_otensors_, 1000, istamp)); - LOG(INFO) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us"; + CHECK_EQ(hiai::AI_SUCCESS, + model_client_->Process( + model_context_, npu_itensors_, npu_otensors_, 1000, istamp)); + VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us"; for (size_t i = 0; i < param.outputs.size(); ++i) { - auto* otensor = param.outputs[i]; + auto* otensor = param.outputs[i].second; CHECK(otensor); auto* o_data = otensor->mutable_data(); auto* npu_obuffer = static_cast(npu_otensors_[i]->GetBuffer()); @@ -147,5 +153,6 @@ REGISTER_LITE_KERNEL(graph_op, paddle::lite::kernels::npu::GraphCompute, def) .BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/lite/kernels/npu/graph_compute.h b/lite/kernels/npu/graph_compute.h index 908dbc55dd0369b606b96cf5e2b924c2e0957839..f4aac57506bbaf93eab85de47990c8ad486ccfec 100644 --- a/lite/kernels/npu/graph_compute.h +++ b/lite/kernels/npu/graph_compute.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "ai_ddk_lib/include/HiAiModelManagerService.h" #include "lite/core/kernel.h" @@ -39,15 +40,15 @@ class GraphCompute : public KernelLite { bool input_dims_changed() const; private: - hiai::AiModelMngerClient* exec_; + std::shared_ptr model_client_; + std::string model_name_; + hiai::AiContext model_context_; + std::vector npu_idims_; std::vector npu_odims_; std::vector> npu_itensors_; std::vector> npu_otensors_; - - // TODO(TJ): find better place - hiai::AiContext npu_context_; }; } // namespace npu diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 65145c40b818a3ad0e43b30ab7758bce2dc18a65..d070eb84c5313e7539f28da0a90dcc3662be01a1 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -15,6 +15,7 @@ add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${te add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context @@ -28,17 +29,19 @@ lite_cc_test(test_fc_opencl SRCS fc_compute_test.cc DEPS fc_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_mul_opencl SRCS mul_compute_test.cc - DEPS mul_opencl op_registry program context - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +# TODO(ysh329): comment for buffer-impl mul +#lite_cc_test(test_mul_opencl SRCS mul_compute_test.cc +# DEPS mul_opencl op_registry program context +# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_io_copy_compute_opencl SRCS io_copy_compute_test.cc DEPS io_copy_compute_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_relu_opencl SRCS relu_compute_test.cc - DEPS relu_opencl op_registry program context - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +#TODO(ysh329): comment buffer-impl relu +#lite_cc_test(test_relu_opencl SRCS relu_compute_test.cc +# DEPS relu_opencl op_registry program context +# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc DEPS depthwise_conv2d_opencl op_registry program context @@ -47,3 +50,7 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc DEPS conv_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + +lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc + DEPS layout_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/image_helper.h b/lite/kernels/opencl/image_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..d164f1ef777a02e5fd3bd33f5cab117de17834b8 --- /dev/null +++ b/lite/kernels/opencl/image_helper.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/tensor.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +static std::map InitImageDimInfoWith( + const DDim& tensor_dim) { + size_t new_dims[] = {1, 1, 1, 1}; + for (size_t j = 0; j < tensor_dim.size(); ++j) { + new_dims[4 - tensor_dim.size() + j] = tensor_dim[j]; + } + size_t N, C, H, W; + N = new_dims[0]; + C = new_dims[1]; + H = new_dims[2]; + W = new_dims[3]; + size_t width = W * ((C + 3) / 4); + size_t height = H * N; + return std::map({{"width", width}, {"height", height}}); +} + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/opencl/io_copy_compute.cc b/lite/kernels/opencl/io_copy_compute.cc index 1d43f7d97eef212393e316cbd12d8115bb773cdb..dc4bdfe64c65f21e8f68a26df3e2962087f50bef 100644 --- a/lite/kernels/opencl/io_copy_compute.cc +++ b/lite/kernels/opencl/io_copy_compute.cc @@ -42,7 +42,16 @@ class IoCopyHostToOpenCLCompute CHECK(param.x->target() == TARGET(kHost) || param.x->target() == TARGET(kARM)); auto mem_size = param.x->memory_size(); + VLOG(4) << "copy size " << mem_size; + VLOG(4) << "param.x->dims().size():" << param.x->dims().size(); + VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " " + << param.x->dims()[1] << " " << param.x->dims()[2] << " " + << param.x->dims()[3]; + VLOG(4) << "param.y->dims().size():" << param.y->dims().size(); + VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " " + << param.y->dims()[1] << " " << param.y->dims()[2] << " " + << param.y->dims()[3]; auto* data = param.y->mutable_data(TARGET(kOpenCL), mem_size); CopyFromHostSync(data, param.x->raw_data(), mem_size); } @@ -81,10 +90,21 @@ class IoCopykOpenCLToHostCompute CHECK(param.x->target() == TARGET(kOpenCL)); auto mem_size = param.x->memory_size(); VLOG(4) << "copy size " << mem_size; + VLOG(4) << "param.x->dims().size():" << param.x->dims().size(); + VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " " + << param.x->dims()[1] << " " << param.x->dims()[2] << " " + << param.x->dims()[3]; + VLOG(4) << "param.y->dims().size():" << param.y->dims().size(); + VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " " + << param.y->dims()[1] << " " << param.y->dims()[2] << " " + << param.y->dims()[3]; auto* data = param.y->mutable_data(TARGET(kHost), mem_size); auto& context = ctx_->As(); auto* wait_list = context.cl_wait_list(); auto* x_ptr = param.x->data(); + + /* TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + in kernel and enable wait_list auto it = wait_list->find(x_ptr); if (it != wait_list->end()) { VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; @@ -93,6 +113,8 @@ class IoCopykOpenCLToHostCompute } else { LOG(FATAL) << "Could not find the sync event for the target cl tensor."; } + */ + CopyToHostSync(data, param.x->raw_data(), mem_size); } diff --git a/lite/kernels/opencl/layout_compute.cc b/lite/kernels/opencl/layout_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2869457fc8dabdfb39d3d447404c0a6f6f77375 --- /dev/null +++ b/lite/kernels/opencl/layout_compute.cc @@ -0,0 +1,295 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/type_system.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class LayoutComputeBufferChwToImage2DHwc + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + } + + void Run() override { + auto& param = Param(); + auto* x_data = param.x->data(); + auto x_dims = param.x->dims(); + auto image_shape = InitImageDimInfoWith(x_dims); + auto* y_data = param.y->mutable_data( + image_shape["width"], image_shape["height"]); + auto y_dims = param.y->dims(); + + // out info + std::vector new_dims = {1, 1, 1, 1}; + for (int tidx = 0; tidx < x_dims.size(); ++tidx) { + new_dims[4 - x_dims.size() + tidx] = x_dims[tidx]; + } + const int out_C = new_dims[1]; + const int out_H = new_dims[2]; + const int out_W = new_dims[3]; + const int Stride2 = out_C * out_H * out_W; + const int Stride1 = out_H * out_W; + const int Stride0 = out_W; + + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " + << new_dims[1] << " " << new_dims[2] << " " << new_dims[3]; + VLOG(4) << "out_C:" << out_C; + VLOG(4) << "out_H:" << out_H; + VLOG(4) << "out_W:" << out_W; + VLOG(4) << "Stride2:" << Stride2; + VLOG(4) << "Stride1:" << Stride1; + VLOG(4) << "Stride0:" << Stride0; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_H)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_W)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_C)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride0)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride1)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride2)); + CL_CHECK_FATAL(status); + + VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] + << " " << (new_dims[0] * new_dims[2]); + auto global_work_size = + cl::NDRange{static_cast((new_dims[1] + 3) / 4), + static_cast(new_dims[3]), + static_cast(new_dims[0] * new_dims[2])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(y_data, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + std::string doc() const override { + return "Trans Layout from cl::Buffer(NCHW) to cl::Image2D(RGBA)"; + } + + private: + std::string kernel_func_name_{"buffer_to_image2d"}; + std::string build_options_{"-DCL_DTYPE_float "}; + std::shared_ptr event_{new cl::Event}; +}; + +class LayoutComputeImage2DHwcToBufferChw + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + } + + void Run() override { + auto& param = Param(); + auto* y_data = param.y->mutable_data(TARGET(kOpenCL)); + auto y_dims = param.y->dims(); + auto* x_data = param.x->data(); + auto x_dims = param.x->dims(); + + std::vector new_dims = {1, 1, 1, 1}; + for (int j = 0; j < x_dims.size(); ++j) { + new_dims[4 - x_dims.size() + j] = x_dims[j]; + } + + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " + << new_dims[1] << " " << new_dims[2] << " " << new_dims[3]; + + size_t C = new_dims[1]; + size_t in_height = new_dims[2]; + size_t in_width = new_dims[3]; + int size_ch = in_height * in_width; + int size_block = size_ch * 4; + int size_batch = size_ch * C; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(in_width)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(in_height)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size_ch)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size_ch)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size_batch)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(C)); + CL_CHECK_FATAL(status); + VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] + << " " << (new_dims[0] * new_dims[2]); + auto global_work_size = + cl::NDRange{static_cast((new_dims[1] + 3) / 4), + static_cast(new_dims[3]), + static_cast(new_dims[0] * new_dims[2])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(y_data, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + std::string doc() const override { + return "Trans Layout from cl::Image2D(RGBA) to cl::Buffer(NCHW)"; + } + + private: + std::string kernel_func_name_{"image2d_to_buffer"}; + std::string build_options_{"-DCL_DTYPE_float"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +// BufferChwToImage2DHwc +// [chw] -> [hwc] +REGISTER_LITE_KERNEL( + layout, + kOpenCL, + kAny, + kNHWC, + paddle::lite::kernels::opencl::LayoutComputeBufferChwToImage2DHwc, + buffer_chw_to_image2d_hwc_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNHWC))}) + .Finalize(); + +// [chw] -> [hwc] +REGISTER_LITE_KERNEL( + layout_once, + kOpenCL, + kAny, + kNHWC, + paddle::lite::kernels::opencl::LayoutComputeBufferChwToImage2DHwc, + buffer_chw_to_image2d_hwc_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNHWC))}) + .Finalize(); + +// Image2DHwcBufferChw +// [hwc] -> [chw] +REGISTER_LITE_KERNEL( + layout, + kOpenCL, + kAny, + kNCHW, + paddle::lite::kernels::opencl::LayoutComputeImage2DHwcToBufferChw, + image2d_hwc_to_buffer_chw_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNCHW))}) + .Finalize(); + +// [hwc] -> [chw] +REGISTER_LITE_KERNEL( + layout_once, + kOpenCL, + kAny, + kNCHW, + paddle::lite::kernels::opencl::LayoutComputeImage2DHwcToBufferChw, + image2d_hwc_to_buffer_chw_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/opencl/layout_compute_test.cc b/lite/kernels/opencl/layout_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e8dd78f616d4d1e3fabf51ba8d3ddf43dd561f1 --- /dev/null +++ b/lite/kernels/opencl/layout_compute_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/image_helper.h" + +namespace paddle { +namespace lite { + +// #define LOOP_TEST +// #define PRINT_RESULT +TEST(layout, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img) -> layout(img2buf) " + "-> device"; + +#ifdef LOOP_TEST + for (int n = 1; n <= 100; n += 21) { + for (auto c : {1, 3}) { + for (int h = 1; h <= 100; h += 13) { + for (int w = 1; w <= 100; w += 17) { +#else + const int n = 1; + const int c = 1; + const int h = 1; + const int w = 100; +#endif // LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " " + << h << " " << w << " ========"; + // set layout kernels + auto buf_to_img_kernels = KernelRegistry::Global().Create( + "layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNHWC)); + auto img_to_buf_kernels = KernelRegistry::Global().Create( + "layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(buf_to_img_kernels.empty()); + + auto buf_to_img_kernel = std::move(buf_to_img_kernels.front()); + auto img_to_buf_kernel = std::move(img_to_buf_kernels.front()); + LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc(); + LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc(); + + // set tensors about op param + LOG(INFO) << "set tensors about op param"; + lite::Tensor x, y_image, y; + operators::LayoutParam BufferToImageParam; + operators::LayoutParam ImageToBufferParam; + BufferToImageParam.x = &x; + BufferToImageParam.y = &y_image; + ImageToBufferParam.x = &y_image; + ImageToBufferParam.y = &y; + + const DDim x_dim = DDim(std::vector{n, c, h, w}); + x.Resize(x_dim); + y_image.Resize(x_dim); // useless for image2D + y.Resize(x_dim); + + // initialize tensors + LOG(INFO) << "initialize tensors"; + auto* x_data = x.mutable_data(TARGET(kOpenCL)); + auto* y_data = y.mutable_data(TARGET(kOpenCL)); + auto image_shape = + paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim); + auto* y_image_data = y_image.mutable_data( + image_shape["width"], image_shape["height"]); + auto* mapped_x = static_cast(TargetWrapperCL::Map( + x_data, 0, sizeof(float) * x_dim.production())); + auto* mapped_y = static_cast(TargetWrapperCL::Map( + y_data, 0, sizeof(float) * x_dim.production())); + for (int i = 0; i < x_dim.production(); ++i) { + mapped_x[i] = static_cast(i); + mapped_y[i] = static_cast(0); + } + + // set context and kernel args + LOG(INFO) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + buf_to_img_kernel->SetParam(BufferToImageParam); + std::unique_ptr buf_to_img_context(new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context->As())); + buf_to_img_kernel->SetContext(std::move(buf_to_img_context)); + + img_to_buf_kernel->SetParam(ImageToBufferParam); + std::unique_ptr img_to_buf_context(new KernelContext); + context->As().CopySharedTo( + &(img_to_buf_context->As())); + img_to_buf_kernel->SetContext(std::move(img_to_buf_context)); + + // run kernels + LOG(INFO) << "run kernel: buf_to_img_kernel"; + buf_to_img_kernel->Launch(); + LOG(INFO) << "run kernel: img_to_buf_kernel"; + img_to_buf_kernel->Launch(); + +// result +#ifdef PRINT_RESULT + LOG(INFO) << "---- print result ----"; + for (int eidx = 0; i < x_dim.production(); ++eidx) { + std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] + << std::endl; + } +#endif // PRINT_RESULT + + // check result: compare input and output + for (int eidx = 0; eidx < x_dim.production(); eidx++) { + EXPECT_NEAR(mapped_x[eidx], mapped_y[eidx], 1e-6); + if (abs(mapped_x[eidx] - mapped_y[eidx]) > 1e-6) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx + << " / " << x_dim.production() << ", mapped_x[" << eidx + << "]:" << mapped_x[eidx] << ", mapped_y[" << eidx + << "]:" << mapped_y[eidx]; + break; + } + } + + // free + LOG(INFO) << "free: unmap x, y"; + TargetWrapperCL::Unmap(x_data, mapped_x); + TargetWrapperCL::Unmap(y_data, mapped_y); +#ifdef LOOP_TEST + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL( + layout, kOpenCL, kAny, kNHWC, buffer_chw_to_image2d_hwc_opencl_fp32); +USE_LITE_KERNEL( + layout, kOpenCL, kAny, kNCHW, image2d_hwc_to_buffer_chw_opencl_fp32); diff --git a/lite/kernels/opencl/relu_compute.cc b/lite/kernels/opencl/relu_compute.cc index 93d1dec6743b6835b7a955994d4ebaeeef081597..c7b89c939b0bf571f27ac1dfdd272a9324f8e89f 100644 --- a/lite/kernels/opencl/relu_compute.cc +++ b/lite/kernels/opencl/relu_compute.cc @@ -15,6 +15,7 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/replace_stl/stream.h" @@ -75,17 +76,96 @@ class ReluCompute std::shared_ptr event_{new cl::Event}; }; +class ReluComputeFloatImage + : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/relu_kernel.cl", build_options_); + } + + void Run() override { + auto& param = *param_.get_mutable(); + const auto& x_dims = param.X->dims(); + auto* x_buf = param.X->data(); + auto image_shape = InitImageDimInfoWith(x_dims); + auto* out_buf = param.Out->mutable_data( + image_shape["width"], image_shape["height"]); + const auto& y_dims = param.Out->dims(); // useless: check dim only + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + + VLOG(4) << TargetToStr(param.X->target()); + VLOG(4) << TargetToStr(param.Out->target()); + VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " + << image_shape["height"]; + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + + auto global_work_size = + cl::NDRange{static_cast(image_shape["width"]), + static_cast(image_shape["height"])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(out_buf, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + private: + std::string kernel_func_name_{"relu"}; + std::string build_options_{"-DCL_DTYPE_float -DRELU"}; + std::shared_ptr event_{new cl::Event}; +}; + } // namespace opencl } // namespace kernels } // namespace lite } // namespace paddle +// REGISTER_LITE_KERNEL(relu, +// kOpenCL, +// kFloat, +// kNCHW, +// paddle::lite::kernels::opencl::ReluCompute, +// def) +// .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .Finalize(); + REGISTER_LITE_KERNEL(relu, kOpenCL, kFloat, - kNCHW, - paddle::lite::kernels::opencl::ReluCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + kNHWC, + paddle::lite::kernels::opencl::ReluComputeFloatImage, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) .Finalize(); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 7080cc8c554da5698f4462302f2fcf4f94db6649..da955e4fd5902373cd881f85a8bc715eef7cec94 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -1,33 +1,69 @@ -if(NOT LITE_WITH_X86) - return() -endif() - -# lite_cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) +add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops math_function) # lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) -# lite_cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps}) -# lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) +add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) +add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) # lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op) # lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) # lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) -# lite_cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) # lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) -# lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) +add_kernel(pool_compute_x86 X86 basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) +add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function) +# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} ) +add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute) +#add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) -# lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) -# lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) -# lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) -# lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) -# lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS relu_compute_x86) -# lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86 operator) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) +add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} blas) +add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) +add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) +add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) + +if(NOT LITE_WITH_X86) + return() +endif() +add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas) + +lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) +lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) +lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) +lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) +lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86) +lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86) +lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) +lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) +lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) +lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) +lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) +lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) +lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) +lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86) +lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) + +lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) +lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) +lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) diff --git a/lite/kernels/x86/activation_compute.cc b/lite/kernels/x86/activation_compute.cc index 94d877de288c47bd79c8c8713a8e7a5de5179472..f2f911dd7d037a3f4e0f28592cff07383c8a49b6 100644 --- a/lite/kernels/x86/activation_compute.cc +++ b/lite/kernels/x86/activation_compute.cc @@ -12,116 +12,58 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/core/kernel.h" -#include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/activation_op.h" +#include "lite/kernels/x86/activation_compute.h" -namespace paddle { -namespace lite { -namespace kernels { -namespace x86 { - -template -void Activate(const platform::CPUDeviceContext& context, - const framework::LoDTensor* X, - framework::LoDTensor* Out) { - using T = typename Functor::ELEMENT_TYPE; - auto* place = context.eigen_device(); - auto x = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(X)); - auto out = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(Out)); - Functor()(*place, x, out); -} - -template -void ActivateGrad(const platform::CPUDeviceContext& context, - const framework::LoDTensor* X, - const framework::LoDTensor* Out, - const framework::LoDTensor* Out_grad, - framework::LoDTensor* X_grad) { - using T = typename Functor::ELEMENT_TYPE; - auto* place = context.eigen_device(); - auto x = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(X)); - auto out = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(Out)); - auto x_grad = framework::EigenVector::Flatten( - paddle::operators::detail::Ref(X_grad)); - auto out_grad = framework::EigenVector::Flatten( - paddle::operators::detail::Ref(Out_grad)); - Functor()(*place, x, out, out_grad, x_grad); -} - -template -class SquareCompute : public KernelLite { - public: - using param_t = operators::ActivationParam; - - void Run() override { - auto& context = ctx_->As(); - auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context()); - - param.Out->template mutable_data(); - Activate>(*context.x86_device_context(), - ¶m.X->raw_tensor(), - ¶m.Out->raw_tensor()); - } - - virtual ~SquareCompute() = default; -}; - -template -class SquareGradCompute : public KernelLite { - public: - using param_t = operators::ActivationGradParam; - - void Run() override { - auto& context = ctx_->As(); - auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context()); - param.X_grad->template mutable_data(); - - ActivateGrad>( - *context.x86_device_context(), - ¶m.X->raw_tensor(), - ¶m.Out->raw_tensor(), - ¶m.Out_grad->raw_tensor(), - ¶m.X_grad->raw_tensor()); - } +// float +REGISTER_LITE_KERNEL(square, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SquareCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); - virtual ~SquareGradCompute() = default; -}; +// float +REGISTER_LITE_KERNEL(relu, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ReluCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); -} // namespace x86 -} // namespace kernels -} // namespace lite -} // namespace paddle +// float +REGISTER_LITE_KERNEL(tanh, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::TanhCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); // float -REGISTER_LITE_KERNEL(square, +REGISTER_LITE_KERNEL(gelu, kX86, kFloat, kNCHW, - paddle::lite::kernels::x86::SquareCompute, + paddle::lite::kernels::x86::GeluCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); -REGISTER_LITE_KERNEL(square_grad, +REGISTER_LITE_KERNEL(softsign, kX86, kFloat, kNCHW, - paddle::lite::kernels::x86::SquareGradCompute, + paddle::lite::kernels::x86::SoftsignCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput(paddle::framework::GradVarName("Out"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput(paddle::framework::GradVarName("X"), - {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/lite/kernels/x86/activation_compute.h b/lite/kernels/x86/activation_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..14d0ffe000311c87dac513a65f731e9654042db2 --- /dev/null +++ b/lite/kernels/x86/activation_compute.h @@ -0,0 +1,218 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" +#include "lite/operators/activation_ops.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +enum ActBwdOpFwdDeps { + kNoDeps = 0x00, // Do not need any forward input/output + kDepX = 0x01, // Only need forward input X + kDepOut = 0x02, // Only need forward output Out + + // Never add kDepXOut, because Out can be always calculated + // by forward input X in backward part. + // FIXME(zjl): but in MKLDNN abs, X and Out are all needed... + // Developers should not rely on this enum value! + kDepXOut = 0x03 +}; + +template +struct BaseActivationFunctor { + using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } + + /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient. + For example, sigmoid op's gradient didn't involve x, so its output can + reuse + input memory. But abs op's gradient use x, it can not be inplaced. + gradient did use x. + */ + bool Inplace() const { return false; } +}; + +template +bool Activate(const lite::Tensor* X, lite::Tensor* Out) { + using T = typename Functor::ELEMENT_TYPE; + auto place = lite::fluid::EigenDeviceType(); + CHECK_OR_FALSE(X) + CHECK_OR_FALSE(Out) + auto x = lite::fluid::EigenVector::Flatten(*X); + auto out = lite::fluid::EigenVector::Flatten(*Out); + Functor()(place, x, out); + return true; +} + +// square(x) = x^2 +template +struct SquareFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.square(); + } +}; + +template +class SquareCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~SquareCompute() = default; +}; + +// relu(x) = max(x, 0) +template +struct ReluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)); + } +}; + +template +class ReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~ReluCompute() = default; +}; + +// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct TanhFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.tanh(); + } +}; + +template +class TanhCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~TanhCompute() = default; +}; + +// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) +template +struct GeluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { +// Because the execute or device context can not be deliver here, it keep the +// marco for NVCC. +#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + auto x_data = x.data(); + auto out_data = out.data(); + int n = std::min(x.size(), out.size()); + + std::memset(out_data, 0, n * sizeof(T)); + paddle::lite::x86::math::CBlas::AXPY( + n, static_cast(M_SQRT1_2), x_data, 1, out_data, 1); + paddle::lite::x86::math::CBlas::VMERF(n, out_data, out_data, VML_LA); + for (int i = 0; i < n; i++) { + out_data[i] += static_cast(1); + } + paddle::lite::x86::math::CBlas::VMUL(n, x_data, out_data, out_data); + for (int i = 0; i < n; i++) { + out_data[i] *= static_cast(0.5); + } +#else + auto temp = (x * static_cast(M_SQRT1_2)).erf(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); +#endif + } +}; + +template +class GeluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~GeluCompute() = default; +}; + +// softsign(x) = x / (1 + |x|) +template +struct SoftsignFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) { + out.device(d) = x / (static_cast(1) + x.abs()); + } +}; + +template +class SoftsignCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + // auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + param.Out->template mutable_data(); + + Activate>(param.X, param.Out); + } + + virtual ~SoftsignCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/activation_compute_test.cc b/lite/kernels/x86/activation_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cc2607e73e605214e08e42e70de457a206e2468 --- /dev/null +++ b/lite/kernels/x86/activation_compute_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/activation_compute.cc" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(relu_x86, retrive_op) { + auto relu = + KernelRegistry::Global().Create("relu"); + ASSERT_FALSE(relu.empty()); + ASSERT_TRUE(relu.front()); +} + +TEST(relu_x86, init) { + ReluComputeCompute relu; + ASSERT_EQ(relu.precision(), PRECISION(kFloat)); + ASSERT_EQ(relu.target(), TARGET(kX86)); +} + +TEST(relu_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector out_shape{batch_size, 3, 2, 2}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + int sign = i % 2 == 0 ? 1 : -1; + x_data[i] = static_cast(i * sign); + } + + // ReluCompute relu; + ReluCompute relu; + operators::Param param; + + param.x = &x; + param.y = &y; + param.out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sequence_expand_as.SetContext(std::move(ctx)); + sequence_expand_as.SetParam(param); + sequence_expand_as.Run(); + auto out_data = out.mutable_data(); + + LOG(INFO) << "output: "; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_expand_as, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/batch_norm_compute.h b/lite/kernels/x86/batch_norm_compute.h index 3a94b99b171e684db9923fc7180195f136f4c414..092280752cb92e1784eefc09cb26fa3bea8eb939 100644 --- a/lite/kernels/x86/batch_norm_compute.h +++ b/lite/kernels/x86/batch_norm_compute.h @@ -13,12 +13,14 @@ // limitations under the License. #pragma once +#include #include #include #include "lite/core/kernel.h" #include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" +#include "lite/operators/batch_norm_op.h" namespace paddle { namespace lite { @@ -42,7 +44,9 @@ class BatchNormCompute : public KernelLite { public: using param_t = operators::BatchNormParam; void Run() override { + // auto &context = ctx_->As(); auto ¶m = *param_.get_mutable(); + param.is_test = true; bool global_stats = param.is_test || param.use_global_stats; const auto *x = param.x; @@ -55,12 +59,12 @@ class BatchNormCompute : public KernelLite { const int sample_size = x->dims().production() / N / C; // alloc memory - param.y->template mutable_data(); + param.y->mutable_data(); if (!param.is_test) { - param.mean_out->template mutable_data(); - param.variance_out->template mutable_data(); - param.saved_mean->template mutable_data(); - param.saved_variance->template mutable_data(); + param.mean_out->mutable_data(); + param.variance_out->mutable_data(); + param.saved_mean->mutable_data(); + param.saved_variance->mutable_data(); } if (!global_stats) { // saved_xx is use just in this batch of data @@ -79,8 +83,7 @@ class BatchNormCompute : public KernelLite { if ((N * sample_size) == 1) { LOG(WARNING) << "Only 1 element in normalization dimension, " << "we skip the batch norm calculation, let y = x."; - framework::TensorCopy( - x->raw_tensor(), platform::CPUPlace(), ¶m.y->raw_tensor()); + param.y->CopyDataFrom(*x); return; } diff --git a/lite/kernels/x86/batch_norm_compute_test.cc b/lite/kernels/x86/batch_norm_compute_test.cc index 254a6a7379e9ab18128020adfe18b206663b7877..5ec2cdcdda0e9ff3698c80584b36396b38328e03 100644 --- a/lite/kernels/x86/batch_norm_compute_test.cc +++ b/lite/kernels/x86/batch_norm_compute_test.cc @@ -15,6 +15,8 @@ #include "lite/kernels/x86/batch_norm_compute.h" #include #include +#include +#include #include #include "lite/core/op_registry.h" @@ -102,7 +104,7 @@ TEST(batch_norm_x86, run_test) { operators::BatchNormParam param; param.x = &x; - param.is_test = false; + param.is_test = true; param.scale = &scale; param.bias = &bias; param.mean = &mean; @@ -116,6 +118,9 @@ TEST(batch_norm_x86, run_test) { param.saved_mean = &saved_mean; param.saved_variance = &saved_variance; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + batch_norm.SetContext(std::move(ctx)); batch_norm.SetParam(param); batch_norm.Run(); diff --git a/lite/kernels/x86/concat_compute.h b/lite/kernels/x86/concat_compute.h index 280320867dd2edf33a600a59bce80ce3375f27d1..3fd1e9f233d2022a1fa0735bd1bc849923e64745 100644 --- a/lite/kernels/x86/concat_compute.h +++ b/lite/kernels/x86/concat_compute.h @@ -18,13 +18,20 @@ #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" -#include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { namespace lite { namespace kernels { namespace x86 { +inline int count(int start_axis, int end_axis, const lite::DDim& dim) { + int count = 1; + for (int i = start_axis; i < end_axis; ++i) { + count *= dim[i]; + } + return count; +} + template class ConcatCompute : public KernelLite { public: @@ -33,67 +40,31 @@ class ConcatCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); int64_t axis = static_cast(param.axis); + auto x_dims = param.x[0]->dims(); auto out = param.output; + if (param.x.size() == 1) { + param.output->ShareDataWith(*param.x[0]); + return; + } - if (axis == 0 && param.x.size() < 10) { - size_t output_offset = 0; - for (auto* in : param.x) { - if (!in || in->dims().production() == 0UL) { - continue; - } - auto in_stride = framework::stride_numel(in->dims().data()); - auto out_stride = framework::stride_numel(out->dims().data()); - paddle::operators::StridedNumelCopyWithAxis( - platform::CPUDeviceContext(), - axis, - out->mutable_data() + output_offset, - out_stride, - in->data(), - in_stride, - in_stride[axis]); - - output_offset += in_stride[axis]; - } - } else { - std::vector inputs; - for (size_t j = 0; j < param.x.size(); ++j) { - if (param.x[j] && param.x[j]->dims().production() > 0) { - inputs.push_back(*param.x[j]); - } else { - continue; - } - } - - int num = inputs.size(); - int rows = 1; - auto dim_0 = inputs[0].dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; - } - int out_rows = rows, out_cols = 0; - - std::vector input_cols(inputs.size()); - for (int i = 0; i < num; ++i) { - int t_cols = inputs[i].dims().production() / rows; - out_cols += t_cols; - input_cols[i] = t_cols; - } - // computation - auto output_data = param.output->template mutable_data(); - int col_idx = 0; - for (int j = 0; j < num; ++j) { - int col_len = input_cols[j]; - auto input_data = inputs[j].data(); - for (int k = 0; k < out_rows; ++k) { - std::memcpy(output_data + k * out_cols + col_idx, - input_data + k * col_len, - sizeof(T) * col_len); - } - col_idx += col_len; + auto output_data = param.output->template mutable_data(); + int offset_concat_axis = 0; + int num_concat = count(0, axis, x_dims); + int concat_input_size = count(axis + 1, x_dims.size(), x_dims); + const int top_concat_axis = out->dims()[axis]; + for (size_t i = 0; i < param.x.size(); ++i) { + auto bottom_data = param.x[i]->data(); + const int64_t bottom_concat_axis = param.x[i]->dims()[axis]; + for (int n = 0; n < num_concat; ++n) { + std::memcpy( + output_data + + (n * top_concat_axis + offset_concat_axis) * concat_input_size, + bottom_data + n * bottom_concat_axis * concat_input_size, + (bottom_concat_axis * concat_input_size) * sizeof(T)); } + offset_concat_axis += bottom_concat_axis; } } - virtual ~ConcatCompute() = default; }; diff --git a/lite/kernels/x86/concat_compute_test.cc b/lite/kernels/x86/concat_compute_test.cc index 5a08903f827982610d094af640d37f45303a521f..468e9422752561ff6416e8859b485462b9e2abbe 100644 --- a/lite/kernels/x86/concat_compute_test.cc +++ b/lite/kernels/x86/concat_compute_test.cc @@ -14,7 +14,6 @@ #include "lite/kernels/x86/concat_compute.h" #include -#include #include #include "lite/core/op_registry.h" @@ -68,11 +67,11 @@ TEST(concat_x86, run_test) { concat.SetParam(param); concat.Run(); - std::cout << "output: "; + std::vector ref_results = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}; for (int i = 0; i < out.dims().production(); i++) { - std::cout << out_data[i] << " "; + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); } - std::cout << std::endl; } } // namespace x86 diff --git a/lite/kernels/x86/conv_compute.h b/lite/kernels/x86/conv_compute.h index 39114e1716a0f1830a739fc034a0845a36c35702..48cb3c74ef3c05675115ab7cec09f16322d1410a 100644 --- a/lite/kernels/x86/conv_compute.h +++ b/lite/kernels/x86/conv_compute.h @@ -16,15 +16,14 @@ #include #include #include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/im2col.h" +#include "lite/backends/x86/math/vol2col.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" +#include "lite/fluid/eigen.h" #include "lite/operators/conv_op.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/depthwise_conv.h" -#include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/operators/math/vol2col.h" namespace paddle { namespace lite { @@ -50,15 +49,14 @@ class Conv2dCompute : public KernelLite { public: using param_t = operators::ConvParam; void Run() override { + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); lite::Tensor filter = *param.filter; - param.output->template mutable_data(); - + param.output->mutable_data(); const int batch_size = static_cast(param.x->dims()[0]); std::vector filter_shape_vec(filter.dims().Vectorize()); std::vector output_shape_vec(param.output->dims().Vectorize()); - size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); col_shape_vec[0] = param.x->dims()[1] / param.groups; @@ -70,7 +68,6 @@ class Conv2dCompute : public KernelLite { lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim + 1); bool is_expand = IsExpand( filter_shape_vec, param.strides, param.paddings, param.dilations); - lite::Tensor col; lite::Tensor col_matrix; if (is_expand) { @@ -80,40 +77,37 @@ class Conv2dCompute : public KernelLite { col_matrix.Resize(col_matrix_shape); } lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size()); - lite::DDim filter_matrix_shape(std::vector{ filter.dims()[0], filter.dims().production() / filter.dims()[0]}); filter.Resize(filter_matrix_shape); - lite::DDim output_matrix_shape(std::vector{ param.output->dims()[1], param.output->dims().production() / (param.output->dims()[0] * param.output->dims()[1])}); - int in_step = static_cast(param.x->dims()[1]) / param.groups; int out_step = static_cast(param.output->dims()[1]) / param.groups; - - paddle::operators::math::Vol2ColFunctor - vol2col; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, - platform::CPUDeviceContext, + paddle::lite::x86::math::Vol2ColFunctor vol2col; + paddle::lite::x86::math::Im2ColFunctor< + paddle::lite::x86::math::ColFormat::kCFO, + lite::TargetType::kX86, T> im2col; - auto blas = paddle::operators::math::GetBlas( - platform::CPUDeviceContext()); + auto blas = + paddle::lite::x86::math::GetBlas(context); for (int i = 0; i < batch_size; i++) { lite::Tensor in_batch; - in_batch.ShareDataWith( - param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); + lite::Tensor tmp_in_batch = param.x->Slice(i, i + 1); + tmp_in_batch.Resize(input_shape); + in_batch.ShareDataWith(tmp_in_batch); lite::Tensor out_batch; - out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( - output_matrix_shape.data())); - + lite::Tensor tmp_out_batch = param.output->Slice(i, i + 1); + tmp_out_batch.Resize(output_matrix_shape); + out_batch.ShareDataWith(tmp_out_batch); for (int g = 0; g < param.groups; g++) { lite::Tensor in_slice; in_slice.ShareDataWith( - in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step)); + in_batch.Slice(static_cast(g * in_step), + static_cast((g + 1) * in_step))); if (!is_expand) { col.ShareDataWith(in_slice); @@ -121,38 +115,40 @@ class Conv2dCompute : public KernelLite { col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { // im2col - im2col(platform::CPUDeviceContext(), - in_slice.raw_tensor(), + im2col(context, + in_slice, param.dilations, param.strides, std::vector{param.paddings[0], param.paddings[1], param.paddings[0], param.paddings[1]}, - &(col.raw_tensor())); + &(col)); } else if (data_dim == 3U) { // vol2col - vol2col(platform::CPUDeviceContext(), - in_slice.raw_tensor(), + vol2col(context, + in_slice, param.dilations, param.strides, param.paddings, - &(col.raw_tensor())); + &(col)); } // gemm lite::Tensor out_slice; out_slice.ShareDataWith( - out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); + out_batch.Slice(static_cast(g * out_step), + static_cast((g + 1) * out_step))); lite::Tensor filter_slice; filter_slice.ShareDataWith( - filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); - blas.MatMul(filter_slice.raw_tensor(), + filter.Slice(static_cast(g * out_step), + static_cast((g + 1) * out_step))); + blas.MatMul(filter_slice, false, - col_matrix.raw_tensor(), + col_matrix, false, T(1.0), - &(out_slice.raw_tensor()), + &(out_slice), T(0.0)); } } diff --git a/lite/kernels/x86/conv_compute_test.cc b/lite/kernels/x86/conv_compute_test.cc index 17efae41601925e217067ce07677bfc10da75bc9..f2dde962b9e77ce26336d17f07f29f5874ef9722 100644 --- a/lite/kernels/x86/conv_compute_test.cc +++ b/lite/kernels/x86/conv_compute_test.cc @@ -14,6 +14,8 @@ #include "lite/kernels/x86/conv_compute.h" #include +#include +#include #include #include "lite/core/op_registry.h" @@ -38,7 +40,7 @@ TEST(conv2d_x86, init) { TEST(conv2d_x86, run_test) { lite::Tensor x, filter, b, out; - constexpr int batch_size = 1; + const int batch_size = 1; std::vector x_shape{batch_size, 3, 3, 3}; x.Resize(lite::DDim(x_shape)); std::vector filter_shape{1, 3, 3, 3}; @@ -74,13 +76,17 @@ TEST(conv2d_x86, run_test) { param.paddings = {0, 0}; param.groups = 1; param.dilations = {1, 1}; - + LOG(INFO) << 123; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + conv2d.SetContext(std::move(ctx)); conv2d.SetParam(param); conv2d.Run(); LOG(INFO) << "output: "; + float ref_result[1] = {27.}; for (int i = 0; i < out.dims().production(); i++) { - LOG(INFO) << out_data[i] << " "; + EXPECT_NEAR(out_data[i], ref_result[i], 1e-5); } } diff --git a/lite/kernels/x86/dropout_compute.h b/lite/kernels/x86/dropout_compute.h index de8730d1981573c439b90c3e0933340abc78a76d..2ba383bdbdc99e7643f3bf09350f833665c8548e 100644 --- a/lite/kernels/x86/dropout_compute.h +++ b/lite/kernels/x86/dropout_compute.h @@ -13,12 +13,14 @@ // limitations under the License. #pragma once +#include #include #include #include "lite/core/kernel.h" #include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" +#include "lite/operators/dropout_op.h" namespace paddle { namespace lite { @@ -28,7 +30,7 @@ namespace x86 { template -using EigenMatrix = framework::EigenMatrix; +using EigenMatrix = lite::fluid::EigenMatrix; template class DropoutCompute : public KernelLite { @@ -37,16 +39,16 @@ class DropoutCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); const auto* x_data = param.x->data(); - auto* out_data = param.output->template mutable_data(); + auto* out_data = param.output->mutable_data(); if (!param.is_test) { - auto* mask_data = param.mask->template mutable_data(); + auto* mask_data = param.mask->mutable_data(); std::random_device rnd; std::minstd_rand engine; int seed = param.fix_seed ? param.seed : rnd(); engine.seed(seed); std::uniform_real_distribution dist(0, 1); - size_t size = framework::product(param.mask->dims().data()); + size_t size = param.mask->dims().production(); for (size_t i = 0; i < size; ++i) { if (dist(engine) < param.dropout_prob) { mask_data[i] = 0; @@ -62,13 +64,13 @@ class DropoutCompute : public KernelLite { } } } else { - auto X = EigenMatrix::Reshape(param.x->raw_tensor(), 1); - auto Y = EigenMatrix::Reshape(param.output->raw_tensor(), 1); - auto& place = *platform::CPUDeviceContext().eigen_device(); + auto X = EigenMatrix::Reshape(*param.x, 1); + auto Y = EigenMatrix::Reshape(*param.output, 1); if (param.dropout_implementation == "upscale_in_train") { - Y.device(place) = X; + Y.device(lite::fluid::EigenDeviceType()) = X; } else { - Y.device(place) = X * static_cast(1.0f - param.dropout_prob); + Y.device(lite::fluid::EigenDeviceType()) = + X * static_cast(1.0f - param.dropout_prob); } } } diff --git a/lite/kernels/x86/dropout_compute_test.cc b/lite/kernels/x86/dropout_compute_test.cc index f68b92a1722fac9f794d9a3b56db1b5a5e0da511..279f639f40ece0a10e45fe16f36fcb443cea550a 100644 --- a/lite/kernels/x86/dropout_compute_test.cc +++ b/lite/kernels/x86/dropout_compute_test.cc @@ -15,6 +15,8 @@ #include "lite/kernels/x86/dropout_compute.h" #include #include +#include +#include #include #include "lite/core/op_registry.h" @@ -60,7 +62,9 @@ TEST(dropout_x86, run_test) { param.is_test = true; param.fix_seed = true; param.output = &out; - + std::unique_ptr ctx(new KernelContext); + ctx->As(); + dropout.SetContext(std::move(ctx)); dropout.SetParam(param); dropout.Run(); diff --git a/lite/kernels/x86/elementwise_compute.cc b/lite/kernels/x86/elementwise_compute.cc index b0c9e958a0fead9e866e9a5471e78a5f9d386da3..710e67956b055b84323a23443c671682704dd2c2 100644 --- a/lite/kernels/x86/elementwise_compute.cc +++ b/lite/kernels/x86/elementwise_compute.cc @@ -35,21 +35,3 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); - -#ifdef LITE_WITH_X86 -REGISTER_LITE_KERNEL( - elementwise_sub_grad, - kX86, - kFloat, - kNCHW, - paddle::lite::kernels::x86::ElementwiseSubGradCompute, - def) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput(paddle::framework::GradVarName("Out"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput(paddle::framework::GradVarName("X"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput(paddle::framework::GradVarName("Y"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .Finalize(); -#endif diff --git a/lite/kernels/x86/elementwise_compute.h b/lite/kernels/x86/elementwise_compute.h index d93f11312594afab64738f28062eea138448abe4..c5598545f112e1d44739c6c88980f74875127836 100644 --- a/lite/kernels/x86/elementwise_compute.h +++ b/lite/kernels/x86/elementwise_compute.h @@ -15,11 +15,8 @@ #include "lite/core/kernel.h" #include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "lite/fluid/eigen.h" +#include "lite/kernels/x86/elementwise_op_function.h" namespace paddle { namespace lite { @@ -45,74 +42,17 @@ class ElementwiseSubCompute void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context()); param.Out->template mutable_data(); - paddle::operators::ElementwiseComputeEx, - platform::CPUDeviceContext, - T>(*context.x86_execution_context(), - ¶m.X->raw_tensor(), - ¶m.Y->raw_tensor(), - param.axis, - SubFunctor(), - ¶m.Out->raw_tensor()); + paddle::lite::kernels::x86::ElementwiseComputeEx, + lite::TargetType::kX86, + T>( + context, param.X, param.Y, param.axis, SubFunctor(), param.Out); } virtual ~ElementwiseSubCompute() = default; }; -template -struct SubGradDX { - T operator()(T x, T y, T out, T dout) const { return dout; } -}; - -template -struct SubGradDY { - T operator()(T x, T y, T out, T dout) const { return -dout; } -}; - -#ifdef LITE_WITH_X86 -template -class ElementwiseSubGradCompute - : public KernelLite { - public: - using param_t = operators::ElementwiseGradParam; - void Run() override { - auto& param = *param_.get_mutable(); - auto& context = ctx_->As(); - CHECK(context.x86_device_context()); - - param.X_grad->template mutable_data(); - // skip out, x, y - auto dout = param.Out_grad->raw_tensor(); - auto dx = param.X_grad->raw_tensor(); - - framework::Tensor* dy = nullptr; - if (param.Y_grad) { - param.Y_grad->template mutable_data(); - dy = ¶m.Y_grad->raw_tensor(); - } - auto& skip = dout; - paddle::operators::ElemwiseExplicitGradCompute, - SubGradDY>( - *context.x86_execution_context(), - skip, - skip, - skip, - dout, - param.axis, - &dx, - dy, - SubGradDX(), - SubGradDY()); - } - - virtual ~ElementwiseSubGradCompute() = default; -}; -#endif - template class ElementwiseAddCompute : public KernelLite { @@ -121,16 +61,11 @@ class ElementwiseAddCompute void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - CHECK(context.x86_device_context()); param.Out->template mutable_data(); - paddle::operators::ElementwiseComputeEx, - platform::CPUDeviceContext, - T>(*context.x86_execution_context(), - ¶m.X->raw_tensor(), - ¶m.Y->raw_tensor(), - param.axis, - AddFunctor(), - ¶m.Out->raw_tensor()); + paddle::lite::kernels::x86::ElementwiseComputeEx, + lite::TargetType::kX86, + T>( + context, param.X, param.Y, param.axis, AddFunctor(), param.Out); } virtual ~ElementwiseAddCompute() = default; diff --git a/lite/kernels/x86/elementwise_compute_test.cc b/lite/kernels/x86/elementwise_compute_test.cc index 5d0f9fd57a4a54a63820a2eeff34434f2e2669af..9850c0ce86756cd12e28ab95688b79a1c539189c 100644 --- a/lite/kernels/x86/elementwise_compute_test.cc +++ b/lite/kernels/x86/elementwise_compute_test.cc @@ -74,9 +74,9 @@ TEST(elementwise_add_x86, run_test) { elementwise_add.SetContext(std::move(ctx)); elementwise_add.Run(); - LOG(INFO) << "output: "; + std::vector ref_results = {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}; for (int i = 0; i < out.dims().production(); i++) { - LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); } } diff --git a/lite/kernels/x86/elementwise_op_function.h b/lite/kernels/x86/elementwise_op_function.h new file mode 100644 index 0000000000000000000000000000000000000000..40116479f6f4d6dc8658c2d781a48b7a07dd20c9 --- /dev/null +++ b/lite/kernels/x86/elementwise_op_function.h @@ -0,0 +1,642 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "lite/fluid/eigen.h" +#include "lite/fluid/transform.h" +#include "lite/utils/paddle_enforce.h" + +#include "lite/backends/x86/math/math_function.h" +#include "lite/fluid/for_range.h" +#include "lite/utils/variant.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/* + * Out = X ⊙ Y + * If Y's shape does not match X' shape, they will be reshaped. + * For example: + * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + * pre=2, n=3*4, post=5 + * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) + * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) + * pre=2*3, n=4*5, post=1 + * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) + * + * New parameter: *mid_flag* is added to solve m*n*k & m*1*k + * broadcast cases. + * 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5) + * mid_flag should not be NULL. + * x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20) + */ +inline void get_mid_dims(const lite::DDim &x_dims, + const lite::DDim &y_dims, + const int axis, + int *pre, + int *n, + int *post, + int *mid_flag = NULL) { + *pre = 1; + *n = 1; + *post = 1; + if (mid_flag != NULL) { + *mid_flag = 0; + int mid = 0; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + if (x_dims[i + axis] != y_dims[i]) { + // only support single y_dims[i] = 1 now. + PADDLE_ENFORCE_EQ( + *mid_flag, 0, "Broadcast support y_dims with single 1."); + PADDLE_ENFORCE_EQ(y_dims[i], 1, "Broadcast dimension mismatch."); + // m*n*k m*1*k + for (int j = 0; j < i; ++j) { + (*pre) *= y_dims[j]; + } + *n = std::max(x_dims[i + axis], y_dims[i]); + *mid_flag = 1; + mid = i; + break; + } + (*n) *= y_dims[i]; + } + if (*mid_flag) { + for (int i = mid + 1; i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } else { + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } + } else { // for fused_elementwise_activation_op. keep the old version. + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + + for (int i = 0; i < y_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + x_dims[i + axis], y_dims[i], "Broadcast dimension mismatch."); + (*n) *= y_dims[i]; + } + + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } +} + +inline lite::DDim trim_trailing_singular_dims(const lite::DDim &dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (int i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return lite::DDim(); + } + lite::DDim actual_dims = lite::DDim(trim_dims); + return actual_dims; +} + +template +class RowwiseTransformIterator; + +template +class MidWiseTransformIterator; + +// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 +template +class RowwiseTransformIterator + : public std::iterator { + public: + RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + + RowwiseTransformIterator &operator++() { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + return *this; + } + + RowwiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + + return *this; + } + + bool operator==( + const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int i_; + int64_t n_; +}; + +template +class MidWiseTransformIterator + : public std::iterator { + public: + MidWiseTransformIterator(const T *ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator &operator++() { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + return *this; + } + + MidWiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + } + return *this; + } + + bool operator==( + const MidWiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const MidWiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int64_t i_; + int64_t j_; + int64_t n_; + int64_t post_; +}; + +template +class TransformFunctor { + public: + TransformFunctor(const lite::Tensor *x, + const lite::Tensor *y, + lite::Tensor *z, + const lite::Context &ctx, + Functor func) + : x_(x->data()), + y_(y->data()), + z_(z->mutable_data()), + nx_(x->numel()), + ctx_(ctx), + func_(func) {} + + inline void Run() const { + lite::fluid::Transform trans; + trans(ctx_, x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + lite::fluid::Transform trans; + trans(ctx_, + x_, + x_ + nx_, + RowwiseTransformIterator(y_, n), + z_, + func_); + } + + inline void RunMidWise(int n, int pre, int post) const { + lite::fluid::Transform trans; + trans(ctx_, + x_, + x_ + nx_, + MidWiseTransformIterator(y_, n, post), + z_, + func_); + } + + inline void RunMidRowWise(int n, int pre, int post) const { + lite::fluid::Transform trans; + for (int i = 0; i < pre; i++) { + trans(ctx_, + x_ + i * n * post, + x_ + (i + 1) * n * post, + RowwiseTransformIterator(y_ + i * post, post), + z_ + i * n * post, + func_); + } + } + + private: + const T *x_; + const T *y_; + OutType *z_; + int64_t nx_; + const lite::Context &ctx_; + Functor func_; +}; + +template + +void ElementwiseComputeEx(const lite::Context &ctx, + const lite::Tensor *x, + const lite::Tensor *y, + int axis, + Functor func, + lite::Tensor *z) { + TransformFunctor functor(x, y, z, ctx, func); + auto x_dims = x->dims(); + auto y_dims_untrimed = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + y_dims_untrimed.size(), + "Rank of first input must >= rank of second input."); + if (x_dims == y_dims_untrimed) { + functor.Run(); + return; + } + + axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + int pre, n, post, mid_flag = 0; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &mid_flag); + if (mid_flag) { + functor.RunMidRowWise(n, pre, post); + return; + } + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } +} + +// FusedElemwiseAndAct +// --- forward +template +struct FusedElemwiseAndActNoBroadcast { + HOSTDEVICE void operator()(size_t i) { + T y_val = y_[i]; + T x_val = x_[i]; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val); + intermediate_out_[i] = intermeidiate_out; + out_[i] = + compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out_[i] = compound_functor_.GetOut(x_val, y_val); + } + } + + const T *x_; + const T *y_; + CompoundFunctor compound_functor_; + T *out_; + T *intermediate_out_; +}; + +// FusedElemwiseAndActBroadcast1: +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2, +// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20) +template +static void FusedElemwiseAndActBroadcast1CPU(const T *x, + const T *y, + CompoundFunctor compound_functor, + int h, + int w, + T *out, + T *intermediate_out) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int offset = i * w + j; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = + compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } +} + +// FusedElemwiseAndActBroadcast2 +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1, +// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1) +// pre = 2, n = 12, post = 5 +template +static void FusedElemwiseAndActBroadcast2CPU(const T *x, + const T *y, + int pre, + int n, + int post, + CompoundFunctor compound_functor, + T *out, + T *intermediate_out) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int offset = i * n * post + j * post + k; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + + if (KeepIntermediateOut) { + T intermeidiate_out = + compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = compound_functor.GetOutUseIntermediateOut( + x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } + } +} + +template +void FusedElemwiseAndActComputeNoBroadcast(const lite::Context &ctx, + const lite::DDim &x_dim, + const lite::Tensor &x, + const lite::Tensor &y, + CompoundFunctor compound_functor, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + size_t N = static_cast(x_dim.production()); + + lite::fluid::ForRange for_range(ctx, N); + + for_range( + FusedElemwiseAndActNoBroadcast{ + x.data(), + y.data(), + compound_functor, + out->mutable_data(), + intermediate_out == nullptr ? nullptr + : intermediate_out->mutable_data()}); +} + +template +void FusedElemwiseAndActComputeWithBroadcast(const lite::Context &ctx, + const lite::DDim &x_dim, + const lite::DDim &y_dim_untrimed, + const lite::Tensor &x, + const lite::Tensor &y, + CompoundFunctor compound_functor, + int axis, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); + auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + + if (post == 1) { + int h = pre; + int w = n; + FusedElemwiseAndActBroadcast1CPU( + x.data(), + y.data(), + compound_functor, + h, + w, + out->mutable_data(), + intermediate_out == nullptr ? nullptr + : intermediate_out->mutable_data()); + + } else { + FusedElemwiseAndActBroadcast2CPU( + x.data(), + y.data(), + pre, + n, + post, + compound_functor, + out->mutable_data(), + intermediate_out == nullptr ? nullptr + : intermediate_out->mutable_data()); + } +} + +template +void FusedElemwiseAndActComputeEx(const lite::Context &ctx, + const lite::Tensor &x, + const lite::Tensor &y, + int axis, + CompoundFunctor compound_functor, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + if (KeepIntermediateOut) { + PADDLE_ENFORCE(intermediate_out, + "The save_intermediate_out is opened, " + "intermediate_out should not be nullptr."); + } + + const lite::DDim &x_dim = x.dims(); + const lite::DDim &y_dim = y.dims(); + if (x.dims() == y.dims()) { + FusedElemwiseAndActComputeNoBroadcast( + ctx, x_dim, x, y, compound_functor, out, intermediate_out); + } else { + // Whether the shape of Y is a continuous subsequence of X, + // For more information please refer to the op's introduction. + bool bcast_y = x.dims().size() >= y.dims().size(); + if (x.dims().size() == y.dims().size()) { + for (int i = 0; i < x.dims().size(); ++i) { + if (x.dims()[i] < y.dims()[i]) { + bcast_y = false; + break; + } + } + } + + // z = f1(x, f2(y)) + // z = f1(f2(x, y)) + if (bcast_y) { // Y should be broadcast. + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape + // of Y. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of X. + FusedElemwiseAndActComputeWithBroadcast( + ctx, + x_dim /*OutShape*/, + y_dim, + x, + y, + compound_functor, + axis, + out, + intermediate_out); + } else { + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape + // of Out. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of Y. + FusedElemwiseAndActComputeWithBroadcast( + ctx, + y_dim /*OutShape*/, + x_dim, + x, + y, + compound_functor, + axis, + out, + intermediate_out); + } + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/fill_constant_batch_size_like_compute.cc b/lite/kernels/x86/fill_constant_batch_size_like_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..75c0ce8a8860d514677c8cfa2791ebda170f0105 --- /dev/null +++ b/lite/kernels/x86/fill_constant_batch_size_like_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h" + +REGISTER_LITE_KERNEL( + fill_constant_batch_size_like, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::FillConstantBatchSizeLikeCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/fill_constant_batch_size_like_compute.h b/lite/kernels/x86/fill_constant_batch_size_like_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..411a114e3f3ec82775c60f5f9a0642aae606eeda --- /dev/null +++ b/lite/kernels/x86/fill_constant_batch_size_like_compute.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class FillConstantBatchSizeLikeCompute + : public KernelLite { + public: + using param_t = operators::FillConstantBatchSizeLikeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& ctx = ctx_->As(); + auto* out = param.Out; + auto* in = param.Input; + if (in->lod().size() && param.input_dim_idx == 0) { + // set the correct batch size for the LoDTensor. + auto odims = out->dims(); + int output_dim_idx = param.output_dim_idx; + odims[output_dim_idx] = static_cast(in->lod().back().size()) - 1; + out->Resize(odims); + // out->mutable_data(); + } + out->mutable_data(); + auto value = param.value; + + paddle::lite::x86::math::SetConstant setter; + setter(ctx, out, static_cast(value)); + } + + virtual ~FillConstantBatchSizeLikeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/fill_constant_batch_size_like_compute_test.cc b/lite/kernels/x86/fill_constant_batch_size_like_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2504e19e149fe8494df53ab22584bebcb295c4f --- /dev/null +++ b/lite/kernels/x86/fill_constant_batch_size_like_compute_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(fill_constant_batch_size_like_x86, retrive_op) { + auto fill_constant_batch_size_like = + KernelRegistry::Global().Create( + "fill_constant_batch_size_like"); + ASSERT_FALSE(fill_constant_batch_size_like.empty()); + ASSERT_TRUE(fill_constant_batch_size_like.front()); +} + +TEST(fill_constant_batch_size_like_x86, init) { + lite::kernels::x86::FillConstantBatchSizeLikeCompute + fill_constant_batch_size_like; + ASSERT_EQ(fill_constant_batch_size_like.precision(), PRECISION(kFloat)); + ASSERT_EQ(fill_constant_batch_size_like.target(), TARGET(kX86)); +} + +TEST(fill_constant_batch_size_like_x86, run_test) { + lite::Tensor input; + lite::Tensor out; + std::vector input_shape{219, 232}; + input.Resize(input_shape); + std::vector out_shape{219, 132, 7}; + out.Resize(out_shape); + + auto input_data = input.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < input.dims().production(); ++i) { + input_data[i] = static_cast(i); + } + + FillConstantBatchSizeLikeCompute fill_constant_batch_size_like; + operators::FillConstantBatchSizeLikeParam param; + param.Input = &input; + param.Out = &out; + std::vector shape{-1, 132, 7}; + float value = 3.5; + param.shape = shape; + param.value = value; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + fill_constant_batch_size_like.SetContext(std::move(ctx)); + fill_constant_batch_size_like.SetParam(param); + fill_constant_batch_size_like.Run(); + + std::vector ref_results{ + 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5}; + for (int i = 0; i < ref_results.size(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(fill_constant_batch_size_like, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/gelu_compute_test.cc b/lite/kernels/x86/gelu_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..20479760e916613f14745d8b7316e094950f6a46 --- /dev/null +++ b/lite/kernels/x86/gelu_compute_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/x86/activation_compute.cc" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(gelu_x86, retrive_op) { + auto gelu = + KernelRegistry::Global().Create("gelu"); + ASSERT_FALSE(gelu.empty()); + ASSERT_TRUE(gelu.front()); +} + +TEST(gelu_x86, init) { + GeluCompute gelu; + ASSERT_EQ(gelu.precision(), PRECISION(kFloat)); + ASSERT_EQ(gelu.target(), TARGET(kX86)); +} + +TEST(gelu_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector out_shape{batch_size, 3, 2, 2}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + int sign = i % 2 == 0 ? 1 : -1; + x_data[i] = static_cast(i * sign) * 0.8f; + } + // GeluCompute gelu; + GeluCompute gelu; + operators::ActivationParam param; + + param.X = &x; + param.Out = &out; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gelu.SetContext(std::move(ctx)); + gelu.SetParam(param); + gelu.Run(); + + LOG(INFO) << "output: "; + std::vector ref_data{0., + -0.169484, + 1.512321, + -0.019674, + 3.197801, + -0.000126719, + 4.8, + -0., + 6.4000001, + -0., + 8., + -0.}; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(gelu, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/gru_compute.cc b/lite/kernels/x86/gru_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8e70833aaa9b4e2914c13f3ae40c84a5083c909 --- /dev/null +++ b/lite/kernels/x86/gru_compute.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gru_compute.h" + +DEFINE_int32(paddle_num_threads, + 1, + "Number of threads for each paddle instance."); + +REGISTER_LITE_KERNEL(gru, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::GRUCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("H0", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/gru_compute.h b/lite/kernels/x86/gru_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..e3c6f70fdbe3d0e0ff025c7b41528b50ff06fca3 --- /dev/null +++ b/lite/kernels/x86/gru_compute.h @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/detail/gru_cpu_kernel.h" +#include "lite/backends/x86/math/detail/gru_kernel.h" +#include "lite/backends/x86/math/gru_compute.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/backends/x86/math/sequence2batch.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" + +DECLARE_int32(paddle_num_threads); + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +using Tensor = lite::Tensor; + +template +inline void ReorderInitState(const lite::Context& context, + const Tensor& src, + const std::vector& index_lod, + Tensor* dst, + bool indexed_src) { + lite::x86::math::CopyMatrixRowsFunctor row_shuffle; + dst->Resize(src.dims()); + dst->mutable_data(); + row_shuffle(context, src, index_lod, dst, indexed_src); +} + +template +class GRUCompute : public KernelLite { + public: + void Run() override { + auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + + bool origin_mode = param.origin_mode; + bool is_reverse = param.is_reverse; + + auto* input = param.input; + auto* h0 = param.h0; + auto* weight = param.weight; + const T* weight_data = weight->data(); + auto* bias = param.bias; + + auto* batch_gate = param.batch_gate; + batch_gate->mutable_data(); + auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev; + batch_reset_hidden_prev->mutable_data(); + auto* batch_hidden = param.batch_hidden; + batch_hidden->mutable_data(); + auto* hidden = param.hidden; + hidden->mutable_data(); + + auto hidden_dims = hidden->dims(); + + lite::x86::math::LoDTensor2BatchFunctor to_batch; + to_batch(context, *input, batch_gate, true, is_reverse); + + if (bias) { + lite::x86::math::RowwiseAdd add_bias; + add_bias(context, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + lite::x86::math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + std::vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(context, *h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.mutable_data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t seq_len = batch_starts.size() - 1; + auto active_node = + lite::x86::math::detail::GetActivationType(param.activation); + auto active_gate = + lite::x86::math::detail::GetActivationType(param.gate_activation); + +#ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM + if (FLAGS_paddle_num_threads >= 4) { + auto blas = lite::x86::math::GetBlas(context); + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + CHECK(packed_gate); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size * 2, + frame_size, + T(1.0), + gru_value.gate_weight, + frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + CHECK(packed_state); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size, + frame_size, + T(1.0), + gru_value.state_weight, + frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int64_t bstart = static_cast(batch_starts[n]); + int64_t bend = static_cast(batch_starts[n + 1]); + int64_t cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.mutable_data(); + gru_value.gate_value = gate_t.mutable_data(); + gru_value.reset_output_value = reset_hidden_prev_t.mutable_data(); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, + CblasPacked, + cur_batch_size, + frame_size * 2, + frame_size, + gru_value.prev_out_value, + frame_size, + packed_gate, + frame_size * 2, + T(1), + gru_value.gate_value, + frame_size * 3); + } + + lite::x86::math::detail::forward_final_output( + lite::x86::math::detail::forward::gru_finalOutput(), + gru_value, + frame_size, + cur_batch_size, + active_node, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } + + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { +#endif + for (size_t n = 0; n < seq_len; n++) { + int64_t bstart = static_cast(batch_starts[n]); + int64_t bend = static_cast(batch_starts[n + 1]); + int64_t cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.mutable_data(); + gru_value.gate_value = gate_t.mutable_data(); + gru_value.reset_output_value = reset_hidden_prev_t.mutable_data(); + + lite::x86::math::GRUUnitFunctor::compute( + context, + gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + } +#endif + lite::x86::math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(context, *batch_hidden, hidden); + } +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/gru_compute_test.cc b/lite/kernels/x86/gru_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e0e944f23bafda6a5eb742a8e4b023c268c9955 --- /dev/null +++ b/lite/kernels/x86/gru_compute_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gru_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(gru_x86, retrive_op) { + auto gru = + KernelRegistry::Global().Create("gru"); + ASSERT_FALSE(gru.empty()); + ASSERT_TRUE(gru.front()); +} + +TEST(gru_x86, init) { + GRUCompute gru; + ASSERT_EQ(gru.precision(), PRECISION(kFloat)); + ASSERT_EQ(gru.target(), TARGET(kX86)); +} + +TEST(gru_x86, run_test) { + lite::Tensor input, h0, weight, bias; + lite::Tensor batch_gate, batch_reset_hidden_prev, batch_hidden, hidden; + constexpr int batch_size = 9; + std::vector input_shape{batch_size, 15}; + input.Resize(lite::DDim(input_shape)); + std::vector weight_shape{5, 15}; + weight.Resize(lite::DDim(weight_shape)); + std::vector h0_shape{3, 5}; + h0.Resize(lite::DDim(h0_shape)); + std::vector bias_shape{1, 15}; + bias.Resize(lite::DDim(bias_shape)); + std::vector batch_gate_shape{batch_size, 15}; + batch_gate.Resize(lite::DDim(batch_gate_shape)); + std::vector batch_reset_hidden_prev_shape{batch_size, 5}; + batch_reset_hidden_prev.Resize(lite::DDim(batch_reset_hidden_prev_shape)); + std::vector batch_hidden_shape{batch_size, 5}; + batch_hidden.Resize(lite::DDim(batch_hidden_shape)); + std::vector hidden_shape{batch_size, 5}; + hidden.Resize(lite::DDim(hidden_shape)); + + std::vector> lod{{0, 2, 6, 9}}; + input.set_lod(lod); + + auto input_data = input.mutable_data(); + auto weight_data = weight.mutable_data(); + auto h0_data = h0.mutable_data(); + auto bias_data = bias.mutable_data(); + + for (int64_t i = 0; i < input.dims().production(); i++) { + input_data[i] = static_cast(0); + } + for (int64_t i = 0; i < weight.dims().production(); i++) { + weight_data[i] = static_cast(0); + } + for (int64_t i = 0; i < h0.dims().production(); i++) { + h0_data[i] = static_cast(0); + } + for (int64_t i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(0); + } + // ReluCompute relu; + GRUCompute gru; + operators::GRUParam param; + + param.input = &input; + param.h0 = &h0; + param.weight = &weight; + param.bias = &bias; + param.batch_gate = &batch_gate; + param.batch_reset_hidden_prev = &batch_reset_hidden_prev; + param.batch_hidden = &batch_hidden; + param.hidden = &hidden; + param.gate_activation = "sigmoid"; + param.activation = "tanh"; + param.is_reverse = false; + param.origin_mode = false; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gru.SetContext(std::move(ctx)); + gru.SetParam(param); + gru.Run(); + + auto batch_gate_data = batch_gate.mutable_data(); + auto batch_reset_hidden_prev_data = + batch_reset_hidden_prev.mutable_data(); + auto batch_hidden_data = batch_hidden.mutable_data(); + auto hidden_data = hidden.mutable_data(); + std::vector batch_gate_out{ + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0}; + std::vector batch_reset_hidden_prev_out{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector batch_hidden_out{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector hidden_out{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + LOG(INFO) << "output: "; + for (int i = 0; i < batch_gate.dims().production(); i++) { + LOG(INFO) << batch_gate_data[i]; + EXPECT_NEAR(batch_gate_data[i], batch_gate_out[i], 1e-3); + } + for (int i = 0; i < batch_reset_hidden_prev.dims().production(); i++) { + LOG(INFO) << batch_reset_hidden_prev_data[i]; + EXPECT_NEAR( + batch_reset_hidden_prev_data[i], batch_reset_hidden_prev_out[i], 1e-3); + } + for (int i = 0; i < batch_hidden.dims().production(); i++) { + LOG(INFO) << batch_hidden_data[i]; + EXPECT_NEAR(batch_hidden_data[i], batch_hidden_out[i], 1e-3); + } + for (int i = 0; i < hidden.dims().production(); i++) { + LOG(INFO) << hidden_data[i]; + EXPECT_NEAR(hidden_data[i], hidden_out[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(gru, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/lookup_table_compute.cc b/lite/kernels/x86/lookup_table_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..364593251e17453011bad5b2c1057fc25d54d7c8 --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/lookup_table_compute.h" + +// REGISTER_LITE_KERNEL(lookup_table, kX86, kFloat, kNCHW, +// paddle::lite::kernels::x86::LookupTableCompute, +// def) +// .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) +// .Finalize(); +//, +REGISTER_LITE_KERNEL(lookup_table, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::LookupTableCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/lookup_table_compute.h b/lite/kernels/x86/lookup_table_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..e0d7752ca77c810700f57722c4186b4e02d6411f --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class LookupTableCompute : public KernelLite { + public: + using param_t = operators::LookupTableParam; + + void Run() override { + auto ¶m = *param_.get_mutable(); + // auto& context = context_->As(); + auto *ids_t = param.Ids; + auto *output_t = param.Out; + int64_t padding_idx = param.padding_idx; + auto *ids = ids_t->data(); + int64_t ids_numel = ids_t->dims().production(); + + auto *table_t = param.W; + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(); + memset(output, 0, output_t->dims().production() * sizeof(float)); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != -1 && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(float)); + } else { + CHECK_LT(ids[i], row_number); + CHECK_GE(ids[i], 0); + memcpy(output + i * row_width, + table + ids[i] * row_width, + row_width * sizeof(float)); + } + } + } + + virtual ~LookupTableCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/matmul_compute.cc b/lite/kernels/x86/matmul_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6949e018cb2764c712c306df71c784d0134787e9 --- /dev/null +++ b/lite/kernels/x86/matmul_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/matmul_compute.h" + +REGISTER_LITE_KERNEL(matmul, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::MatMulCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/matmul_compute.h b/lite/kernels/x86/matmul_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2b3c7482c266d0c8771c9be1dbac540a315528 --- /dev/null +++ b/lite/kernels/x86/matmul_compute.h @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "lite/backends/x86/math/blas.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static lite::DDim RowMatrixFromVector(const lite::DDim &x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return lite::DDim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static lite::DDim ColumnMatrixFromVector(const lite::DDim &y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return lite::DDim({y_dim[0], 1}); +} + +template +class MatMulCompute : public KernelLite { + public: + using param_t = operators::MatMulParam; + + void Run() override { + auto &context = ctx_->As(); + auto ¶m = *param_.get_mutable(); + + auto *x = param.X; + auto *y = param.Y; + auto *out = param.Out; + out->mutable_data(); + + auto blas = lite::x86::math::GetBlas(context); + auto mat_dim_a = lite::x86::math::CreateMatrixDescriptor( + RowMatrixFromVector(x->dims()), 0, param.transpose_X); + auto mat_dim_b = lite::x86::math::CreateMatrixDescriptor( + ColumnMatrixFromVector(y->dims()), 0, param.transpose_Y); + auto scale = static_cast(param.alpha); + blas.MatMul(*x, mat_dim_a, *y, mat_dim_b, scale, out, T(0)); + } + + virtual ~MatMulCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/matmul_compute_test.cc b/lite/kernels/x86/matmul_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..53d2d1a47a0cdbdaf5dfa83a79987d908171a36d --- /dev/null +++ b/lite/kernels/x86/matmul_compute_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/matmul_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(matmul_x86, retrive_op) { + auto matmul = + KernelRegistry::Global().Create( + "matmul"); + ASSERT_FALSE(matmul.empty()); + ASSERT_TRUE(matmul.front()); +} + +TEST(matmul_x86, init) { + lite::kernels::x86::MatMulCompute matmul; + ASSERT_EQ(matmul.precision(), PRECISION(kFloat)); + ASSERT_EQ(matmul.target(), TARGET(kX86)); +} + +TEST(matmul_x86, run_test) { + lite::Tensor x, y, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector y_shape{2, 4}; + y.Resize(lite::DDim(y_shape)); + std::vector out_shape{batch_size, 3, 4}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto y_data = y.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < y.dims().production(); i++) { + y_data[i] = static_cast(i); + } + // MatMulCompute matmul; + MatMulCompute matmul; + operators::MatMulParam param; + + param.X = &x; + param.Y = &y; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + matmul.SetContext(std::move(ctx)); + matmul.SetParam(param); + matmul.Run(); + + std::vector ref_result = {4, 5, 6, 7, 12, 17, 22, 27, 20, 29, 38, 47}; + + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_result[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(matmul, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/mul_compute.cc b/lite/kernels/x86/mul_compute.cc index d021a73532527b57895c635bf7a554562a98953f..64558f66772381ad402a3eb203bb6efd9fceff60 100644 --- a/lite/kernels/x86/mul_compute.cc +++ b/lite/kernels/x86/mul_compute.cc @@ -25,18 +25,20 @@ REGISTER_LITE_KERNEL(mul, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); -REGISTER_LITE_KERNEL(mul_grad, - kX86, - kFloat, - kNCHW, - paddle::lite::kernels::x86::MulGradCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput(paddle::framework::GradVarName("Out"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput(paddle::framework::GradVarName("X"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput(paddle::framework::GradVarName("Y"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .Finalize(); +// #ifdef LITE_WITH_TRAIN +// REGISTER_LITE_KERNEL(mul_grad, +// kX86, +// kFloat, +// kNCHW, +// paddle::lite::kernels::x86::MulGradCompute, +// def) +// .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindInput(paddle::framework::GradVarName("Out"), +// {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindOutput(paddle::framework::GradVarName("X"), +// {LiteType::GetTensorTy(TARGET(kX86))}) +// .BindOutput(paddle::framework::GradVarName("Y"), +// {LiteType::GetTensorTy(TARGET(kX86))}) +// .Finalize(); +// #endif diff --git a/lite/kernels/x86/mul_compute.h b/lite/kernels/x86/mul_compute.h index ae47d4a59e63fcac0beb360642b4061c42487e04..e204fc81f28de4af43d63e289b01d81188502988 100644 --- a/lite/kernels/x86/mul_compute.h +++ b/lite/kernels/x86/mul_compute.h @@ -13,17 +13,26 @@ // limitations under the License. #pragma once +#include "lite/backends/x86/math/blas.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" -#include "paddle/fluid/operators/math/blas.h" - namespace paddle { namespace lite { namespace kernels { namespace x86 { -using Tensor = framework::Tensor; +// using Tensor = framework::Tensor; +inline lite::Tensor ReshapeToMatrix(const lite::Tensor& src, int num_col_dims) { + int rank = src.dims().size(); + if (rank == 2) { + return src; + } + lite::Tensor res; + res.ShareDataWith(src); + res.Resize(src.dims().Flatten2D(num_col_dims)); + return res; +} template class MulCompute : public KernelLite { @@ -33,36 +42,35 @@ class MulCompute : public KernelLite { void Run() override { auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context()); + // CHECK(context.x86_device_context()); - param.output->template mutable_data(); + auto* z = param.output; - auto* x = ¶m.x->raw_tensor(); - auto* y = ¶m.y->raw_tensor(); + auto* x = param.x; + auto* y = param.y; Tensor x_matrix, y_matrix; if (x->dims().size() > 2) { - x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims); + x_matrix = ReshapeToMatrix(*x, param.x_num_col_dims); } else { x_matrix = *x; } if (y->dims().size() > 2) { - y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims); + y_matrix = ReshapeToMatrix(*y, param.y_num_col_dims); } else { y_matrix = *y; } - auto* z = ¶m.output->raw_tensor(); + z->mutable_data(); auto z_dim = z->dims(); if (z_dim.size() != 2) { z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - auto blas = paddle::operators::math::GetBlas( - *context.x86_device_context()); + auto blas = lite::x86::math::GetBlas(context); blas.MatMul(x_matrix, y_matrix, z); if (z_dim.size() != 2) { @@ -73,6 +81,7 @@ class MulCompute : public KernelLite { virtual ~MulCompute() = default; }; +#ifdef LITE_WITH_TRAIN template class MulGradCompute : public KernelLite { public: @@ -142,6 +151,7 @@ class MulGradCompute : public KernelLite { virtual ~MulGradCompute() = default; }; +#endif } // namespace x86 } // namespace kernels diff --git a/lite/kernels/x86/mul_compute_test.cc b/lite/kernels/x86/mul_compute_test.cc index 6737b750414866b23458ba6bcf560ec32370bff1..32d82cbb77aeb71dcd1c172ec0c1e343c3954fea 100644 --- a/lite/kernels/x86/mul_compute_test.cc +++ b/lite/kernels/x86/mul_compute_test.cc @@ -19,7 +19,6 @@ #include #include #include "lite/core/op_registry.h" - namespace paddle { namespace lite { namespace kernels { @@ -33,7 +32,7 @@ TEST(mul_x86, retrive_op) { } TEST(mul_x86, init) { - MulCompute mul; + lite::kernels::x86::MulCompute mul; ASSERT_EQ(mul.precision(), PRECISION(kFloat)); ASSERT_EQ(mul.target(), TARGET(kX86)); } @@ -72,9 +71,10 @@ TEST(mul_x86, run_test) { mul.SetParam(param); mul.Run(); - LOG(INFO) << "output: "; + std::vector ref_result = {20, 23, 26, 29}; + for (int i = 0; i < out.dims().production(); i++) { - LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_result[i], 1e-3); } } diff --git a/lite/kernels/x86/pool_compute.h b/lite/kernels/x86/pool_compute.h index 1e3ba36a7e68e4d8108cbe485d241a5035d78401..57bcddcec9512d626962465e717b7a202cfe0b17 100644 --- a/lite/kernels/x86/pool_compute.h +++ b/lite/kernels/x86/pool_compute.h @@ -14,12 +14,12 @@ #pragma once #include +#include "lite/backends/x86/math/math_function.h" +#include "lite/backends/x86/math/pooling.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/math/pooling.h" +#include "lite/fluid/eigen.h" namespace paddle { namespace lite { @@ -31,6 +31,7 @@ class PoolCompute : public KernelLite { public: using param_t = operators::PoolParam; void Run() override { + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); if (param.global_pooling) { for (size_t i = 0; i < param.ksize.size(); ++i) { @@ -41,37 +42,37 @@ class PoolCompute : public KernelLite { switch (param.ksize.size()) { case 2: { if (param.pooling_type == "max") { - paddle::operators::math::Pool2dFunctor< - platform::CPUDeviceContext, - paddle::operators::math::MaxPool, + paddle::lite::x86::math::Pool2dFunctor< + lite::TargetType::kX86, + paddle::lite::x86::math::MaxPool, T> pool2d_forward; - paddle::operators::math::MaxPool pool_process; - pool2d_forward(platform::CPUDeviceContext(), - param.x->raw_tensor(), + paddle::lite::x86::math::MaxPool pool_process; + pool2d_forward(context, + param.x, param.ksize, param.strides, param.paddings, pool_process, true, false, - &(param.output->raw_tensor())); + param.output); } else if (param.pooling_type == "avg") { - paddle::operators::math::Pool2dFunctor< - platform::CPUDeviceContext, - paddle::operators::math::AvgPool, + paddle::lite::x86::math::Pool2dFunctor< + lite::TargetType::kX86, + paddle::lite::x86::math::AvgPool, T> pool2d_forward; - paddle::operators::math::AvgPool pool_process; - pool2d_forward(platform::CPUDeviceContext(), - param.x->raw_tensor(), + paddle::lite::x86::math::AvgPool pool_process; + pool2d_forward(context, + param.x, param.ksize, param.strides, param.paddings, pool_process, param.exclusive, param.adaptive, - &(param.output->raw_tensor())); + param.output); } } break; case 3: { diff --git a/lite/kernels/x86/pool_compute_test.cc b/lite/kernels/x86/pool_compute_test.cc index 9b073b35edf452d04f797ec67fcd30cd726ef059..87b75a0760bca45057f25b2cb948a66feb22496c 100644 --- a/lite/kernels/x86/pool_compute_test.cc +++ b/lite/kernels/x86/pool_compute_test.cc @@ -15,6 +15,8 @@ #include "lite/kernels/x86/pool_compute.h" #include #include +#include +#include #include #include "lite/core/op_registry.h" @@ -61,13 +63,18 @@ TEST(pool2d_x86, run_test) { param.paddings = {0, 0}; param.ksize = {2, 2}; param.pooling_type = "max"; - + std::unique_ptr ctx(new KernelContext); + ctx->As(); + pool2d.SetContext(std::move(ctx)); pool2d.SetParam(param); pool2d.Run(); LOG(INFO) << "output: "; + float ref_result[12] = { + 5., 7., 13., 15., 21., 23., 29., 31., 37., 39., 45., 47.}; for (int i = 0; i < out.dims().production(); i++) { LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_result[i], 1e-5); } } diff --git a/lite/kernels/x86/reduce_compute.cc b/lite/kernels/x86/reduce_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..f95f4cfb881fef329ea940ca8b9fa6b4fd6ff7b6 --- /dev/null +++ b/lite/kernels/x86/reduce_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/reduce_compute.h" + +REGISTER_LITE_KERNEL(reduce_sum, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ReduceSumCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/reduce_compute.h b/lite/kernels/x86/reduce_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..655f104ce65906f1904a7cf02d703069b0a7a2bf --- /dev/null +++ b/lite/kernels/x86/reduce_compute.h @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" +#include "lite/kernels/x86/reduce_op_function.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +struct SumFunctor { + template + void operator()(X* x, Y* y, const Dim& dim) { + y->device(lite::fluid::EigenDeviceType()) = x->sum(dim); + } +}; + +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + paddle::lite::kernels::x86:: \ + ReduceFunctor( \ + *input, output, dims, keep_dim); \ + } + +template +class ReduceSumCompute : public KernelLite { + public: + using param_t = operators::ReduceParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = ctx_->As(); + bool reduce_all = param.reduce_all; + auto* input = param.x; + auto* output = param.output; + param.output->mutable_data(); + + auto dims = param.dim; + bool keep_dim = param.keep_dim; + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = lite::fluid::EigenVector::Flatten(*input); + auto out = lite::fluid::EigenScalar::From(output); + // auto& place = *platform::CPUDeviceContext().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + SumFunctor functor; + functor(&x, &out, reduce_dim); + } else { + int ndim = input->dims().size(); + int rdim = dims.size(); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); + } + } + + virtual ~ReduceSumCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reduce_op_function.h b/lite/kernels/x86/reduce_op_function.h new file mode 100644 index 0000000000000000000000000000000000000000..b3ddab64e4bf8dc72cec3b86398f42269c5a947c --- /dev/null +++ b/lite/kernels/x86/reduce_op_function.h @@ -0,0 +1,84 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenTensor = lite::fluid::EigenTensor; +template +using EigenScalar = lite::fluid::EigenScalar; +template +using EigenVector = lite::fluid::EigenVector; + +template +// const lite::Context& context, +void ReduceFunctor(const lite::Tensor& input, + lite::Tensor* output, + const std::vector& dims, + bool keep_dim) { + auto x = EigenTensor::From(input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + // construct the squeezed output tensor + lite::DDim out_dims = output->dims(); + if (keep_dim && x_rank > 1) { + const int kDelFlag = -2; + auto dims_vector = out_dims.Vectorize(); + for (size_t i = 0; i < dims_ref.size(); ++i) { + dims_vector[dims_ref[i]] = kDelFlag; + } + dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = lite::DDim(dims_vector); + } + // auto& place = *context.eigen_device(); + Functor functor; + + if (D == 1) { + auto out = EigenScalar::From(output); + functor(&x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(&x, &out, reduce_dim); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/relu_compute_test.cc b/lite/kernels/x86/relu_compute_test.cc index ec446de73f0fa210c0eaf7740b719a7667c4a6e6..37ed6db7f919e31828f89462fa46d5263c480fcc 100644 --- a/lite/kernels/x86/relu_compute_test.cc +++ b/lite/kernels/x86/relu_compute_test.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/x86/relu_compute.h" #include #include #include #include "lite/core/op_registry.h" +#include "lite/kernels/x86/activation_compute.h" namespace paddle { namespace lite { @@ -64,6 +64,8 @@ TEST(relu_x86, run_test) { LOG(INFO) << "output: "; for (int i = 0; i < out.dims().production(); i++) { LOG(INFO) << out_data[i]; + int sign = i % 2 == 0 ? 1 : 0; + ASSERT_EQ(out_data[i], i * sign); } } diff --git a/lite/kernels/x86/reshape_compute.cc b/lite/kernels/x86/reshape_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..7afe4f6d8bc4740c00d3ed95fafc4e32f59b6d02 --- /dev/null +++ b/lite/kernels/x86/reshape_compute.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/reshape_compute.h" + +REGISTER_LITE_KERNEL(reshape, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ReshapeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(reshape2, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::Reshape2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); +REGISTER_LITE_KERNEL(reshape2, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::Reshape2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("XShape", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/x86/reshape_compute.h b/lite/kernels/x86/reshape_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..948c4ec31d7b3a7cf16f23582e6e17ea54dd081c --- /dev/null +++ b/lite/kernels/x86/reshape_compute.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/operators/reshape_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +void Compute(const lite::Tensor* in, lite::Tensor* out) { + auto out_dims = out->dims(); + auto in_dims = in->dims(); + out->CopyDataFrom(*in); + out->Resize(out_dims); +} + +template +class ReshapeCompute : public KernelLite { + public: + using param_t = operators::ReshapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + Compute(param.x, param.output); + } + + virtual ~ReshapeCompute() = default; +}; + +template +void reshape2_compute() {} + +template +class Reshape2Compute : public KernelLite { + public: + using param_t = operators::ReshapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + Compute(param.x, param.output); + } + + virtual ~Reshape2Compute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/reshape_compute_test.cc b/lite/kernels/x86/reshape_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..16fc8f31aded0ef62fdf14aa671a73ccf6635fb7 --- /dev/null +++ b/lite/kernels/x86/reshape_compute_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/reshape_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +// reshape +TEST(reshape_x86, retrive_op) { + auto reshape = + KernelRegistry::Global().Create( + "reshape"); + ASSERT_FALSE(reshape.empty()); + ASSERT_TRUE(reshape.front()); +} + +TEST(reshape_x86, init) { + lite::kernels::x86::ReshapeCompute reshape; + ASSERT_EQ(reshape.precision(), PRECISION(kFloat)); + ASSERT_EQ(reshape.target(), TARGET(kX86)); +} + +TEST(reshape_x86, run_test) { + lite::Tensor x, actual_shape; + lite::Tensor out; + std::vector x_shape({1, 2, 4, 1}); + x.Resize(lite::DDim(x_shape)); + actual_shape.Resize(lite::DDim(std::vector({4}))); + std::vector out_shape({1, 8, 1, 1}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto actual_data = actual_shape.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + actual_data[0] = 1; + actual_data[1] = 8; + actual_data[2] = 1; + actual_data[1] = 1; + + std::vector shape({1, 8, 1, 1}); + + // ReshapeCompute reshape; + ReshapeCompute reshape; + operators::ReshapeParam param; + + param.x = &x; + param.output = &out; + param.shape_vct = shape; + param.shape_tensor = &actual_shape; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + for (int i = 0; i < 2; ++i) { + if (1 == i) param.shape_tensor = nullptr; + reshape.SetContext(std::move(ctx)); + reshape.SetParam(param); + reshape.Run(); + + for (int j = 0; j < out.dims().production(); ++j) { + EXPECT_NEAR(out_data[j], x_data[j], 1e-5); + } + } +} + +// reshape2 +TEST(reshape2_x86, retrive_op) { + auto reshape2 = + KernelRegistry::Global().Create( + "reshape2"); + ASSERT_FALSE(reshape2.empty()); + ASSERT_TRUE(reshape2.front()); +} + +TEST(reshape2_x86, init) { + lite::kernels::x86::Reshape2Compute reshape2; + ASSERT_EQ(reshape2.precision(), PRECISION(kFloat)); + ASSERT_EQ(reshape2.target(), TARGET(kX86)); +} + +TEST(reshape2_x86, run_test) { + lite::Tensor x, actual_shape; + lite::Tensor out, xshape; + std::vector x_shape({1, 2, 4}); + x.Resize(lite::DDim(x_shape)); + actual_shape.Resize(lite::DDim(std::vector({3}))); + std::vector out_shape({1, 4, 2}); + out.Resize(lite::DDim(out_shape)); + std::vector xshape_shape({1, 4, 2}); + xshape.Resize(lite::DDim(xshape_shape)); + + auto x_data = x.mutable_data(); + auto actual_data = actual_shape.mutable_data(); + auto out_data = out.mutable_data(); + auto xshape_data = xshape.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + xshape_data[i] = static_cast(i); + } + actual_data[0] = 1; + actual_data[1] = 4; + actual_data[2] = 2; + + std::vector shape({1, 4, 2}); + + // Reshape2Compute reshape2; + Reshape2Compute reshape2; + operators::ReshapeParam param; + + param.x = &x; + param.output = &out; + param.xshape = &xshape; + param.shape_vct = shape; + param.shape_tensor = &actual_shape; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + for (int i = 0; i < 2; ++i) { + if (1 == i) param.shape_tensor = nullptr; + reshape2.SetContext(std::move(ctx)); + reshape2.SetParam(param); + reshape2.Run(); + + for (int j = 0; j < out.dims().production(); ++j) { + EXPECT_NEAR(out_data[j], x_data[j], 1e-5); + } + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(reshape, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(reshape2, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/sequence_expand_as_compute.cc b/lite/kernels/x86/sequence_expand_as_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e030969764bd7e45ebc4c76b509c9217ba4d216 --- /dev/null +++ b/lite/kernels/x86/sequence_expand_as_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_expand_as_compute.h" + +REGISTER_LITE_KERNEL(sequence_expand_as, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceExpandAsCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_expand_as_compute.h b/lite/kernels/x86/sequence_expand_as_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..16759c1b9f1d136d5aaf58d4531882ab6a2618a2 --- /dev/null +++ b/lite/kernels/x86/sequence_expand_as_compute.h @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +using Tensor = lite::Tensor; + +template +struct SequenceExpandFunctor { + void operator()(const Tensor &x, + const std::vector &ref_lod, /*expand referenced lod*/ + Tensor *out) { + int64_t hight = x.dims()[0]; + int64_t width = x.data_size() / hight; + + const T *in_data = x.data(); + T *out_data = out->mutable_data(); + + for (int h_id = 0; h_id < hight; ++h_id) { + size_t span = ref_lod[h_id + 1] - ref_lod[h_id]; + if (span == 0) continue; + const T *src = in_data + h_id * width; + for (int64_t w_id = 0; w_id < width; ++w_id) { + T ele = src[w_id]; + size_t offset = ref_lod[h_id] * width; + for (size_t k = 0; k < span; ++k) { + out_data[offset + k * width + w_id] = ele; + } + } + } + } +}; + +template +class SequenceExpandAsCompute + : public KernelLite { + public: + void Run() override { + auto ¶m = *param_.get_mutable(); + + auto *x = param.x; + auto *y = param.y; + auto *out = param.out; + + auto &y_lod = y->lod(); + CHECK_EQ(y_lod.size(), 1); + CHECK_GT(y_lod[0].size(), 1); + + out->mutable_data(); + + SequenceExpandFunctor seq_espand_functor; + seq_espand_functor(*x, y_lod[0], out); + } +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_expand_as_compute_test.cc b/lite/kernels/x86/sequence_expand_as_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d49fdbb7a6164435abb9eb7189b18376066d55df --- /dev/null +++ b/lite/kernels/x86/sequence_expand_as_compute_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_expand_as_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(sequence_expand_as_x86, retrive_op) { + auto sequence_expand_as = + KernelRegistry::Global().Create( + "sequence_expand_as"); + ASSERT_FALSE(sequence_expand_as.empty()); + ASSERT_TRUE(sequence_expand_as.front()); +} + +TEST(sequence_expand_as_x86, init) { + SequenceExpandAsCompute sequence_expand_as; + ASSERT_EQ(sequence_expand_as.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_expand_as.target(), TARGET(kX86)); +} + +TEST(sequence_expand_as_x86, run_test) { + lite::Tensor x, y, out; + std::vector x_shape{4, 1}; + x.Resize(lite::DDim(x_shape)); + std::vector y_shape{1, 5}; + y.Resize(lite::DDim(y_shape)); + std::vector out_shape{8, 1}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto y_data = y.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < y.dims().production(); i++) { + y_data[i] = static_cast(i); + } + + std::vector> lod{{0, 3, 6, 7, 8}}; + y.set_lod(lod); + // MulCompute mul; + SequenceExpandAsCompute sequence_expand_as; + operators::SequenceExpandAsParam param; + + param.x = &x; + param.y = &y; + param.out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sequence_expand_as.SetContext(std::move(ctx)); + sequence_expand_as.SetParam(param); + sequence_expand_as.Run(); + auto out_data = out.mutable_data(); + + int index = 1; + int lod_sum = lod[0][index]; + LOG(INFO) << "output: "; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + if (i >= lod_sum) { + index++; + lod_sum = lod[0][index]; + } + ASSERT_EQ(out_data[i], x_data[index - 1]); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_expand_as, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/sequence_pool_compute.cc b/lite/kernels/x86/sequence_pool_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..46b38b7e8cb521c9a3ce343f66778c21acf659f0 --- /dev/null +++ b/lite/kernels/x86/sequence_pool_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_pool_compute.h" + +REGISTER_LITE_KERNEL(sequence_pool, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequencePoolCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_pool_compute.h b/lite/kernels/x86/sequence_pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..329a76658d342078ed5d708125d9ff01e0ecef02 --- /dev/null +++ b/lite/kernels/x86/sequence_pool_compute.h @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/backends/x86/math/math_function.h" +#include "lite/backends/x86/math/sequence_pooling.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequencePoolCompute : public KernelLite { + public: + using param_t = operators::SequencePoolParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + auto* out = param.Out; + auto dims = param.X->dims(); + auto lod = param.X->lod(); + CHECK_EQ(lod.size(), 1UL); + CHECK_GE(dims[0], static_cast(lod[0].size() - 1)); + + dims[0] = lod[0].size() - 1; + out->Resize({dims}); + out->mutable_data(); + lite::Tensor* index = nullptr; + + const bool is_test = true; + float pad_value = 0.0; + + lite::x86::math::SequencePoolFunctor pool; + pool(context, param.pool_type, pad_value, *param.X, out, is_test, index); + } + + virtual ~SequencePoolCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_pool_compute_test.cc b/lite/kernels/x86/sequence_pool_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..93cc122f7a6c5c19602bda53e697b6768120870f --- /dev/null +++ b/lite/kernels/x86/sequence_pool_compute_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_pool_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(sequence_pool_x86, retrive_op) { + auto sequence_pool = + KernelRegistry::Global().Create( + "sequence_pool"); + ASSERT_FALSE(sequence_pool.empty()); + ASSERT_TRUE(sequence_pool.front()); +} + +TEST(sequence_pool_x86, init) { + SequencePoolCompute sequence_pool; + ASSERT_EQ(sequence_pool.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_pool.target(), TARGET(kX86)); +} + +TEST(sequence_pool_x86, run_test) { + lite::Tensor x, out; + lite::LoD lod; + lod.push_back(std::vector{0, 10}); + + x.set_lod(lod); + const size_t second_dim = 8u; + std::vector input_shape{static_cast(lod[0].back()), + static_cast(second_dim)}; + lite::DDim in_dims(input_shape); + x.Resize(in_dims); + + const size_t out_first_dim = lod[0].size() - 1; + std::vector output_shape{static_cast(out_first_dim), + static_cast(second_dim)}; + lite::DDim out_dims(output_shape); + out.Resize(out_dims); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = 1.1f * i; + } + + SequencePoolCompute sequence_pool; + operators::SequencePoolParam param; + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sequence_pool.SetContext(std::move(ctx)); + sequence_pool.SetParam(param); + sequence_pool.Run(); + + std::vector ref_results = { + 39.6, 40.7, 41.8, 42.9, 44, 45.1, 46.2, 47.3}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_pool, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/sequence_reshape_compute.cc b/lite/kernels/x86/sequence_reshape_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..ccaeef27d7439b739b298f3b0756e2a2eddef2c1 --- /dev/null +++ b/lite/kernels/x86/sequence_reshape_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_reshape_compute.h" + +REGISTER_LITE_KERNEL( + sequence_reshape, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::SequenceReshapeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_reshape_compute.h b/lite/kernels/x86/sequence_reshape_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..68a573c2f674edcf0a09cccec730a8d7dbcea844 --- /dev/null +++ b/lite/kernels/x86/sequence_reshape_compute.h @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequenceReshapeCompute + : public KernelLite { + public: + using param_t = operators::SequenceReshapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = context_->As(); + auto* in = param.x; + auto* out = param.output; + int out_width = param.new_dim; + + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + // LOG(INFO)<<"sequence_reshape in tensor:"<<*in; + auto& in_lod = in->lod(); + + CHECK_EQ(in_lod.size(), 1UL); + CHECK_EQ((uint64_t)in_dims[0], in_lod[0].back()); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + if (in_width == out_width) { + out->set_lod(in->lod()); + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + CHECK_EQ(offset * out_width, seq_len * in_width); + out_lod[0][i + 1] = out_lod[0][i] + offset; + } + } + + out->Resize(in_dims); + auto* dst_ptr = out->mutable_data(); + auto size = in->numel() * sizeof(T); + std::memcpy(dst_ptr, in->data(), size); + std::vector out_shape{static_cast(out->lod()[0].back()), + out_width}; + out->Resize(lite::DDim(out_shape)); + } + + virtual ~SequenceReshapeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/shape_compute.cc b/lite/kernels/x86/shape_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..eed4c8d77099ca203a7dcd4637f106ae48fd6728 --- /dev/null +++ b/lite/kernels/x86/shape_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/shape_compute.h" + +REGISTER_LITE_KERNEL(shape, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ShapeCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/shape_compute.h b/lite/kernels/x86/shape_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..ee3678a7f1c6651226c479aeedcacce91085b295 --- /dev/null +++ b/lite/kernels/x86/shape_compute.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class ShapeCompute : public KernelLite { + public: + using param_t = operators::ShapeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = context_->As(); + auto out_data = param.Out->mutable_data(); + auto in_dims = param.X->dims(); + for (int i = 0; i < in_dims.size(); ++i) { + out_data[i] = in_dims[i]; + } + } + + virtual ~ShapeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/shape_compute_test.cc b/lite/kernels/x86/shape_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88bd98f33ffc7a727de584543bc7392cdbb2883f --- /dev/null +++ b/lite/kernels/x86/shape_compute_test.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/shape_compute.h" +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(shape_x86, retrive_op) { + auto shape = + KernelRegistry::Global().Create("shape"); + ASSERT_FALSE(shape.empty()); + ASSERT_TRUE(shape.front()); +} + +TEST(shape_x86, init) { + ShapeCompute shape; + ASSERT_EQ(shape.precision(), PRECISION(kFloat)); + ASSERT_EQ(shape.target(), TARGET(kX86)); +} + +TEST(shape_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 1, 3, 3}; + x.Resize(lite::DDim(x_shape)); + + std::vector out_shape{4}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = 1; + } + + ShapeCompute shape; + operators::ShapeParam param; + param.X = &x; + param.Out = &out; + + shape.SetParam(param); + shape.Run(); + + std::vector ref_results = {1, 1, 3, 3}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(shape, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/slice_compute.cc b/lite/kernels/x86/slice_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..00602ce62b80814e2b78460122bd3ed3cc8b81a8 --- /dev/null +++ b/lite/kernels/x86/slice_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/slice_compute.h" + +REGISTER_LITE_KERNEL(slice, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SliceCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/slice_compute.h b/lite/kernels/x86/slice_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..a3540cafdf4f219cae659b1a818d793302aab12c --- /dev/null +++ b/lite/kernels/x86/slice_compute.h @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/fluid/eigen.h" +#include "lite/operators/relu_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +void slice_compute(const lite::Tensor* in, + lite::Tensor* out, + std::vector axes, + std::vector starts, + std::vector ends, + std::vector decrease_axis) { + auto out_dims = out->dims(); + auto in_dims = in->dims(); + + // resize out_dims + if (decrease_axis.size() > 0) { + if (decrease_axis.size() == (size_t)in_dims.size()) { + std::vector vec_origin_out_shape(decrease_axis.size(), 1); + // lite::DDim dims(vec_origin_out_shape); + out->Resize(vec_origin_out_shape); + } else { + std::vector vec_origin_out_shape( + out_dims.size() + decrease_axis.size(), -1); + for (size_t i = 0; i < decrease_axis.size(); ++i) { + vec_origin_out_shape[decrease_axis[i]] = 1; + } + int index = 0; + for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) { + if (-1 == vec_origin_out_shape[i]) { + vec_origin_out_shape[i] = out_dims[index]; + ++index; + } + } + // lite::DDim dims(vec_origin_out_shape); + out->Resize(vec_origin_out_shape); + } + } + + out->mutable_data(lite::TargetType::kX86); + + auto new_out_dims = out->dims(); + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = new_out_dims[i]; + } + int start; + for (size_t i = 0; i < axes.size(); ++i) { + start = starts[i]; + if (start < 0) { + start = (start + in_dims[axes[i]]); + } + start = std::max(start, 0); + offsets[axes[i]] = start; + } + auto in_t = + lite::fluid::EigenTensor:: + From(*in, in->dims()); + auto out_t = + lite::fluid::EigenTensor:: + From(*out, new_out_dims); + out_t = in_t.slice(offsets, extents); + + out->Resize(out_dims); +} + +template +void slice_compute_(const lite::Tensor* Input, + lite::Tensor* Out, + std::vector axes, + std::vector starts, + std::vector ends, + std::vector decrease_axis) { + int rank = Input->dims().size(); + switch (rank) { + case 1: + slice_compute<1>(Input, Out, axes, starts, ends, decrease_axis); + break; + case 2: + slice_compute<2>(Input, Out, axes, starts, ends, decrease_axis); + break; + case 3: + slice_compute<3>(Input, Out, axes, starts, ends, decrease_axis); + break; + case 4: + slice_compute<4>(Input, Out, axes, starts, ends, decrease_axis); + break; + case 5: + slice_compute<5>(Input, Out, axes, starts, ends, decrease_axis); + break; + case 6: + slice_compute<6>(Input, Out, axes, starts, ends, decrease_axis); + break; + } +} + +template +class SliceCompute : public KernelLite { + public: + using param_t = operators::SliceParam; + + void Run() override { + auto& param = *param_.get_mutable(); + slice_compute_(param.X, + param.Out, + param.axes, + param.starts, + param.ends, + param.decrease_axis); + } + + virtual ~SliceCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/slice_compute_test.cc b/lite/kernels/x86/slice_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..db3cb35ccbe248d800a5975bcd62d9f1216f3997 --- /dev/null +++ b/lite/kernels/x86/slice_compute_test.cc @@ -0,0 +1,265 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/slice_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(slice_x86, retrive_op) { + auto slice = + KernelRegistry::Global().Create("slice"); + ASSERT_FALSE(slice.empty()); + ASSERT_TRUE(slice.front()); +} + +TEST(slice_x86, init) { + lite::kernels::x86::SliceCompute slice; + ASSERT_EQ(slice.precision(), PRECISION(kFloat)); + ASSERT_EQ(slice.target(), TARGET(kX86)); +} + +void test_case1(lite::Tensor x, lite::Tensor out) { + std::vector x_shape({3}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + std::vector starts({-3}); + std::vector ends({3}); + std::vector axes({0}); + + // SliceCompute slice; + SliceCompute slice; + operators::SliceParam param; + + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + slice.SetContext(std::move(ctx)); + slice.SetParam(param); + slice.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +void test_case2(lite::Tensor x, lite::Tensor out) { + std::vector x_shape({3, 4}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 4}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + std::vector starts({-3, 0}); + std::vector ends({3, 100}); + std::vector axes({0, 1}); + + // SliceCompute slice; + SliceCompute slice; + operators::SliceParam param; + + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + slice.SetContext(std::move(ctx)); + slice.SetParam(param); + slice.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +void test_case3(lite::Tensor x, lite::Tensor out) { + std::vector x_shape({3, 4, 5}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 4, 2}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + std::vector starts({-3, 0, 2}); + std::vector ends({3, 100, -1}); + std::vector axes({0, 1, 2}); + + // SliceCompute slice; + SliceCompute slice; + operators::SliceParam param; + + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + slice.SetContext(std::move(ctx)); + slice.SetParam(param); + slice.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} +void test_case4(lite::Tensor x, lite::Tensor out) { + std::vector x_shape({3, 4, 5, 6}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 4, 2, 6}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + std::vector starts({-3, 0, 2}); + std::vector ends({3, 100, -1}); + std::vector axes({0, 1, 2}); + + // SliceCompute slice; + SliceCompute slice; + operators::SliceParam param; + + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + slice.SetContext(std::move(ctx)); + slice.SetParam(param); + slice.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +void test_case5(lite::Tensor x, lite::Tensor out) { + std::vector x_shape({3, 4, 5, 6, 3}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 4, 2, 6, 3}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + std::vector starts({-3, 0, 2}); + std::vector ends({3, 100, -1}); + std::vector axes({0, 1, 2}); + + // SliceCompute slice; + SliceCompute slice; + operators::SliceParam param; + + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + slice.SetContext(std::move(ctx)); + slice.SetParam(param); + slice.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} +void test_case6(lite::Tensor x, lite::Tensor out) { + std::vector x_shape({3, 4, 5, 6, 5, 2}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 4, 2, 6, 5, 2}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + std::vector starts({-3, 0, 2}); + std::vector ends({3, 100, -1}); + std::vector axes({0, 1, 2}); + + // SliceCompute slice; + SliceCompute slice; + operators::SliceParam param; + + param.X = &x; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + slice.SetContext(std::move(ctx)); + slice.SetParam(param); + slice.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +TEST(slice_x86, run_test) { + lite::Tensor x; + lite::Tensor out; + + test_case1(x, out); + test_case2(x, out); + test_case3(x, out); + test_case4(x, out); + test_case5(x, out); + test_case6(x, out); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(slice, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/softmax_compute.h b/lite/kernels/x86/softmax_compute.h index 8769ffcf03a6781a1ad41aa10d9cdf71d98ad4c5..169644db05e2fc9b83b11e068e03d6a44d5d06b7 100644 --- a/lite/kernels/x86/softmax_compute.h +++ b/lite/kernels/x86/softmax_compute.h @@ -14,12 +14,9 @@ #pragma once #include +#include "lite/backends/x86/math/softmax.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/softmax.h" namespace paddle { namespace lite { namespace kernels { @@ -55,7 +52,7 @@ class SoftmaxCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); - // auto& context = context_->As(); + auto& context = ctx_->As(); CHECK(param.output); CHECK(param.x); param.output->mutable_data(); @@ -72,13 +69,8 @@ class SoftmaxCompute : public KernelLite { out_2d.ShareDataWith(*param.output); out_2d.Resize(lite::DDim(shape)); - paddle::operators::math::SoftmaxFunctor()( - platform::CPUDeviceContext(), - axis_dim, - &input_2d.raw_tensor(), - &out_2d.raw_tensor()); + lite::x86::math::SoftmaxFunctor()( + context, axis_dim, &input_2d, &out_2d); } virtual ~SoftmaxCompute() = default; diff --git a/lite/kernels/x86/softmax_compute_test.cc b/lite/kernels/x86/softmax_compute_test.cc index 2ea20b8c0a344a53870bb1f4020746d251683466..6f18931d6bbcc8b7274ae3d294acd2e0dd1dc636 100644 --- a/lite/kernels/x86/softmax_compute_test.cc +++ b/lite/kernels/x86/softmax_compute_test.cc @@ -14,7 +14,8 @@ #include "lite/kernels/x86/softmax_compute.h" #include -#include +#include +#include #include #include "lite/core/op_registry.h" @@ -54,15 +55,24 @@ TEST(softmax_x86, run_test) { SoftmaxCompute softmax; operators::SoftmaxParam param; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + softmax.SetContext(std::move(ctx)); + param.x = &x; param.output = &out; softmax.SetParam(param); softmax.Run(); - LOG(INFO) << "output: "; + std::vector ref_results = { + 0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241, + 0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241, + 0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241, + 0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241, + 0.0900306, 0.244728, 0.665241}; for (int i = 0; i < out.dims().production(); i++) { - LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_results[i], 1e-3); } } diff --git a/lite/kernels/x86/squeeze_compute.cc b/lite/kernels/x86/squeeze_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..17ecd0c49bd0ee96b525f688b9d1f7bce100232a --- /dev/null +++ b/lite/kernels/x86/squeeze_compute.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/squeeze_compute.h" + +REGISTER_LITE_KERNEL(squeeze, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SqueezeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(squeeze2, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::Squeeze2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/squeeze_compute.h b/lite/kernels/x86/squeeze_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..67086f8c732d412064c6bb0bece7e8208f8a0799 --- /dev/null +++ b/lite/kernels/x86/squeeze_compute.h @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/operators/squeeze_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SqueezeCompute : public KernelLite { + public: + using param_t = operators::SqueezeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto x = param.X; + auto output = param.Out; + auto x_dims = x->dims(); + auto* x_data = x->data(); + auto* out_data = output->mutable_data(); + memcpy(out_data, x_data, x_dims.production() * sizeof(T)); + } + + virtual ~SqueezeCompute() = default; +}; + +template +class Squeeze2Compute : public KernelLite { + public: + using param_t = operators::SqueezeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto x = param.X; + auto output = param.Out; + auto xshape = param.XShape; + auto x_dims = x->dims(); + auto* x_data = x->data(); + auto* out_data = output->mutable_data(); + auto* xshape_data = xshape->mutable_data(); + memcpy(out_data, x_data, x_dims.production() * sizeof(T)); + memcpy(xshape_data, x_data, x_dims.production() * sizeof(T)); + } + + virtual ~Squeeze2Compute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/squeeze_compute_test.cc b/lite/kernels/x86/squeeze_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0799a522b339951521dc2a80b00e447e19657a62 --- /dev/null +++ b/lite/kernels/x86/squeeze_compute_test.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/squeeze_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +// squeeze +TEST(squeeze_x86, retrive_op) { + auto squeeze = + KernelRegistry::Global().Create( + "squeeze"); + ASSERT_FALSE(squeeze.empty()); + ASSERT_TRUE(squeeze.front()); +} + +TEST(squeeze_x86, init) { + lite::kernels::x86::SqueezeCompute squeeze; + ASSERT_EQ(squeeze.precision(), PRECISION(kFloat)); + ASSERT_EQ(squeeze.target(), TARGET(kX86)); +} + +TEST(squeeze_x86, run_test) { + lite::Tensor x; + lite::Tensor out; + std::vector x_shape({1, 3, 1, 5}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 5}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + // SqueezeCompute squeeze; + SqueezeCompute squeeze; + operators::SqueezeParam param; + + param.X = &x; + param.Out = &out; + std::vector> ref_res({{3, 5}, {3, 5}}); + std::vector> axes({{0, -2}, {}}); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + for (int i = 0; i < 2; ++i) { + param.axes = axes[i]; + squeeze.SetContext(std::move(ctx)); + squeeze.SetParam(param); + squeeze.Run(); + + for (int j = 0; j < out.dims().production(); ++j) { + EXPECT_NEAR(out_data[j], x_data[j], 1e-5); + } + } +} + +// squeeze2 +TEST(squeeze2_x86, retrive_op) { + auto squeeze2 = + KernelRegistry::Global().Create( + "squeeze2"); + ASSERT_FALSE(squeeze2.empty()); + ASSERT_TRUE(squeeze2.front()); +} + +TEST(squeeze2_x86, init) { + lite::kernels::x86::Squeeze2Compute squeeze2; + ASSERT_EQ(squeeze2.precision(), PRECISION(kFloat)); + ASSERT_EQ(squeeze2.target(), TARGET(kX86)); +} + +TEST(squeeze2_x86, run_test) { + lite::Tensor x; + lite::Tensor xshape; + lite::Tensor out; + std::vector x_shape({1, 3, 1, 5}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 5}); + out.Resize(lite::DDim(out_shape)); + std::vector xshape_shape({1, 3, 1, 5}); + xshape.Resize(lite::DDim(xshape_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + auto xshape_data = xshape.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + xshape_data[i] = static_cast(i); + } + + // Squeeze2Compute squeeze2; + Squeeze2Compute squeeze2; + operators::SqueezeParam param; + + param.X = &x; + param.Out = &out; + param.XShape = &xshape; + std::vector> ref_res({{3, 5}, {3, 5}}); + std::vector> axes({{0, -2}, {}}); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + for (int i = 0; i < 2; ++i) { + param.axes = axes[i]; + squeeze2.SetContext(std::move(ctx)); + squeeze2.SetParam(param); + squeeze2.Run(); + + for (int j = 0; j < out.dims().production(); ++j) { + EXPECT_NEAR(out_data[j], x_data[j], 1e-5); + } + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(squeeze, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(squeeze2, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/tanh_compute_test.cc b/lite/kernels/x86/tanh_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa65ca02df27642fc0114a075ad8a4249f3b70de --- /dev/null +++ b/lite/kernels/x86/tanh_compute_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/x86/activation_compute.cc" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(tanh_x86, retrive_op) { + auto tanh = + KernelRegistry::Global().Create("tanh"); + ASSERT_FALSE(tanh.empty()); + ASSERT_TRUE(tanh.front()); +} + +TEST(tanh_x86, init) { + TanhCompute tanh; + ASSERT_EQ(tanh.precision(), PRECISION(kFloat)); + ASSERT_EQ(tanh.target(), TARGET(kX86)); +} + +TEST(tanh_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector out_shape{batch_size, 3, 2, 2}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + int sign = i % 2 == 0 ? 1 : -1; + x_data[i] = static_cast(i * sign) * 0.08f; + } + // TanhCompute tanh; + TanhCompute tanh; + operators::ActivationParam param; + + param.X = &x; + param.Out = &out; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + tanh.SetContext(std::move(ctx)); + tanh.SetParam(param); + tanh.Run(); + + LOG(INFO) << "output: "; + std::vector ref_data{0., + -0.079829, + 0.158648, + -0.235495, + 0.309506, + -0.379949, + 0.446243, + -0.507977, + 0.564899, + -0.616909, + 0.664036, + -0.706419}; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(tanh, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/transpose_compute.cc b/lite/kernels/x86/transpose_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..58041a0cd39e9c985ccde6247b3df0002a39103d --- /dev/null +++ b/lite/kernels/x86/transpose_compute.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/transpose_compute.h" + +REGISTER_LITE_KERNEL(transpose, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::TransposeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(transpose2, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::Transpose2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/transpose_compute.h b/lite/kernels/x86/transpose_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..603b96015e267aa24d20bf20f2c3090a2daab74c --- /dev/null +++ b/lite/kernels/x86/transpose_compute.h @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/backends/x86/math/math_function.h" +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/operators/transpose_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +inline void TransCompute(const int dim, + const lite::Context& context, + const lite::Tensor& in, + lite::Tensor* out, + const std::vector& axis) { + switch (dim) { + case 1: + paddle::lite::x86::math::Transpose trans1; + trans1(context, in, out, axis); + break; + case 2: + paddle::lite::x86::math::Transpose trans2; + trans2(context, in, out, axis); + break; + case 3: + paddle::lite::x86::math::Transpose trans3; + trans3(context, in, out, axis); + break; + case 4: + paddle::lite::x86::math::Transpose trans4; + trans4(context, in, out, axis); + break; + case 5: + paddle::lite::x86::math::Transpose trans5; + trans5(context, in, out, axis); + break; + case 6: + paddle::lite::x86::math::Transpose trans6; + trans6(context, in, out, axis); + break; + default: + PADDLE_THROW("Tensors with rank at most 6 are supported"); + } +} + +template +class TransposeCompute : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto* x = param.x; + auto* out = param.output; + out->mutable_data(); + int ndims = param.axis.size(); + auto& context = ctx_->As(); + TransCompute( + ndims, context, *x, out, param.axis); + } + + virtual ~TransposeCompute() = default; +}; + +template +class Transpose2Compute : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto* x = param.x; + auto* out = param.output; + out->mutable_data(); + int ndims = param.axis.size(); + auto& context = ctx_->As(); + TransCompute( + ndims, context, *x, out, param.axis); + } + + virtual ~Transpose2Compute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/transpose_compute_test.cc b/lite/kernels/x86/transpose_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8533d98258637eba516974e03cd4d88fd452293 --- /dev/null +++ b/lite/kernels/x86/transpose_compute_test.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/transpose_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +// transpose +TEST(transpose_x86, retrive_op) { + auto transpose = + KernelRegistry::Global().Create( + "transpose"); + ASSERT_FALSE(transpose.empty()); + ASSERT_TRUE(transpose.front()); +} + +TEST(transpose_x86, init) { + lite::kernels::x86::TransposeCompute transpose; + ASSERT_EQ(transpose.precision(), PRECISION(kFloat)); + ASSERT_EQ(transpose.target(), TARGET(kX86)); +} + +TEST(transpose_x86, run_test) { + lite::Tensor x; + lite::Tensor out; + std::vector x_shape({3, 4, 5}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 5, 4}); + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + + // TransposeCompute transpose; + TransposeCompute transpose; + operators::TransposeParam param; + + param.x = &x; + param.output = &out; + std::vector axis({0, 2, 1}); + param.axis = axis; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + transpose.SetContext(std::move(ctx)); + transpose.SetParam(param); + transpose.Run(); + + for (int j = 0; j < out.dims().production(); ++j) { + // EXPECT_NEAR(out_data[j], x_data[j], 1e-5); + LOG(INFO) << out_data[j]; + } +} + +// transpose2 +TEST(transpose2_x86, retrive_op) { + auto transpose2 = + KernelRegistry::Global().Create( + "transpose2"); + ASSERT_FALSE(transpose2.empty()); + ASSERT_TRUE(transpose2.front()); +} + +TEST(transpose2_x86, init) { + lite::kernels::x86::Transpose2Compute transpose2; + ASSERT_EQ(transpose2.precision(), PRECISION(kFloat)); + ASSERT_EQ(transpose2.target(), TARGET(kX86)); +} + +TEST(transpose2_x86, run_test) { + lite::Tensor x; + lite::Tensor xshape; + lite::Tensor out; + std::vector x_shape({3, 4, 5}); + x.Resize(lite::DDim(x_shape)); + std::vector out_shape({3, 5, 4}); + out.Resize(lite::DDim(out_shape)); + std::vector xshape_shape({3, 4, 5}); + xshape.Resize(lite::DDim(xshape_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + auto xshape_data = xshape.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + xshape_data[i] = static_cast(i); + } + + // Transpose2Compute transpose2; + Transpose2Compute transpose2; + operators::TransposeParam param; + + param.x = &x; + param.output = &out; + param.xshape = &xshape; + std::vector axis({0, 2, 1}); + param.axis = axis; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + transpose2.SetContext(std::move(ctx)); + transpose2.SetParam(param); + transpose2.Run(); + + for (int j = 0; j < out.dims().production(); ++j) { + LOG(INFO) << out_data[j]; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(transpose, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(transpose2, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..72c48ceab079bc65e4f2363a1702de52586733d6 --- /dev/null +++ b/lite/kernels/xpu/CMakeLists.txt @@ -0,0 +1,9 @@ + +if(NOT LITE_WITH_XPU) + return () +endif() + +add_kernel(graph_compute_xpu XPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} xpu_runtime) +# lite_cc_test(test_graph_compute_xpu SRCS graph_compute_test.cc DEPS graph_compute_xpu) + +add_subdirectory(bridges) diff --git a/lite/kernels/xpu/bridges/CMakeLists.txt b/lite/kernels/xpu/bridges/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..47724728dfdb270ae5beb85852af6037735fda71 --- /dev/null +++ b/lite/kernels/xpu/bridges/CMakeLists.txt @@ -0,0 +1,32 @@ +lite_cc_library(xpu_bridge_registry SRCS registry.cc) + +set(xpu_bridge_deps xpu_bridge_registry xpu_builder op) + +lite_cc_library(xpu_bridge_act_op SRCS act_op.cc DEPS ${xpu_bridge_deps}) +lite_cc_library(xpu_bridge_conv_op SRCS conv_op.cc DEPS ${xpu_bridge_deps}) +lite_cc_library(xpu_bridge_elementwise_ops SRCS elementwise_ops.cc DEPS ${xpu_bridge_deps}) +lite_cc_library(xpu_bridge_pool_op SRCS pool_op.cc DEPS ${xpu_bridge_deps}) +lite_cc_library(xpu_bridge_softmax_op SRCS softmax_op.cc DEPS ${xpu_bridge_deps}) +lite_cc_library(xpu_bridge_mul_op SRCS mul_op.cc DEPS ${xpu_bridge_deps}) +lite_cc_library(xpu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${xpu_bridge_deps}) + +set(xpu_bridges + xpu_bridge_registry + xpu_bridge_act_op + xpu_bridge_conv_op + xpu_bridge_elementwise_ops + xpu_bridge_pool_op + xpu_bridge_softmax_op + xpu_bridge_mul_op + xpu_bridge_batch_norm_op + CACHE INTERNAL "xpu_bridges") + +set(xpu_bridge_test_deps ${xpu_bridges} ${xpu_kernels} ${ops}) + +lite_cc_test(test_xpu_bridge_act_op SRCS act_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) +lite_cc_test(test_xpu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) +lite_cc_test(test_xpu_bridge_elementwise_ops SRCS elementwise_ops_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) +lite_cc_test(test_xpu_bridge_pool_op SRCS pool_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) +lite_cc_test(test_xpu_bridge_softmax_op SRCS softmax_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) +lite_cc_test(test_xpu_bridge_mul_op SRCS mul_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) +lite_cc_test(test_xpu_bridge_batch_norm_op SRCS batch_norm_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps}) diff --git a/lite/kernels/xpu/bridges/act_op.cc b/lite/kernels/xpu/bridges/act_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8e11caa96fdbff3a853a192a8d16f2eccd96337 --- /dev/null +++ b/lite/kernels/xpu/bridges/act_op.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type ActConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " + op_type + "..."; + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // create act node and set params from op + auto x_var_name = op_info->Input("X").front(); + CHECK(input_nodes.count(x_var_name)); + std::shared_ptr act_node = nullptr; + if (op_type == "relu") { + act_node = std::make_shared( + graph_ctx->builder->CreateRelu(*input_nodes.at(x_var_name))); + } else { + // TODO(hong19860320) supports more activation ops + LOG(FATAL) << "[XPU] Unsupported activation type " << op_type; + } + graph_ctx->builder->SetLayer(unique_op_type); + + // output converted nodes + node_map_type output_nodes; + output_nodes[op_info->Output("Out").front()] = act_node; + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(relu, paddle::lite::kernels::xpu::bridges::ActConverter); diff --git a/lite/kernels/xpu/bridges/act_op_test.cc b/lite/kernels/xpu/bridges/act_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1a3efab46e3c7caee08bf646a560a0ab9abcf5c7 --- /dev/null +++ b/lite/kernels/xpu/bridges/act_op_test.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" +#include "lite/operators/activation_ops.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +void relu_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + auto x_data = x->data(); + auto out_data = out->mutable_data(); + DDim x_dims = x->dims(); + DDim out_dims = out->dims(); + CHECK_EQ(x_dims.production(), out_dims.production()); + for (int i = 0; i < out_dims.production(); i++) { + out_data[i] = std::max(0.f, x_data[i]); + } +} + +void test_relu(int bs, int ic, int ih, int iw) { + // prepare input&output variables + Scope scope; + std::string x_var_name("x"); + std::string out_var_name("out"); + std::string out_ref_var_name("out_ref"); + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("relu"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + + // create and convert op to XPU model, and run it on XPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + // execute reference implementation and save to output tensor + relu_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(NPUBridges, relu) { + for (auto bs : {1, 3}) { + for (auto ic : {3, 4}) { + for (auto ih : {2, 5}) { + for (auto iw : {5, 9}) { + VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih + << " iw: " << iw; + test_relu(bs, ic, ih, iw); + } + } + } + } +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(relu); +USE_XPU_BRIDGE(relu); diff --git a/lite/kernels/xpu/bridges/batch_norm_op.cc b/lite/kernels/xpu/bridges/batch_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c46b7878cdbb6987a11215d4dfcb80a2672aad2 --- /dev/null +++ b/lite/kernels/xpu/bridges/batch_norm_op.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type BatchNormConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " + op_type + "..."; + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // get input, and attributes + auto x_var_name = op_info->Input("X").front(); + auto scale_var_name = op_info->Input("Scale").front(); + auto* scale = scope->FindMutableTensor(scale_var_name); + auto bias_var_name = op_info->Input("Bias").front(); + auto* bias = scope->FindMutableTensor(bias_var_name); + auto mean_var_name = op_info->Input("Mean").front(); + auto* mean = scope->FindMutableTensor(mean_var_name); + auto variance_var_name = op_info->Input("Variance").front(); + auto* variance = scope->FindMutableTensor(variance_var_name); + auto epsilon = op_info->GetAttr("epsilon"); + + // create scale node + CHECK(!input_nodes.count(scale_var_name)); + auto scale_const_node = std::make_shared( + graph_ctx->builder->CreateTensor(scale_var_name, + lite::xpu::CvtShape(scale->dims()), + ::xtcl::Float(32))); + auto scale_const_tensor = lite::xpu::CvtTensor(scale); + graph_ctx->params->emplace( + std::make_pair(scale_var_name, *scale_const_tensor)); + + // create bias node + CHECK(!input_nodes.count(bias_var_name)); + auto bias_const_node = + std::make_shared(graph_ctx->builder->CreateTensor( + bias_var_name, lite::xpu::CvtShape(bias->dims()), ::xtcl::Float(32))); + auto bias_const_tensor = lite::xpu::CvtTensor(bias); + graph_ctx->params->emplace(std::make_pair(bias_var_name, *bias_const_tensor)); + + // create mean node + CHECK(!input_nodes.count(mean_var_name)); + auto mean_const_node = + std::make_shared(graph_ctx->builder->CreateTensor( + mean_var_name, lite::xpu::CvtShape(mean->dims()), ::xtcl::Float(32))); + auto mean_const_tensor = lite::xpu::CvtTensor(mean); + graph_ctx->params->emplace(std::make_pair(mean_var_name, *mean_const_tensor)); + + // create variance node + CHECK(!input_nodes.count(variance_var_name)); + auto variance_const_node = std::make_shared( + graph_ctx->builder->CreateTensor(variance_var_name, + lite::xpu::CvtShape(variance->dims()), + ::xtcl::Float(32))); + auto variance_const_tensor = lite::xpu::CvtTensor(variance); + graph_ctx->params->emplace( + std::make_pair(variance_var_name, *variance_const_tensor)); + + // create batch_norm node and set params from op + CHECK(input_nodes.count(x_var_name)); + auto batch_norm_node = std::make_shared( + graph_ctx->builder->CreateBatchNorm(*input_nodes.at(x_var_name), + *scale_const_node, + *bias_const_node, + *mean_const_node, + *variance_const_node, + 1, + epsilon)); + batch_norm_node = std::make_shared( + graph_ctx->builder->GetField(*batch_norm_node, 0)); + graph_ctx->builder->SetLayer(unique_op_type); + + // output converted nodes + node_map_type output_nodes; + output_nodes[op_info->Output("Y").front()] = batch_norm_node; + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(batch_norm, + paddle::lite::kernels::xpu::bridges::BatchNormConverter); diff --git a/lite/kernels/xpu/bridges/batch_norm_op_test.cc b/lite/kernels/xpu/bridges/batch_norm_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dec475530a5bb5c692946bc8d185ea81990a6408 --- /dev/null +++ b/lite/kernels/xpu/bridges/batch_norm_op_test.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/batch_norm_op.h" +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +template +void batch_norm_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto y = scope->FindVar(op_info->Output("Y").front())->GetMutable(); + auto bias = + scope->FindVar(op_info->Input("Bias").front())->GetMutable(); + auto scale = + scope->FindVar(op_info->Input("Scale").front())->GetMutable(); + auto mean = + scope->FindVar(op_info->Input("Mean").front())->GetMutable(); + auto variance = + scope->FindVar(op_info->Input("Variance").front())->GetMutable(); + + auto x_data = x->data(); + auto y_data = y->mutable_data(); + auto scale_data = scale->mutable_data(); + auto bias_data = bias->mutable_data(); + auto mean_data = mean->mutable_data(); + auto variance_data = variance->mutable_data(); + DDim x_dims = x->dims(); + + float epsilon = op_info->GetAttr("epsilon"); + auto data_layout = op_info->GetAttr("data_layout"); + + bool global_stats = op_info->GetAttr("use_global_stats"); + if (global_stats) { + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + if (data_layout == "NCHW") { + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + } else { + LOG(FATAL) << "Unknown storage order: " << data_layout; + } + auto x_ptr = x_data; + auto y_ptr = y_data; + for (int o = 0; o < outer_size; o++) { + for (int c = 0; c < channel_size; c++) { + for (int i = 0; i < inner_size; i++) { + dtype norm_x = + (*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon); + *y_ptr = norm_x * scale_data[c] + bias_data[c]; + x_ptr++; + y_ptr++; + } + } + } + } +} + +void test_batch_norm(int bs, int ic, int ih, int iw, float epsilon) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + std::string scale_var_name = "scale"; + std::string bias_var_name = "bias"; + std::string mean_var_name = "mean"; + std::string variance_var_name = "variance"; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* scale = scope.Var(scale_var_name)->GetMutable(); + auto* bias = scope.Var(bias_var_name)->GetMutable(); + auto* mean = scope.Var(mean_var_name)->GetMutable(); + auto* variance = scope.Var(variance_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + scale->Resize({ic}); + bias->Resize({ic}); + mean->Resize({ic}); + variance->Resize({ic}); + + // initialize input&output data + FillTensor(x); + FillTensor(scale); + FillTensor(bias); + FillTensor(mean); + // variance > 0 + FillTensor(variance, 1.f, 5.f); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("batch_norm"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetInput("Scale", {scale_var_name}); + opdesc.SetInput("Bias", {bias_var_name}); + opdesc.SetInput("Mean", {mean_var_name}); + opdesc.SetInput("Variance", {variance_var_name}); + opdesc.SetOutput("Y", {out_var_name}); + opdesc.SetAttr("is_test", 1); + opdesc.SetAttr("use_global_stats", true); + opdesc.SetAttr("epsilon", epsilon); + opdesc.SetAttr("momentum", 0.9f); + opdesc.SetAttr("data_layout", std::string("NCHW")); + + // create and convert op to XPU model, then run it on XPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + // execute reference implementation and save to output tensor + batch_norm_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(NPUBridges, batch_norm) { + for (auto bs : {1, 3}) { + for (auto ic : {2, 3}) { + for (auto ih : {4}) { + for (auto iw : {5}) { + for (auto epsilon : {1e-5f}) { + test_batch_norm(bs, ic, ih, iw, epsilon); + } + } + } + } + } +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(batch_norm); +USE_XPU_BRIDGE(batch_norm); diff --git a/lite/kernels/xpu/bridges/conv_op.cc b/lite/kernels/xpu/bridges/conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c758cf9507087fb53d476ff86a64707e0c6249b --- /dev/null +++ b/lite/kernels/xpu/bridges/conv_op.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type ConvConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " << op_type << "... "; + + // get input, filter and op attributes + auto input_var_name = op_info->Input("Input").front(); + auto input = scope->FindVar(input_var_name)->GetMutable(); + auto input_dims = input->dims(); + auto filter_var_name = op_info->Input("Filter").front(); + auto filter = scope->FindVar(filter_var_name)->GetMutable(); + auto filter_dims = filter->dims(); + auto bs = input_dims[0]; + auto oc = filter_dims[0]; + CHECK_EQ(input_dims.size(), 4); + CHECK_EQ(filter_dims.size(), 4); + auto strides = op_info->GetAttr>("strides"); + auto paddings = op_info->GetAttr>("paddings"); + auto groups = op_info->GetAttr("groups"); + auto dilations = op_info->GetAttr>("dilations"); + auto fuse_relu = op_info->GetAttr("fuse_relu"); + CHECK_EQ(strides.size(), 2); + CHECK_EQ(paddings.size(), 2); + CHECK_EQ(dilations.size(), 2); + std::vector output_shape({bs, oc}); + for (size_t i = 0; i < 2; i++) { + const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1; + output_shape.push_back( + (input_dims[i + 2] + 2 * paddings[i] - dkernel) / strides[i] + 1); + } + DDim output_dims(output_shape); + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // create filter node + CHECK(!input_nodes.count(filter_var_name)); + auto filter_const_node = std::make_shared( + graph_ctx->builder->CreateTensor(filter_var_name, + lite::xpu::CvtShape(filter_dims), + ::xtcl::Float(32))); + auto filter_const_tensor = lite::xpu::CvtTensor(filter); + graph_ctx->params->emplace( + std::make_pair(filter_var_name, *filter_const_tensor)); + + // create conv node and set input, filter, bias nodes and attributes + auto conv_attrs = xtcl::make_node(); + conv_attrs->strides = std::move(lite::xpu::CvtShape(strides)); + conv_attrs->padding = std::move(lite::xpu::CvtShape(paddings)); + conv_attrs->dilation = std::move(lite::xpu::CvtShape(dilations)); + conv_attrs->groups = groups; + // conv_attrs->channels = nullptr; + conv_attrs->kernel_size = std::move(xtcl::Array(nullptr)); + conv_attrs->data_layout = "NCHW"; + conv_attrs->kernel_layout = "OIHW"; + conv_attrs->out_layout = ""; + // conv_attrs->out_dtype = ""; + CHECK(input_nodes.count(input_var_name)); + auto conv_node = + std::make_shared(graph_ctx->builder->CreateConv2D( + *input_nodes.at(input_var_name), *filter_const_node, conv_attrs)); + graph_ctx->builder->SetLayer(unique_op_type); + + // create bias node if has bias + // supports the bias nodes with the following dimensions + // 0: {oc} + // 1: {1, oc, oh, ow} + // 2: {n, oc, oh, ow} + if (lite::xpu::HasInputArg(op_info, scope, "Bias")) { + auto bias_var_name = op_info->Input("Bias").front(); + auto* bias = scope->FindVar(bias_var_name)->GetMutable(); + auto bias_dims = bias->dims(); + auto bias_data_size = bias_dims.production(); + auto output_data_size = output_dims.production(); + std::vector bias_shape; + bool is_channel_bias = false; + if (bias_data_size == oc) { + // 0: {oc} + bias_shape = {oc}; + is_channel_bias = true; + } else if (bias_data_size == output_data_size / bs) { + // 1: {1, oc, oh, ow} + bias_shape = {1, output_dims[1], output_dims[2], output_dims[3]}; + } else if (bias_data_size == output_data_size) { + // 2: {n, oc, oh, ow} + bias_shape = output_dims.Vectorize(); + } else { + LOG(ERROR) << "bias dimension " << bias_dims + << " isn't supported in conv2d Op when output dimension is " + << output_dims; + } + std::shared_ptr bias_node = nullptr; + if (input_nodes.count(bias_var_name)) { + // bias node from input node + bias_node = input_nodes.at(bias_var_name); + } else { + // bias node with const tensor + auto bias_const_node = std::make_shared( + graph_ctx->builder->CreateTensor(bias_var_name, + lite::xpu::CvtShape(bias_shape), + ::xtcl::Float(32))); + auto bias_const_tensor = lite::xpu::CvtTensor(bias, bias_shape); + graph_ctx->params->emplace( + std::make_pair(bias_var_name, *bias_const_tensor)); + bias_node = bias_const_node; + } + std::shared_ptr add_node = nullptr; + if (is_channel_bias) { + add_node = std::make_shared( + graph_ctx->builder->CreateBiasAdd(*conv_node, 1, *bias_node)); + } else { + add_node = std::make_shared( + graph_ctx->builder->CreateBinaryOp("add", *conv_node, *bias_node)); + } + graph_ctx->builder->SetLayer(unique_op_type + "/add"); + conv_node = add_node; + } + + // output converted nodes + node_map_type output_nodes; + if (fuse_relu) { + // append relu node if fuse_relu is true + auto relu_node = std::make_shared( + graph_ctx->builder->CreateRelu(*conv_node)); + graph_ctx->builder->SetLayer(unique_op_type + "/relu"); + output_nodes[op_info->Output("Output").front()] = relu_node; + } else { + output_nodes[op_info->Output("Output").front()] = conv_node; + } + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(conv2d, paddle::lite::kernels::xpu::bridges::ConvConverter); +REGISTER_XPU_BRIDGE(depthwise_conv2d, + paddle::lite::kernels::xpu::bridges::ConvConverter); diff --git a/lite/kernels/xpu/bridges/conv_op_test.cc b/lite/kernels/xpu/bridges/conv_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebdb67bd0d2801a9036696f52790f7104279b0cb --- /dev/null +++ b/lite/kernels/xpu/bridges/conv_op_test.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/conv_op.h" +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +void conv_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto input = + scope->FindVar(op_info->Input("Input").front())->GetMutable(); + auto filter = + scope->FindVar(op_info->Input("Filter").front())->GetMutable(); + auto output = + scope->FindVar(op_info->Output("Output").front())->GetMutable(); + std::vector strides = + op_info->GetAttr>("strides"); + std::vector paddings = + op_info->GetAttr>("paddings"); + int32_t groups = op_info->GetAttr("groups"); + std::vector dilations = + op_info->GetAttr>("dilations"); + bool fuse_relu = op_info->GetAttr("fuse_relu"); + auto input_dims = input->dims(); + auto filter_dims = filter->dims(); + auto output_dims = output->dims(); + auto input_data = input->mutable_data(); + auto filter_data = filter->mutable_data(); + auto output_data = output->mutable_data(); + int kernel_w = filter_dims[3]; + int kernel_h = filter_dims[2]; + int stride_w = strides[1]; + int stride_h = strides[0]; + int dila_w = dilations[1]; + int dila_h = dilations[0]; + int pad_w = paddings[1]; + int pad_h = paddings[0]; + int batch_size = input_dims[0]; + int in_ch_size = input_dims[1]; + int in_h = input_dims[2]; + int in_w = input_dims[3]; + int out_ch_size = output_dims[1]; + int out_h = output_dims[2]; + int out_w = output_dims[3]; + int out_c_group = out_ch_size / groups; + int in_c_group = in_ch_size / groups; + Tensor* bias = nullptr; + float* bias_data = nullptr; + bool is_channel_bias = false; + if (op_info->HasInput("Bias")) { + auto bias_var_names = op_info->Input("Bias"); + if (bias_var_names.size() > 0) { + auto bias_var_name = bias_var_names.front(); + bias = scope->FindVar(bias_var_name)->GetMutable(); + auto bias_dims = bias->dims(); + is_channel_bias = bias_dims.production() == out_ch_size; + bias_data = bias->mutable_data(); + } + } + for (int n = 0; n < batch_size; ++n) { + for (int g = 0; g < groups; ++g) { + for (int oc = 0; oc < out_c_group; ++oc) { + for (int oh = 0; oh < out_h; ++oh) { + for (int ow = 0; ow < out_w; ++ow) { + int out_idx = n * groups * out_c_group * out_h * out_w + + g * out_c_group * out_h * out_w + oc * out_h * out_w + + oh * out_w + ow; + float out_value = + bias_data != nullptr + ? (is_channel_bias ? bias_data[g * out_c_group + oc] + : bias_data[out_idx]) + : 0; + // + out_value *= beta; + for (int ic = 0; ic < in_c_group; ++ic) { + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int iw = ow * stride_w - pad_w + kw * (dila_w); + int ih = oh * stride_h - pad_h + kh * (dila_h); + if (iw < 0 || iw >= in_w) continue; + if (ih < 0 || ih >= in_h) continue; + int in_idx = n * in_ch_size * in_h * in_w + + g * in_c_group * in_h * in_w + ic * in_h * in_w + + ih * in_w + iw; + int filter_idx = + g * out_c_group * in_c_group * kernel_h * kernel_w + + oc * in_c_group * kernel_h * kernel_w + + ic * kernel_h * kernel_w + kh * kernel_w + kw; + out_value += input_data[in_idx] * filter_data[filter_idx]; + } + } + } + if (fuse_relu) { + out_value = out_value > 0 ? out_value : 0; + } + output_data[out_idx] = out_value; + } + } + } + } + } +} + +void test_conv(int bs, + int ic, + int oc, + int ih, + int iw, + bool has_bias, + bool is_channel_bias, + bool fuse_relu, + bool depthwise, + int dilation, + int stride, + int padding, + int kernel) { + // prepare input&output variables + Scope scope; + std::string input_var_name("input"); + std::string filter_var_name("filter"); + std::string bias_var_name("bias"); + std::string output_var_name("output"); + std::string output_ref_var_name("output_ref"); + auto* input = scope.Var(input_var_name)->GetMutable(); + auto* filter = scope.Var(filter_var_name)->GetMutable(); + auto* bias = scope.Var(bias_var_name)->GetMutable(); + auto* output = scope.Var(output_var_name)->GetMutable(); + auto* output_ref = scope.Var(output_ref_var_name)->GetMutable(); + + // get group size and input&filter shape + int groups = 1; + if (depthwise) { // depthwise convolution ? + groups = oc = ic; + } + std::vector input_shape = {bs, ic, ih, iw}; + std::vector filter_shape = {oc, ic / groups, kernel, kernel}; + std::vector output_shape({bs, oc}); + for (size_t i = 0; i < 2; i++) { + const int dkernel = dilation * (kernel - 1) + 1; + int output_size = (input_shape[i + 2] + 2 * padding - dkernel) / stride + 1; + output_shape.push_back(output_size); + } + input->Resize(input_shape); + filter->Resize(filter_shape); + + // initialize input&output data + FillTensor(input); + FillTensor(filter); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType(depthwise ? "depthwise_conv2d" : "conv2d"); + opdesc.SetInput("Input", {input_var_name}); + opdesc.SetInput("Filter", {filter_var_name}); + opdesc.SetOutput("Output", {output_var_name}); + opdesc.SetAttr("dilations", std::vector({dilation, dilation})); + opdesc.SetAttr("strides", std::vector({stride, stride})); + opdesc.SetAttr("paddings", std::vector({padding, padding})); + opdesc.SetAttr("groups", groups); + opdesc.SetAttr("fuse_relu", static_cast(fuse_relu)); + if (has_bias) { + if (is_channel_bias) { + bias->Resize({1, oc, 1, 1}); + } else { + bias->Resize({1, output_shape[1], output_shape[2], output_shape[3]}); + } + FillTensor(bias); + opdesc.SetInput("Bias", {bias_var_name}); + } + + // create and convert op to NPU model, then run it on NPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {input_var_name}, {output_var_name}); + output_ref->CopyDataFrom(*output); + + // execute reference implementation and save to output tensor('out') + conv_ref(op); + + // compare results + auto* output_data = output->mutable_data(); + auto* output_ref_data = output_ref->mutable_data(); + for (int i = 0; i < output->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } +} + +TEST(NPUBridges, conv) { +#if 0 + for (auto bs : {1, 2}) { + for (auto ic : {3, 6}) { + for (auto oc : {6, 9}) { + for (auto ih : {14, 28}) { + for (auto iw : {14, 28}) { + for (auto has_bias : {false, true}) { + for (auto is_channel_bias : {false, true}) { + for (auto fuse_relu : {false, true}) { + for (auto depthwise : {false, true}) { + for (auto dilation : {1, 2}) { + for (auto stride : {1, 2}) { + for (auto kernel : {1, 3, 5}) { + std::vector paddings = {kernel / 2}; + if (kernel / 2 != 0) { + paddings.push_back(0); + } + for (auto padding : paddings) { + VLOG(3) << "bs: " << bs << " ic: " << ic + << " oc: " << oc << " ih: " << ih + << " iw: " << iw + << " has_bias: " << has_bias + << " is_channel_bias: " << is_channel_bias + << " fuse_relu: " << fuse_relu + << " depthwise: " << depthwise + << " dilation: " << dilation + << " stride: " << stride + << " padding: " << padding + << " kernel: " << kernel; + test_conv(bs, + ic, + oc, + ih, + iw, + has_bias, + is_channel_bias, + fuse_relu, + depthwise, + dilation, + stride, + padding, + kernel); + } + } + } + } + } + } + } + } + } + } + } + } + } +#else + test_conv(1, 1, 1, 4, 4, false, false, false, false, 1, 1, 1, 3); + test_conv(1, 1, 1, 4, 4, true, true, false, false, 1, 1, 1, 3); + test_conv(1, 1, 1, 4, 4, true, false, false, false, 1, 1, 1, 3); +#endif +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(conv2d); +USE_XPU_BRIDGE(conv2d); + +USE_LITE_OP(depthwise_conv2d); +USE_XPU_BRIDGE(depthwise_conv2d); diff --git a/lite/kernels/xpu/bridges/elementwise_ops.cc b/lite/kernels/xpu/bridges/elementwise_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9fe7db14d2dfd00a7e74c77d2fe3b84e9593f72 --- /dev/null +++ b/lite/kernels/xpu/bridges/elementwise_ops.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type ElementwiseConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " + op_type + "..."; + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // get input, and attributes + auto x_var_name = op_info->Input("X").front(); + auto y_var_name = op_info->Input("Y").front(); + auto axis = op_info->GetAttr("axis"); + auto x_tensor = scope->FindMutableTensor(x_var_name); + auto y_tensor = scope->FindMutableTensor(y_var_name); + auto x_dims = x_tensor->dims(); + auto y_dims = y_tensor->dims(); + + // create x and y node + std::shared_ptr x_node = nullptr; + if (input_nodes.count(x_var_name)) { + x_node = input_nodes.at(x_var_name); + } else { + x_node = std::make_shared(graph_ctx->builder->CreateTensor( + x_var_name, lite::xpu::CvtShape(x_dims), ::xtcl::Float(32))); + auto x_const_tensor = lite::xpu::CvtTensor(x_tensor); + graph_ctx->params->emplace(std::make_pair(x_var_name, *x_const_tensor)); + } + + std::shared_ptr y_node = nullptr; + if (input_nodes.count(y_var_name)) { + y_node = input_nodes.at(y_var_name); + } else { + y_node = std::make_shared(graph_ctx->builder->CreateTensor( + y_var_name, lite::xpu::CvtShape(y_dims), ::xtcl::Float(32))); + auto y_const_tensor = lite::xpu::CvtTensor(y_tensor); + graph_ctx->params->emplace(std::make_pair(y_var_name, *y_const_tensor)); + } + + // create elementwise node and set input, attributes + std::shared_ptr elementwise_node = nullptr; + if (y_dims.size() == 1) { + elementwise_node = std::make_shared( + graph_ctx->builder->CreateBiasAdd(*x_node, axis, *y_node)); + } else if (x_dims.size() == y_dims.size()) { + elementwise_node = std::make_shared( + graph_ctx->builder->CreateBinaryOp("add", *x_node, *y_node)); + } else { + LOG(ERROR) << "XPU elementwise_add only support y of one dimension, or x " + "and y of the same dimension. But recieved x's dimension: " + << x_dims << ", y's dimension: " << y_dims << ", axis: " << axis; + } + graph_ctx->builder->SetLayer(unique_op_type); + + // output converted nodes + node_map_type output_nodes; + output_nodes[op_info->Output("Out").front()] = elementwise_node; + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(elementwise_add, + paddle::lite::kernels::xpu::bridges::ElementwiseConverter); diff --git a/lite/kernels/xpu/bridges/elementwise_ops_test.cc b/lite/kernels/xpu/bridges/elementwise_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2abda822e3ae380ad376e92db99b5ad204a2a2a4 --- /dev/null +++ b/lite/kernels/xpu/bridges/elementwise_ops_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/elementwise_ops.h" +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +template +void elementwise_add_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto y = scope->FindVar(op_info->Input("Y").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + + auto x_data = x->data(); + auto y_data = y->data(); + dtype* out_data = out->mutable_data(); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + int axis = op_info->GetAttr("axis"); + + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + // do elementwise add/sub/max... + std::string elt_type = "add"; + if (elt_type == "add") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr + diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } else if (elt_type == "sub") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr - diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } else if (elt_type == "mul") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr * diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } else if (elt_type == "max") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = std::max(*din_ptr, diny_data); + dout_ptr++; + din_ptr++; + } + } + } + } else { + LOG(FATAL) << "unsupported Elementwise type: " << elt_type; + } +} + +void test_elementwise_add(std::vector x_dims, + std::vector y_dims, + int axis) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string y_var_name = "y"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* y = scope.Var(y_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize(x_dims); + if (y_dims.size() == 0) { + y->Resize(x_dims); + } else { + y->Resize(y_dims); + } + + // initialize input&output data + FillTensor(x); + FillTensor(y); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("elementwise_add"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetInput("Y", {y_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("axis", axis); + + // create and convert op to XPU model, then run it on XPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name, y_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + // execute reference implementation and save to output tensor + elementwise_add_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +// xpu's bias_add only support y with one dimension +TEST(XPUBridges, elementwise_add) { + test_elementwise_add({1, 2, 3, 4}, {1}, 0); + test_elementwise_add({1, 2, 3, 4}, {2}, 1); + test_elementwise_add({2, 2, 3, 4}, {3}, 2); + test_elementwise_add({2, 2, 3, 4}, {4}, 3); + test_elementwise_add({2, 2, 3, 4}, {4}, -1); + test_elementwise_add({2, 2, 3, 4}, {}, -1); +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_XPU_BRIDGE(elementwise_add); diff --git a/lite/kernels/xpu/bridges/mul_op.cc b/lite/kernels/xpu/bridges/mul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..549abd3b1370a0fb90b4e9f4606ab15b3f9444ba --- /dev/null +++ b/lite/kernels/xpu/bridges/mul_op.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type MulConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " + op_type + "..."; + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // get input, and attributes + auto x_var_name = op_info->Input("X").front(); + auto y_var_name = op_info->Input("Y").front(); + auto y_tensor = scope->FindMutableTensor(y_var_name); + auto y_dims = y_tensor->dims(); + CHECK_EQ(y_dims.size(), 2) << "xpu now only support y_dims.size() == 2"; + + auto x_num_col_dims = op_info->GetAttr("x_num_col_dims"); + CHECK_EQ(x_num_col_dims, 1) << "xpu now only support x_num_col_dims == 1"; + auto y_num_col_dims = op_info->GetAttr("x_num_col_dims"); + CHECK_EQ(y_num_col_dims, 1) << "xpu now only support y_num_col_dims == 1"; + + // create x node + std::shared_ptr x_node = nullptr; + x_node = std::make_shared( + graph_ctx->builder->CreateBatchFlatten(*input_nodes.at(x_var_name))); + graph_ctx->builder->SetLayer(unique_op_type + "/X"); + + // transpose y + DDimLite y_dims_t(std::vector{1, 1}); + y_dims_t[0] = y_dims[1]; + y_dims_t[1] = y_dims[0]; + auto y_var_name_t = unique_op_type + "/Y"; + Tensor* y_tensor_t = new Tensor(); + y_tensor_t->Resize(y_dims_t); + auto y_data_t = y_tensor_t->mutable_data(); + auto y_data = y_tensor->mutable_data(); + for (int i = 0; i < y_dims_t[0]; i++) { + for (int j = 0; j < y_dims_t[1]; j++) { + y_data_t[i * y_dims_t[1] + j] = y_data[j * y_dims_t[0] + i]; + } + } + + // create y node + std::shared_ptr y_const_node = nullptr; + y_const_node = std::make_shared(graph_ctx->builder->CreateTensor( + y_var_name_t, lite::xpu::CvtShape(y_dims_t), ::xtcl::Float(32))); + auto y_const_tensor = lite::xpu::CvtTensor(y_tensor_t); + graph_ctx->params->emplace(std::make_pair(y_var_name_t, *y_const_tensor)); + delete y_tensor_t; + + // create mul node and set params from op + std::shared_ptr mul_node = nullptr; + mul_node = std::make_shared( + graph_ctx->builder->CreateDense(*x_node, + static_cast(y_dims[1]), + ::xtcl::NullValue<::xtcl::DataType>(), + *y_const_node)); + graph_ctx->builder->SetLayer(unique_op_type); + + // output converted nodes + node_map_type output_nodes; + output_nodes[op_info->Output("Out").front()] = mul_node; + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(mul, paddle::lite::kernels::xpu::bridges::MulConverter); diff --git a/lite/kernels/xpu/bridges/mul_op_test.cc b/lite/kernels/xpu/bridges/mul_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd439b68cb7286a919a8fce97371443f53ed40db --- /dev/null +++ b/lite/kernels/xpu/bridges/mul_op_test.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/mul_op.h" +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +void mul_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto y = scope->FindVar(op_info->Input("Y").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + int32_t x_num_col_dims = op_info->GetAttr("x_num_col_dims"); + int32_t y_num_col_dims = op_info->GetAttr("y_num_col_dims"); + auto x_data = x->mutable_data(); + auto y_data = y->mutable_data(); + auto out_data = out->mutable_data(); + auto x_mat_dims = x->dims().Flatten2D(x_num_col_dims); + auto y_mat_dims = y->dims().Flatten2D(y_num_col_dims); + CHECK_EQ(x_mat_dims[1], y_mat_dims[0]); + const int M = x_mat_dims[0]; + const int K = x_mat_dims[1]; + const int N = y_mat_dims[1]; + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + out_data[m * N + n] = 0; + for (int k = 0; k < K; ++k) { + out_data[m * N + n] += x_data[m * K + k] * y_data[k * N + n]; + } + } + } +} + +void test_mul(const std::vector& x_shape, + const std::vector& y_shape, + int x_num_col_dims, + int y_num_col_dims) { + Scope scope; + std::string x_var_name("X"); + std::string y_var_name("Y"); + std::string out_var_name("Out"); + std::string out_ref_var_name("out_ref"); + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* y = scope.Var(y_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize(x_shape); + y->Resize(y_shape); + + FillTensor(x); + FillTensor(y); + + // create mul op + cpp::OpDesc mul_op_desc; + mul_op_desc.SetType("mul"); + mul_op_desc.SetInput("X", {x_var_name}); + mul_op_desc.SetInput("Y", {y_var_name}); + mul_op_desc.SetOutput("Out", {out_var_name}); + mul_op_desc.SetAttr("x_num_col_dims", static_cast(x_num_col_dims)); + mul_op_desc.SetAttr("y_num_col_dims", static_cast(y_num_col_dims)); + + auto mul_op = CreateOp(mul_op_desc, &scope); + LauchOp(mul_op, {x_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + mul_ref(mul_op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(XPUBridges, mul) { + test_mul({1, 2, 3, 4}, {24, 2}, 1, 1); + test_mul({2, 2, 3, 4}, {24, 2}, 1, 1); + test_mul({2, 7}, {7, 3}, 1, 1); + // test_mul({1, 8, 8, 1}, {1, 8, 2, 2}, 2, 2); + // test_mul({1, 5, 5, 1}, {1, 5, 7, 7}, 2, 2); + // test_mul({1, 4, 1, 1}, {4, 8}, 1, 1); +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_XPU_BRIDGE(mul); diff --git a/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h b/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h new file mode 100644 index 0000000000000000000000000000000000000000..3c76e0e8b5cf0842cb8d5a613cef7aee3cd13bdb --- /dev/null +++ b/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/kernels/xpu/bridges/registry.h" + +USE_XPU_BRIDGE(relu); +USE_XPU_BRIDGE(conv2d); +USE_XPU_BRIDGE(depthwise_conv2d); +USE_XPU_BRIDGE(elementwise_add); +USE_XPU_BRIDGE(pool2d); +USE_XPU_BRIDGE(softmax); +USE_XPU_BRIDGE(mul); +USE_XPU_BRIDGE(batch_norm); diff --git a/lite/kernels/xpu/bridges/pool_op.cc b/lite/kernels/xpu/bridges/pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fbc6a9919c446508afa5a3b8a1c35352f9b8ecfa --- /dev/null +++ b/lite/kernels/xpu/bridges/pool_op.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type PoolConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " + op_type + "..."; + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // get input, and attributes + auto x_var_name = op_info->Input("X").front(); + auto pooling_type = op_info->GetAttr("pooling_type"); + auto ceil_mode = op_info->GetAttr("ceil_mode"); + auto paddings = op_info->GetAttr>("paddings"); + auto global_pooling = op_info->GetAttr("global_pooling"); + auto ksize = op_info->GetAttr>("ksize"); + auto strides = op_info->GetAttr>("strides"); + auto exclusive = op_info->GetAttr("exclusive"); + + // create pool node and set params from op + CHECK(input_nodes.count(x_var_name)); + std::shared_ptr pool_node = nullptr; + if (pooling_type == "max") { + if (global_pooling) { + pool_node = std::make_shared( + graph_ctx->builder->CreateGlobalMaxPool2D( + *input_nodes.at(x_var_name))); + } else { + pool_node = std::make_shared( + graph_ctx->builder->CreateMaxPool2D(*input_nodes.at(x_var_name), + lite::xpu::CvtShape(ksize), + lite::xpu::CvtShape(strides), + lite::xpu::CvtShape(paddings), + "NCHW", + ceil_mode)); + } + } else if (pooling_type == "avg") { + if (global_pooling) { + pool_node = std::make_shared( + graph_ctx->builder->CreateGlobalAvgPool2D( + *input_nodes.at(x_var_name))); + } else { + pool_node = std::make_shared( + // !exclusive ---> count_include_pad + graph_ctx->builder->CreateAvgPool2D(*input_nodes.at(x_var_name), + lite::xpu::CvtShape(ksize), + lite::xpu::CvtShape(strides), + lite::xpu::CvtShape(paddings), + "NCHW", + ceil_mode, + !exclusive)); + } + } else { + LOG(FATAL) << "Unsupported pooling type: " << pooling_type; + } + graph_ctx->builder->SetLayer(unique_op_type); + + // output converted nodes + node_map_type output_nodes; + output_nodes[op_info->Output("Out").front()] = pool_node; + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(pool2d, paddle::lite::kernels::xpu::bridges::PoolConverter); diff --git a/lite/kernels/xpu/bridges/pool_op_test.cc b/lite/kernels/xpu/bridges/pool_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed5f922d59b5ca5e387076c9a533c4b4c251cc87 --- /dev/null +++ b/lite/kernels/xpu/bridges/pool_op_test.cc @@ -0,0 +1,267 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/pool_op.h" +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +void pool_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + auto& in_dims = x->dims(); + auto& out_dims = out->dims(); + + const float* src_ptr = x->data(); + float* dst_ptr = out->mutable_data(); + + std::vector ksize = op_info->GetAttr>("ksize"); + std::vector strides = op_info->GetAttr>("strides"); + std::vector paddings = op_info->GetAttr>("paddings"); + bool exclusive = op_info->GetAttr("exclusive"); + std::string pooling_type = op_info->GetAttr("pooling_type"); + bool global_pooling = op_info->GetAttr("global_pooling"); + + int in_n = in_dims[0]; + int in_c = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int size_in_n = in_c * in_h * in_w; + int size_in_c = in_h * in_w; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + int size_out_n = in_c * out_h * out_w; + int size_out_c = out_h * out_w; + + int window_h = ksize[0]; + int window_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + if (global_pooling == true) { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + const float* src = src_ptr + n * size_in_n + c * size_in_c; + float res = src[0]; + if (pooling_type == "max") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res = cur_val > res ? cur_val : res; + } + } else if (pooling_type == "avg") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res += cur_val; + } + res /= size_in_c; + } + dst_ptr[n * size_out_n + c] = res; + } + } + } else { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + for (int h = 0; h < out_h; ++h) { + int sh = h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + for (int w = 0; w < out_w; ++w) { + int sw = w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + int pooling_size = (ew - sw) * (eh - sh); + if (pooling_size == 0) continue; + float res = 0.f; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw; + if (kh == sh && kw == sw) { + res = src_ptr[src_idx]; + } else { + if (pooling_type == "max") { + res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx]; + } + if (pooling_type == "avg") { + res += src_ptr[src_idx]; + } + } + } + } + if (pooling_type == "avg") { + if (exclusive) { + res /= pooling_size; + } else { + res /= window_h * window_w; + } + } + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res; + } + } + } + } + } +} + +void test_pool(int bs, + int ic, + int ih, + int iw, + std::string pooling_type, + bool ceil_mode, + bool global_pooling, + bool exclusive, + int ksize, + int stride, + int padding) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("pool2d"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("pooling_type", pooling_type); + opdesc.SetAttr("ksize", std::vector({ksize, ksize})); + opdesc.SetAttr("global_pooling", global_pooling); + opdesc.SetAttr("exclusive", exclusive); + opdesc.SetAttr("strides", std::vector({stride, stride})); + opdesc.SetAttr("paddings", std::vector({padding, padding})); + opdesc.SetAttr("ceil_mode", ceil_mode); + + // create and convert op to XPU model, then run it on XPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + // execute reference implementation and save to output tensor + pool_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(XPUBridges, pool) { + for (auto pooling_type : {"max", "avg"}) { + for (auto bs : {1, 3}) { + for (auto ic : {2}) { + for (auto ih : {3}) { + for (auto iw : {4}) { + test_pool(bs, ic, ih, iw, pooling_type, true, true, true, 0, 1, 0); + } + } + } + } + } + + for (auto pooling_type : {"max"}) { + for (auto ceil_mode : {true, false}) { + for (auto ksize : {2, 3}) { + for (auto stride : {1, 2}) { + for (auto padding : {0, 1}) { + for (auto bs : {1, 3}) { + for (auto ic : {2}) { + for (auto ih : {3}) { + for (auto iw : {4}) { + test_pool(bs, + ic, + ih, + iw, + pooling_type, + ceil_mode, + false, + true, + ksize, + stride, + padding); + } + } + } + } + } + } + } + } + } + + for (auto pooling_type : {"avg"}) { + for (auto ceil_mode : {true, false}) { + for (auto exclusive : {true, false}) { + for (auto ksize : {2, 3}) { + for (auto stride : {1, 2}) { + for (auto padding : {0, 1}) { + for (auto bs : {1, 3}) { + for (auto ic : {2}) { + for (auto ih : {3}) { + for (auto iw : {4}) { + test_pool(bs, + ic, + ih, + iw, + pooling_type, + ceil_mode, + false, + exclusive, + ksize, + stride, + padding); + } + } + } + } + } + } + } + } + } + } +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(pool2d); +USE_XPU_BRIDGE(pool2d); diff --git a/lite/kernels/xpu/bridges/registry.cc b/lite/kernels/xpu/bridges/registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ab1b69a25a29aeb1c1ceaff25525459ef2e94cd --- /dev/null +++ b/lite/kernels/xpu/bridges/registry.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/bridges/registry.h" +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +Factory& Factory::Instance() { + static Factory g_xpu_bridge; + return g_xpu_bridge; +} + +bool Factory::HasType(const std::string& op_type) const { + return map_.count(op_type); +} + +void Factory::Insert(const std::string& op_type, const func_type& func_name) { + map_.insert(std::make_pair(op_type, func_name)); +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/bridges/registry.h b/lite/kernels/xpu/bridges/registry.h new file mode 100644 index 0000000000000000000000000000000000000000..c990399c1cdeb865dc214d2f1c6d1970b6d27b85 --- /dev/null +++ b/lite/kernels/xpu/bridges/registry.h @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/core/op_lite.h" +#include "lite/utils/macros.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +// xpu network builder and constant tensors +class graph_ctx_type { + public: + std::shared_ptr builder; + std::shared_ptr params; +}; + +// var_name, xpu node pointer +using node_map_type = + std::unordered_map>; + +using func_type = std::function, graph_ctx_type*, const node_map_type&)>; +using cvt_map_type = std::unordered_map; +class Factory { + public: + static Factory& Instance(); + + const cvt_map_type& AllFunctions() const { return map_; } + bool HasType(const std::string& op_type) const; + void Insert(const std::string& op_type, const func_type& func_name); + Factory() = default; + + private: + cvt_map_type map_; + DISALLOW_COPY_AND_ASSIGN(Factory); +}; + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +// some platform-independent defintion +#if defined(_WIN32) +#define UNUSED +#define __builtin_expect(EXP, C) (EXP) +#else +#define UNUSED __attribute__((unused)) +#endif + +#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +#define REGISTER_XPU_BRIDGE(op_type, cvt_func_name) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_xpu_bridge_##op_type##__, \ + "REGISTER_XPU_BRIDGE must be called in global namespace only once!"); \ + int __reg_xpu_bridge_##op_type##_Insert() { \ + paddle::lite::kernels::xpu::bridges::Factory::Instance().Insert( \ + #op_type, cvt_func_name); \ + return 0; \ + } + +#define USE_XPU_BRIDGE(op_type) \ + extern int __reg_xpu_bridge_##op_type##_Insert(); \ + static int __reg_xpu_bridge_##op_type##_Insert_return UNUSED = \ + __reg_xpu_bridge_##op_type##_Insert(); diff --git a/lite/kernels/xpu/bridges/softmax_op.cc b/lite/kernels/xpu/bridges/softmax_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3972496762a1d399ab59e7a69b0e9e18a9c28300 --- /dev/null +++ b/lite/kernels/xpu/bridges/softmax_op.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/builder.h" +#include "lite/kernels/xpu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +node_map_type SoftmaxConverter(const std::shared_ptr op, + graph_ctx_type* graph_ctx, + const node_map_type& input_nodes) { + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::xpu::UniqueName(op_type); + LOG(INFO) << "[XPU] Converting " + op_type + "..."; + + // check context + CHECK(graph_ctx != nullptr); + CHECK(graph_ctx->builder != nullptr); + CHECK(graph_ctx->params != nullptr); + + // get op's attributes + auto x_var_name = op_info->Input("X").front(); + auto axis = op_info->GetAttr("axis"); + + // create softmax node and set params from ops + CHECK(input_nodes.count(x_var_name)); + std::shared_ptr softmax_node = nullptr; + softmax_node = std::make_shared( + graph_ctx->builder->CreateSoftmax(*input_nodes.at(x_var_name), axis)); + graph_ctx->builder->SetLayer(unique_op_type); + + // output converted nodes + node_map_type output_nodes; + output_nodes[op_info->Output("Out").front()] = softmax_node; + return output_nodes; +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_XPU_BRIDGE(softmax, + paddle::lite::kernels::xpu::bridges::SoftmaxConverter); diff --git a/lite/kernels/xpu/bridges/softmax_op_test.cc b/lite/kernels/xpu/bridges/softmax_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cd12cbf4e8dc108ac43fec55a568ecec72a51ab --- /dev/null +++ b/lite/kernels/xpu/bridges/softmax_op_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/softmax_op.h" +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/test_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +template +void softmax_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + auto x_data = x->data(); + auto out_data = out->mutable_data(); + DDim x_dims = x->dims(); + + auto x_rank = x_dims.size(); + int axis = op_info->GetAttr("axis"); + if (axis < 0) { + axis += x_rank; + } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + dtype max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + dtype sum_data = (dtype)0; + for (int j = 0; j < axis_size; j++) { + out_data[offset] = exp(x_data[offset] - max_data); + sum_data += out_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + out_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +void test_softmax(int bs, int ic, int ih, int iw, int axis) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("softmax"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("axis", axis); + + // create and convert op to XPU model, then run it on XPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + // execute reference implementation and save to output tensor + softmax_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(XPUBridges, softmax) { + for (auto bs : {2, 3}) { + for (auto ic : {4}) { + for (auto ih : {5}) { + for (auto iw : {6}) { + for (auto axis : {-3, -1, 0, 1, 2, 3}) { + test_softmax(bs, ic, ih, iw, axis); + } + } + } + } + } +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(softmax); +USE_XPU_BRIDGE(softmax); diff --git a/lite/kernels/xpu/bridges/test_helper.cc b/lite/kernels/xpu/bridges/test_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..1a19324b946203c008093136d7a207ffaf23fbd6 --- /dev/null +++ b/lite/kernels/xpu/bridges/test_helper.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/bridges/test_helper.h" +#include +#include "lite/backends/xpu/builder.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/operators/graph_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +void LauchOp(const std::shared_ptr op, + const std::vector& input_var_names, + const std::vector& output_var_names) { + auto scope = op->scope(); + auto op_type = op->op_info()->Type(); + + // convert lite op to XPU op + const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); + const auto& supported_lists = bridges.AllFunctions(); + CHECK(bridges.HasType(op_type)); + graph_ctx_type graph_ctx; + graph_ctx.builder = std::make_shared(); + graph_ctx.params = + std::make_shared(); + node_map_type input_nodes; + for (auto input_var_name : input_var_names) { + auto input = scope->FindVar(input_var_name)->GetMutable(); + auto input_node = std::make_shared( + graph_ctx.builder->CreateTensor(input_var_name, + lite::xpu::CvtShape(input->dims()), + ::xtcl::Float(32))); + input_nodes[input_var_name] = input_node; + } + auto output_nodes = supported_lists.at(op_type)(op, &graph_ctx, input_nodes); + CHECK_GT(output_nodes.size(), 0); + + // build network graph and output model data + std::vector> ordered_output_nodes; + for (auto output_var_name : output_var_names) { + ordered_output_nodes.push_back(output_nodes.at(output_var_name)); + } + std::string weight_var_name = "weight"; + auto weight = scope->Var(weight_var_name)->GetMutable(); + weight->set_persistable(true); + weight->set_precision(PRECISION(kInt8)); + CHECK(lite::xpu::BuildModel( + graph_ctx.builder, graph_ctx.params, &ordered_output_nodes, weight)); + CHECK_GT(weight->numel(), 0); + CHECK(weight->data() != nullptr); + + // create graph op and set inputs and outputs + cpp::OpDesc graph_op_desc; + graph_op_desc.SetType("graph_op"); + graph_op_desc.SetInput("Inputs", input_var_names); + graph_op_desc.SetInput("Weight", {weight_var_name}); + graph_op_desc.SetOutput("Outputs", output_var_names); + + auto graph_op = + std::make_shared(graph_op_desc.Type()); + graph_op->SetValidPlaces({Place{TARGET(kXPU), PRECISION(kFloat)}}); + CHECK(graph_op->Attach(graph_op_desc, scope)); + CHECK(graph_op->CheckShape()); + CHECK(graph_op->InferShape()); + + // create graph op kernel and set XPU context + auto graph_kernels = + graph_op->CreateKernels({Place{TARGET(kXPU), PRECISION(kFloat)}}); + CHECK(!graph_kernels.empty()); + auto graph_kernel = + std::move(graph_kernels.front()); // use the first kernel by default + auto graph_device = ContextScheduler::Global().NewContext(TARGET(kXPU)); + graph_kernel->SetContext(std::move(graph_device)); + + // perform graph op kernel and store to output variables + graph_kernel->Launch(); + + lite::xpu::DeviceInfo::Global().Clear(); +} + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(graph_op); +USE_LITE_KERNEL(graph_op, kXPU, kFloat, kNCHW, def); diff --git a/lite/kernels/xpu/bridges/test_helper.h b/lite/kernels/xpu/bridges/test_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..c8bba5da66550a9eccaefa8b2d9a31a233f5f706 --- /dev/null +++ b/lite/kernels/xpu/bridges/test_helper.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +namespace bridges { + +template +std::shared_ptr CreateOp(const cpp::OpDesc& opdesc, lite::Scope* scope) { + auto op = std::make_shared(opdesc.Type()); + op->SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kXPU), PRECISION(kFloat)}}); + CHECK(op->Attach(opdesc, scope)); + CHECK(op->CheckShape()); + CHECK(op->InferShape()); + return op; +} + +// T is the target data type +// R is the range data type, e.g. int, half +template +void FillTensor(Tensor* x, + T lower = static_cast(-2), + T upper = static_cast(2)) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + + T* x_data = x->mutable_data(); + for (int i = 0; i < x->dims().production(); ++i) { + auto r = uniform_dist(rng) * (upper - lower) + lower; + x_data[i] = static_cast(static_cast(r)); + } +} + +void LauchOp(const std::shared_ptr op, + const std::vector& input_var_names, + const std::vector& output_var_names); + +} // namespace bridges +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/graph_compute.cc b/lite/kernels/xpu/graph_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9e5be1a1d5c764c378f3fdf29d73148743962a4 --- /dev/null +++ b/lite/kernels/xpu/graph_compute.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/graph_compute.h" +#include +#include +#include +#include +#include "lite/backends/xpu/runtime.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void GraphCompute::PrepareForRun() { + // auto& ctx = this->ctx_->template As(); + auto& param = this->Param(); + CHECK(param.weight); + CHECK(lite::xpu::LoadModel(*param.weight, &runtime_)); + CHECK(runtime_ != nullptr); +} + +void GraphCompute::Run() { + auto& param = this->Param(); + auto GetCurrentUS = []() -> double { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; + }; + auto start_time = GetCurrentUS(); + for (int i = 0; i < param.inputs.size(); i++) { + auto input_var_name = param.inputs[i].first; + auto input_tensor = param.inputs[i].second; + LOG(INFO) << "input dims[" << i << ":" << input_var_name + << "]: " << input_tensor->dims(); + auto input_tensor_data = input_tensor->data(); + for (int j = 0; j < input_tensor->dims().production(); j++) { + VLOG(3) << input_tensor_data[j]; + } + auto input_ndarray = xtcl::xNDArray::Empty( + input_tensor->dims().Vectorize(), {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto input_ndarray_data = + static_cast(input_ndarray.ToDLPack()->dl_tensor.data); + std::memcpy(input_ndarray_data, + input_tensor_data, + sizeof(float) * input_tensor->dims().production()); + runtime_->SetInputZeroCopy(input_var_name, + &input_ndarray.ToDLPack()->dl_tensor); + } + runtime_->Run(); + for (int i = 0; i < param.outputs.size(); i++) { + auto output_ndarray = runtime_->GetOutput(i); + auto output_var_name = param.outputs[i].first; + auto output_tensor = param.outputs[i].second; + output_tensor->Resize(output_ndarray.Shape()); + LOG(INFO) << "output dims[" << i << ":" << output_var_name + << "]: " << output_tensor->dims(); + auto output_ndarray_data = + static_cast(output_ndarray.ToDLPack()->dl_tensor.data); + auto output_tensor_data = output_tensor->mutable_data(); + std::memcpy(output_tensor_data, + output_ndarray_data, + sizeof(float) * output_tensor->dims().production()); + for (int j = 0; j < output_tensor->dims().production(); j++) { + VLOG(3) << output_tensor_data[j]; + } + } + LOG(INFO) << "[XPU] Process cost " << GetCurrentUS() - start_time << " us"; +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(graph_op, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::GraphCompute, + def) + .BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/xpu/graph_compute.h b/lite/kernels/xpu/graph_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..5406daa8a1b757989d006f4e0ea09baedc809e33 --- /dev/null +++ b/lite/kernels/xpu/graph_compute.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class GraphCompute : public KernelLite { + public: + using param_t = operators::GraphParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~GraphCompute() = default; + + private: + std::shared_ptr runtime_{nullptr}; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/CMakeLists.txt b/lite/model_parser/CMakeLists.txt index 2df9b97e552e00e0c04faa0f4995cf9ef348b84e..34d524c5c1b86fb6b689b86089c355e3de42a34e 100644 --- a/lite/model_parser/CMakeLists.txt +++ b/lite/model_parser/CMakeLists.txt @@ -28,7 +28,16 @@ lite_cc_library(model_parser SRCS model_parser.cc DEPS target_wrapper_host compatible_pb memory - CUDA_DEPS target_wrapper_cuda - NPU_DEPS npu_helper) - + CUDA_DEPS target_wrapper_cuda) lite_cc_test(test_compatible_pb SRCS compatible_pb_test.cc DEPS compatible_pb) + +if (LITE_WITH_CUDA AND NOT LITE_ON_TINY_PUBLISH) + lite_cc_library(compatibility SRCS compatibility.cc DEPS + kernel + variable + compatible_pb + type_system + ${cpp_wrapper} + ${naive_wrapper}) + lite_cc_test(test_compatibility SRCS compatibility_test.cc DEPS compatibility leaky_relu_compute_cuda) +endif() diff --git a/lite/model_parser/compatibility.cc b/lite/model_parser/compatibility.cc new file mode 100644 index 0000000000000000000000000000000000000000..d2fcfaa49cb3505c3cbbc0c9efb2034739301915 --- /dev/null +++ b/lite/model_parser/compatibility.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/compatibility.h" + +#include "lite/core/type_system.h" +#include "lite/model_parser/naive_buffer/block_desc.h" +#include "lite/model_parser/naive_buffer/op_desc.h" +#include "lite/model_parser/naive_buffer/program_desc.h" +#include "lite/model_parser/naive_buffer/var_desc.h" +#ifndef LITE_ON_TINY_PUBLISH +#include "lite/model_parser/cpp/block_desc.h" +#include "lite/model_parser/cpp/op_desc.h" +#include "lite/model_parser/cpp/program_desc.h" +#include "lite/model_parser/cpp/var_desc.h" +#endif + +namespace paddle { +namespace lite { + +template +bool CompatibleChecker::CheckKernelVersion(const std::string& type, + const lite_api::Place& place) { + int64_t impl_version = ParamTypeRegistry::Global().GetVersion(type, place); + const int64_t prog_version = program_.Version(); + VLOG(3) << "Kernel implement version: " << type << ", " << impl_version; + VLOG(3) << "Kernel program version: " << type << ", " << prog_version; + if (impl_version == -1) { + impl_version = mini_version_; + } + return prog_version <= impl_version; +} + +template +std::unordered_set CompatibleChecker::OpsType(T* program) { + LOG(WARNING) << "OpsType() is not yet implemented."; + return std::unordered_set(); +} + +#ifndef LITE_ON_TINY_PUBLISH +template <> +std::unordered_set CompatibleChecker::OpsType( + cpp::ProgramDesc* program) { + std::unordered_set ops_type; + for (size_t i = 0; i < program->BlocksSize(); ++i) { + auto* block = program->GetBlock(i); + for (size_t j = 0; j < block->OpsSize(); ++j) { + auto* op = block->GetOp(j); + ops_type.insert(op->Type()); + } + } + return ops_type; +} + +template class CompatibleChecker; +#endif // LITE_ON_TINY_PUBLISH + +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/compatibility.h b/lite/model_parser/compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..132f5c941a82bb4361300dcd29565069a22c165e --- /dev/null +++ b/lite/model_parser/compatibility.h @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/model_parser/desc_apis.h" + +namespace paddle { +namespace lite { + +template +class CompatibleChecker { + public: + explicit CompatibleChecker(const T& program, + const int64_t mini_version = 1005000) + : program_(program), mini_version_(mini_version) {} + + bool operator()(const lite_api::Place& place) { + bool status = true; + const std::unordered_set& ops_type = OpsType(&program_); + if (ops_type.empty()) { + VLOG(3) << "You are checking the compatibility of an empty program."; + } + for (const auto& type : ops_type) { + bool ret = CheckKernelVersion(type, place); + VLOG(3) << "Kernel version is supported: " << type << ", " << ret; + status = status && ret; + } + return status; + } + + private: + std::unordered_set OpsType(T* program); + bool CheckKernelVersion(const std::string& type, + const lite_api::Place& place); + T program_; + int64_t mini_version_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/compatibility_test.cc b/lite/model_parser/compatibility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3cb38f1c95649567b72d73b8938420537ec7b5b --- /dev/null +++ b/lite/model_parser/compatibility_test.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/compatibility.h" +#include +#include "lite/api/paddle_lite_factory_helper.h" + +#include "lite/model_parser/compatible_pb.h" +#include "lite/model_parser/cpp/block_desc.h" +#include "lite/model_parser/cpp/op_desc.h" +#include "lite/model_parser/cpp/program_desc.h" +#include "lite/model_parser/cpp/var_desc.h" + +USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); + +namespace paddle { +namespace lite { + +static constexpr int64_t version = 1005000; + +TEST(CompatibleChecker, CppProgramDesc) { + cpp::ProgramDesc program; + program.SetVersion(version); + auto* block = program.AddBlock(); + auto* op = block->AddOp(); + op->SetType("leaky_relu"); + + CompatibleChecker checker(program); + lite_api::Place place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}; + CHECK(checker(place)); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/compatible_pb.cc b/lite/model_parser/compatible_pb.cc index 09604b014adbde810516eebc60ad226d05d17fe2..2df4a92270466b1f3b56dec8deecf8e9a8e62390 100644 --- a/lite/model_parser/compatible_pb.cc +++ b/lite/model_parser/compatible_pb.cc @@ -116,10 +116,8 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { name, any_desc.template GetAttr>(name)); break; case AttrType::BLOCK: { - LOG(INFO) << "loading block " << name; auto i = any_desc.template GetAttr(name); - LOG(INFO) << i; - cpp_desc->SetAttr(name, i); + cpp_desc->SetAttr(name, i); // naive_buffer::BlockDesc* sub_block = any_desc.template // GetAttr(name); // LOG(INFO) << sub_block->OpsSize(); @@ -152,6 +150,8 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { IMPL_ONE(FLOATS, std::vector); IMPL_ONE(INTS, std::vector); IMPL_ONE(BOOLEAN, bool); + IMPL_ONE(LONG, int64_t); + IMPL_ONE(LONGS, std::vector); default: LOG(FATAL) << "Unsupported attr type found: " << static_cast(type); } diff --git a/lite/model_parser/cpp/op_desc.cc b/lite/model_parser/cpp/op_desc.cc index 4c99fdfb3dfbd3d93a8da5ce97c4109a29b8f867..f4be0106fcdce351056c648a35f93d410fd5712c 100644 --- a/lite/model_parser/cpp/op_desc.cc +++ b/lite/model_parser/cpp/op_desc.cc @@ -28,7 +28,6 @@ namespace cpp { } SET_ATTR_IMPL(int32_t, INT); -SET_ATTR_IMPL(int16_t, INT); SET_ATTR_IMPL(float, FLOAT); SET_ATTR_IMPL(std::string, STRING); SET_ATTR_IMPL(bool, BOOLEAN); @@ -108,7 +107,6 @@ bool OpDesc::HasOutput(const std::string& param) const { } GET_IMPL_ONE(float, FLOAT); -GET_IMPL_ONE(int16_t, INT); GET_IMPL_ONE(std::string, STRING); GET_IMPL_ONE(int64_t, LONG); GET_IMPL_ONE(bool, BOOLEAN); diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index 7f50726c80d52557fa741ae060f93eb889df64ad..13b6cb5b77d00a2a5f733a0015dec4dbebc088b7 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -31,10 +31,6 @@ #endif #include "lite/utils/io.h" -#ifdef LITE_WITH_NPU -#include "lite/backends/npu/npu_helper.h" -#endif - namespace paddle { namespace lite { @@ -266,25 +262,6 @@ void LoadModelPb(const std::string &model_dir, } } -#ifdef LITE_WITH_NPU - auto main_block = pb_proto_prog.blocks(0); - for (auto &op : main_block.ops()) { - LOG(INFO) << "op type:" << op.type(); - if (op.type() != "graph_op") { - continue; - } - auto xs = op.attrs(); - auto it = std::find_if( - xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == "model_name"; - }); - CHECK(it != xs.end()); - auto model_name = it->s(); - std::string file_path = model_dir + "/" + model_name; - CHECK(npu::BuildNPUClient(file_path, model_name)) - << "NPU model load failed!"; - } -#endif VLOG(4) << "Load protobuf model in '" << model_dir << "'' successfully"; } @@ -466,7 +443,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, #define SET_DATA_TYPE(precision, type_desc) \ case precision: \ desc.SetDataType(type_desc); \ - break + break; SET_DATA_TYPE(PRECISION(kFloat), VarDescAPI::VarDataType::FP32); SET_DATA_TYPE(PRECISION(kInt8), VarDescAPI::VarDataType::INT8); @@ -487,14 +464,14 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, if (tensor.target() == TARGET(kCUDA)) { switch (tensor.precision()) { #define DO(precision, type) \ - case precision: \ + case precision: { \ std::unique_ptr tmp_buffer(new type[tensor.data_size()]); \ TargetWrapperCuda::MemcpySync(tmp_buffer.get(), \ tensor.data(), \ tensor.data_size(), \ IoDirection::DtoH); \ desc.SetData(tmp_buffer.get(), tensor.data_size()); \ - break + } break; DO(PRECISION(kFloat), float); DO(PRECISION(kInt8), int8_t); DO(PRECISION(kInt16), int16_t); @@ -512,7 +489,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, #define DO(precision, type) \ case precision: \ desc.SetData(tensor.data(), tensor.data_size()); \ - break + break; DO(PRECISION(kFloat), float); DO(PRECISION(kInt8), int8_t); DO(PRECISION(kInt16), int16_t); @@ -737,21 +714,6 @@ void LoadModelNaive(const std::string &model_dir, } } -#ifdef LITE_WITH_NPU - auto &prog = *cpp_prog; - auto &main_block_desc = *prog.GetBlock(0); - for (size_t i = 0; i < main_block_desc.OpsSize(); ++i) { - auto &op = *main_block_desc.GetOp(i); - if (op.Type() != "graph_op") { - continue; - } - auto model_name = op.GetAttr("model_name"); - std::string file_path = model_dir + "/" + model_name; - CHECK(npu::BuildNPUClient(file_path, model_name)) - << "NPU model load failed!"; - } -#endif - VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully"; } @@ -765,10 +727,8 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer, // Load model - std::string prog_path = model_buffer; - naive_buffer::BinaryTable table; - table.LoadFromMemory(prog_path.c_str(), prog_path.length()); + table.LoadFromMemory(model_buffer.c_str(), model_buffer.length()); naive_buffer::proto::ProgramDesc nb_proto_prog(&table); nb_proto_prog.Load(); @@ -780,12 +740,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer, // Load Params // NOTE: Only main block be used now. // only combined Params are supported in Loading Model from memory - std::string combined_params_path = param_buffer; - LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true); - -#ifdef LITE_WITH_NPU - LOG(FATAL) << "load from memory is not supported by NPU"; -#endif + LoadCombinedParamsNaive(param_buffer, scope, *cpp_prog, true); VLOG(4) << "Load model from naive buffer memory successfully"; } diff --git a/lite/model_parser/model_parser.h b/lite/model_parser/model_parser.h index 81be2579e3932d7165480afd5bb89f567155cf36..bca7533c24af517994dae677c7b63e088f2ef1ca 100644 --- a/lite/model_parser/model_parser.h +++ b/lite/model_parser/model_parser.h @@ -72,7 +72,7 @@ void SerializeTensor(std::ostream& os, // LoDTensor to ostream void TensorToStream(std::ostream& os, const lite::Tensor& tensor); - +void TensorFromStream(std::istream& is, lite::Tensor* tensor); void ReadBinaryFile(const std::string& filename, std::string* contents); // For naive buffer diff --git a/lite/model_parser/naive_buffer/naive_buffer.h b/lite/model_parser/naive_buffer/naive_buffer.h index e2e2f7fb1ea3cb5b226bf09bd16074f51e171c75..717dd3c5a6b0c48d6a1f2ae0d7dba9f08a6d99f3 100644 --- a/lite/model_parser/naive_buffer/naive_buffer.h +++ b/lite/model_parser/naive_buffer/naive_buffer.h @@ -126,6 +126,41 @@ using UInt64Builder = PrimaryBuilder; using Float32Builder = PrimaryBuilder; using Float64Builder = PrimaryBuilder; +template +class PrimaryListBuilder : public FieldBuilder { + std::vector data_; + + public: + using value_type = Primary; + + explicit PrimaryListBuilder(BinaryTable* table) : FieldBuilder(table) {} + PrimaryListBuilder(BinaryTable* table, const std::vector& val) + : FieldBuilder(table), data_(val) {} + + /// Set data. + void set(const std::vector& x) { data_ = x; } + + const std::vector& data() const { return data_; } + + /// Save information to the corresponding BinaryTable. + void Save() override; + + /// Load information from the corresponding BinaryTable. + void Load() override; + + /// Number of elements. + size_t size() const { return data_.size(); } + + Type type() const override { + return core::StdTypeToRepr>(); + } + + /// clear builder + void Clear() { data_.clear(); } + + ~PrimaryListBuilder() = default; +}; + /* * Builder for all the primary types. int32, float, bool and so on. */ @@ -344,6 +379,36 @@ void PrimaryBuilder::Load() { table()->Consume(sizeof(value_type)); } +template +void PrimaryListBuilder::Load() { + CHECK(data_.empty()) << "Duplicate load"; + // Load number of elements first. + uint64_t num_elems{}; + memcpy(&num_elems, table()->cursor(), sizeof(uint64_t)); + table()->Consume(sizeof(uint64_t)); + + data_.resize(num_elems); + for (uint64_t i = 0; i < num_elems; i++) { + memcpy(&data_[i], table()->cursor(), sizeof(value_type)); + table()->Consume(sizeof(value_type)); + } +} + +template +void PrimaryListBuilder::Save() { + // store number of elements in the head. + uint64_t num_elems = size(); + table()->Require(sizeof(uint64_t)); + memcpy(table()->cursor(), &num_elems, sizeof(uint64_t)); + table()->Consume(sizeof(uint64_t)); + + table()->Require(num_elems * sizeof(value_type)); + memcpy(table()->cursor(), + reinterpret_cast(&data_[0]), + num_elems * sizeof(value_type)); + table()->Consume(num_elems * sizeof(value_type)); +} + template void EnumBuilder::Save() { value_type holder = static_cast(data_); diff --git a/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc b/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc index 45224de12248589d127fc4e0a3da44c1a52961da..46fbec1b67cbc3741e73fcbc9b6ad9934531d0ff 100644 --- a/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc +++ b/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc @@ -293,7 +293,7 @@ TEST(NaiveBufferWrapper, ProgramDesc) { // Set ProgramDesc nb_desc0.SetVersion(1); for (int i = 0; i < 3; ++i) { - auto* item = nb_desc0.AddBlock(); + nb_desc0.AddBlock(); } // Save model diff --git a/lite/model_parser/naive_buffer/op_desc.cc b/lite/model_parser/naive_buffer/op_desc.cc index 8d36a4ad3d2406cce027dbb3b92811986f5684b1..8a2ad55807a07c2ea79e0bab2b4368b22bf3b13c 100644 --- a/lite/model_parser/naive_buffer/op_desc.cc +++ b/lite/model_parser/naive_buffer/op_desc.cc @@ -54,6 +54,7 @@ SET_ATTR_IMPL(int, INT, Int32, i); SET_ATTR_IMPL(float, FLOAT, Float32, f); SET_ATTR_IMPL(bool, BOOLEAN, Bool, b); SET_ATTR_IMPL(std::string, STRING, String, s); +SET_ATTR_IMPL(int64_t, LONG, Int64, l); #undef SET_ATTR_IMPL #define SET_ATTRS_IMPL(T, ty__, bd__, pb_f__) \ @@ -77,6 +78,7 @@ SET_ATTR_IMPL(std::string, STRING, String, s); SET_ATTRS_IMPL(int, INTS, Int32, ints); SET_ATTRS_IMPL(float, FLOATS, Float32, floats); SET_ATTRS_IMPL(std::string, STRINGS, String, strings); +SET_ATTRS_IMPL(int64_t, LONGS, Int64, longs); #undef SET_ATTRS_IMPL const proto::OpDesc::Attr& GetFindAttr(const proto::OpDesc& desc, diff --git a/lite/model_parser/naive_buffer/op_desc.h b/lite/model_parser/naive_buffer/op_desc.h index c29229316917cc06369619ff3edf1ed56f660125..907f33a2a70939005f8a404d08b83e65312d7072 100644 --- a/lite/model_parser/naive_buffer/op_desc.h +++ b/lite/model_parser/naive_buffer/op_desc.h @@ -130,6 +130,7 @@ class OpDesc : public OpDescAPI { DEF_ONE(LONGS); default: LOG(FATAL) << "Unknown attribute type"; + return static_cast(-1); } #undef DEF_ONE } diff --git a/lite/model_parser/naive_buffer/param_desc.cc b/lite/model_parser/naive_buffer/param_desc.cc index 4d38ca4a8de2a9a52eabf631c54501ad59d1cbf1..4397b3c413e8a09d2e5e5b41b8f9222bcfab4e20 100644 --- a/lite/model_parser/naive_buffer/param_desc.cc +++ b/lite/model_parser/naive_buffer/param_desc.cc @@ -97,6 +97,7 @@ VarDescAPI::VarDataType ParamDesc::GetDataType() const { default: LOG(FATAL) << "Unknown var data type"; } + return VarDescAPI::VarDataType(); #undef GET_DATA_TYPE_CASE_ITEM } @@ -148,15 +149,16 @@ void ParamDesc::SetDim(const std::vector& dim) { CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \ << "Data Type mismatch"; \ std::vector res; \ - auto& data_builder = desc_->GetField>("data"); \ - auto data = RepeatedToVector(data_builder); \ + auto& data_builder = desc_->GetField>("data"); \ + auto& data = data_builder.data(); \ size_t size = data.size() / sizeof(T); \ - auto* data_ptr = reinterpret_cast(&data[0]); \ + auto* data_ptr = reinterpret_cast(&data[0]); \ for (size_t i = 0; i < size; ++i) { \ res.push_back(data_ptr[i]); \ } \ return res; \ } + GET_DATA_IMPL(uint8_t, UINT8); GET_DATA_IMPL(int8_t, INT8); GET_DATA_IMPL(int16_t, INT16); @@ -171,14 +173,13 @@ GET_DATA_IMPL(double, FP64); CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \ << "Data Type mismatch, call SetDataType first."; \ auto* data_builder = \ - desc_->GetMutableField>("data"); \ + desc_->GetMutableField>("data"); \ CHECK(data_builder); \ data_builder->Clear(); \ size_t size = size__ * sizeof(T); \ auto* data_ptr = reinterpret_cast(data_ptr__); \ - for (size_t i = 0; i < size; ++i) { \ - data_builder->New()->set(data_ptr[i]); \ - } + std::vector data_vec(data_ptr, data_ptr + size); \ + data_builder->set(data_vec); #define SET_DATA_IMPL(T, type__) \ template <> \ diff --git a/lite/model_parser/naive_buffer/proto/framework.nb.h b/lite/model_parser/naive_buffer/proto/framework.nb.h index f495a12b460c57e2464a76409d69778f4e2754a8..2427e49d2690811ded0a19d7a7bd6dec1ef6394a 100644 --- a/lite/model_parser/naive_buffer/proto/framework.nb.h +++ b/lite/model_parser/naive_buffer/proto/framework.nb.h @@ -191,7 +191,7 @@ class ParamDesc : public StructBuilder { New("lod"); NewUInt32("tensor_version"); New("tensor_desc"); - New>("data"); + New>("data"); } }; diff --git a/lite/model_parser/naive_buffer/var_desc.cc b/lite/model_parser/naive_buffer/var_desc.cc index 2e001999294766fcefca2683b2aa5627d543707a..cccf7582912d1edff2c91fbfa5ed602f028be648 100644 --- a/lite/model_parser/naive_buffer/var_desc.cc +++ b/lite/model_parser/naive_buffer/var_desc.cc @@ -51,6 +51,7 @@ VarDescAPI::Type VarDesc::GetType() const { GET_TYPE_CASE_ITEM(READER); default: LOG(FATAL) << "Unknown var type"; + return VarDescAPI::Type(); } #undef GET_TYPE_CASE_ITEM } diff --git a/lite/model_parser/pb/op_desc.cc b/lite/model_parser/pb/op_desc.cc index 34b83d55b55b0efb5189516ae57363798ee2f3f0..37ed07a2c5af60e9754e330ada79b587077aceaa 100644 --- a/lite/model_parser/pb/op_desc.cc +++ b/lite/model_parser/pb/op_desc.cc @@ -46,6 +46,7 @@ FindAttr(framework::proto::OpDesc *desc, const std::string &name) { SET_IMPL_ONE(int, INT, i); SET_IMPL_ONE(float, FLOAT, f); SET_IMPL_ONE(bool, BOOLEAN, b); +SET_IMPL_ONE(int64_t, LONG, l); template <> void OpDesc::SetAttr>(const std::string &name, @@ -88,6 +89,16 @@ void OpDesc::SetAttr>( } } +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework::proto::LONGS); + it->clear_longs(); + for (auto &i : v) { + it->add_longs(i); + } +} google::protobuf::internal::RepeatedPtrIterator< const framework::proto::OpDesc_Attr> GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) { diff --git a/lite/model_parser/pb/op_desc.h b/lite/model_parser/pb/op_desc.h index 1a0af22f272d2307dc3eb0b553fee7edf140bac4..a9c2f863a087790317653b916389cddfd457a3f2 100644 --- a/lite/model_parser/pb/op_desc.h +++ b/lite/model_parser/pb/op_desc.h @@ -121,6 +121,7 @@ class OpDesc : public OpDescAPI { DEF_ONE(LONGS); default: LOG(FATAL) << "Unknown attribute type"; + return static_cast(-1); } #undef DEF_ONE } @@ -142,8 +143,6 @@ class OpDesc : public OpDescAPI { template T GetAttr(const std::string &name) const; - std::string DebugString() const { return desc_->DebugString(); } - private: std::vector GetArguments( const google::protobuf::RepeatedPtrField diff --git a/lite/model_parser/pb/var_desc.cc b/lite/model_parser/pb/var_desc.cc index 91800c88b593180913ca09d44b784748de064f05..517f4cc6dcefbb5e517b6f84ac1b695dbbbc5925 100644 --- a/lite/model_parser/pb/var_desc.cc +++ b/lite/model_parser/pb/var_desc.cc @@ -39,6 +39,7 @@ VarDescAPI::Type VarDesc::GetType() const { GET_TYPE_CASE_ITEM(READER); default: LOG(FATAL) << "Unknown var type"; + return VarDescAPI::Type(); } #undef GET_TYPE_CASE_ITEM } diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 12f96121db0a5160c7e416aa2e7f6391f7374b69..21b8ec278a6df16711bef3d1b3be34f77c52c9b3 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -5,6 +5,7 @@ lite_cc_library(op_params SRCS op_params.cc DEPS tensor any) add_operator(conv_op basic SRCS conv_op.cc DEPS ${op_DEPS}) add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS}) add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS}) +add_operator(assign_op extra SRCS assign_op.cc DEPS ${op_DEPS}) add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS}) add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS}) add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS}) @@ -58,18 +59,25 @@ add_operator(norm_op basic SRCS norm_op.cc DEPS ${op_DEPS}) add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS}) add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS}) +add_operator(unsqueeze_op_lite extra SRCS unsqueeze_op.cc DEPS ${op_DEPS}) add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS}) -add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS}) -add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS}) -add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS}) -add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS}) -add_operator(affine_channel_op basic SRCS affine_channel_op.cc DEPS ${op_DEPS}) -add_operator(anchor_generator_op basic SRCS anchor_generator_op.cc DEPS ${op_DEPS}) -add_operator(generate_proposals_op basic SRCS generate_proposals_op.cc DEPS ${op_DEPS}) -add_operator(roi_align_op basic SRCS roi_align_op.cc DEPS ${op_DEPS}) -add_operator(box_clip_op basic SRCS box_clip_op.cc DEPS ${op_DEPS}) -add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS}) -add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) +add_operator(gather_op extra SRCS gather_op.cc DEPS ${op_DEPS}) +add_operator(reduce_mean_op extra SRCS reduce_mean_op.cc DEPS ${op_DEPS}) +add_operator(stack_op extra SRCS stack_op.cc DEPS ${op_DEPS}) +add_operator(cast_op_lite extra SRCS cast_op.cc DEPS ${op_DEPS}) +add_operator(affine_channel_op extra SRCS affine_channel_op.cc DEPS ${op_DEPS}) +add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS}) +add_operator(generate_proposals_op extra SRCS generate_proposals_op.cc DEPS ${op_DEPS}) +add_operator(roi_align_op extra SRCS roi_align_op.cc DEPS ${op_DEPS}) +add_operator(box_clip_op extra SRCS box_clip_op.cc DEPS ${op_DEPS}) +add_operator(flatten_op extra SRCS flatten_op.cc DEPS ${op_DEPS}) +add_operator(fake_quantize_range_abs_max_op extra SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) +add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ${op_DEPS}) +add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS}) +add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) +add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) +add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) @@ -88,13 +96,14 @@ add_operator(greater_than extra SRCS compare_op.cc DEPS ${op_DEPS}) add_operator(greater_equal extra SRCS compare_op.cc DEPS ${op_DEPS}) add_operator(read_from_array_op extra SRCS read_from_array_op.cc DEPS ${op_DEPS}) add_operator(beam_search_op extra SRCS beam_search_op.cc DEPS ${op_DEPS}) -add_operator(sequence_pool_op_lite extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) +add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) add_operator(lod_reset_op extra SRCS lod_reset_op.cc DEPS ${op_DEPS}) add_operator(is_empty extra SRCS is_empty_op.cc DEPS ${op_DEPS}) -add_operator(slice_op_lite extra SRCS slice_op.cc DEPS ${op_DEPS}) +add_operator(slice_op_lite basic SRCS slice_op.cc DEPS ${op_DEPS}) add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS}) add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS}) add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) +add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index c9d9d49d2e39181c8b65cef44141ac7189c92fcd..c3c5de311f41f88fbeed4b03f9bfd618cf51c3b3 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -51,6 +51,11 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { if (opdesc.Type() == "swish") { param_.Swish_beta = opdesc.GetAttr("beta"); } + + if (opdesc.Type() == "hard_sigmoid") { + param_.hard_sigmoid_slope = opdesc.GetAttr("slope"); + param_.hard_sigmoid_offset = opdesc.GetAttr("offset"); + } param_.Out = scope->FindVar(out_name)->GetMutable(); return true; } @@ -111,6 +116,9 @@ REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp); #ifdef LITE_WITH_TRAIN REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); diff --git a/lite/operators/argmax_op.cc b/lite/operators/argmax_op.cc index ccfce32bb63a829fb28cd57cffe8f2e7a902ceca..6b246603e1f640316e32804465a72c01b7984bfd 100644 --- a/lite/operators/argmax_op.cc +++ b/lite/operators/argmax_op.cc @@ -50,7 +50,7 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.X = scope->FindVar(x)->GetMutable(); param_.Out = scope->FindVar(out)->GetMutable(); - param_.Axis = op_desc.GetAttr("Axis"); + param_.Axis = op_desc.GetAttr("axis"); return true; } @@ -59,4 +59,4 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(argmax, paddle::lite::operators::ArgmaxOpLite); +REGISTER_LITE_OP(arg_max, paddle::lite::operators::ArgmaxOpLite); diff --git a/lite/operators/assign_value_op.cc b/lite/operators/assign_value_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..046c5222283fc73bd3af1e53520b1fc5539bcd31 --- /dev/null +++ b/lite/operators/assign_value_op.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/assign_value_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool AssignValueOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.Out); + auto shape = param_.shape; + auto int32_values = param_.int32_values; + auto fp32_values = param_.fp32_values; + size_t shape_num = 1; + for (int i = 0; i < shape.size(); i++) { + shape_num *= shape[i]; + } + CHECK_OR_FALSE(shape_num == int32_values.size() || + shape_num == fp32_values.size()); + return true; +} + +bool AssignValueOpLite::InferShape() const { + std::vector shape = param_.shape; + std::vector out_shape; + for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); + param_.Out->Resize(out_shape); + return true; +} + +bool AssignValueOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + param_.shape = op_desc.GetAttr>("shape"); + param_.dtype = op_desc.GetAttr("dtype"); + param_.fp32_values = op_desc.GetAttr>("fp32_values"); + param_.int32_values = op_desc.GetAttr>("int32_values"); + + auto out = op_desc.Output("Out").front(); + param_.Out = scope->FindVar(out)->GetMutable(); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(assign_value, paddle::lite::operators::AssignValueOpLite); diff --git a/lite/operators/assign_value_op.h b/lite/operators/assign_value_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7bf220615935f02051ed606adb894bf9842378f3 --- /dev/null +++ b/lite/operators/assign_value_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class AssignValueOpLite : public OpLite { + public: + AssignValueOpLite() {} + + explicit AssignValueOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "assign value"; } + + private: + mutable AssignValueParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index cbc946dbb0df7f6c23d7871f12dfd091c154b65c..dfd95e4658ddbfe244659887e9c738722be439ec 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -21,7 +21,7 @@ namespace lite { namespace operators { bool ConcatOpLite::CheckShape() const { - CHECK_GT_OR_FALSE(param_.x.size(), 1UL); + CHECK_GE_OR_FALSE(param_.x.size(), 1UL); CHECK_OR_FALSE(param_.output); return true; } @@ -60,6 +60,7 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto inputs = op_desc.Input("X"); auto out = op_desc.Output("Out").front(); + param_.x.clear(); for (auto var : inputs) { param_.x.push_back(scope->FindVar(var)->GetMutable()); } diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 640cec1a6c8c5ff897c98a3102fe73288adcebfa..ceca1a61ce3457ed0a2c25541d02bd868c380b3b 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/operators/conv_op.h" +#include #include #include "lite/core/op_registry.h" @@ -33,10 +34,6 @@ bool ConvOpLite::CheckShape() const { CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U); - CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size()); - - CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups); - CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0); CHECK_EQ_OR_FALSE(filter_dims.size(), 4UL); return true; @@ -46,15 +43,46 @@ inline int ConvOutputSize( int input_size, int filter_size, int dilation, int padding, int stride) { const int dkernel = dilation * (filter_size - 1) + 1; int output_size = (input_size + 2 * padding - dkernel) / stride + 1; - CHECK_GT_OR_FALSE(output_size, 0); + // CHECK_GT_OR_FALSE(output_size, 0); return output_size; } +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize) { + // when padding_desc is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (size_t i = 0; i < strides.size(); ++i) { + int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; + int pad_sum = + std::max((out_size - 1) * strides[i] + ksize[i] - data_dims[i + 2], + (int64_t)0); + // pad + *(paddings->begin() + i) = pad_sum / 2; + // dilation + *(dilations->begin() + i) = 1; + } + } else if (padding_algorithm == "VALID") { + for (auto& it : *paddings) { + it = 0; + } + } +} + bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); + UpdatePaddingAndDilation(¶m_.paddings, + ¶m_.dilations, + param_.strides, + padding_algorithm_, + in_dims, + filter_dims); std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < param_.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 2], diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index fe0393ee59cd99c6e000fecca974edbb3e1cbbfb..1d6e1c93490a394723d34de76fc3ff8040d31e81 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -76,8 +76,26 @@ class ConvOpLite : public OpLite { } } } - if (op_desc.HasAttr("fuse_relu")) { - param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + + if (op_desc.HasAttr("with_act") && op_desc.GetAttr("with_act")) { + param_.activation_param.has_active = true; + auto act_type = op_desc.GetAttr("act_type"); + if (act_type == "relu") { + param_.activation_param.active_type = lite_api::ActivationType::kRelu; + param_.fuse_relu = true; + } else if (act_type == "leaky_relu") { + param_.activation_param.active_type = + lite_api::ActivationType::kLeakyRelu; + param_.activation_param.Leaky_relu_alpha = + op_desc.GetAttr("leaky_relu_alpha"); + } else { + CHECK(false) + << "The fused conv only supports fuse with relu and leaky relu"; + } + } + + if (op_desc.HasAttr("padding_algorithm")) { + padding_algorithm_ = op_desc.GetAttr("padding_algorithm"); } // For Int8 if (op_desc.HasAttr("enable_int8")) { @@ -100,6 +118,7 @@ class ConvOpLite : public OpLite { private: mutable ConvParam param_; + std::string padding_algorithm_{""}; }; } // namespace operators diff --git a/lite/operators/dropout_op.cc b/lite/operators/dropout_op.cc index 332475bf6b602b650e61d6c7a82f300d05447102..bef089184751342545d56f6b16ed8554be775fae 100644 --- a/lite/operators/dropout_op.cc +++ b/lite/operators/dropout_op.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "lite/operators/dropout_op.h" #include #include #include "lite/core/op_lite.h" @@ -20,59 +21,48 @@ namespace paddle { namespace lite { namespace operators { -class DropoutOpLite : public OpLite { - public: - explicit DropoutOpLite(const std::string& type) : OpLite(type) {} +bool DropoutOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + return true; +} - bool CheckShape() const override { - CHECK_OR_FALSE(param_.x); - return true; +bool DropoutOp::InferShape() const { + const auto x_dims = param_.x->dims(); + param_.output->Resize(x_dims); + if (param_.is_test == false) { + param_.mask->Resize(x_dims); } + // share LoD + // param_.output->set_lod(param_.input->lod()); + return true; +} - bool InferShape() const override { - const auto x_dims = param_.x->dims(); - param_.output->Resize(x_dims); - if (param_.is_test == false) { - param_.mask->Resize(x_dims); - } - // share LoD - // param_.output->set_lod(param_.input->lod()); - return true; - } - - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } - // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { - auto input = op_desc.Input("X").front(); - auto out = op_desc.Output("Out").front(); - auto Mask = op_desc.Output("Mask").front(); - - param_.x = GetVar(scope, input); - param_.output = GetMutableVar(scope, out); - param_.mask = GetMutableVar(scope, Mask); - - param_.dropout_prob = op_desc.GetAttr("dropout_prob"); - param_.is_test = true; - // TODO(sangoly): `is_test` has different attr type in x86 and arm, set - // `true` now. - // if (op_desc.HasAttr("is_test")) { - // param_.is_test = op_desc.GetAttr("is_test"); - // } - param_.fix_seed = op_desc.GetAttr("fix_seed"); - param_.seed = op_desc.GetAttr("seed"); - param_.dropout_implementation = - op_desc.GetAttr("dropout_implementation"); - return true; - } +// TODO(Superjomn) replace framework::OpDesc with a lite one. +bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { + auto input = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + auto Mask = op_desc.Output("Mask").front(); - std::string DebugString() const override { return "dropout"; } + param_.x = GetVar(scope, input); + param_.output = GetMutableVar(scope, out); + param_.mask = GetMutableVar(scope, Mask); - private: - mutable DropoutParam param_; -}; + param_.dropout_prob = op_desc.GetAttr("dropout_prob"); + param_.is_test = true; + // TODO(sangoly): `is_test` has different attr type in x86 and arm, set + // `true` now. + // if (op_desc.HasAttr("is_test")) { + // param_.is_test = op_desc.GetAttr("is_test"); + // } + param_.fix_seed = op_desc.GetAttr("fix_seed"); + param_.seed = op_desc.GetAttr("seed"); + param_.dropout_implementation = + op_desc.GetAttr("dropout_implementation"); + return true; +} } // namespace operators } // namespace lite } // namespace paddle -REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOpLite); +REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOp); diff --git a/lite/operators/dropout_op.h b/lite/operators/dropout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..97e17e350c6a87a82e3cf05635d9575269489d7a --- /dev/null +++ b/lite/operators/dropout_op.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class DropoutOp : public OpLite { + public: + explicit DropoutOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override; + + std::string DebugString() const override { return "dropout"; } + + private: + mutable DropoutParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a86d3e4681ae6a039aa5a7f5610c9a0762e4c17 --- /dev/null +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP( + fake_quantize_dequantize_moving_average_abs_max, + paddle::lite::operators::FakeQuantizeDequantizeMovingAvgMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h new file mode 100644 index 0000000000000000000000000000000000000000..8efa46c41501be79ccc69f4cc9f9646c11673d2d --- /dev/null +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { + public: + FakeQuantizeDequantizeMovingAvgMaxAbsOpLite() {} + + explicit FakeQuantizeDequantizeMovingAvgMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShape() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_dequantize_moving_avg_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index d2772bf890f2ec777519170c5c7ed0d0639addbb..3f2a69dfbc76a3a7c0bdcac69866b901b239d1e4 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -87,6 +87,10 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.output = scope->FindVar(out)->GetMutable(); param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims"); + if (op_desc.HasAttr("activation_type")) { + param_.activation_type = op_desc.GetAttr("activation_type"); + } + // For Int8 if (op_desc.HasAttr("enable_int8")) { param_.enable_int8 = op_desc.GetAttr("enable_int8"); diff --git a/lite/operators/fill_constant_batch_size_like_op.h b/lite/operators/fill_constant_batch_size_like_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b073ba8379e5e52fcd3a2d0ee28aaaf5ceaea678 --- /dev/null +++ b/lite/operators/fill_constant_batch_size_like_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FillConstantBatchSizeLikeOp : public OpLite { + public: + FillConstantBatchSizeLikeOp() {} + + explicit FillConstantBatchSizeLikeOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { + return "fill_constant_batch_size_like"; + } + + private: + mutable FillConstantBatchSizeLikeParam param_; +}; + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index 50b1372248ba5ad370dd4171a3469e2141152f28..6e4bee4da87095245d90c6af5db98d2e95d7d3d8 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -52,8 +52,67 @@ class FillConstantOp : public OpLite { mutable operators::FillConstantParam param_; }; +class FillConstantBatchLikeOp : public OpLite { + public: + explicit FillConstantBatchLikeOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.out); + CHECK_OR_FALSE(param_.input); + CHECK_GT_OR_FALSE(param_.shape.size(), 0); + CHECK_GE_OR_FALSE(param_.input_dim_idx, 0); + CHECK_GE_OR_FALSE(param_.output_dim_idx, 0); + return true; + } + + bool InferShape() const override { + auto output_dim = param_.shape; + output_dim[param_.output_dim_idx] = + param_.input->dims()[param_.input_dim_idx]; + param_.out->Resize(output_dim); + return true; + } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + auto Out_name = opdesc.Output("Out").front(); + auto In_name = opdesc.Input("Input").front(); + + param_.out = GetMutableVar(scope, Out_name); + param_.input = GetMutableVar(scope, In_name); + param_.dtype = opdesc.GetAttr("dtype"); + auto shape = opdesc.GetAttr>("shape"); + std::vector outshape; + for (auto i : shape) { + outshape.push_back(i); + } + param_.shape = outshape; + if (opdesc.HasAttr("value")) { + param_.value = opdesc.GetAttr("value"); + } + if (opdesc.HasAttr("input_dim_idx")) { + param_.input_dim_idx = opdesc.GetAttr("input_dim_idx"); + } + if (opdesc.HasAttr("output_dim_idx")) { + param_.output_dim_idx = opdesc.GetAttr("output_dim_idx"); + } + + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fill_constant_batch_size_like"; + } + + private: + mutable operators::FillConstantBatchLikeParam param_; +}; + } // namespace operators } // namespace lite } // namespace paddle REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp); +REGISTER_LITE_OP(fill_constant_batch_size_like, + paddle::lite::operators::FillConstantBatchLikeOp); diff --git a/lite/operators/gather_op.cc b/lite/operators/gather_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6de2e97a3c079e373e8747dba4c1c1d4779aa70a --- /dev/null +++ b/lite/operators/gather_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lite/operators/gather_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool GatherOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Index); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool GatherOp::InferShape() const { + auto index_dims = param_.Index->dims(); + CHECK(index_dims.size() == 1 || + (index_dims.size() == 2 && index_dims[1] == 1)) + << "index dims unmatch"; + int batch_size = index_dims[0]; + auto out_dims = param_.X->dims(); + out_dims[0] = batch_size; + param_.Out->Resize(out_dims); + return true; +} + +bool GatherOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.X = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.Index = + scope->FindVar(opdesc.Input("Index").front())->GetMutable(); + CHECK(param_.X) << "X is null"; + CHECK(param_.Out) << "out is null"; + CHECK(param_.Index) << "index is null"; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(gather, paddle::lite::operators::GatherOp); diff --git a/lite/operators/gather_op.h b/lite/operators/gather_op.h new file mode 100644 index 0000000000000000000000000000000000000000..58d5a30ffbb5f563503c8934d8c9e40bb539d5df --- /dev/null +++ b/lite/operators/gather_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class GatherOp : public OpLite { + public: + GatherOp() {} + explicit GatherOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "gather"; } + + private: + mutable GatherParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/graph_op.cc b/lite/operators/graph_op.cc index 266187d6e890e3611606b0563b47eba4fe10aaee..018ce264e2f18862549a4abc0444d02dcbb573ee 100644 --- a/lite/operators/graph_op.cc +++ b/lite/operators/graph_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/operators/graph_op.h" +#include #include "lite/core/op_registry.h" namespace paddle { @@ -29,19 +30,24 @@ bool GraphOpLite::InferShape() const { return CheckShape(); /* enrich me */ } bool GraphOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto inputs = op_desc.Input("Inputs"); + auto weight = op_desc.Input("Weight"); auto outputs = op_desc.Output("Outputs"); for (auto var : inputs) { CHECK(scope->FindVar(var)); - param_.inputs.push_back(scope->FindVar(var)->GetMutable()); + param_.inputs.push_back( + std::make_pair(var, scope->FindVar(var)->GetMutable())); } + param_.weight = scope->FindVar(weight.front())->GetMutable(); + CHECK(param_.weight); + for (auto var : outputs) { CHECK(scope->FindVar(var)); - param_.outputs.push_back(scope->FindVar(var)->GetMutable()); + param_.outputs.push_back( + std::make_pair(var, scope->FindVar(var)->GetMutable())); } - param_.model_name = op_desc.GetAttr("model_name"); return true; } diff --git a/lite/operators/gru_unit_op.cc b/lite/operators/gru_unit_op.cc index b1efd8d048e2803e022bde0249f4173539683286..ed33507fc3fa61fce1e718581309ae37992c0531 100644 --- a/lite/operators/gru_unit_op.cc +++ b/lite/operators/gru_unit_op.cc @@ -32,7 +32,6 @@ bool GRUUnitOpLite::CheckShape() const { auto hidden_prev_dims = param_.hidden_prev->dims(); auto weight_dims = param_.weight->dims(); - int batch_size = input_dims[0]; int input_size = input_dims[1]; int frame_size = hidden_prev_dims[1]; int weight_height = weight_dims[0]; diff --git a/lite/operators/im2sequence_op.cc b/lite/operators/im2sequence_op.cc index 1cd415bcd5a55c683f4800e79e7454f1176e1255..40ab2106af85b3386f93385785b65b9293b1c7f9 100644 --- a/lite/operators/im2sequence_op.cc +++ b/lite/operators/im2sequence_op.cc @@ -29,7 +29,6 @@ bool Im2SequenceOp::CheckShape() const { return true; } bool Im2SequenceOp::InferShape() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. - auto inputs = param_.X; auto input_dims = param_.X->dims(); int img_num = input_dims[0]; int img_channels = input_dims[1]; diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index f29acf70a75a7ac6464d8df5da145e760fb1faa3..b98240ba4f255377c0ac661950a45bef0a7d0516 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -88,6 +88,9 @@ bool InterpolateOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { if (op_desc.HasAttr("out_h")) { param_.out_h = op_desc.GetAttr("out_h"); } + if (op_desc.HasAttr("align_mode")) { + param_.align_mode = op_desc.GetAttr("align_mode"); + } param_.align_corners = op_desc.GetAttr("align_corners"); param_.interp_method = op_desc.GetAttr("interp_method"); return true; diff --git a/lite/operators/is_empty_op.cc b/lite/operators/is_empty_op.cc index e89c72d414e78ef6f5cf310b48494752b5c995c3..ed4c69e64eaae8fdcb8289c5389dcff1df2ea8b5 100644 --- a/lite/operators/is_empty_op.cc +++ b/lite/operators/is_empty_op.cc @@ -21,7 +21,7 @@ namespace operators { bool IsEmptyOp::CheckShape() const { return true; } -bool IsEmptyOp::InferShape() const {} +bool IsEmptyOp::InferShape() const { return true; } bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = diff --git a/lite/operators/layer_norm_op.cc b/lite/operators/layer_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..061355733c9a6722fcca4ba01af81981d2b5c9ac --- /dev/null +++ b/lite/operators/layer_norm_op.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/layer_norm_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool LayerNormOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.Mean); + CHECK_OR_FALSE(param_.Variance); + return true; +} + +bool LayerNormOp::InferShape() const { + auto out_dims = param_.X->dims(); + param_.Y->Resize(out_dims); + auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[1]; + param_.Mean->Resize(std::vector({inner_size})); + param_.Variance->Resize(std::vector({inner_size})); + + auto out_lod = param_.Y->mutable_lod(); + *out_lod = param_.X->lod(); + return true; +} + +bool LayerNormOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.X = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.Y = + scope->FindVar(opdesc.Output("Y").front())->GetMutable(); + param_.Mean = + scope->FindVar(opdesc.Output("Mean").front())->GetMutable(); + param_.Variance = scope->FindVar(opdesc.Output("Variance").front()) + ->GetMutable(); + CHECK(param_.X); + CHECK(param_.Y); + CHECK(param_.Mean); + CHECK(param_.Variance); + if (opdesc.HasInput("Scale")) { + param_.Scale = scope->FindVar(opdesc.Input("Scale").front()) + ->GetMutable(); + } + if (opdesc.HasInput("Bias")) { + param_.Bias = scope->FindVar(opdesc.Input("Bias").front()) + ->GetMutable(); + } + param_.begin_norm_axis = opdesc.GetAttr("begin_norm_axis"); + param_.epsilon = opdesc.GetAttr("epsilon"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(layer_norm, paddle::lite::operators::LayerNormOp); diff --git a/lite/operators/layer_norm_op.h b/lite/operators/layer_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..297f6bdd402b919b4baa1915135ed909c57cfa0b --- /dev/null +++ b/lite/operators/layer_norm_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class LayerNormOp : public OpLite { + public: + LayerNormOp() {} + explicit LayerNormOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "layer_norm"; } + + private: + mutable LayerNormParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/lookup_table_op.cc b/lite/operators/lookup_table_op.cc index 192de2ecf85d5dda9bbf42b4fb1dccd28d8b02d5..3d5a71cee96adb520aeafc83156e5f37638912ad 100644 --- a/lite/operators/lookup_table_op.cc +++ b/lite/operators/lookup_table_op.cc @@ -50,6 +50,7 @@ bool LookupTableOpLite::InferShape() const { } out_dims.push_back(table_dims[1]); param_.Out->Resize(lite::DDim{out_dims}); + param_.Out->set_lod(param_.Ids->lod()); return true; } diff --git a/lite/operators/mul_op.cc b/lite/operators/mul_op.cc index 43048f29963c5746d5e93366aef7aa98c7fb8ce5..6067be5315220ec8b2f75265982e55f874e4b23a 100644 --- a/lite/operators/mul_op.cc +++ b/lite/operators/mul_op.cc @@ -23,6 +23,7 @@ bool MulOpLite::CheckShape() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.y); CHECK_OR_FALSE(param_.output); + // bias is optional. const auto x_dims = param_.x->dims(); @@ -54,17 +55,15 @@ bool MulOpLite::InferShape() const { const auto y_dims = param_.y->dims(); // Set output dims - std::vector out_dims( - param_.x_num_col_dims + y_dims.size() - param_.y_num_col_dims, 0); + std::vector out_dims; for (int i = 0; i < param_.x_num_col_dims; ++i) { - out_dims[i] = x_dims[i]; + out_dims.push_back(x_dims[i]); } for (auto i = static_cast(param_.y_num_col_dims); i < y_dims.size(); ++i) { - out_dims[i] = y_dims[i]; + out_dims.push_back(y_dims[i]); } - param_.output->Resize(lite::DDim(out_dims)); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x->lod(); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 9d2cea030f85c583affea94b367d216f276c5e87..5ae22e6039bf55bb57f4e90a49b4eca835b879ea 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -14,9 +14,12 @@ #pragma once #include +#include #include +#include "lite/api/paddle_place.h" #include "lite/core/scope.h" #include "lite/core/tensor.h" +#include "lite/core/types.h" #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/desc_apis.h" #include "lite/utils/all.h" @@ -33,7 +36,8 @@ using param_t = Any; bool enable_int8{false}; \ float input_scale{1.0}; \ std::vector weight_scale{}; \ - float output_scale{1.0}; + float output_scale{1.0}; \ + int bit_length{8}; /// ----------------------- Functional operators ------------------------------ struct FeedParam { @@ -66,9 +70,9 @@ struct CalibParam { }; struct GraphParam { - std::vector inputs{}; - std::vector outputs{}; - std::string model_name{"model"}; + std::vector> inputs{}; + lite::Tensor* weight{}; + std::vector> outputs{}; }; /// -------------------------- NN operators ------------------------------------ @@ -80,7 +84,7 @@ struct FcParam { lite::Tensor* output{nullptr}; lite::DDim in_mat_dims; int in_num_col_dims{1}; - bool weight_transposed{false}; + std::string activation_type{""}; // for int8 WITH_INT8_CONFIG }; @@ -95,6 +99,7 @@ struct InterpolateParam { int out_h{-1}; int out_w{-1}; bool align_corners{true}; + int align_mode{1}; std::string interp_method{"Nearest"}; }; @@ -188,11 +193,12 @@ struct SoftmaxParam { // For Reshape and Reshape2 Op struct ReshapeParam { const lite::Tensor* x{}; - const lite::Tensor* actual_shape{nullptr}; + std::vector shape_tensor_vct{}; + const lite::Tensor* shape_tensor{}; + std::vector shape_vct{}; lite::Tensor* output{}; - lite::Tensor* xshape{}; - std::vector shape{}; + lite::Tensor* xshape{}; bool inplace{false}; }; @@ -203,6 +209,30 @@ struct ConcatParam { int axis{0}; }; +/// ----------------------- activation operators ---------------------- +struct ActivationParam { + const lite::Tensor* X{}; + float Leaky_relu_alpha{0}; // leaky_relu param + float Relu_clipped_coef{6}; // relu_clipped param + std::string Prelu_mode{ + "channel"}; // prelu param, can be "all", "channel" or "element" + lite::Tensor* Prelu_alpha{}; // prelu param + float Swish_beta; // swish param + float hard_sigmoid_slope{0.2}; + float hard_sigmoid_offset{0.5}; + lite::Tensor* Out{}; + bool has_active{false}; + lite_api::ActivationType active_type; +}; + +struct ActivationGradParam { + const lite::Tensor* X{}; + const lite::Tensor* Out{}; + // for backward + lite::Tensor* X_grad{}; + const lite::Tensor* Out_grad{}; +}; + // For Convolution op struct ConvParam { lite::Tensor* x{}; @@ -226,6 +256,8 @@ struct ConvParam { float scale_weights{1.0f}; // only used with mkl-dnn int8 bool force_fp32_output{false}; // only used in mkl-dnn int8 std::string data_format{"Anylayout"}; + // for activation + ActivationParam activation_param; // for int8 WITH_INT8_CONFIG }; @@ -264,6 +296,8 @@ struct PoolParam { bool ceil_mode{false}; bool use_quantizer{false}; std::string data_format{"AnyLayout"}; + // for int8 + WITH_INT8_CONFIG }; // For Dropout op @@ -291,6 +325,8 @@ struct SplitParam { struct TransposeParam { const lite::Tensor* x{}; lite::Tensor* output{}; + lite::Tensor* xshape{}; + std::vector axis; bool use_mkldnn{false}; std::string data_format{"AnyLayout"}; @@ -302,6 +338,10 @@ struct ElementwiseParam { const lite::Tensor* Y{}; lite::Tensor* Out{}; int axis{-1}; // for broadcasting. + // for int8 + WITH_INT8_CONFIG + float x_input_scale{1.0}; + float y_input_scale{1.0}; }; struct ElementwiseGradParam { @@ -320,26 +360,6 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { std::string act_type; }; -/// ----------------------- activation operators ---------------------- -struct ActivationParam { - const lite::Tensor* X{}; - float Leaky_relu_alpha{0}; // leaky_relu param - float Relu_clipped_coef{6}; // relu_clipped param - std::string Prelu_mode{ - "channel"}; // prelu param, can be "all", "channel" or "element" - lite::Tensor* Prelu_alpha{}; // prelu param - float Swish_beta; // swish param - lite::Tensor* Out{}; -}; - -struct ActivationGradParam { - const lite::Tensor* X{}; - const lite::Tensor* Out{}; - // for backward - lite::Tensor* X_grad{}; - const lite::Tensor* Out_grad{}; -}; - /// ----------------------- mean operators ---------------------- struct MeanParam { const lite::Tensor* X{}; @@ -362,6 +382,28 @@ struct FillConstantParam { bool force_cpu{false}; lite::Tensor* Out{}; }; +struct FillConstantBatchLikeParam { + int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; + std::vector shape{}; + float value{0.0f}; + // useless for x86, keep it for compatibility + bool force_cpu{false}; + lite::Tensor* out{}; + const lite::Tensor* input{}; + int input_dim_idx{0}; + int output_dim_idx{0}; +}; + +struct FillConstantBatchSizeLikeParam { + lite::Tensor* Input; + lite::Tensor* Out; + + std::vector shape; + int input_dim_idx{0}; + int output_dim_idx{0}; + int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; + float value{0.0f}; +}; // struct FakeQuantizeMovingAvgMaxAbsParam { @@ -532,6 +574,7 @@ struct PriorBoxParam { int prior_num{0}; // priortype: prior_min, prior_max, prior_com std::vector order; + bool min_max_aspect_ratios_order{false}; }; struct DensityPriorBoxParam : public PriorBoxParam { @@ -592,9 +635,20 @@ struct SequenceSoftmaxParam { struct NormParam { const lite::Tensor* X{}; lite::Tensor* Out{}; + lite::Tensor* Norm{}; int axis{1}; float epsilon{1e-10}; }; +struct LayerNormParam { + const lite::Tensor* X{}; + const lite::Tensor* Scale{}; + const lite::Tensor* Bias{}; + lite::Tensor* Y{}; + lite::Tensor* Mean{}; + lite::Tensor* Variance{}; + int begin_norm_axis{1}; + float epsilon{1e-5}; +}; struct LogicalParam { const lite::Tensor* X{}; @@ -660,7 +714,17 @@ struct BeamSearchParam { struct SequencePoolParam { const lite::Tensor* X{}; lite::Tensor* Out{}; - std::string pool_type; + std::string pool_type{"AVERAGE"}; +#ifdef LITE_WITH_X86 + float pad_value{0.0}; + lite::Tensor* MaxIndex{}; +#endif +}; + +struct SequenceReshapeParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + int new_dim; }; struct SequenceExpandParam { @@ -670,6 +734,12 @@ struct SequenceExpandParam { int ref_level{-1}; }; +struct SequenceExpandAsParam { + const lite::Tensor* x{nullptr}; + const lite::Tensor* y{nullptr}; + lite::Tensor* out{nullptr}; +}; + struct ReduceMaxParam { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -689,6 +759,15 @@ struct IsEmptyParam { const lite::Tensor* X{}; lite::Tensor* Out{}; }; + +struct ReduceParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + std::vector dim{0}; + bool keep_dim{false}; + bool reduce_all{false}; +}; + /// ----------------------- shape operators ---------------------- struct ShapeParam { const lite::Tensor* X{}; @@ -750,7 +829,6 @@ struct GenerateProposalsParam { lite::Tensor* RpnRois{}; lite::Tensor* RpnRoiProbs{}; }; -/// ----------------------- shape operators ---------------------- /// ----------------------- squeeze operators ---------------------- struct SqueezeParam { const lite::Tensor* X{}; @@ -759,6 +837,13 @@ struct SqueezeParam { std::vector axes{}; }; +struct UnsqueezeParam { + const lite::Tensor* X{}; + lite::Tensor* Out{}; + lite::Tensor* XShape{}; + std::vector axes{}; +}; + /// ----------------------- expand operators ---------------------- struct ExpandParam { const lite::Tensor* X{}; @@ -776,12 +861,19 @@ struct MatMulParam { float alpha{1.0f}; }; +struct GatherParam { + const lite::Tensor* X{}; + const lite::Tensor* Index{}; + lite::Tensor* Out{}; +}; + /// ----------------------- assign operators ----------------------- struct AssignParam { const lite::Tensor* X{}; lite::Tensor* Out{}; }; +/// ----------------------- roi_align operators ----------------------- struct RoiAlignParam { lite::Tensor* X{}; lite::Tensor* ROIs{}; @@ -792,12 +884,29 @@ struct RoiAlignParam { int sampling_ratio{-1}; }; +/// ----------------------- box_clip operators ----------------------- struct BoxClipParam { const lite::Tensor* Input{}; const lite::Tensor* ImInfo{}; lite::Tensor* Output{}; }; +struct RangeParam { + const lite::Tensor* Start; + const lite::Tensor* End; + const lite::Tensor* Step; + lite::Tensor* Out; +}; + +/// ----------------------- assign_value operators ----------------------- +struct AssignValueParam { + std::vector shape{}; + int dtype{}; + std::vector fp32_values{}; + std::vector int32_values{}; + lite::Tensor* Out{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index 3cc8938f4eb3ffc5720a6e1cfc1746e1defd048e..c4717c8185b24cfd9f6a551dcb932dc325a502d2 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -67,6 +67,10 @@ bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { if (opdesc.HasAttr("order")) { param_.order = opdesc.GetAttr>("order"); } + if (opdesc.HasAttr("min_max_aspect_ratios_order")) { + param_.min_max_aspect_ratios_order = + opdesc.GetAttr("min_max_aspect_ratios_order"); + } return true; } diff --git a/lite/operators/range_op.cc b/lite/operators/range_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a179d8ffe7abc1665b13f7d0dfeaa8b3c18cf1d5 --- /dev/null +++ b/lite/operators/range_op.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/range_op.h" +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool RangeOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.Start); + CHECK_OR_FALSE(param_.End); + CHECK_OR_FALSE(param_.Step); + CHECK_OR_FALSE(param_.Out); + return true; +} + +template +void GetSize(T start, T end, T step, int64_t* size) { + CHECK(!std::equal_to()(step, 0)) + << "The step of range op should not be 0."; + CHECK(((start < end) && (step > 0)) || ((start > end) && (step < 0))) + << "The step should be greater than 0 while start < end. And the " + "step should be less than 0 while start > end."; + *size = std::is_integral::value + ? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step)) + : std::ceil(std::abs((end - start) / step)); +} + +bool RangeOpLite::InferShape() const { + int start = param_.Start->data()[0]; + int end = param_.End->data()[0]; + int step = param_.Step->data()[0]; + int64_t size = 0; + GetSize(start, end, step, &size); + param_.Out->Resize(std::vector({size})); + return true; +} + +bool RangeOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + auto start = opdesc.Input("Start").front(); + auto end = opdesc.Input("End").front(); + auto step = opdesc.Input("Step").front(); + auto out = opdesc.Output("Out").front(); + + param_.Start = scope->FindVar(start)->GetMutable(); + param_.End = scope->FindVar(end)->GetMutable(); + param_.Step = scope->FindVar(step)->GetMutable(); + param_.Out = scope->FindVar(out)->GetMutable(); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(range, paddle::lite::operators::RangeOpLite); diff --git a/lite/operators/range_op.h b/lite/operators/range_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a1c7d4d4cc43d72001ac3519cb1c4f85ab8196ff --- /dev/null +++ b/lite/operators/range_op.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class RangeOpLite : public OpLite { + public: + RangeOpLite() {} + explicit RangeOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "range"; } + + private: + mutable RangeParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..e986b0ca5412f8380cccc9f981e5e4069ffcdabc --- /dev/null +++ b/lite/operators/reduce_ops.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reduce_ops.h" +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool ReduceOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_LE(x_rank, 6UL) << "Tensors with rank at most 6 are supported."; + return true; +} + +bool ReduceOp::InferShape() const { + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + auto dims = param_.dim; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + CHECK_LT(dims[i], x_rank) + << "The dim should be in the range [-rank(input), rank(input)."; + } + sort(dims.begin(), dims.end()); + bool reduce_all = param_.reduce_all; + bool keep_dim = param_.keep_dim; + + if (reduce_all) { + if (keep_dim) + param_.output->Resize(lite::DDim(std::vector(x_rank, 1))); + else + param_.output->Resize(lite::DDim(std::vector{1})); + } else { + auto dims_vector = x_dims.Vectorize(); + if (keep_dim) { + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = 1; + } + } else { + const int kDelFlag = -2; + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = kDelFlag; + } + dims_vector.erase( + remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + auto out_dims = lite::DDim(dims_vector); + param_.output->Resize(out_dims); + if (dims[0] != 0) { + param_.output->set_lod(param_.x->lod()); + } + } + return true; +} + +bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + param_.dim = opdesc.GetAttr>("dim"); + param_.reduce_all = opdesc.GetAttr("reduce_all"); + param_.keep_dim = opdesc.GetAttr("keep_dim"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(reduce_sum, paddle::lite::operators::ReduceOp); diff --git a/lite/operators/reduce_ops.h b/lite/operators/reduce_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0063aba1fa606c6228e7dcb1197bfb36f57aa33c --- /dev/null +++ b/lite/operators/reduce_ops.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ReduceOp : public OpLite { + public: + ReduceOp() {} + explicit ReduceOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reduce"; } + + private: + mutable ReduceParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 0e7059d66d2f934e862662c5e6bd234ff4e33a64..89cf698f8e3d6c8f04bf1100f30742712615fe2f 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -14,6 +14,7 @@ #include "lite/operators/reshape_op.h" #include "lite/core/op_registry.h" +#include "lite/core/tensor.h" namespace paddle { namespace lite { @@ -22,13 +23,31 @@ namespace operators { bool ReshapeOp::CheckShape() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); - CHECK_OR_FALSE(!param_.shape.empty()); return true; } bool ReshapeOp::InferShape() const { + auto shape_tensor_vct = param_.shape_tensor_vct; + auto *shape_tensor = param_.shape_tensor; + auto shape_vct = param_.shape_vct; + std::vector final_shape; + + if (shape_tensor_vct.size() > 0) { + for (int i = 0; i < shape_tensor_vct.size(); i++) { + final_shape.push_back(shape_tensor_vct[i]->data()[0]); + } + } else if (shape_tensor != nullptr) { + auto *shape_tensor_data = shape_tensor->data(); + final_shape = std::vector(shape_tensor_data, + shape_tensor_data + shape_tensor->numel()); + } else if (!shape_vct.empty()) { + final_shape = shape_vct; + } else { + LOG(FATAL) << "input shape error"; + } + auto x_dims = param_.x->dims(); - auto output_dims = ValidateShape(param_.shape, x_dims); + auto output_dims = ValidateShape(final_shape, x_dims); param_.output->Resize(output_dims); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x->lod(); @@ -36,31 +55,33 @@ bool ReshapeOp::InferShape() const { } bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { - auto x_var = scope->FindVar(opdesc.Input("X").front()); - auto output_var = scope->FindVar(opdesc.Output("Out").front()); - CHECK(x_var); - CHECK(output_var); - param_.x = const_cast(&(x_var->Get())); - param_.output = output_var->GetMutable(); - std::vector input_arg_names = opdesc.InputArgumentNames(); - if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") != - input_arg_names.end()) { - if (opdesc.Input("Shape").size() > 0) { - auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front()); - if (actual_shape_var != nullptr) { - param_.actual_shape = const_cast( - &(actual_shape_var->Get())); + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + if (opdesc.HasInput("ShapeTensor") && + opdesc.Input("ShapeTensor").size() > 0) { + auto args = opdesc.Input("ShapeTensor"); + for (auto arg : args) { + auto *var = scope->FindVar(arg); + if (var != nullptr) { + param_.shape_tensor_vct.push_back(var->GetMutable()); } } } - param_.shape = (opdesc.GetAttr>("shape")); + if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) { + auto var = scope->FindVar(opdesc.Input("Shape").front()); + if (var != nullptr) { + param_.shape_tensor = var->GetMutable(); + } + } + if (opdesc.HasAttr("shape")) { + param_.shape_vct = opdesc.GetAttr>("shape"); + } if (opdesc.HasAttr("inplace")) { param_.inplace = opdesc.GetAttr("inplace"); } - CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; - CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; - CHECK(!param_.shape.empty()) - << "The shape information must be set by Attr(shape)."; return true; } @@ -73,36 +94,37 @@ bool Reshape2Op::CheckShape() const { bool Reshape2Op::InferShape() const { ReshapeOp::InferShape(); auto x_dims = param_.x->dims(); - std::vector xshape_dims(x_dims.size() + 1, 1); + std::vector xshape_dims(x_dims.size() + 1, 0); for (size_t i = 0; i < x_dims.size(); i++) { xshape_dims[i + 1] = x_dims[i]; } - param_.xshape->Resize(DDim(xshape_dims)); + param_.xshape->Resize(xshape_dims); + auto xshape_lod = param_.xshape->mutable_lod(); + *xshape_lod = param_.x->lod(); return true; } bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ReshapeOp::AttachImpl(opdesc, scope); auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); - CHECK(xshape_var); param_.xshape = xshape_var->GetMutable(); - CHECK(param_.xshape) << "Output(XShape) of ReshapeOp should not be null."; return true; } DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { - const DDim::value_type input_size = input_dims.production(); + const lite::DDim::value_type input_size = input_dims.production(); auto input_shape = input_dims.Vectorize(); - bool all_positive = std::all_of(input_shape.cbegin(), - input_shape.cend(), - [](DDim::value_type i) { return i > 0; }); + bool all_positive = std::all_of( + input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) { + return i > 0; + }); // only one dimension can be set to -1, whose size will be automatically // infered. const int unk_dim_val = -1; const int copy_dim_val = 0; - std::vector output_shape(shape.size(), 0); - DDim::value_type capacity = 1; + std::vector output_shape(shape.size(), 0); + lite::DDim::value_type capacity = 1; int unk_dim_idx = -1; for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] == unk_dim_val) { @@ -118,10 +140,10 @@ DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { "be negtive except one unknown dimension."; } - capacity *= - (shape[i] ? static_cast(shape[i]) : input_shape[i]); - output_shape[i] = - (shape[i] ? static_cast(shape[i]) : input_shape[i]); + capacity *= (shape[i] ? static_cast(shape[i]) + : input_shape[i]); + output_shape[i] = (shape[i] ? static_cast(shape[i]) + : input_shape[i]); } if (unk_dim_idx != -1) { @@ -139,7 +161,7 @@ DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { } else { CHECK_EQ(capacity, input_size) << "Invalid shape is given."; } - return DDim(output_shape); + return lite::DDim(output_shape); } } // namespace operators diff --git a/lite/operators/sequence_expand_as_op.cc b/lite/operators/sequence_expand_as_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..22a4743103fd4b188357d067a062ea827de7aaa0 --- /dev/null +++ b/lite/operators/sequence_expand_as_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_expand_as_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceExpandAsOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x) + CHECK_OR_FALSE(param_.y) + CHECK_OR_FALSE(param_.out) + + auto x_dims = param_.x->dims(); + CHECK_EQ_OR_FALSE(x_dims.size(), 2) + auto y_lod = param_.y->lod(); + CHECK_EQ_OR_FALSE(y_lod.size(), 1) + CHECK_EQ_OR_FALSE(static_cast(x_dims[0]), y_lod[0].size() - 1) + + return true; +} + +bool SequenceExpandAsOpLite::InferShape() const { + auto x_dims = param_.x->dims(); + auto y_lod = param_.y->lod(); + auto out_dims = x_dims; + + int64_t out_first_dim = 0; + if (y_lod[0].size() <= 1) { + out_first_dim = x_dims[0]; + } else { + for (size_t i = 1; i < y_lod[0].size(); ++i) { + out_first_dim += (y_lod[0][i] - y_lod[0][i - 1]); + } + } + out_dims[0] = out_first_dim; + + param_.out->Resize(out_dims); + param_.out->set_lod(y_lod); + + return true; +} + +bool SequenceExpandAsOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto y = op_desc.Input("Y").front(); + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.y = scope->FindVar(y)->GetMutable(); + param_.out = scope->FindVar(out)->GetMutable(); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_expand_as, + paddle::lite::operators::SequenceExpandAsOpLite) diff --git a/lite/operators/sequence_expand_as_op.h b/lite/operators/sequence_expand_as_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2eae8a26da31eb2937ab88f15d70bd44515e6a5f --- /dev/null +++ b/lite/operators/sequence_expand_as_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceExpandAsOpLite : public OpLite { + public: + SequenceExpandAsOpLite() {} + explicit SequenceExpandAsOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_expand_as"; } + + private: + mutable SequenceExpandAsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/sequence_reshape_op.cc b/lite/operators/sequence_reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7e86af65033205bcb389cecff8db14721507142 --- /dev/null +++ b/lite/operators/sequence_reshape_op.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_reshape_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceReshapeOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + CHECK_EQ_OR_FALSE(x_dims.size(), 2U); + return true; +} + +bool SequenceReshapeOp::InferShape() const { + int new_dim = param_.new_dim; + auto x_numel = param_.x->dims().production(); + std::vector out_shape{x_numel / new_dim, + static_cast(new_dim)}; + param_.output->Resize(lite::DDim(out_shape)); + return true; +} + +bool SequenceReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + param_.x = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + param_.new_dim = opdesc.GetAttr("new_dim"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_reshape, paddle::lite::operators::SequenceReshapeOp); diff --git a/lite/operators/sequence_reshape_op.h b/lite/operators/sequence_reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c8378aebc44acf22017eee17f5b58d6ff4dd65bf --- /dev/null +++ b/lite/operators/sequence_reshape_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceReshapeOp : public OpLite { + public: + SequenceReshapeOp() {} + explicit SequenceReshapeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_reshape"; } + + private: + mutable SequenceReshapeParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 4ab42d4d2129313220598d3ebc5f3cf7757308b2..18280616aa00b734596b620727f6dcfd5beb67d7 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -69,6 +69,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { auto input = opdesc.Input("X").front(); auto outs = opdesc.Output("Out"); param_.x = scope->FindVar(input)->GetMutable(); + param_.output.clear(); for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc index 19bd20f1ac0ee5c02b4fde6f6ec7bf9bcf75237c..01f96c28ff6be38e426030aa3c580f28f73b3a38 100644 --- a/lite/operators/squeeze_op.cc +++ b/lite/operators/squeeze_op.cc @@ -121,7 +121,7 @@ bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); CHECK(xshape_var); param_.XShape = xshape_var->GetMutable(); - CHECK(param_.XShape) << "Output(XShape) of ReshapeOp should not be null."; + CHECK(param_.XShape) << "Output(XShape) of SqueezeOp should not be null."; return true; } diff --git a/lite/operators/transpose_op.cc b/lite/operators/transpose_op.cc index 80e1c2f87b1b70579540a9bb404962b1589ec797..ce850be5334d596104cf545dc82abd44c62c88cc 100644 --- a/lite/operators/transpose_op.cc +++ b/lite/operators/transpose_op.cc @@ -154,6 +154,10 @@ bool Transpose2Op::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { if (op_desc.HasAttr("data_format")) { param_.data_format = op_desc.GetAttr("data_format"); } + if (op_desc.HasOutput("XShape")) { + auto xshape_var = scope->FindVar(op_desc.Output("XShape").front()); + param_.xshape = xshape_var->GetMutable(); + } return true; } diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..aca9a9c0e8bb2693d80c70d384489193ec94758c --- /dev/null +++ b/lite/operators/unsqueeze_op.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/unsqueeze_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +static DDim GetOutputShape(const std::vector &unsqz_dims, + const DDim &in_dims) { + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validate Check: rank range. + CHECK_LE(output_size, 6) << "The output tensor's rank should be less than 6."; + + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Validate Check: the axis bound + CHECK((cur >= 0) && (cur <= cur_output_size)) + << "The unsqueeze dims must be within range of current rank."; + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + return DDim(output_shape); +} + +bool UnsqueezeOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool UnsqueezeOp::InferShape() const { + std::vector unsqueeze_dims = param_.axes; + DDim in_dims = param_.X->dims(); + DDim out_dim = GetOutputShape(unsqueeze_dims, in_dims); + param_.Out->Resize(out_dim); + return true; +} + +bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.X = const_cast(&(x_var->Get())); + param_.Out = output_var->GetMutable(); + + if (opdesc.HasAttr("axes")) { + param_.axes = opdesc.GetAttr>("axes"); + } + CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null."; + CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null."; + return true; +} + +bool Unsqueeze2Op::CheckShape() const { + UnsqueezeOp::CheckShape(); + CHECK_OR_FALSE(param_.XShape); + return true; +} + +bool Unsqueeze2Op::InferShape() const { + UnsqueezeOp::InferShape(); + auto x_dims = param_.X->dims(); + std::vector xshape_dims(x_dims.size() + 1, 1); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.XShape->Resize(DDim(xshape_dims)); + return true; +} + +bool Unsqueeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + UnsqueezeOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.XShape = xshape_var->GetMutable(); + CHECK(param_.XShape) << "Output(XShape) of Unsqueeze2Op should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(unsqueeze, paddle::lite::operators::UnsqueezeOp); +REGISTER_LITE_OP(unsqueeze2, paddle::lite::operators::Unsqueeze2Op); diff --git a/lite/operators/unsqueeze_op.h b/lite/operators/unsqueeze_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1e88828c6c5fdef767850909c0dae8ec65e9d1e0 --- /dev/null +++ b/lite/operators/unsqueeze_op.h @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class UnsqueezeOp : public OpLite { + public: + UnsqueezeOp() {} + explicit UnsqueezeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "unsqueeze"; } + + protected: + mutable UnsqueezeParam param_; +}; + +class Unsqueeze2Op : public UnsqueezeOp { + public: + Unsqueeze2Op() : UnsqueezeOp() {} + explicit Unsqueeze2Op(const std::string &op_type) : UnsqueezeOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "unsqueeze2"; } +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/yolo_box_op.cc b/lite/operators/yolo_box_op.cc index 068cdf043193d3771334f0e7bac33ea190edf5e1..c8186d3f3182e21856919c46b83fe96a6e2bef93 100644 --- a/lite/operators/yolo_box_op.cc +++ b/lite/operators/yolo_box_op.cc @@ -31,11 +31,23 @@ bool YoloBoxOp::CheckShape() const { CHECK_OR_FALSE(ImgSize); CHECK_OR_FALSE(Boxes); CHECK_OR_FALSE(Scores); + + auto dim_x = X->dims(); + auto dim_imgsize = ImgSize->dims(); + std::vector anchors = param_.anchors; + int anchor_num = anchors.size() / 2; + auto class_num = param_.class_num; + CHECK_OR_FALSE(dim_x.size() == 4); + CHECK_OR_FALSE(dim_x[1] == anchor_num * (5 + class_num)); + CHECK_OR_FALSE(dim_imgsize[0] == dim_x[0]); + CHECK_OR_FALSE(dim_imgsize[1] == 2); + CHECK_OR_FALSE(anchors.size() > 0 && anchors.size() % 2 == 0); + CHECK_OR_FALSE(class_num > 0); + return true; } bool YoloBoxOp::InferShape() const { auto* X = param_.X; - auto* ImgSize = param_.ImgSize; auto anchors = param_.anchors; int anchor_num = anchors.size() / 2; auto class_num = param_.class_num; diff --git a/lite/tests/CMakeLists.txt b/lite/tests/CMakeLists.txt index 94e1eba1a5c894a6d8badc183fc1582a0e182a44..11fa2f0cb6d26a2c2739cc2e90aadf61b58001d2 100644 --- a/lite/tests/CMakeLists.txt +++ b/lite/tests/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(kernels) +add_subdirectory(math) diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 9df2f93acfd8148c1813bb26e3a934b33d5c5051..f2c2c9a71666b539248c955c6e75470c5933b5c9 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -1,4 +1,4 @@ -if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) +if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -15,6 +15,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -30,8 +31,6 @@ if(LITE_BUILD_EXTRA) lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() - - lite_cc_test(test_sgemm SRCS test_sgemm.cc DEPS ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -41,11 +40,13 @@ endif() lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/activation_compute_test.cc b/lite/tests/kernels/activation_compute_test.cc index 6f1d1cdcf0debcadbac0fc5389f2f19bbdd9db23..5aaca9083aea5afabf5171d13f666e7bd41d00c1 100644 --- a/lite/tests/kernels/activation_compute_test.cc +++ b/lite/tests/kernels/activation_compute_test.cc @@ -33,7 +33,8 @@ enum activation_type_test { RELU6, LOG, EXP, - FLOOR + FLOOR, + RSQRT }; class ActivationComputeTester : public arena::TestCase { @@ -177,6 +178,12 @@ class ActivationComputeTester : public arena::TestCase { } break; } + case RSQRT: { + for (int i = 0; i < dims_.production(); i++) { + output_data[i] = 1.0 / std::sqrt(x_data[i]); + } + break; + } default: LOG(INFO) << "the type of activation is unknow."; } @@ -205,7 +212,7 @@ class ActivationComputeTester : public arena::TestCase { std::vector data(dims_.production()); for (int i = 0; i < dims_.production(); i++) { float sign = i % 3 == 0 ? -1.0f : 1.0f; - sign = type_ == "log" ? 1 : sign; + sign = (type_ == "log" || type_ == "rsqrt") ? 1 : sign; data[i] = sign * static_cast(i % 128) * 0.013f + 0.001; } SetCommonTensor(input_, dims_, data.data()); @@ -553,5 +560,31 @@ TEST(Activation_floor, precision) { #endif } +TEST(Activation_rsqrt, precision) { + LOG(INFO) << "test rsqrt op"; +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + for (auto n : {2}) { + for (auto c : {2}) { + for (auto h : {2}) { + for (auto w : {2}) { + std::unique_ptr tester(new ActivationComputeTester( + place, + "def", + 0.01, + 6., + "all", + 0., + DDim(std::vector({n, c, h, w})), + "rsqrt", + RSQRT)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } +#endif +} } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/affine_channel_compute_test.cc b/lite/tests/kernels/affine_channel_compute_test.cc index 0e0c044e5628960be85eb8486f3d1d205f67ee9c..9fac0d9379a535a20cf20f4bc5002d193eeeb2b9 100644 --- a/lite/tests/kernels/affine_channel_compute_test.cc +++ b/lite/tests/kernels/affine_channel_compute_test.cc @@ -64,8 +64,6 @@ class AffineChannelComputeTester : public arena::TestCase { if (data_layout_ == "NCHW") { int channel = x_dims_[1]; - int height = x_dims_[2]; - int width = x_dims_[3]; int size = x_dims_[2] * x_dims_[3]; int in_channel = channel * size; for (int n = 0; n < num; n++) { diff --git a/lite/tests/kernels/argmax_compute_test.cc b/lite/tests/kernels/argmax_compute_test.cc index 49cbd910718c095f8c704bf501ad31bc2cdf5517..9163e4bdaf5ab1da71b565dbd435b1a31ea9dcce 100644 --- a/lite/tests/kernels/argmax_compute_test.cc +++ b/lite/tests/kernels/argmax_compute_test.cc @@ -25,7 +25,7 @@ class ArgmaxComputeTester : public arena::TestCase { // common attributes for this op. std::string input_ = "x"; std::string output_ = "out"; - int axis_ = 0.; + int64_t axis_ = 0.; DDim dims_{{2, 5, 20, 30}}; public: @@ -82,10 +82,10 @@ class ArgmaxComputeTester : public arena::TestCase { } void PrepareOpDesc(cpp::OpDesc* op_desc) { - op_desc->SetType("argmax"); + op_desc->SetType("arg_max"); op_desc->SetInput("X", {input_}); op_desc->SetOutput("Out", {output_}); - op_desc->SetAttr("Axis", axis_); + op_desc->SetAttr("axis", axis_); } void PrepareData() override { diff --git a/lite/tests/kernels/assign_value_compute_test.cc b/lite/tests/kernels/assign_value_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..96959e507d21b52a56dddfa45eaf7e773f770967 --- /dev/null +++ b/lite/tests/kernels/assign_value_compute_test.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class AssignValueComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string out_ = "out"; + int dtype_{}; + std::vector shape_{}; + std::vector int32_values_{}; + std::vector fp32_values_{}; + size_t num_ = 1; + + public: + AssignValueComputeTester(const Place& place, + const std::string& alias, + int dtype, + int n, + int c, + int h, + int w) + : TestCase(place, alias) { + dtype_ = dtype; + shape_.push_back(n); + shape_.push_back(c); + shape_.push_back(h); + shape_.push_back(w); + num_ = n * c * h * w; + } + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(out_); + CHECK(out); + std::vector out_shape(shape_.begin(), shape_.end()); + out->Resize(out_shape); + if (dtype_ == 2) { + auto* out_data = out->mutable_data(); + for (int i = 0; i < out->numel(); i++) { + out_data[i] = int32_values_[i]; + } + } else if (dtype_ == 5) { + auto* out_data = out->mutable_data(); + for (int i = 0; i < out->numel(); i++) { + out_data[i] = fp32_values_[i]; + } + } else { + LOG(FATAL) << "unsuport dtype_:" << dtype_; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("assign_value"); + op_desc->SetAttr("shape", shape_); + op_desc->SetAttr("dtype", dtype_); + op_desc->SetAttr("fp32_values", fp32_values_); + op_desc->SetAttr("int32_values", int32_values_); + op_desc->SetOutput("Out", {out_}); + } + + void PrepareData() override { + // int32 + if (dtype_ == 2) { + int32_values_.resize(num_); + for (int i = 0; i < num_; i++) { + int32_values_[i] = i; + } + } else if (dtype_ == 5) { + fp32_values_.resize(num_); + for (int i = 0; i < num_; i++) { + fp32_values_[i] = i / 1.23f; + } + } else { + LOG(FATAL) << "unsupport dtype_:" << dtype_; + } + } +}; + +TEST(AssignValue, precision) { + LOG(INFO) << "test argmax op"; +#ifdef LITE_WITH_ARM + LOG(INFO) << "test argmax arm"; + Place place(TARGET(kARM)); + + for (int dtype : {2, 5}) { + for (int n : {1}) { + for (int c : {2}) { + for (int h : {1}) { + for (int w : {2}) { + std::unique_ptr tester( + new AssignValueComputeTester(place, "def", dtype, n, c, h, w)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/bilinear_interp_compute_test.cc b/lite/tests/kernels/bilinear_interp_compute_test.cc index b80861217b3669fc071e82605946d665232b7903..0779caf67aac907e6f8ccde8b3e65d413cf65db9 100644 --- a/lite/tests/kernels/bilinear_interp_compute_test.cc +++ b/lite/tests/kernels/bilinear_interp_compute_test.cc @@ -156,26 +156,30 @@ class BilinearInterpComputeTester : public arena::TestCase { float width_scale_ = 0.f; int out_height_ = -1; int out_width_ = -1; + int outsize_height_ = -1; + int outsize_width_ = -1; bool align_corners_ = true; std::string interp_method_ = "Bilinear"; - DDim dims_{{1, 1}}; DDim _dims0_{{1, 1, 16, 16}}; DDim _dims1_{{2}}; public: BilinearInterpComputeTester(const Place& place, const std::string& alias, - float height_scale, - float width_scale, + float scale, int out_height, int out_width, + int outsize_height, + int outsize_width, bool align_corners, std::string interp_method) : TestCase(place, alias), - height_scale_(height_scale), - width_scale_(width_scale), + height_scale_(scale), + width_scale_(scale), out_height_(out_height), out_width_(out_width), + outsize_height_(outsize_height), + outsize_width_(outsize_width), align_corners_(align_corners), interp_method_(interp_method) {} @@ -183,8 +187,9 @@ class BilinearInterpComputeTester : public arena::TestCase { width_scale_ = height_scale_; std::vector inputs; inputs.emplace_back(scope->FindTensor(input0_)); - inputs.emplace_back(scope->FindTensor(input1_)); - auto outsize_data = inputs[1]->data(); + if (outsize_height_ > 0 && outsize_width_ > 0) { + inputs.emplace_back(scope->FindTensor(input1_)); + } if (out_width_ != -1 && out_height_ != -1) { height_scale_ = static_cast(out_height_ / inputs[0]->dims()[2]); width_scale_ = static_cast(out_width_ / inputs[0]->dims()[3]); @@ -192,6 +197,7 @@ class BilinearInterpComputeTester : public arena::TestCase { auto* outputs = scope->NewTensor(output_); CHECK(outputs); if (inputs.size() > 1) { + auto outsize_data = inputs[1]->data(); int h_out = outsize_data[0]; // HW int w_out = outsize_data[1]; // HW int num_cout = inputs[0]->dims()[0]; @@ -221,7 +227,9 @@ class BilinearInterpComputeTester : public arena::TestCase { void PrepareOpDesc(cpp::OpDesc* op_desc) { op_desc->SetType("bilinear_interp"); op_desc->SetInput("X", {input0_}); - op_desc->SetInput("OutSize", {input1_}); + if (outsize_height_ > 0 && outsize_width_ > 0) { + op_desc->SetInput("OutSize", {input1_}); + } op_desc->SetOutput("Out", {output_}); op_desc->SetAttr("scale", height_scale_); op_desc->SetAttr("out_h", out_height_); @@ -237,32 +245,58 @@ class BilinearInterpComputeTester : public arena::TestCase { } SetCommonTensor(input0_, _dims0_, data0.data()); - std::vector data1(_dims1_.production()); - for (int i = 0; i < _dims1_.production(); i++) { - data1[i] = 16; + if (outsize_height_ > 0 && outsize_width_ > 0) { + std::vector data1(2); + data1[0] = outsize_height_; + data1[1] = outsize_width_; + SetCommonTensor(input1_, _dims1_, data1.data()); } - SetCommonTensor(input1_, _dims1_, data1.data()); } }; void test_bilinear_interp(Place place) { std::string interp_method = "Bilinear"; - for (float scale : {1., 0.5, 0.3}) { - for (int out_height : {8, 16}) { - for (int out_width : {8, 16}) { - for (bool align_corners : {true, false}) { - std::unique_ptr tester( - new BilinearInterpComputeTester(place, - "def", - scale, - scale, - out_height, - out_width, - align_corners, - interp_method)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } + for (float scale : {2., 1., 0.3}) { + for (bool align_corners : {true, false}) { + std::unique_ptr tester(new BilinearInterpComputeTester( + place, "def", scale, -1, -1, -1, -1, align_corners, interp_method)); + arena::Arena arena(std::move(tester), place, 5e-5); + arena.TestPrecision(); + } + } + for (int out_height : {8, 16, 24}) { + for (int out_width : {8, 16, 24}) { + for (bool align_corners : {true, false}) { + std::unique_ptr tester( + new BilinearInterpComputeTester(place, + "def", + 0, + out_height, + out_width, + -1, + -1, + align_corners, + interp_method)); + arena::Arena arena(std::move(tester), place, 5e-5); + arena.TestPrecision(); + } + } + } + for (int outsize_height : {8, 16, 24}) { + for (int outsize_width : {8, 16, 24}) { + for (bool align_corners : {true, false}) { + std::unique_ptr tester( + new BilinearInterpComputeTester(place, + "def", + 0, + -1, + -1, + outsize_height, + outsize_width, + align_corners, + interp_method)); + arena::Arena arena(std::move(tester), place, 5e-5); + arena.TestPrecision(); } } } diff --git a/lite/tests/kernels/box_coder_compute_test.cc b/lite/tests/kernels/box_coder_compute_test.cc index f3f9b7e0ab9339161a5b227db1a04a22d98833e5..9a833db31db7a6a53a4d29ed208b67e5dc77af12 100644 --- a/lite/tests/kernels/box_coder_compute_test.cc +++ b/lite/tests/kernels/box_coder_compute_test.cc @@ -121,16 +121,10 @@ class BoxCoderComputeTester : public arena::TestCase { auto* output_box = scope->NewTensor(output_box_); CHECK(output_box); output_box->Resize(target_box_dims_); - auto* output_box_data = output_box->mutable_data(); auto* prior_box = scope->FindTensor(prior_box_); - const auto* prior_box_data = prior_box->data(); - auto* prior_box_var = scope->FindTensor(prior_box_var_); - const auto* prior_box_var_data = prior_box_var->data(); - auto* target_box = scope->FindTensor(target_box_); - const auto* target_box_data = target_box->data(); box_coder_ref(output_box, prior_box, diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index f000ea1d719bfc1389ce4656d688a31de67346d6..e738b67a71755c0c051d2741638cc22d55287e93 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -25,34 +25,49 @@ class CastComputeTester : public arena::TestCase { // common attributes for this op. std::string input_ = "x"; std::string output_ = "out"; - int in_dtype_ = 21; - int out_dtype_ = 5; - DDim x_dims_{{2, 2, 2, 2}}; + int in_dtype_; + int out_dtype_; + DDim x_dims_{{2, 2}}; public: - CastComputeTester(const Place& place, const std::string& alias) - : TestCase(place, alias) {} + CastComputeTester(const Place& place, + const std::string& alias, + int in_dtype, + int out_dtype) + : TestCase(place, alias), in_dtype_(in_dtype), out_dtype_(out_dtype) {} void RunBaseline(Scope* scope) override { auto* out = scope->NewTensor(output_); CHECK(out); out->Resize(x_dims_); - auto* output_data = out->mutable_data(); - auto* x = scope->FindTensor(input_); - const auto* x_data = x->data(); - - int num = x_dims_[0]; - int channel = x_dims_[1]; - int size = x_dims_[2] * x_dims_[3]; - int in_channel = channel * size; - - auto* output_data_tmp = output_data; - auto* x_data_tmp = x_data; - for (int i = 0; i < x_dims_.production(); i++) { - *output_data_tmp = static_cast(*x_data_tmp); - output_data_tmp++; - x_data_tmp++; + if (out_dtype_ == 5 && in_dtype_ == 20) { + auto* x = scope->FindTensor(input_); + auto* x_data = x->data(); + auto* output_data = out->mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + *output_data = static_cast(*x_data); + output_data++; + x_data++; + } + } else if (out_dtype_ == 5 && in_dtype_ == 21) { + auto* output_data = out->mutable_data(); + auto* x = scope->FindTensor(input_); + auto* x_data = x->data(); + for (int i = 0; i < x_dims_.production(); i++) { + *output_data = static_cast(*x_data); + output_data++; + x_data++; + } + } else if (out_dtype_ == 5 && in_dtype_ == 2) { + auto* output_data = out->mutable_data(); + auto* x = scope->FindTensor(input_); + auto* x_data = x->data(); + for (int i = 0; i < x_dims_.production(); i++) { + *output_data = static_cast(*x_data); + output_data++; + x_data++; + } } } @@ -65,12 +80,29 @@ class CastComputeTester : public arena::TestCase { } void PrepareData() override { - std::vector x_data(x_dims_.production()); - for (int i = 0; i < x_dims_.production(); i++) { - float sign = i % 3 == 0 ? -1.0f : 1.0f; - x_data[i] = sign * static_cast(i % 128); + if (in_dtype_ == 20) { + std::vector x_data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = static_cast(i % 128); + } + SetCommonTensor(input_, x_dims_, x_data.data()); + } else if (in_dtype_ == 21) { + std::vector x_data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_data[i] = sign * static_cast(i % 128); + } + SetCommonTensor(input_, x_dims_, x_data.data()); + } else if (in_dtype_ == 2) { + std::vector x_data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + int sign = i % 3 == 0 ? -1 : 1; + x_data[i] = sign * static_cast(i % 128); + } + SetCommonTensor(input_, x_dims_, x_data.data()); + } else { + LOG(FATAL) << "not implemented!"; } - SetCommonTensor(input_, x_dims_, x_data.data()); } }; @@ -79,9 +111,15 @@ TEST(Cast, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - std::unique_ptr tester(new CastComputeTester(place, "def")); + std::unique_ptr tester( + new CastComputeTester(place, "def", 20, 5)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); + +// std::unique_ptr tester1( +// new CastComputeTester(place, "def", 2, 5)); +// arena::Arena arena1(std::move(tester1), place, 2e-5); +// arena1.TestPrecision(); #endif } diff --git a/lite/tests/kernels/conv2d_transpose_compute_test.cc b/lite/tests/kernels/conv2d_transpose_compute_test.cc index c44259022d19cd67ca437292c80487a2274bca5f..a287f0bb6610921e0f048fcc4d46f8729dd177c1 100644 --- a/lite/tests/kernels/conv2d_transpose_compute_test.cc +++ b/lite/tests/kernels/conv2d_transpose_compute_test.cc @@ -190,7 +190,6 @@ bool deconv_basic(const Dtype1* din, auto* workspace_ptr = workspace_tensor.mutable_data(); int group_size_in = win * hin * chin / group; - int group_size_out = wout * hout * chout / group; int group_size_coldata = m * n; int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && diff --git a/lite/tests/kernels/elementwise_compute_test.cc b/lite/tests/kernels/elementwise_compute_test.cc index ceceb6394a383e797561e10a2856fe8bc6ceb334..635f6e7c080c0565299ca416fc445637254d8a4e 100644 --- a/lite/tests/kernels/elementwise_compute_test.cc +++ b/lite/tests/kernels/elementwise_compute_test.cc @@ -43,7 +43,6 @@ class ElementwiseComputeTester : public arena::TestCase { auto* x = scope->FindTensor(inputx_); const auto* x_data = x->data(); - auto* y = scope->FindTensor(inputy_); const auto* y_data = x->data(); for (int i = 0; i < dims_.production(); i++) { @@ -71,6 +70,56 @@ class ElementwiseComputeTester : public arena::TestCase { } }; +class ElementwiseSubComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string inputx_ = "x"; + std::string inputy_ = "y"; + std::string output_ = "out"; + int axis_; + DDim dims_{{1, 2, 3, 4}}; + + public: + ElementwiseSubComputeTester(const Place& place, + const std::string& alias, + int axis) + : TestCase(place, alias), axis_(axis) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + out->Resize(dims_); + auto* out_data = out->mutable_data(); + + auto* x = scope->FindTensor(inputx_); + const auto* x_data = x->data(); + const auto* y_data = x->data(); + + for (int i = 0; i < dims_.production(); i++) { + out_data[i] = x_data[i] - y_data[i]; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("elementwise_sub"); + op_desc->SetInput("X", {inputx_}); + op_desc->SetInput("Y", {inputy_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("axis", axis_); + } + + void PrepareData() override { + std::vector data(dims_.production()); + + for (int i = 0; i < dims_.production(); i++) { + data[i] = i * 1.1; + } + + SetCommonTensor(inputx_, dims_, data.data()); + SetCommonTensor(inputy_, dims_, data.data()); + } +}; + class ElementwiseMulComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -94,7 +143,6 @@ class ElementwiseMulComputeTester : public arena::TestCase { auto* x = scope->FindTensor(inputx_); const auto* x_data = x->data(); - auto* y = scope->FindTensor(inputy_); const auto* y_data = x->data(); for (int i = 0; i < dims_.production(); i++) { @@ -145,7 +193,6 @@ class ElementwiseMaxComputeTester : public arena::TestCase { auto* x = scope->FindTensor(inputx_); const auto* x_data = x->data(); - auto* y = scope->FindTensor(inputy_); const auto* y_data = x->data(); for (int i = 0; i < dims_.production(); i++) { @@ -198,7 +245,6 @@ class FusionElementwiseAddActivationComputeTester : public arena::TestCase { auto* x = scope->FindTensor(inputx_); const auto* x_data = x->data(); - auto* y = scope->FindTensor(inputy_); const auto* y_data = x->data(); for (int i = 0; i < dims_.production(); i++) { @@ -232,6 +278,64 @@ class FusionElementwiseAddActivationComputeTester : public arena::TestCase { } }; +class FusionElementwiseSubActivationComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string inputx_ = "x"; + std::string inputy_ = "y"; + std::string output_ = "out"; + int axis_; + std::string act_type_; + DDim dims_{{1, 2, 3, 4}}; + + public: + FusionElementwiseSubActivationComputeTester(const Place& place, + const std::string& alias, + int axis, + std::string act_type) + : TestCase(place, alias), axis_(axis), act_type_(act_type) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + out->Resize(dims_); + auto* out_data = out->mutable_data(); + + auto* x = scope->FindTensor(inputx_); + const auto* x_data = x->data(); + const auto* y_data = x->data(); + + for (int i = 0; i < dims_.production(); i++) { + out_data[i] = x_data[i] - y_data[i]; + if (act_type_ == "relu") { + out_data[i] = out_data[i] > 0 ? out_data[i] : 0; + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type_; + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("fusion_elementwise_sub_activation"); + op_desc->SetInput("X", {inputx_}); + op_desc->SetInput("Y", {inputy_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("axis", axis_); + op_desc->SetAttr("act_type", act_type_); + } + + void PrepareData() override { + std::vector data(dims_.production()); + + for (int i = 0; i < dims_.production(); i++) { + data[i] = i * 1.1; + } + + SetCommonTensor(inputx_, dims_, data.data()); + SetCommonTensor(inputy_, dims_, data.data()); + } +}; + class FusionElementwiseMulActivationComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -257,7 +361,6 @@ class FusionElementwiseMulActivationComputeTester : public arena::TestCase { auto* x = scope->FindTensor(inputx_); const auto* x_data = x->data(); - auto* y = scope->FindTensor(inputy_); const auto* y_data = x->data(); for (int i = 0; i < dims_.production(); i++) { @@ -316,7 +419,6 @@ class FusionElementwiseMaxActivationComputeTester : public arena::TestCase { auto* x = scope->FindTensor(inputx_); const auto* x_data = x->data(); - auto* y = scope->FindTensor(inputy_); const auto* y_data = x->data(); for (int i = 0; i < dims_.production(); i++) { @@ -441,7 +543,6 @@ class FusionElementwiseDivActivationComputeTester : public arena::TestCase { } else { LOG(FATAL) << "unsupported Activation type: " << act_type_; } - LOG(INFO) << "fusion div resul:" << out_data[i]; } } @@ -476,6 +577,11 @@ void test_elementwise(Place place) { arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); + std::unique_ptr tester_sub( + new ElementwiseSubComputeTester(place, "def", axis)); + arena::Arena arena_sub(std::move(tester_sub), place, 2e-5); + arena_sub.TestPrecision(); + std::unique_ptr tester_mul( new ElementwiseMulComputeTester(place, "def", axis)); arena::Arena arena_mul(std::move(tester_mul), place, 2e-5); @@ -511,6 +617,12 @@ void test_fusion_elementwise(Place place) { arena::Arena arena_add_act(std::move(tester_add_act), place, 2e-5); arena_add_act.TestPrecision(); + std::unique_ptr tester_sub_act( + new FusionElementwiseSubActivationComputeTester( + place, "def", axis, "relu")); + arena::Arena arena_sub_act(std::move(tester_sub_act), place, 2e-5); + arena_sub_act.TestPrecision(); + std::unique_ptr tester_mul_act( new FusionElementwiseMulActivationComputeTester( place, "def", axis, "relu")); diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index 95a8167701aa72dcc992f3ba829182bea6f3d143..ef5baa81853294b45aa73c7911ad4e1b993a07d5 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -16,8 +16,8 @@ #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/core/arena/framework.h" -#include "lite/tests/kernels/fill_data.h" -#include "lite/tests/kernels/test_funcs.h" +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" namespace paddle { namespace lite { @@ -51,10 +51,10 @@ class FcOPTest : public arena::TestCase { std::string weight_ = "w"; std::string bias_ = "b"; std::string out_ = "out"; - int in_num_col_dims_{1}; DDim dims_{{1, 128}}; DDim wdims_{{128, 4}}; DDim bdims_{{4}}; + int in_num_col_dims_{1}; public: FcOPTest(const Place& place, diff --git a/lite/tests/kernels/gru_unit_test.cc b/lite/tests/kernels/gru_unit_test.cc index bf4b7dd5e285d30a3227ee463653186cd3b42953..98ce7ebc198a13abca86a7e2f40a61330eebb9be 100644 --- a/lite/tests/kernels/gru_unit_test.cc +++ b/lite/tests/kernels/gru_unit_test.cc @@ -14,8 +14,8 @@ #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/core/arena/framework.h" -#include "lite/tests/kernels/fill_data.h" -#include "lite/tests/kernels/test_funcs.h" +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" namespace paddle { namespace lite { @@ -243,11 +243,11 @@ class GRUUnitTester : public arena::TestCase { std::string reset_hidden_prev_ = "reset_hidden_prev"; std::string hidden_ = "hidden"; - DDim dims_{{16, 256 * 3}}; // 0: indentity; 1: sigmoid; 2: tanh; 3: relu int gate_activation_{1}; int activation_{2}; bool origin_mode_{false}; + DDim dims_{{16, 256 * 3}}; public: GRUUnitTester(const Place& place, diff --git a/lite/tests/kernels/lrn_compute_test.cc b/lite/tests/kernels/lrn_compute_test.cc index cd0931fcc5ea4324a218e2107550bc047e2268cc..9ee43c5c60b4703f64e7a2575ec15ba59b618052 100644 --- a/lite/tests/kernels/lrn_compute_test.cc +++ b/lite/tests/kernels/lrn_compute_test.cc @@ -123,7 +123,6 @@ class LrnComputeTester : public arena::TestCase { int H = dims_[2]; int W = dims_[3]; - int pre_pad = (local_size_ - 1) / 2; int offset_num = 0; int offset_within_channel = 0; int dst_id; diff --git a/lite/tests/kernels/matmul_compute_test.cc b/lite/tests/kernels/matmul_compute_test.cc index 8b70f59d4756c47ceee039ab7797a66e8f695c2e..4915614b345c23119af37aa575bc07d4174fdcde 100644 --- a/lite/tests/kernels/matmul_compute_test.cc +++ b/lite/tests/kernels/matmul_compute_test.cc @@ -120,12 +120,12 @@ class MatMulComputeTester : public arena::TestCase { // common attributes for this op. std::string x_ = "X"; std::string y_ = "Y"; - std::string out_ = "Out"; - DDim x_dims_; - DDim y_dims_; bool x_transpose_; bool y_transpose_; float alpha_; + std::string out_ = "Out"; + DDim x_dims_; + DDim y_dims_; public: MatMulComputeTester(const Place& place, diff --git a/lite/tests/kernels/nearest_interp_compute_test.cc b/lite/tests/kernels/nearest_interp_compute_test.cc index a81557774c759852b83cde0e46955ae63bfa7535..3256ababcab639cd31ef51294a890b7fbdb54d5d 100644 --- a/lite/tests/kernels/nearest_interp_compute_test.cc +++ b/lite/tests/kernels/nearest_interp_compute_test.cc @@ -51,9 +51,11 @@ void resize_nearest_align(std::vector inputs, int src_index = n * src_stride_batch + c * src_stride_c; for (int h = 0; h < hout; ++h) { for (int w = 0; w < wout; ++w) { - dtype fw = scale_w * w + 0.5; + int fw = (with_align) ? static_cast(scale_w * w + 0.5) + : static_cast(scale_w * w); fw = (fw < 0) ? 0 : fw; - dtype fh = scale_h * h + 0.5; + int fh = (with_align) ? static_cast(scale_h * h + 0.5) + : static_cast(scale_h * h); fh = (fh < 0) ? 0 : fh; int w_start = static_cast(fw); int h_start = static_cast(fh); diff --git a/lite/tests/kernels/norm_compute_test.cc b/lite/tests/kernels/norm_compute_test.cc index 830bac062784a8c16752f4e43a23ed8157cc6c0f..6aee1758c19cd793de709921c4733b8892e5f3d9 100644 --- a/lite/tests/kernels/norm_compute_test.cc +++ b/lite/tests/kernels/norm_compute_test.cc @@ -46,7 +46,7 @@ class NormComputeTester : public arena::TestCase { auto* x = scope->FindTensor(input_); const auto* x_data = x->data(); - int axis = axis_ < 0 ? axis + dims_.size() : axis_; + int axis = axis_ < 0 ? axis_ + dims_.size() : axis_; int pre_n = dims_.count(0, axis); int n = dims_[axis]; int post_n = dims_.count(axis + 1, dims_.size()); diff --git a/lite/tests/kernels/pad2d_compute_test.cc b/lite/tests/kernels/pad2d_compute_test.cc index 78afbd97ae71fdc15c31b1f2dd1664805e541cc0..818e7d2e3b2bf7ba59f658d0545fbc255e332eaa 100644 --- a/lite/tests/kernels/pad2d_compute_test.cc +++ b/lite/tests/kernels/pad2d_compute_test.cc @@ -26,8 +26,8 @@ class Pad2dComputeTester : public arena::TestCase { std::string input_ = "X"; std::string output_ = "Out"; DDim dims_{{1, 1, 14, 14}}; - std::vector paddings_; std::string mode_{"constant"}; + std::vector paddings_; float pad_value_ = 0.f; std::string data_format_{"NCHW"}; diff --git a/lite/tests/kernels/prior_box_compute_test.cc b/lite/tests/kernels/prior_box_compute_test.cc index 47f7bc9447b1b33b57c4bc4a495a106f49d6abbc..73fd612c3a03c0a15ddaf3ce6c08ff0ed1a5a95b 100644 --- a/lite/tests/kernels/prior_box_compute_test.cc +++ b/lite/tests/kernels/prior_box_compute_test.cc @@ -125,7 +125,6 @@ void prior_box_compute_ref(const lite::Tensor* input, if (fixed_size_.size() > 0) { for (int s = 0; s < fixed_size_.size(); ++s) { int fixed_size = fixed_size_[s]; - int com_idx = 0; box_width = fixed_size; box_height = fixed_size; diff --git a/lite/tests/kernels/range_compute_test.cc b/lite/tests/kernels/range_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d98e882c88aa05395facc7c0afcf023b0fd8ccde --- /dev/null +++ b/lite/tests/kernels/range_compute_test.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class RangeComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string start = "Start"; + std::string end = "End"; + std::string step = "Step"; + std::string out = "Out"; + int st_, ed_, sp_; + + public: + RangeComputeTester(const Place& place, + const std::string& alias, + float st, + float ed, + float sp) + : TestCase(place, alias), st_(st), ed_(ed), sp_(sp) {} + + void RunBaseline(Scope* scope) override { + auto* output = scope->NewTensor(out); + CHECK(output); + int64_t size; + auto* st = scope->FindMutableTensor(start); + auto* ed = scope->FindMutableTensor(end); + auto* sp = scope->FindMutableTensor(step); + float st_val = st->data()[0]; + float ed_val = ed->data()[0]; + float sp_val = sp->data()[0]; + // size = (std::abs(ed_val - st_val) + std::abs(sp_val) - 1) / + // std::abs(sp_val); + size = std::ceil(std::abs((ed_val - st_val) / sp_val)); + output->Resize(DDim(std::vector({static_cast(size)}))); + auto* out_data = output->mutable_data(); + + float val = st_; + for (int i = 0; i < size; i++) { + out_data[i] = val; + val += sp_; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("range"); + op_desc->SetInput("Start", {start}); + op_desc->SetInput("End", {end}); + op_desc->SetInput("Step", {step}); + op_desc->SetOutput("Out", {out}); + } + + void PrepareData() override { + std::vector st(1); + std::vector ed(1); + std::vector sp(1); + + st[0] = st_; + ed[0] = ed_; + sp[0] = sp_; + DDim dim(std::vector({1})); + + SetCommonTensor(start, dim, st.data()); + SetCommonTensor(end, dim, ed.data()); + SetCommonTensor(step, dim, sp.data()); + } +}; + +void test_range(Place place) { + std::unique_ptr tester1( + new RangeComputeTester(place, "def", 1, 10, 1)); + arena::Arena arena(std::move(tester1), place, 2e-5); + arena.TestPrecision(); + + std::unique_ptr tester2( + new RangeComputeTester(place, "def", 10, 1, -2)); + arena::Arena arena2(std::move(tester2), place, 2e-5); + arena2.TestPrecision(); +} + +TEST(Range, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_range(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/reduce_max_compute_test.cc b/lite/tests/kernels/reduce_max_compute_test.cc index 2a1116d65f48ce8f519c2fde566ecd3061577eb1..a6d66846d595035a9954195f3e452d71ed22aa89 100644 --- a/lite/tests/kernels/reduce_max_compute_test.cc +++ b/lite/tests/kernels/reduce_max_compute_test.cc @@ -28,7 +28,7 @@ void reduce_n(const float* src, int width_in) { int hw_size = height_in * width_in; int chw_size = channel_in * hw_size; - int data_index, src_index, src_index0; + int data_index, src_index; for (int c = 0; c < channel_in; ++c) { for (int h = 0; h < height_in; ++h) { for (int w = 0; w < width_in; ++w) { @@ -196,9 +196,9 @@ class ReduceMaxComputeTester : public arena::TestCase { std::string input_ = "x"; std::string output_ = "out"; std::vector dim_{0}; - DDim x_dims_{{3, 2, 3, 4}}; bool keep_dim_ = false; bool reduce_all_ = false; + DDim x_dims_{{3, 2, 3, 4}}; public: ReduceMaxComputeTester(const Place& place, diff --git a/lite/tests/kernels/reduce_mean_compute_test.cc b/lite/tests/kernels/reduce_mean_compute_test.cc index cda273239de29d1f9edf99fcca91081615e0e86f..23f97fbb776a9b4aad7b81fe76315752b8524f93 100644 --- a/lite/tests/kernels/reduce_mean_compute_test.cc +++ b/lite/tests/kernels/reduce_mean_compute_test.cc @@ -28,7 +28,7 @@ void reduce_mean_n(const float* src, int width_in) { int hw_size = height_in * width_in; int chw_size = channel_in * hw_size; - int data_index, src_index, src_index0; + int data_index, src_index; for (int c = 0; c < channel_in; ++c) { for (int h = 0; h < height_in; ++h) { for (int w = 0; w < width_in; ++w) { @@ -195,8 +195,8 @@ class ReduceMeanComputeTester : public arena::TestCase { std::string input_ = "x"; std::string output_ = "out"; std::vector dim_{0}; - DDim x_dims_{{3, 2, 3, 4}}; bool keep_dim_ = false; + DDim x_dims_{{3, 2, 3, 4}}; bool reduce_all_ = false; public: diff --git a/lite/tests/kernels/sequence_expand_compute_test.cc b/lite/tests/kernels/sequence_expand_compute_test.cc index c110f52793e2c79386c477f4a3ccdaa674572efa..05d814979624943f72a5ecdf480c7eafc0dba160 100644 --- a/lite/tests/kernels/sequence_expand_compute_test.cc +++ b/lite/tests/kernels/sequence_expand_compute_test.cc @@ -25,10 +25,10 @@ class SequenceExpandComputeTester : public arena::TestCase { const std::string input_x_ = "x"; const std::string input_y_ = "y"; const std::string output_ = "out"; - int ref_level_ = -1; - DDim dims_{{4, 1}}; LoD lod_x_{{0, 2, 4}}; LoD lod_y_{{0, 1, 4}}; + int ref_level_ = -1; + DDim dims_{{4, 1}}; public: SequenceExpandComputeTester(const Place& place, @@ -50,7 +50,6 @@ class SequenceExpandComputeTester : public arena::TestCase { const auto* x_data = x->data(); (x->mutable_lod())->clear(); (x->mutable_lod())->push_back(lod_x_[0]); - int x_rank = dims_.size(); auto width = x->numel() / dims_[0]; auto lod_x = x->lod(); @@ -59,7 +58,6 @@ class SequenceExpandComputeTester : public arena::TestCase { for (int i = 0; i < lod_y_.size(); i++) { (y->mutable_lod())->push_back(lod_y_[i]); } - const auto* y_data = y->data(); if (ref_level_ == -1) { ref_level_ = lod_y_.size() - 1; } diff --git a/lite/tests/kernels/sequence_pool_compute_test.cc b/lite/tests/kernels/sequence_pool_compute_test.cc index 717b468721769bb19ac0395832dbfd61a2224ec2..f987fb280220c74d1a9e3377c5170580bc65d42a 100644 --- a/lite/tests/kernels/sequence_pool_compute_test.cc +++ b/lite/tests/kernels/sequence_pool_compute_test.cc @@ -25,9 +25,9 @@ class SequencePoolComputeTester : public arena::TestCase { // common attributes for this op. std::string input_ = "x"; std::string output_ = "out"; - DDim dims_{{5, 1}}; LoD lod_{{0, 2, 5}}; std::string pool_type_ = "SUM"; + DDim dims_{{5, 1}}; public: SequencePoolComputeTester(const Place& place, diff --git a/lite/tests/kernels/test_sgemm.cc b/lite/tests/kernels/test_sgemm.cc deleted file mode 100644 index 4801b3086ee74c3256a7f127ff1b64f17d674547..0000000000000000000000000000000000000000 --- a/lite/tests/kernels/test_sgemm.cc +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// -// Created by Li,Xiaoyang(SYS) on 2019-07-25. -// - -#include "lite/tests/kernels/fill_data.h" -#include "lite/tests/kernels/test_funcs.h" -#ifdef LITE_WITH_ARM -#include "lite/backends/arm/math/funcs.h" -#endif -#include "lite/core/context.h" -#include "lite/core/tensor.h" -int g_cluster = 0; -int g_threads = 1; - -bool g_basic_test = false; - -int g_M = 512; -int g_N = 512; -int g_K = 512; -bool g_traA = false; -bool g_traB = false; -bool g_flag_relu = false; -bool g_flag_bias = false; -int g_test_iter = 1; -int g_warmup_iter = 0; -bool g_compare_result = true; - -int g_offset_a = 10; -int g_offset_b = 10; -int g_offset_c = 10; - -float g_alpha = 1.f; -float g_beta = 0.f; - -const int MALLOC_ALIGN = 16; - -static void* fast_malloc1(size_t size) { - size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; - char* p; - p = static_cast(malloc(offset + size)); - if (!p) { - return nullptr; - } - void* r = reinterpret_cast(reinterpret_cast(p + offset) & - (~(MALLOC_ALIGN - 1))); - static_cast(r)[-1] = p; - return r; -} - -static void fast_free1(void* ptr) { - if (ptr) { - free(static_cast(ptr)[-1]); - } -} - -bool test_sgemm(bool tra, - bool trb, - int m, - int n, - int k, - int lda, - int ldb, - int ldc, - float alpha, - float beta, - bool has_bias, - bool has_relu, - int cls, - int ths) { - size_t size_a = tra ? k * lda : m * lda; - size_t size_b = trb ? n * ldb : k * ldb; - - auto da = static_cast(fast_malloc1(size_a * sizeof(float))); - auto db = static_cast(fast_malloc1(size_b * sizeof(float))); - auto dc = static_cast(fast_malloc1(m * ldc * sizeof(float))); - auto dc_basic = static_cast(fast_malloc1(m * ldc * sizeof(float))); - auto dbias = static_cast(fast_malloc1(m * sizeof(float))); - - fill_data_rand(da, -1.f, 1.f, size_a); - fill_data_rand(db, -1.f, 1.f, size_b); - fill_data_rand(dbias, -1.f, 1.f, m); - fill_data_rand(dc, -1.f, 1.f, m * ldc); - memcpy(dc_basic, dc, sizeof(float) * m * ldc); - - LOG(INFO) << "sgemm M: " << m << ", N: " << n << ", K: " << k; - LOG(INFO) << "strides, lda: " << lda << ", ldb: " << ldb << ", ldc: " << ldc; - LOG(INFO) << "alpha: " << alpha << ", beta: " << beta; - LOG(INFO) << "transA: " << (tra ? "true" : "false") - << ", transB: " << (trb ? "true" : "false"); - LOG(INFO) << "relu: " << (has_relu ? "true" : "false") - << ", bias: " << (has_bias ? "true" : "false"); - - LOG(INFO) << "basic sgemm compute"; - basic_gemm(tra, - trb, - m, - n, - k, - alpha, - da, - lda, - db, - ldb, - beta, - dc_basic, - ldc, - dbias, - has_bias, - has_relu); - - float max_error = 0.f; - float max_ratio = 0.f; -#ifdef LITE_WITH_ARM - //! compute - LOG(INFO) << "sgemm compute"; - double ops = 2.0 * m * n * k; - std::unique_ptr ctx1( - new paddle::lite::KernelContext); - auto& ctx = ctx1->As(); - - paddle::lite::arm::math::sgemm(tra, - trb, - m, - n, - k, - alpha, - da, - lda, - db, - ldb, - beta, - dc, - ldc, - dbias, - has_bias, - has_relu, - &ctx); - - for (int i = 0; i < m * ldc; ++i) { - auto error = fabsf(dc[i] - dc_basic[i]); - if (error > max_error) { - max_error = error; - max_ratio = error / fabsf(dc_basic[i]); - } - } - if (max_error > 2e-5f && max_ratio > 2e-5f) { - LOG(INFO) << "max ratio: " << max_ratio << ", max_error: " << max_error; - LOG(INFO) << "sgemm result:"; - for (int i = 0; i < m * ldc; ++i) { - printf("%f ", dc[i]); - if ((i + 1) % ldc == 0) { - printf("\n"); - } - } - LOG(INFO) << "basic result:"; - for (int i = 0; i < m * ldc; ++i) { - printf("%f ", dc_basic[i]); - if ((i + 1) % ldc == 0) { - printf("\n"); - } - } - } -#endif - fast_free1(da); - fast_free1(db); - fast_free1(dbias); - fast_free1(dc); - fast_free1(dc_basic); - return max_error < 2e-5f || max_ratio < 2e-5f; -} - -void test_input() { - int lda = g_K + g_offset_a; - if (g_traA) { - lda = g_M + g_offset_a; - } - int ldb = g_N + g_offset_b; - if (g_traB) { - ldb = g_K + g_offset_b; - } - int ldc = g_N + g_offset_c; - auto flag = test_sgemm(g_traA, - g_traB, - g_M, - g_N, - g_K, - lda, - ldb, - ldc, - g_alpha, - g_beta, - g_flag_bias, - g_flag_relu, - g_cluster, - g_threads); - if (!flag) { - LOG(FATAL) << "test m = " << g_M << ", n=" << g_N << ", k=" << g_K - << ", trans A: " << g_traA << ", trans B: " << g_traB - << ", bias: " << g_flag_bias << ", relu: " << g_flag_relu - << " failed!!"; - } - LOG(INFO) << "test m = " << g_M << ", n=" << g_N << ", k=" << g_K - << ", trans A: " << g_traA << ", trans B: " << g_traB - << ", bias: " << g_flag_bias << ", relu: " << g_flag_relu - << " passed!!"; -} - -void test_func_sgemm_prepacked() { - if (g_basic_test) { - LOG(INFO) << "run basic sgemm test"; - for (auto& m : {1, 8, 16, 111, 256, 397, 512, 777, 1024}) { - for (auto& n : {1, 3, 13, 141, 256, 345, 512, 789, 1024}) { - for (auto& k : {1, 4, 15, 59, 128, 234, 512, 678, 1024}) { - for (auto& tra : {false, true}) { - for (auto& trb : {false, true}) { - for (auto& alpha : {1.f, 0.5f}) { - for (auto& beta : {0.f, 0.5f}) { - for (auto& offset : {0, 10}) { - for (auto& has_bias : {false, true}) { - for (auto& has_relu : {false, true}) { - for (auto& th : {1, 2, 4}) { - int lda = k + offset; - if (tra) { - lda = m + offset; - } - int ldb = n + offset; - if (trb) { - ldb = k + offset; - } - int ldc = n + offset; - auto flag = test_sgemm(tra, - trb, - m, - n, - k, - lda, - ldb, - ldc, - alpha, - beta, - has_bias, - has_relu, - g_cluster, - th); - if (flag) { - LOG(INFO) - << "test m = " << m << ", n=" << n - << ", k=" << k - << ", bias: " << (has_bias ? "true" : "false") - << ", relu: " << (has_relu ? "true" : "false") - << ", trans A: " << (tra ? "true" : "false") - << ", trans B: " << (trb ? "true" : "false") - << " passed\n"; - } else { - LOG(FATAL) - << "test m = " << m << ", n=" << n - << ", k=" << k - << ", bias: " << (has_bias ? "true" : "false") - << ", relu: " << (has_relu ? "true" : "false") - << ", trans A: " << (tra ? "true" : "false") - << ", trans B: " << (trb ? "true" : "false") - << " failed\n"; - } - } - } - } - } - } - } - } - } - } - } - } - } -} - -int main(int argc, const char** argv) { -#ifdef LITE_WITH_ARM - paddle::lite::DeviceInfo::Init(); -#endif - LOG(ERROR) << "usage: ./" << argv[0] - << " [do_basic_test] [cluster] [threads] [m] [n] [k] [transA] " - "[transB] [relu] [bias] [test iter] [compare result]"; - if (argc > 1) { - g_basic_test = atoi(argv[1]) > 0; - } - if (argc > 2) { - g_cluster = atoi(argv[2]); - } - if (argc > 3) { - g_threads = atoi(argv[3]); - } - if (argc > 4) { - if (argc < 10) { - LOG(ERROR) << "usage: ./" << argv[0] << " [do_basic_test] [cluster] " - "[threads] [m] [n] [k] " - "[transA] [transB] [bias] [relu] " - "[test iter] [compare result]"; - return 0; - } - g_M = atoi(argv[4]); - g_N = atoi(argv[5]); - g_K = atoi(argv[6]); - g_traA = atoi(argv[7]) > 0; - g_traB = atoi(argv[8]) > 0; - g_flag_bias = atoi(argv[9]) > 0; - g_flag_relu = atoi(argv[10]) > 0; - } - if (argc > 11) { - g_test_iter = atoi(argv[11]); - } - if (argc > 12) { - g_compare_result = atoi(argv[12]) > 0; - } - if (argc > 13) { - g_warmup_iter = atoi(argv[13]); - } - if (argc > 14) { - g_offset_a = atoi(argv[14]); - } - if (argc > 15) { - g_offset_b = atoi(argv[15]); - } - if (argc > 16) { - g_offset_c = atoi(argv[16]); - } - if (argc > 17) { - g_alpha = atof(argv[17]); - } - if (argc > 18) { - g_beta = atof(argv[18]); - } - test_input(); - if (g_basic_test) { - test_func_sgemm_prepacked(); - } - return 0; -} diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6f35c615e8e2fba35d235d7a8ef78e0786cc11a --- /dev/null +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -0,0 +1,250 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class UnsqueezeComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "X"; + std::string out_ = "Out"; + std::vector axes_; + DDim dims_; + + public: + UnsqueezeComputeTester(const Place& place, + const std::string& alias, + const std::vector& axes, + DDim dims) + : TestCase(place, alias), axes_(axes), dims_(dims) {} + + void RunBaseline(Scope* scope) override { + const auto* input = scope->FindTensor(x_); + CHECK(input); + auto* out = scope->NewTensor(out_); + CHECK(out); + + DDim in_dims(dims_); + int output_size = in_dims.size() + static_cast(axes_.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validate Check: rank range. + CHECK_LE(output_size, 6) + << "The output tensor's rank should be less than 6."; + + for (int axis : axes_) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Validate Check: the axis bound + CHECK((cur >= 0) && (cur <= cur_output_size)) + << "The unsqueeze dims must be within range of current rank."; + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + for (size_t i = 0; i < output_shape.size(); ++i) + out->Resize(DDim(output_shape)); + auto* input_data = input->data(); + auto* out_data = out->mutable_data(); + memcpy(out_data, input_data, sizeof(float) * dims_.production()); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("unsqueeze"); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetAttr("axes", axes_); + } + + void PrepareData() override { + std::vector in_data(dims_.production()); + for (int i = 0; i < dims_.production(); ++i) { + in_data[i] = i; + } + SetCommonTensor(x_, dims_, in_data.data()); + } +}; + +class Unsqueeze2ComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "X"; + std::string out_ = "Out"; + std::string xshape_ = "XShape"; + std::vector axes_; + DDim dims_; + + public: + Unsqueeze2ComputeTester(const Place& place, + const std::string& alias, + const std::vector& axes, + DDim dims) + : TestCase(place, alias), axes_(axes), dims_(dims) {} + + void RunBaseline(Scope* scope) override { + const auto* input = scope->FindTensor(x_); + CHECK(input); + auto* out = scope->NewTensor(out_); + CHECK(out); + auto* xshape = scope->NewTensor(xshape_); + CHECK(xshape); + std::vector xshape_sp(dims_.size() + 1, 1); + for (size_t i = 0; i < dims_.size(); ++i) { + xshape_sp[i + 1] = dims_[i]; + } + xshape->Resize(DDim(xshape_sp)); + + DDim in_dims(dims_); + int output_size = in_dims.size() + static_cast(axes_.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validate Check: rank range. + CHECK_LE(output_size, 6) + << "The output tensor's rank should be less than 6."; + + for (int axis : axes_) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Validate Check: the axis bound + CHECK((cur >= 0) && (cur <= cur_output_size)) + << "The unsqueeze dims must be within range of current rank."; + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + out->Resize(DDim(output_shape)); + + auto* input_data = input->data(); + auto* out_data = out->mutable_data(); + auto* xshape_data = xshape->mutable_data(); + memcpy(out_data, input_data, sizeof(float) * dims_.production()); + memcpy(xshape_data, input_data, sizeof(float) * dims_.production()); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("unsqueeze2"); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetOutput("XShape", {xshape_}); + op_desc->SetAttr("axes", axes_); + } + + void PrepareData() override { + std::vector in_data(dims_.production()); + for (int i = 0; i < dims_.production(); ++i) { + in_data[i] = i; + } + SetCommonTensor(x_, dims_, in_data.data()); + } +}; + +void test_unsqueeze(Place place) { + for (std::vector axes : {std::vector({}), + std::vector({0, 2}), + std::vector({0, -2})}) { + for (int N : {1}) { + for (int C : {3}) { + for (int H : {1}) { + for (int W : {5}) { + std::unique_ptr tester(new UnsqueezeComputeTester( + place, "def", axes, DDim({N, C, H, W}))); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +void test_unsqueeze2(Place place) { + for (std::vector axes : {std::vector({}), + std::vector({0, 2}), + std::vector({0, -2})}) { + for (int N : {1}) { + for (int C : {3}) { + for (int H : {1}) { + for (int W : {5}) { + std::unique_ptr tester(new Unsqueeze2ComputeTester( + place, "def", axes, DDim({N, C, H, W}))); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +TEST(squeeze, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_unsqueeze(place); +#endif +} + +TEST(squeeze2, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_unsqueeze2(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/yolo_box_compute_test.cc b/lite/tests/kernels/yolo_box_compute_test.cc index a051e06b6bcb23647f8b9f467b9f76a751fecec4..2e98ce96cef479d55e77acebbe464d9a56f92934 100644 --- a/lite/tests/kernels/yolo_box_compute_test.cc +++ b/lite/tests/kernels/yolo_box_compute_test.cc @@ -101,7 +101,7 @@ class YoloBoxComputeTester : public arena::TestCase { float conf_thresh_ = 0.f; int downsample_ratio_ = 0; - DDim _dims0_{{1, 2, 2, 1}}; + DDim _dims0_{{1, 255, 13, 13}}; DDim _dims1_{{1, 2}}; public: @@ -115,7 +115,10 @@ class YoloBoxComputeTester : public arena::TestCase { anchors_(anchors), class_num_(class_num), conf_thresh_(conf_thresh), - downsample_ratio_(downsample_ratio) {} + downsample_ratio_(downsample_ratio) { + int anchor_num = anchors_.size() / 2; + _dims0_[1] = anchor_num * (5 + class_num); + } void RunBaseline(Scope* scope) override { const lite::Tensor* X = scope->FindTensor(input0_); @@ -149,14 +152,14 @@ class YoloBoxComputeTester : public arena::TestCase { auto anchors_data = anchors.data(); const float* in_data = in->data(); - const float* imgsize_data = imgsize->data(); + const int* imgsize_data = imgsize->data(); float* boxes_data = boxes->mutable_data(); float* scores_data = scores->mutable_data(); float box[4]; for (int i = 0; i < n; i++) { - int img_height = static_cast(imgsize_data[2 * i]); - int img_width = static_cast(imgsize_data[2 * i + 1]); + int img_height = imgsize_data[2 * i]; + int img_width = imgsize_data[2 * i + 1]; for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { @@ -218,7 +221,7 @@ class YoloBoxComputeTester : public arena::TestCase { } std::vector data1(_dims1_.production()); for (int i = 0; i < _dims1_.production(); i++) { - data1[i] = i + 8; + data1[i] = 608; } SetCommonTensor(input0_, _dims0_, data0.data()); SetCommonTensor(input1_, _dims1_, data1.data()); @@ -227,10 +230,9 @@ class YoloBoxComputeTester : public arena::TestCase { void test_yolobox(Place place) { for (int class_num : {1, 2, 3, 4}) { - for (float conf_thresh : {0.5, 0.2, 0.7}) { - for (int downsample_ratio : {1, 2, 3}) { - std::vector anchor({1, 2, 3, 4}); - + for (float conf_thresh : {0.01, 0.2, 0.7}) { + for (int downsample_ratio : {16, 32}) { + std::vector anchor({10, 13, 16, 30}); std::unique_ptr tester(new YoloBoxComputeTester( place, "def", anchor, class_num, conf_thresh, downsample_ratio)); arena::Arena arena(std::move(tester), place, 2e-5); diff --git a/lite/tests/math/CMakeLists.txt b/lite/tests/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..342901f075e46c4348ea294966e179de03cc292d --- /dev/null +++ b/lite/tests/math/CMakeLists.txt @@ -0,0 +1,8 @@ +if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) + lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(conv_transpose_compute_test SRCS conv_transpose_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) +endif() diff --git a/lite/tests/math/conv_compute_test.cc b/lite/tests/math/conv_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bfb74e6e0a6f5ea0cae199f1c7dc5f1c03e83363 --- /dev/null +++ b/lite/tests/math/conv_compute_test.cc @@ -0,0 +1,520 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#include "lite/tests/utils/naive_math_impl.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +#ifdef LITE_WITH_ARM +#include "lite/kernels/arm/conv_compute.h" +#endif // LITE_WITH_ARM + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(batch, 1, "batch size"); +DEFINE_int32(in_channel, 32, "input channel"); +DEFINE_int32(in_height, 112, "input height"); +DEFINE_int32(in_width, 112, "input width"); + +DEFINE_int32(out_channel, 32, "output channel"); +DEFINE_int32(group, 1, "group"); +DEFINE_int32(kernel_h, 3, "kernel height"); +DEFINE_int32(kernel_w, 3, "kernel width"); +DEFINE_int32(pad_h, 1, "pad height"); +DEFINE_int32(pad_w, 1, "pad width"); +DEFINE_int32(stride_h, 1, "stride height"); +DEFINE_int32(stride_w, 1, "stride width"); +DEFINE_int32(dila_h, 1, "dilation height"); +DEFINE_int32(dila_w, 1, "dilation width"); + +DEFINE_bool(flag_relu, true, "do relu"); +DEFINE_bool(flag_bias, true, "with bias"); + +typedef paddle::lite::DDim DDim; +typedef paddle::lite::Tensor Tensor; +typedef paddle::lite::operators::ConvParam ConvParam; +using paddle::lite::Timer; + +DDim compute_out_dim(const DDim& dim_in, + const paddle::lite::operators::ConvParam& param) { + DDim dim_out = dim_in; + dim_out[1] = param.filter->dims()[0]; + auto kernel_h = param.filter->dims()[2]; + auto kernel_w = param.filter->dims()[3]; + auto h = dim_in[2]; + auto w = dim_in[3]; + int dila_h = param.dilations[0]; + int dila_w = param.dilations[1]; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride_h = param.strides[0]; + int stride_w = param.strides[1]; + auto kernel_exten = dila_h * (kernel_h - 1) + 1; + auto hout = (h + 2 * pad_h - kernel_exten) / stride_h + 1; + kernel_exten = dila_w * (kernel_w - 1) + 1; + auto wout = (w + 2 * pad_w - kernel_exten) / stride_w + 1; + dim_out[2] = hout; + dim_out[3] = wout; + return dim_out; +} + +#ifdef LITE_WITH_ARM +void test_conv_fp32(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + ConvParam param; + param.x = new Tensor; + param.x->set_precision(PRECISION(kFloat)); + param.filter = new Tensor; + param.filter->Resize(weight_dim); + param.filter->set_precision(PRECISION(kFloat)); + if (flag_bias) { + param.bias = new Tensor; + param.bias->Resize({weight_dim[0]}); + param.bias->set_precision(PRECISION(kFloat)); + } + param.strides = strides; + param.paddings = pads; + param.dilations = dilas; + param.fuse_relu = flag_relu; + param.groups = group; + + param.output = new Tensor; + param.output->set_precision(PRECISION(kFloat)); + + paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f); + // paddle::lite::fill_tensor_const(*param.filter, 1.f); + if (flag_bias) { + paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f); + // paddle::lite::fill_tensor_const(*param.bias, 1.f); + } + auto wptr = param.filter->data(); + auto bias_ptr = flag_bias ? param.bias->data() : nullptr; + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::ConvCompute + conv; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + for (auto& dim_in : input_dims) { + param.x->Resize(dim_in); + DDim out_tmp_dims = compute_out_dim(dim_in, param); + if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) { + continue; + } + param.output->Resize(out_tmp_dims); + break; + } + conv.SetParam(param); + conv.SetContext(std::move(ctx1)); + /// prepare for run + conv.PrepareForRun(); + + for (auto& dim_in : input_dims) { + CHECK_EQ(weight_dim[1] * group, dim_in[1]) + << "input channel must equal to weights channel"; + DDim dim_out = compute_out_dim(dim_in, param); + if (dim_out[2] < 1 || dim_out[3] < 1) { + continue; + } + param.x->Resize(dim_in); + param.output->Resize(dim_out); + + paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f); + // paddle::lite::fill_tensor_const(*param.x, 1.f); + auto din = param.x->data(); + + Tensor tout_basic; + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kFloat)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + conv_basic(din, + dout_basic, + dim_in[0], + dim_out[1], + dim_out[2], + dim_out[3], + dim_in[1], + dim_in[2], + dim_in[3], + wptr, + bias_ptr, + group, + weight_dim[3], + weight_dim[2], + strides[1], + strides[0], + dilas[1], + dilas[0], + pads[1], + pads[0], + flag_bias, + flag_relu); + } + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + conv.Launch(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + conv.Launch(); + t0.end(); + } + + double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] * + weight_dim[3] / param.groups; + LOG(INFO) << "conv fp32: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.output); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic, *param.output, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test fp32 conv: input: " << dim_in + << ", output: " << dim_out + << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + } + LOG(INFO) << "test fp32 conv: input: " << dim_in + << ", output: " << dim_out << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.filter; + delete param.output; + delete param.bias; +} +#else +void test_conv_fp32(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& power_mode) {} +#endif // LITE_WITH_ARM + +#if 1 /// 3x3dw +TEST(TestConv3x3DW, test_conv3x3_depthwise) { + if (FLAGS_basic_test) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + for (auto& c : {1, 3, 5, 8, 16, 32}) { + std::vector dims; + DDim weights_dim({c, 1, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { + dims.push_back(DDim({batch, c, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + c, + {stride, stride}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// 3x3dw + +#if 1 /// 5x5dw +TEST(TestConv5x5DW, test_conv5x5_depthwise) { + if (FLAGS_basic_test) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + for (auto& c : {1, 3, 5, 8, 16, 32}) { + std::vector dims; + DDim weights_dim({c, 1, 5, 5}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { + dims.push_back(DDim({batch, c, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + c, + {stride, stride}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// 5x5dw + +#if 1 /// conv1x1s1 +TEST(TestConv1x1s1, test_conv1x1s1) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 11, 32}) { + for (auto& cout : {1, 5, 16, 37}) { + for (auto& g : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + if (cin % g != 0 || cout % g != 0) { + continue; + } + DDim weights_dim({cout, cin / g, 1, 1}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 28, 32, 56, 1}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + g, + {1, 1}, + {0, 0}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// conv1x1s1 + +#if 1 /// conv3x3s1 +TEST(TestConv3x3s1, test_conv_3x3s1) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 32, 48}) { + for (auto& cout : {1, 5, 8, 32, 48}) { + for (auto& pad : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + 1, + {1, 1}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// conv3x3s1 + +#if 1 /// conv3x3s2 +TEST(TestConv3x3s2, test_conv_3x3s2) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 32}) { + for (auto& cout : {1, 5, 8, 32}) { + for (auto& pad : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + 1, + {2, 2}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// conv3x3s2 + +#if 1 /// random param conv +TEST(TestConvRand, test_conv_rand) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 16}) { + for (auto& cout : {1, 5, 8, 16}) { + for (auto& g : {1, 2}) { + for (auto& kw : {1, 2, 3}) { + for (auto& kh : {1, 2, 3}) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1, 2}) { + for (auto& dila : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + if (cin % g != 0 || cout % g != 0) { + continue; + } + std::vector dims; + DDim weights_dim({cout, cin / g, kh, kw}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + g, + {stride, stride}, + {pad, pad}, + {dila, dila}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } + } + } + } + } + } +} +#endif /// random param conv + +#if 1 /// custom +TEST(TestConvCustom, test_conv_fp32_custom_size) { + CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0) + << "input channel must be divided by group"; + CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0) + << "num_output must be divided by group"; + test_conv_fp32( + {DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})}, + DDim({FLAGS_out_channel, + FLAGS_in_channel / FLAGS_group, + FLAGS_kernel_h, + FLAGS_kernel_w}), + FLAGS_group, + {FLAGS_stride_h, FLAGS_stride_w}, + {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_dila_h, FLAGS_dila_w}, + FLAGS_flag_bias, + FLAGS_flag_relu, + {FLAGS_threads}, + {FLAGS_power_mode}); +} +#endif // custom diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e15b7d22bc2a5859db73f21aa54b1bcdaabf4d2c --- /dev/null +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -0,0 +1,698 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#include "lite/tests/utils/naive_math_impl.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +#ifdef LITE_WITH_ARM +#include "lite/kernels/arm/conv_compute.h" +#endif // LITE_WITH_ARM + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(batch, 1, "batch size"); +DEFINE_int32(in_channel, 32, "input channel"); +DEFINE_int32(in_height, 112, "input height"); +DEFINE_int32(in_width, 112, "input width"); + +DEFINE_int32(out_channel, 32, "output channel"); +DEFINE_int32(group, 1, "group"); +DEFINE_int32(kernel_h, 3, "kernel height"); +DEFINE_int32(kernel_w, 3, "kernel width"); +DEFINE_int32(pad_h, 1, "pad height"); +DEFINE_int32(pad_w, 1, "pad width"); +DEFINE_int32(stride_h, 1, "stride height"); +DEFINE_int32(stride_w, 1, "stride width"); +DEFINE_int32(dila_h, 1, "dilation height"); +DEFINE_int32(dila_w, 1, "dilation width"); + +DEFINE_bool(flag_relu, true, "do relu"); +DEFINE_bool(flag_bias, true, "with bias"); + +typedef paddle::lite::DDim DDim; +typedef paddle::lite::Tensor Tensor; +typedef paddle::lite::operators::ConvParam ConvParam; +using paddle::lite::Timer; + +DDim compute_out_dim(const DDim& dim_in, + const paddle::lite::operators::ConvParam& param) { + DDim dim_out = dim_in; + dim_out[1] = param.filter->dims()[0]; + auto kernel_h = param.filter->dims()[2]; + auto kernel_w = param.filter->dims()[3]; + auto h = dim_in[2]; + auto w = dim_in[3]; + int dila_h = param.dilations[0]; + int dila_w = param.dilations[1]; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride_h = param.strides[0]; + int stride_w = param.strides[1]; + auto kernel_exten = dila_h * (kernel_h - 1) + 1; + auto hout = (h + 2 * pad_h - kernel_exten) / stride_h + 1; + kernel_exten = dila_w * (kernel_w - 1) + 1; + auto wout = (w + 2 * pad_w - kernel_exten) / stride_w + 1; + dim_out[2] = hout; + dim_out[3] = wout; + return dim_out; +} + +template +void get_conv_param(const DDim& dim_w, + int g, + const std::vector& strides, + const std::vector& pads, + const std::vector& dila, + bool flag_bias, + bool flag_relu, + ConvParam* param) { + param->x = new Tensor; + param->x->set_precision(PRECISION(kInt8)); + param->filter = new Tensor; + param->filter->Resize(dim_w); + param->filter->set_precision(PRECISION(kInt8)); + if (flag_bias) { + param->bias = new Tensor; + param->bias->Resize({dim_w[0]}); + param->bias->set_precision(PRECISION(kFloat)); + } + param->strides = strides; + param->paddings = pads; + param->dilations = dila; + param->fuse_relu = flag_relu; + param->groups = g; + + param->output = new Tensor; + param->output->set_precision(ptype); +} + +void release_param(ConvParam* param) { + delete param->x; + delete param->filter; + delete param->output; + delete param->bias; +} + +#ifdef LITE_WITH_ARM +#include "lite/backends/arm/math/funcs.h" +void test_conv_int8(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& power_mode) { + paddle::lite::DeviceInfo::Init(); + ConvParam param_int8_out; + ConvParam param_fp32_out; + + get_conv_param(weight_dim, + group, + strides, + pads, + dilas, + flag_bias, + flag_relu, + ¶m_int8_out); + + get_conv_param(weight_dim, + group, + strides, + pads, + dilas, + flag_bias, + flag_relu, + ¶m_fp32_out); + Tensor weight_fp32; + Tensor bias_fp32; + weight_fp32.Resize(weight_dim); + paddle::lite::fill_tensor_rand(*param_int8_out.filter, -127, 127); + param_fp32_out.filter->CopyDataFrom(*param_int8_out.filter); + if (flag_bias) { + auto dim_b = param_int8_out.bias->dims(); + bias_fp32.Resize(dim_b); + paddle::lite::fill_tensor_rand(*param_int8_out.bias, -1.f, 1.f); + param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias); + bias_fp32.CopyDataFrom(*param_int8_out.bias); + } + + std::vector scale_in{1.f / 127}; + std::vector scale_out{weight_dim.count(1, 4) / 127.f}; + std::vector scale_w(weight_dim[0], 1.f / 127); + + param_int8_out.input_scale = scale_in[0]; + param_int8_out.output_scale = scale_out[0]; + param_int8_out.weight_scale = scale_w; + + param_fp32_out.input_scale = scale_in[0]; + param_fp32_out.output_scale = scale_out[0]; + param_fp32_out.weight_scale = scale_w; + + auto wptr_fp32 = weight_fp32.mutable_data(); + auto bptr_fp32 = flag_bias ? bias_fp32.data() : nullptr; + + paddle::lite::arm::math::int8_to_fp32(param_int8_out.filter->data(), + wptr_fp32, + scale_w.data(), + weight_dim[0], + 1, + weight_dim.count(1, 4)); + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + std::unique_ptr ctx2( + new paddle::lite::KernelContext); + auto& ctx_tmp1 = ctx1->As(); + ctx_tmp1.SetRunMode(static_cast(cls), th); + auto& ctx_tmp2 = ctx2->As(); + ctx_tmp2.SetRunMode(static_cast(cls), th); + + paddle::lite::kernels::arm::ConvCompute + conv_int8_int8; + paddle::lite::kernels::arm::ConvCompute + conv_int8_fp32; + conv_int8_int8.SetContext(std::move(ctx1)); + conv_int8_fp32.SetContext(std::move(ctx2)); + + /// set param and context + for (auto& dim_in : input_dims) { + param_int8_out.x->Resize(dim_in); + DDim out_tmp_dims = compute_out_dim(dim_in, param_int8_out); + if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) { + continue; + } + param_fp32_out.x->Resize(dim_in); + param_int8_out.output->Resize(out_tmp_dims); + param_fp32_out.output->Resize(out_tmp_dims); + break; + } + conv_int8_int8.SetParam(param_int8_out); + conv_int8_fp32.SetParam(param_fp32_out); + /// prepare for run + conv_int8_int8.PrepareForRun(); + conv_int8_fp32.PrepareForRun(); + + for (auto& dim_in : input_dims) { + CHECK_EQ(weight_dim[1] * group, dim_in[1]) + << "input channel must equal to weights channel"; + DDim dim_out = compute_out_dim(dim_in, param_int8_out); + if (dim_out[2] < 1 || dim_out[3] < 1) { + continue; + } + delete param_fp32_out.output; + param_fp32_out.output = new Tensor; + param_fp32_out.output->set_precision(PRECISION(kFloat)); + delete param_int8_out.output; + param_int8_out.output = new Tensor; + param_int8_out.output->set_precision(PRECISION(kInt8)); + + param_int8_out.x->Resize(dim_in); + param_int8_out.output->Resize(dim_out); + param_fp32_out.x->Resize(dim_in); + param_fp32_out.output->Resize(dim_out); + + Tensor tin_fp32; + tin_fp32.Resize(dim_in); + tin_fp32.set_precision(PRECISION(kFloat)); + Tensor tout_basic_fp32; + Tensor tout_basic_int8; + + paddle::lite::fill_tensor_rand(*param_int8_out.x, -127, 127); + param_fp32_out.x->CopyDataFrom(*param_int8_out.x); + + auto din_fp32 = tin_fp32.mutable_data(); + paddle::lite::arm::math::int8_to_fp32(param_int8_out.x->data(), + din_fp32, + scale_in.data(), + 1, + 1, + dim_in.production()); + + if (FLAGS_check_result) { + tout_basic_fp32.set_precision(PRECISION(kFloat)); + tout_basic_fp32.Resize(dim_out); + tout_basic_int8.set_precision(PRECISION(kInt8)); + tout_basic_int8.Resize(dim_out); + fill_tensor_const(tout_basic_fp32, 0.f); + auto dout_basic_fp32 = tout_basic_fp32.mutable_data(); + auto dout_basic_int8 = tout_basic_int8.mutable_data(); + conv_basic(din_fp32, + dout_basic_fp32, + dim_in[0], + dim_out[1], + dim_out[2], + dim_out[3], + dim_in[1], + dim_in[2], + dim_in[3], + wptr_fp32, + bptr_fp32, + group, + weight_dim[3], + weight_dim[2], + strides[1], + strides[0], + dilas[1], + dilas[0], + pads[1], + pads[0], + flag_bias, + flag_relu); + paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32, + dout_basic_int8, + scale_out.data(), + 1, + 1, + dim_out.production()); + } + + double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] * + weight_dim[3] / group; + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + conv_int8_int8.Launch(); + } + /// compute fp32 output + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + conv_int8_fp32.Launch(); + t0.end(); + } + LOG(INFO) << "int8 conv, fp32 output: output shape" << dim_out + << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + /// compute int8 output + t0.clear(); + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + conv_int8_int8.Launch(); + t0.end(); + } + LOG(INFO) << "int8 conv, int8 output: output shape" << dim_out + << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + /// compare result fp32 output + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host( + tout_basic_fp32, *param_fp32_out.output, max_ratio, max_diff); + LOG(INFO) << "FP32 compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-5f) { + if (max_diff > 5e-5f) { + LOG(WARNING) << "basic result"; + print_tensor(tout_basic_fp32); + LOG(WARNING) << "lite result"; + print_tensor(*param_fp32_out.output); + Tensor tdiff; + tdiff.Resize(tout_basic_fp32.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic_fp32, *param_fp32_out.output, tdiff); + print_tensor(tdiff); + release_param(¶m_int8_out); + release_param(¶m_fp32_out); + LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in + << ", output: " << dim_out + << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + } + /// compare result int8 output + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + // ! int8 + tensor_cmp_host( + tout_basic_int8, *param_int8_out.output, max_ratio, max_diff); + LOG(INFO) << "int8 compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (fabs(max_diff) > 0) { + Tensor tdiff; + tdiff.Resize(tout_basic_int8.dims()); + tdiff.set_precision(PRECISION(kInt8)); + tensor_diff(tout_basic_int8, *param_int8_out.output, tdiff); + auto ptr = tdiff.data(); + auto ptr_basic_fp32 = tout_basic_fp32.data(); + float count = 0; + bool check = true; + for (int i = 0; i < tdiff.numel(); ++i) { + if (abs(ptr[i]) > 1) { + check = false; + LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] + << ", after scale: " + << ptr_basic_fp32[i] / scale_out[0]; + break; + } + if (ptr[i] != 0) { + LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] + << ", after scale: " + << ptr_basic_fp32[i] / scale_out[0]; + count += 1; + } + } + check = + check && + count < std::max(10, static_cast(0.01 * tdiff.numel())); + if (!check) { + LOG(WARNING) << "int8 basic result"; + print_tensor(tout_basic_int8); + LOG(WARNING) << "int8 lite result"; + print_tensor(*param_int8_out.output); + LOG(WARNING) << "int8 diff tensor"; + print_tensor(tdiff); + release_param(¶m_int8_out); + release_param(¶m_fp32_out); + LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in + << ", output: " << dim_out + << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + } + LOG(INFO) << "test int8 conv: input: " << dim_in + << ", output: " << dim_out << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + release_param(¶m_int8_out); + release_param(¶m_fp32_out); +} +#else +void test_conv_int8(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& power_mode) {} +#endif // LITE_WITH_ARM + +#if 1 /// 3x3dw +TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { + if (FLAGS_basic_test) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + for (auto& c : {1, 3, 5, 8, 16, 32}) { + std::vector dims; + DDim weights_dim({c, 1, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 15, 19, 75, 32, 28}) { + dims.push_back(DDim({batch, c, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + c, + {stride, stride}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// 3x3dw + +#if 1 /// 5x5dw +TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { + if (FLAGS_basic_test) { + for (auto& stride : {1}) { + for (auto& pad : {0, 1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + for (auto& c : {1, 3, 5, 8, 16, 32}) { + std::vector dims; + DDim weights_dim({c, 1, 5, 5}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { + dims.push_back(DDim({batch, c, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + c, + {stride, stride}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// 5x5dw + +#if 1 /// conv1x1s1 +TEST(TestConv1x1s1Int8, test_conv1x1s1) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 11, 32}) { + for (auto& cout : {1, 5, 16, 37}) { + for (auto& g : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + if (cin % g != 0 || cout % g != 0) { + continue; + } + DDim weights_dim({cout, cin / g, 1, 1}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 28, 32, 56, 1}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + g, + {1, 1}, + {0, 0}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// conv1x1s1 + +#if 1 /// conv3x3s1 +TEST(TestConv3x3s1Int8, test_conv_3x3s1) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 32, 48}) { + for (auto& cout : {1, 5, 8, 32, 48}) { + for (auto& pad : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + 1, + {1, 1}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// conv3x3s1 + +#if 1 /// conv3x3s2 +TEST(TestConv3x3s2Int8, test_conv_3x3s2) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 32}) { + for (auto& cout : {1, 5, 8, 32}) { + for (auto& pad : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + 1, + {2, 2}, + {pad, pad}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } +} +#endif /// conv3x3s2 + +#if 1 /// random param conv +TEST(TestConvRandInt8, test_conv_rand) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 16}) { + for (auto& cout : {1, 5, 8, 16}) { + for (auto& g : {1, 2}) { + for (auto& kw : {1, 2, 3}) { + for (auto& kh : {1, 2, 3}) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1, 2}) { + for (auto& dila : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + if (cin % g != 0 || cout % g != 0) { + continue; + } + std::vector dims; + DDim weights_dim({cout, cin / g, kh, kw}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + g, + {stride, stride}, + {pad, pad}, + {dila, dila}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } + } + } + } + } + } +} +#endif /// random param conv + +#if 1 /// custom +TEST(TestConvCustomInt8, test_conv_custom_size) { + CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0) + << "input channel must be divided by group"; + CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0) + << "num_output must be divided by group"; + test_conv_int8( + {DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})}, + DDim({FLAGS_out_channel, + FLAGS_in_channel / FLAGS_group, + FLAGS_kernel_h, + FLAGS_kernel_w}), + FLAGS_group, + {FLAGS_stride_h, FLAGS_stride_w}, + {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_dila_h, FLAGS_dila_w}, + FLAGS_flag_bias, + FLAGS_flag_relu, + {FLAGS_threads}, + {FLAGS_power_mode}); +} +#endif // custom diff --git a/lite/tests/math/conv_transpose_compute_test.cc b/lite/tests/math/conv_transpose_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0da07a53462cf902107efc0b6daaeef819f3288 --- /dev/null +++ b/lite/tests/math/conv_transpose_compute_test.cc @@ -0,0 +1,340 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#include "lite/tests/utils/naive_math_impl.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +#ifdef LITE_WITH_ARM +#include "lite/kernels/arm/conv_transpose_compute.h" +#endif // LITE_WITH_ARM + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(batch, 1, "batch size"); +DEFINE_int32(in_channel, 32, "input channel"); +DEFINE_int32(in_height, 32, "input height"); +DEFINE_int32(in_width, 32, "input width"); + +DEFINE_int32(out_channel, 64, "output channel"); +DEFINE_int32(group, 1, "group"); +DEFINE_int32(kernel_h, 2, "kernel height"); +DEFINE_int32(kernel_w, 2, "kernel width"); +DEFINE_int32(pad_h, 0, "pad height"); +DEFINE_int32(pad_w, 0, "pad width"); +DEFINE_int32(stride_h, 2, "stride height"); +DEFINE_int32(stride_w, 2, "stride width"); +DEFINE_int32(dila_h, 1, "dilation height"); +DEFINE_int32(dila_w, 1, "dilation width"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +typedef paddle::lite::DDim DDim; +typedef paddle::lite::Tensor Tensor; +typedef paddle::lite::operators::ConvParam ConvParam; +using paddle::lite::Timer; + +DDim compute_out_dim(const DDim& dim_in, + const paddle::lite::operators::ConvParam& param) { + auto filter_dims = param.filter->dims(); + DDim output_shape = dim_in; + output_shape[1] = filter_dims[1] * param.groups; + for (int i = 0; i < 2; i++) { + int kernel_extent = param.dilations[i] * (filter_dims[i + 2] - 1) + 1; + int output_len = (dim_in[i + 2] - 1) * param.strides[i] + kernel_extent - + 2 * param.paddings[i]; + output_shape[i + 2] = output_len; + } + return output_shape; +} + +#ifdef LITE_WITH_ARM +void test_conv_transpose_fp32(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + ConvParam param; + param.x = new Tensor; + param.x->set_precision(PRECISION(kFloat)); + param.filter = new Tensor; + param.filter->Resize(weight_dim); + param.filter->set_precision(PRECISION(kFloat)); + if (flag_bias) { + param.bias = new Tensor; + param.bias->Resize({weight_dim[0]}); + param.bias->set_precision(PRECISION(kFloat)); + } + param.strides = strides; + param.paddings = pads; + param.dilations = dilas; + param.fuse_relu = flag_relu; + param.groups = group; + + param.output = new Tensor; + param.output->set_precision(PRECISION(kFloat)); + + // paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f); + paddle::lite::fill_tensor_const(*param.filter, 1.f); + if (flag_bias) { + // paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f); + paddle::lite::fill_tensor_const(*param.bias, 1.f); + } + Tensor tmp_weights; + tmp_weights.Resize(weight_dim); + tmp_weights.CopyDataFrom(*param.filter); + auto wptr = tmp_weights.data(); + auto bias_ptr = flag_bias ? param.bias->data() : nullptr; + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::Conv2DTransposeCompute conv_t; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + for (auto& dim_in : input_dims) { + param.x->Resize(dim_in); + DDim out_tmp_dims = compute_out_dim(dim_in, param); + if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) { + continue; + } + param.output->Resize(out_tmp_dims); + break; + } + conv_t.SetParam(param); + conv_t.SetContext(std::move(ctx1)); + /// prepare for run + conv_t.PrepareForRun(); + + for (auto& dim_in : input_dims) { + CHECK_EQ(weight_dim[0], dim_in[1]) + << "input channel must equal to weights channel"; + DDim dim_out = compute_out_dim(dim_in, param); + if (dim_out[2] < 1 || dim_out[3] < 1) { + continue; + } + param.x->Resize(dim_in); + param.output->Resize(dim_out); + + // paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f); + paddle::lite::fill_tensor_const(*param.x, 1.f); + auto din = param.x->data(); + + Tensor tout_basic; + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kFloat)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + + deconv_basic(din, + dout_basic, + dim_in[0], + dim_out[1], + dim_out[2], + dim_out[3], + dim_in[1], + dim_in[2], + dim_in[3], + wptr, + bias_ptr, + group, + weight_dim[3], + weight_dim[2], + strides[1], + strides[0], + dilas[1], + dilas[0], + pads[1], + pads[0], + flag_bias, + flag_relu); + } + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + conv_t.Launch(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + conv_t.Launch(); + t0.end(); + } + + float gops = + 2.f * tmp_weights.numel() * dim_in[0] * dim_in[2] * dim_in[3]; + LOG(INFO) << "conv fp32: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.output); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic, *param.output, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test fp32 conv: input: " << dim_in + << ", output: " << dim_out + << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + } + LOG(INFO) << "test fp32 conv: input: " << dim_in + << ", output: " << dim_out << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.filter; + delete param.output; + delete param.bias; +} +#else +void test_conv_transpose_fp32(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& power_mode) {} +#endif // LITE_WITH_ARM + +#if 1 /// random param conv +TEST(TestConvRand, test_conv_transpose_rand) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 16}) { + for (auto& cout : {1, 5, 8, 16}) { + for (auto& g : {1, 2}) { + for (auto& kw : {1, 2, 3}) { + for (auto& kh : {1, 2, 3}) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1, 2}) { + for (auto& dila : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + if (cin % g != 0 || cout % g != 0) { + continue; + } + std::vector dims; + DDim weights_dim({cin, cout / g, kh, kw}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_transpose_fp32(dims, + weights_dim, + g, + {stride, stride}, + {pad, pad}, + {dila, dila}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } + } + } + } + } + } +} +#endif /// random param conv + +#if 1 /// custom +TEST(TestConvCustom, test_conv_transpose_fp32_custom_size) { + CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0) + << "input channel must be divided by group"; + CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0) + << "num_output must be divided by group"; + test_conv_transpose_fp32( + {DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})}, + DDim({FLAGS_in_channel, + FLAGS_out_channel / FLAGS_group, + FLAGS_kernel_h, + FLAGS_kernel_w}), + FLAGS_group, + {FLAGS_stride_h, FLAGS_stride_w}, + {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_dila_h, FLAGS_dila_w}, + FLAGS_flag_bias, + FLAGS_flag_relu, + {FLAGS_threads}, + {FLAGS_power_mode}); +} +#endif // custom diff --git a/lite/tests/math/gemm_int8_compute_test.cc b/lite/tests/math/gemm_int8_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..06a1a0a65e1e5d0abb4a3eef2a6bf7d1e7ce5db0 --- /dev/null +++ b/lite/tests/math/gemm_int8_compute_test.cc @@ -0,0 +1,386 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" +#ifdef LITE_WITH_ARM +#include "lite/backends/arm/math/funcs.h" +#endif // LITE_WITH_ARM +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +typedef paddle::lite::Tensor Tensor; +using paddle::lite::Timer; + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(M, 512, "gemm: M"); +DEFINE_int32(N, 512, "gemm: N"); +DEFINE_int32(K, 512, "gemm: K"); + +DEFINE_bool(traA, false, "gemm: A transpose"); +DEFINE_bool(traB, false, "gemm: B transpose"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +bool test_gemm_int8(bool tra, + bool trb, + int m, + int n, + int k, + bool has_bias, + bool has_relu, + int cls, + int ths) { + Tensor ta; + Tensor tb; + Tensor tc_int8; + Tensor tc_fp32; + Tensor tc_basic_int8; + Tensor tc_basic_fp32; + Tensor tbias; + + ta.Resize({m, k}); + tb.Resize({k, n}); + tc_int8.Resize({m, n}); + tc_fp32.Resize({m, n}); + tc_basic_int8.Resize({m, n}); + tc_basic_fp32.Resize({m, n}); + tbias.Resize({m}); + + ta.set_precision(PRECISION(kInt8)); + tb.set_precision(PRECISION(kInt8)); + tc_int8.set_precision(PRECISION(kInt8)); + tc_fp32.set_precision(PRECISION(kFloat)); + tc_basic_int8.set_precision(PRECISION(kInt8)); + tc_basic_fp32.set_precision(PRECISION(kFloat)); + tbias.set_precision(PRECISION(kFloat)); + + fill_tensor_rand(ta, -127, 127); + fill_tensor_rand(tb, -127, 127); + fill_tensor_rand(tbias, -1.f, 1.f); + + std::vector scale_a(static_cast(m), 1.f / 127); + std::vector scale_b = {1.f / 127}; + std::vector scale_c = {k / 127.f}; + std::vector scale_merge_fp32(static_cast(m)); + std::vector scale_merge_int8(static_cast(m)); + for (int j = 0; j < m; ++j) { + scale_merge_fp32[j] = scale_a[j] * scale_b[0]; + scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0]; + } + + LOG(INFO) << "gemm_int8 M: " << m << ", N: " << n << ", K: " << k + << ", transA: " << (tra ? "true" : "false") + << ", transB: " << (trb ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", bias: " << (has_bias ? "true" : "false"); +#ifdef LITE_WITH_ARM + int lda = tra ? m : k; + int ldb = trb ? k : n; + int ldc = n; + + auto da = ta.mutable_data(); + auto db = tb.mutable_data(); + auto dc_int8 = tc_int8.mutable_data(); + auto dc_fp32 = tc_fp32.mutable_data(); + auto dc_basic_int8 = tc_basic_int8.mutable_data(); + auto dc_basic_fp32 = tc_basic_fp32.mutable_data(); + auto dbias = tbias.mutable_data(); + + if (FLAGS_check_result) { + Tensor ta_fp32; + Tensor tb_fp32; + ta_fp32.Resize({m, k}); + ta_fp32.set_precision(PRECISION(kFloat)); + tb_fp32.Resize({k, n}); + tb_fp32.set_precision(PRECISION(kFloat)); + + auto da_fp32 = ta_fp32.mutable_data(); + auto db_fp32 = tb_fp32.mutable_data(); + + paddle::lite::arm::math::int8_to_fp32( + da, da_fp32, scale_a.data(), 1, 1, ta.numel()); + paddle::lite::arm::math::int8_to_fp32( + db, db_fp32, scale_b.data(), 1, 1, tb.numel()); + basic_gemm(tra, + trb, + m, + n, + k, + 1.f, + da_fp32, + lda, + db_fp32, + ldb, + 0.f, + dc_basic_fp32, + ldc, + dbias, + has_bias, + has_relu); + paddle::lite::arm::math::fp32_to_int8(dc_basic_fp32, + dc_basic_int8, + scale_c.data(), + 1, + 1, + tc_basic_fp32.numel()); + } + Timer t0; + //! compute + double ops = 2.0 * m * n * k; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), ths); + //! prepack + Tensor tpackedA; + int hblock = paddle::lite::arm::math::get_hblock_int8(&ctx); + int round_up_a = ((hblock + m - 1) / hblock) * hblock; + int round_up_k = 4 * ((k + 3) / 4); + tpackedA.Resize({round_up_a * round_up_k}); + paddle::lite::arm::math::prepackA_int8( + tpackedA.mutable_data(), da, lda, 0, m, 0, k, tra, &ctx); + /// warmup + for (int j = 0; j < FLAGS_warmup; ++j) { + paddle::lite::arm::math::gemm_prepack_int8(tpackedA.data(), + db, + dbias, + dc_fp32, + m, + n, + k, + has_bias, + has_relu, + trb, + scale_merge_fp32.data(), + &ctx); + } + + /// int8 output compute + Tensor tbias_int8; + tbias_int8.Resize(tbias.dims()); + tbias_int8.set_precision(PRECISION(kFloat)); + auto dbias_int8 = tbias_int8.mutable_data(); + for (int l = 0; l < tbias_int8.numel(); ++l) { + dbias_int8[l] = dbias[l] / scale_c[0]; + } + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + paddle::lite::arm::math::gemm_prepack_int8(tpackedA.data(), + db, + dbias_int8, + dc_int8, + m, + n, + k, + has_bias, + has_relu, + trb, + scale_merge_int8.data(), + &ctx); + t0.end(); + } + LOG(INFO) << "gemm_int8_int8 output: M: " << m << ", N: " << n << ", K: " << k + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + /// fp32 output compute + t0.clear(); + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + paddle::lite::arm::math::gemm_prepack_int8(tpackedA.data(), + db, + dbias, + dc_fp32, + m, + n, + k, + has_bias, + has_relu, + trb, + scale_merge_fp32.data(), + &ctx); + t0.end(); + } + LOG(INFO) << "gemm_int8_fp32 output: M: " << m << ", N: " << n << ", K: " << k + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + /// fp32 result + tensor_cmp_host(tc_basic_fp32, tc_fp32, max_ratio, max_diff); + LOG(INFO) << "fp32 compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) { + Tensor tdiff; + tdiff.set_precision(PRECISION(kFloat)); + tdiff.Resize(tc_fp32.dims()); + tensor_diff(tc_basic_fp32, tc_fp32, tdiff); + LOG(INFO) << "basic result: "; + print_tensor(tc_basic_fp32); + LOG(INFO) << "lite result: "; + print_tensor(tc_fp32); + LOG(INFO) << "diff result: "; + print_tensor(tdiff); + return false; + } + /// int8 result + max_ratio = 0; + max_diff = 0; + tensor_cmp_host(tc_basic_int8, tc_int8, max_ratio, max_diff); + LOG(INFO) << "int8 compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (fabs(max_ratio) > 1e-4f) { + Tensor tdiff; + tdiff.Resize(tc_int8.dims()); + tdiff.set_precision(PRECISION(kInt8)); + tensor_diff(tc_basic_int8, tc_int8, tdiff); + auto ptr = tdiff.data(); + auto ptr_basic_fp32 = tc_basic_fp32.data(); + float count = 0; + bool check = true; + for (int i = 0; i < tdiff.numel(); ++i) { + if (abs(ptr[i]) > 1) { + check = false; + LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] + << ", after scale: " << ptr_basic_fp32[i] / scale_c[0]; + break; + } + if (ptr[i] != 0) { + LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] + << ", after scale: " << ptr_basic_fp32[i] / scale_c[0]; + count += 1; + } + } + check = + check && count < std::max(10, static_cast(0.01 * tdiff.numel())); + if (!check) { + LOG(WARNING) << "int8 basic result"; + print_tensor(tc_basic_int8); + LOG(WARNING) << "int8 lite result"; + print_tensor(tc_int8); + LOG(WARNING) << "int8 diff tensor"; + print_tensor(tdiff); + return false; + } + } + } +#endif + return true; +} + +TEST(TestLiteGemmInt8, gemm_prepacked_int8) { + if (FLAGS_basic_test) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LOG(INFO) << "run basic sgemm test"; + for (auto& m : {1, 3, 8, 32, 397}) { + for (auto& n : {1, 3, 13, 141, 512, 789}) { + for (auto& k : {1, 3, 8, 59, 234}) { + for (auto& tra : {false, true}) { + for (auto& trb : {false, true}) { + for (auto& has_bias : {false, true}) { + for (auto& has_relu : {false, true}) { + for (auto& th : {1, 2, 4}) { + auto flag = test_gemm_int8(tra, + trb, + m, + n, + k, + has_bias, + has_relu, + FLAGS_power_mode, + th); + if (flag) { + LOG(INFO) << "test m = " << m << ", n=" << n + << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << ", trans B: " << (trb ? "true" : "false") + << " passed\n"; + } else { + LOG(FATAL) << "test m = " << m << ", n=" << n + << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << ", trans B: " << (trb ? "true" : "false") + << " failed\n"; + } + } + } + } + } + } + } + } + } + } +} + +TEST(TestGemmInt8Custom, gemm_prepacked_int8_custom) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + auto flag = test_gemm_int8(FLAGS_traA, + FLAGS_traB, + FLAGS_M, + FLAGS_N, + FLAGS_K, + FLAGS_flag_bias, + FLAGS_flag_relu, + FLAGS_power_mode, + FLAGS_threads); + if (!flag) { + LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N + << ", k=" << FLAGS_K << ", trans A: " << FLAGS_traA + << ", trans B: " << FLAGS_traB << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " failed!!"; + } + LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K + << ", trans A: " << FLAGS_traA << ", trans B: " << FLAGS_traB + << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu + << " passed!!"; +} diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c64e78d66a4193f1b20c525120d8b0281afc9a9c --- /dev/null +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -0,0 +1,337 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" +#ifdef LITE_WITH_ARM +#include "lite/backends/arm/math/funcs.h" +#endif // LITE_WITH_ARM +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +typedef paddle::lite::Tensor Tensor; +using paddle::lite::Timer; + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(M, 512, "gemv: M"); +DEFINE_int32(N, 512, "gemv: N"); + +DEFINE_bool(traA, false, "gemv: A transpose"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +bool test_gemv_int8( + bool tra, int m, int n, bool has_bias, bool has_relu, int cls, int ths) { + Tensor ta; + Tensor tb; + Tensor tc_int8; + Tensor tc_fp32; + Tensor tc_basic_int8; + Tensor tc_basic_fp32; + Tensor tbias; + + ta.Resize({m, n}); + tb.Resize({n}); + tc_int8.Resize({m}); + tc_fp32.Resize({m}); + tc_basic_int8.Resize({m}); + tc_basic_fp32.Resize({m}); + tbias.Resize({m}); + + ta.set_precision(PRECISION(kInt8)); + tb.set_precision(PRECISION(kInt8)); + tc_int8.set_precision(PRECISION(kInt8)); + tc_fp32.set_precision(PRECISION(kFloat)); + tc_basic_int8.set_precision(PRECISION(kInt8)); + tc_basic_fp32.set_precision(PRECISION(kFloat)); + tbias.set_precision(PRECISION(kFloat)); + + fill_tensor_rand(ta, -127, 127); + fill_tensor_rand(tb, -127, 127); + fill_tensor_rand(tbias, -1.f, 1.f); + + std::vector scale_a(static_cast(m), 1.f / 127); + std::vector scale_b = {1.f / 127}; + std::vector scale_c = {n / 127.f}; + std::vector scale_merge_fp32(static_cast(m)); + std::vector scale_merge_int8(static_cast(m)); + for (int j = 0; j < m; ++j) { + scale_merge_fp32[j] = scale_a[j] * scale_b[0]; + scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0]; + } + + LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n + << ", transA: " << (tra ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", bias: " << (has_bias ? "true" : "false"); +#ifdef LITE_WITH_ARM + auto da = ta.mutable_data(); + auto db = tb.mutable_data(); + auto dc_int8 = tc_int8.mutable_data(); + auto dc_fp32 = tc_fp32.mutable_data(); + auto dc_basic_int8 = tc_basic_int8.mutable_data(); + auto dc_basic_fp32 = tc_basic_fp32.mutable_data(); + auto dbias = tbias.mutable_data(); + + if (FLAGS_check_result) { + Tensor ta_fp32; + Tensor tb_fp32; + ta_fp32.Resize({m, n}); + ta_fp32.set_precision(PRECISION(kFloat)); + tb_fp32.Resize({n}); + tb_fp32.set_precision(PRECISION(kFloat)); + + auto da_fp32 = ta_fp32.mutable_data(); + auto db_fp32 = tb_fp32.mutable_data(); + + paddle::lite::arm::math::int8_to_fp32( + da, da_fp32, scale_a.data(), 1, 1, ta.numel()); + paddle::lite::arm::math::int8_to_fp32( + db, db_fp32, scale_b.data(), 1, 1, tb.numel()); + basic_gemv(m, + n, + da_fp32, + db_fp32, + dbias, + dc_basic_fp32, + 1.f, + 0.f, + false, + has_bias, + has_relu); + paddle::lite::arm::math::fp32_to_int8(dc_basic_fp32, + dc_basic_int8, + scale_c.data(), + 1, + 1, + tc_basic_fp32.numel()); + } + Timer t0; + //! compute + double ops = 2.0 * m * n; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), ths); + /// warmup + for (int j = 0; j < FLAGS_warmup; ++j) { + paddle::lite::arm::math::gemv_int8(da, + db, + dc_fp32, + false, + m, + n, + scale_merge_fp32.data(), + has_bias, + dbias, + has_relu, + &ctx); + } + + /// int8 output compute + Tensor tbias_int8; + tbias_int8.Resize(tbias.dims()); + tbias_int8.set_precision(PRECISION(kFloat)); + auto dbias_int8 = tbias_int8.mutable_data(); + for (int l = 0; l < tbias_int8.numel(); ++l) { + dbias_int8[l] = dbias[l] / scale_c[0]; + } + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + paddle::lite::arm::math::gemv_int8(da, + db, + dc_fp32, + false, + m, + n, + scale_merge_fp32.data(), + has_bias, + dbias, + has_relu, + &ctx); + t0.end(); + } + LOG(INFO) << "gemv_int8_int8 output: M: " << m << ", N: " << n + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + /// fp32 output compute + t0.clear(); + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + paddle::lite::arm::math::gemv_int8(da, + db, + dc_int8, + false, + m, + n, + scale_merge_int8.data(), + has_bias, + dbias_int8, + has_relu, + &ctx); + t0.end(); + } + LOG(INFO) << "gemm_int8_fp32 output: M: " << m << ", N: " << n + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + /// fp32 result + tensor_cmp_host(tc_basic_fp32, tc_fp32, max_ratio, max_diff); + LOG(INFO) << "fp32 compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) { + Tensor tdiff; + tdiff.set_precision(PRECISION(kFloat)); + tdiff.Resize(tc_fp32.dims()); + tensor_diff(tc_basic_fp32, tc_fp32, tdiff); + LOG(INFO) << "basic result: "; + print_tensor(tc_basic_fp32); + LOG(INFO) << "lite result: "; + print_tensor(tc_fp32); + LOG(INFO) << "diff result: "; + print_tensor(tdiff); + return false; + } + /// int8 result + max_ratio = 0; + max_diff = 0; + tensor_cmp_host(tc_basic_int8, tc_int8, max_ratio, max_diff); + LOG(INFO) << "int8 compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (fabs(max_ratio) > 1e-4f) { + Tensor tdiff; + tdiff.Resize(tc_int8.dims()); + tdiff.set_precision(PRECISION(kInt8)); + tensor_diff(tc_basic_int8, tc_int8, tdiff); + auto ptr = tdiff.data(); + auto ptr_basic_fp32 = tc_basic_fp32.data(); + float count = 0; + bool check = true; + for (int i = 0; i < tdiff.numel(); ++i) { + if (abs(ptr[i]) > 1) { + check = false; + LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] + << ", after scale: " << ptr_basic_fp32[i] / scale_c[0]; + break; + } + if (ptr[i] != 0) { + LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] + << ", after scale: " << ptr_basic_fp32[i] / scale_c[0]; + count += 1; + } + } + check = + check && count < std::max(10, static_cast(0.01 * tdiff.numel())); + if (!check) { + LOG(WARNING) << "int8 basic result"; + print_tensor(tc_basic_int8); + LOG(WARNING) << "int8 lite result"; + print_tensor(tc_int8); + LOG(WARNING) << "int8 diff tensor"; + print_tensor(tdiff); + return false; + } + } + } +#endif + return true; +} + +TEST(TestLiteGemvInt8, gemv_prepacked_int8) { + if (FLAGS_basic_test) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LOG(INFO) << "run basic sgemm test"; + for (auto& m : {1, 3, 8, 32, 397}) { + for (auto& n : {1, 3, 13, 141, 512, 789}) { + for (auto& tra : {false}) { + for (auto& has_bias : {false, true}) { + for (auto& has_relu : {false, true}) { + for (auto& th : {1, 2, 4}) { + auto flag = test_gemv_int8( + tra, m, n, has_bias, has_relu, FLAGS_power_mode, th); + if (flag) { + LOG(INFO) << "test m = " << m << ", n=" << n + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << " passed\n"; + } else { + LOG(FATAL) << "test m = " << m << ", n=" << n + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << " failed\n"; + } + } + } + } + } + } + } + } +} + +TEST(TestGemvInt8Custom, gemv_prepacked_int8_custom) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + auto flag = test_gemv_int8(FLAGS_traA, + FLAGS_M, + FLAGS_N, + FLAGS_flag_bias, + FLAGS_flag_relu, + FLAGS_power_mode, + FLAGS_threads); + if (!flag) { + LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N + << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " failed!!"; + } + LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N + << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " passed!!"; +} diff --git a/lite/tests/math/sgemm_compute_test.cc b/lite/tests/math/sgemm_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1621ceb9047125d0d2a4141a01111eb54892dee9 --- /dev/null +++ b/lite/tests/math/sgemm_compute_test.cc @@ -0,0 +1,340 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" +#ifdef LITE_WITH_ARM +#include "lite/backends/arm/math/funcs.h" +#endif // LITE_WITH_ARM +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +typedef paddle::lite::Tensor Tensor; +using paddle::lite::Timer; + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(M, 512, "gemm: M"); +DEFINE_int32(N, 512, "gemm: N"); +DEFINE_int32(K, 512, "gemm: K"); + +DEFINE_bool(traA, false, "gemm: A transpose"); +DEFINE_bool(traB, false, "gemm: B transpose"); + +DEFINE_int32(offset_a, 0, "A offset"); +DEFINE_int32(offset_b, 0, "B offset"); +DEFINE_int32(offset_c, 0, "C offset"); + +DEFINE_double(alpha, 1.0, "alpha"); +DEFINE_double(beta, 0.0, "beta"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +bool test_sgemm(bool tra, + bool trb, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + float alpha, + float beta, + bool has_bias, + bool has_relu, + int cls, + int ths) { + int size_a = tra ? k * lda : m * lda; + int size_b = trb ? n * ldb : k * ldb; + + Tensor ta; + Tensor tb; + Tensor tc; + Tensor tc_basic; + Tensor tc_backup; + Tensor tbias; + + ta.Resize({size_a}); + tb.Resize({size_b}); + tc.Resize({m * ldc}); + tc_basic.Resize({m * ldc}); + tc_backup.Resize({m * ldc}); + tbias.Resize({m}); + + ta.set_precision(PRECISION(kFloat)); + tb.set_precision(PRECISION(kFloat)); + tc.set_precision(PRECISION(kFloat)); + tc_basic.set_precision(PRECISION(kFloat)); + tc_backup.set_precision(PRECISION(kFloat)); + tbias.set_precision(PRECISION(kFloat)); + + fill_tensor_rand(ta, -1.f, 1.f); + fill_tensor_rand(tb, -1.f, 1.f); + fill_tensor_rand(tbias, -1.f, 1.f); + fill_tensor_rand(tc, -1.f, 1.f); + + auto da = ta.mutable_data(); + auto db = tb.mutable_data(); + auto dc = tc.mutable_data(); + auto dc_basic = tc_basic.mutable_data(); + auto dc_backup = tc_backup.mutable_data(); + auto dbias = tbias.mutable_data(); + + memcpy(dc_basic, dc, sizeof(float) * m * ldc); + memcpy(dc_backup, dc, sizeof(float) * m * ldc); + + LOG(INFO) << "sgemm M: " << m << ", N: " << n << ", K: " << k + << ", strides, lda: " << lda << ", ldb: " << ldb << ", ldc: " << ldc + << ", alpha: " << alpha << ", beta: " << beta + << ", transA: " << (tra ? "true" : "false") + << ", transB: " << (trb ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", bias: " << (has_bias ? "true" : "false"); + if (FLAGS_check_result) { + basic_gemm(tra, + trb, + m, + n, + k, + alpha, + da, + lda, + db, + ldb, + beta, + dc_basic, + ldc, + dbias, + has_bias, + has_relu); + } + Timer t0; +#ifdef LITE_WITH_ARM + //! compute + double ops = 2.0 * m * n * k; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), ths); + //! prepack + Tensor tpackedA; + int hblock = paddle::lite::arm::math::get_hblock(&ctx); + int round_up_a = ((hblock + m - 1) / hblock) * hblock; + tpackedA.Resize({round_up_a * k}); + paddle::lite::arm::math::prepackA( + tpackedA.mutable_data(), da, alpha, lda, 0, m, 0, k, tra, &ctx); + for (int j = 0; j < FLAGS_warmup; ++j) { + paddle::lite::arm::math::sgemm_prepack(trb, + m, + n, + k, + tpackedA.data(), + db, + ldb, + beta, + dc, + ldc, + dbias, + has_bias, + has_relu, + &ctx); + } + + for (int i = 0; i < FLAGS_repeats; ++i) { + if (i == FLAGS_repeats - 1) { + memcpy(dc, dc_backup, sizeof(float) * m * ldc); + } + t0.start(); + paddle::lite::arm::math::sgemm_prepack(trb, + m, + n, + k, + tpackedA.data(), + db, + ldb, + beta, + dc, + ldc, + dbias, + has_bias, + has_relu, + &ctx); + t0.end(); + } + LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tc_basic, tc, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) { + Tensor tdiff; + tdiff.set_precision(PRECISION(kFloat)); + tdiff.Resize(tc.dims()); + tensor_diff(tc_basic, tc, tdiff); + LOG(INFO) << "a: "; + print_tensor(ta); + LOG(INFO) << "b: "; + print_tensor(tb); + LOG(INFO) << "c: "; + print_tensor(tc_backup); + LOG(INFO) << "basic result: "; + print_tensor(tc_basic); + LOG(INFO) << "lite result: "; + print_tensor(tc); + LOG(INFO) << "diff result: "; + print_tensor(tdiff); + return false; + } + } +#endif + return true; +} + +TEST(TestSgemm, test_func_sgemm_prepacked) { + if (FLAGS_basic_test) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LOG(INFO) << "run basic sgemm test"; + for (auto& m : {1, 3, 8, 32, 397}) { + for (auto& n : {1, 3, 13, 141, 512, 789}) { + for (auto& k : {1, 3, 8, 59, 234}) { + for (auto& tra : {false, true}) { + for (auto& trb : {false, true}) { + for (auto& alpha : {1.f, 0.5f}) { + for (auto& beta : {0.f, 0.5f}) { + for (auto& offset : {0, 10}) { + for (auto& has_bias : {false, true}) { + for (auto& has_relu : {false, true}) { + for (auto& th : {1, 2, 4}) { + int lda = k + offset; + if (tra) { + lda = m + offset; + } + int ldb = n + offset; + if (trb) { + ldb = k + offset; + } + int ldc = n + offset; + auto flag = test_sgemm(tra, + trb, + m, + n, + k, + lda, + ldb, + ldc, + alpha, + beta, + has_bias, + has_relu, + FLAGS_power_mode, + th); + if (flag) { + LOG(INFO) + << "test m = " << m << ", n=" << n + << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << ", trans B: " << (trb ? "true" : "false") + << " passed\n"; + } else { + LOG(FATAL) + << "test m = " << m << ", n=" << n + << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << ", trans B: " << (trb ? "true" : "false") + << " failed\n"; + } + } + } + } + } + } + } + } + } + } + } + } + } +} + +TEST(TestSgemmCustom, test_func_sgemm_prepacked_custom) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + int lda = FLAGS_K + FLAGS_offset_a; + if (FLAGS_traA) { + lda = FLAGS_M + FLAGS_offset_a; + } + int ldb = FLAGS_N + FLAGS_offset_b; + if (FLAGS_traB) { + ldb = FLAGS_K + FLAGS_offset_b; + } + int ldc = FLAGS_N + FLAGS_offset_c; + auto flag = test_sgemm(FLAGS_traA, + FLAGS_traB, + FLAGS_M, + FLAGS_N, + FLAGS_K, + lda, + ldb, + ldc, + FLAGS_alpha, + FLAGS_beta, + FLAGS_flag_bias, + FLAGS_flag_relu, + FLAGS_power_mode, + FLAGS_threads); + if (!flag) { + LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N + << ", k=" << FLAGS_K << ", trans A: " << FLAGS_traA + << ", trans B: " << FLAGS_traB << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " failed!!"; + } + LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K + << ", trans A: " << FLAGS_traA << ", trans B: " << FLAGS_traB + << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu + << " passed!!"; +} diff --git a/lite/tests/kernels/fill_data.h b/lite/tests/utils/fill_data.h similarity index 100% rename from lite/tests/kernels/fill_data.h rename to lite/tests/utils/fill_data.h diff --git a/lite/tests/kernels/test_funcs.h b/lite/tests/utils/naive_math_impl.h similarity index 52% rename from lite/tests/kernels/test_funcs.h rename to lite/tests/utils/naive_math_impl.h index accbb0eeadf92e43102410519b3b2c78e09f121d..846126ac247ee685bd8772ede87635c45b52f79a 100644 --- a/lite/tests/kernels/test_funcs.h +++ b/lite/tests/utils/naive_math_impl.h @@ -189,3 +189,176 @@ static void conv_basic(const Dtype1* din, } } } + +template +static void fill_bias_relu(Dtype* tensor, + const Dtype* bias, + int channel, + int channel_size, + bool flag_bias, + bool flag_relu) { + Dtype* data = tensor; + for (int j = 0; j < channel; ++j) { + Dtype bias_c = flag_bias ? bias[j] : 0; + for (int i = 0; i < channel_size; i++) { + data[i] += bias_c; + if (flag_relu) { + data[i] = data[i] > 0 ? data[i] : 0.f; + } + } + data += channel_size; + } +} + +template +static void do_relu(Dtype* tensor, int size) { + for (int j = 0; j < size; ++j) { + tensor[j] = tensor[j] > 0 ? tensor[j] : (Dtype)0; + } +} + +inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { + return static_cast(a) < static_cast(b); +} + +template +static void col2im(const Dtype* data_col, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + Dtype* data_im) { + memset(data_im, 0, height * width * channels * sizeof(Dtype)); + const int output_h = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + + for (int channel = channels; channel--; data_im += channel_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + + for (int output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { + data_col += output_w; + } else { + int input_col = -pad_w + kernel_col * dilation_w; + + for (int output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, width)) { + data_im[input_row * width + input_col] += *data_col; + } + data_col++; + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} + +//! for float, dtype1 and type2 is float +//! for int8, dytpe1 is char, dtype2 is int +template +void deconv_basic(const Dtype1* din, + Dtype2* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const Dtype1* weights, + const Dtype2* bias, + int group, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int dila_w, + int dila_h, + int pad_w, + int pad_h, + bool flag_bias, + bool flag_relu) { + int m = chout * kernel_w * kernel_h / group; + int n = hin * win; + int k = chin / group; + + int group_size_in = win * hin * chin / group; + int group_size_out = wout * hout * chout / group; + int group_size_coldata = m * n; + int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); + bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && + (stride_w == 1) && (pad_w == 1) && (pad_h == 1) && + (dila_w == 1) && (dila_h == 1); + + Dtype2* workspace_ptr = + static_cast(malloc(sizeof(float) * m * n * group)); + + for (int i = 0; i < num; ++i) { + const Dtype1* din_batch = din + i * chin * hin * win; + Dtype2* dout_batch = dout + i * chout * hout * wout; + + Dtype2* col_data = workspace_ptr; + if (flag_1x1s1p1) { + col_data = dout_batch; + } + memset(col_data, 0, sizeof(Dtype2) * group_size_coldata); + for (int g = 0; g < group; ++g) { + const Dtype1* din_group = din_batch + g * group_size_in; + const Dtype1* weights_group = weights + g * group_size_weights; + Dtype2* coldata_group = col_data + g * group_size_coldata; + basic_gemm(true, + false, + m, + n, + k, + 1, + weights_group, + m, + din_group, + n, + 0, + coldata_group, + n, + nullptr, + false, + (!flag_bias && flag_relu)); + } + + if (!flag_1x1s1p1) { + col2im(col_data, + chout, + hout, + wout, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dila_h, + dila_w, + dout_batch); + } + //! add bias + if (flag_bias) { + fill_bias_relu( + dout_batch, bias, chout, wout * hout, flag_bias, flag_relu); + } + } + free(workspace_ptr); +} diff --git a/lite/tests/utils/tensor_utils.h b/lite/tests/utils/tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4f8d1ad2aa70dc09ab22d0e22df2180b5da83788 --- /dev/null +++ b/lite/tests/utils/tensor_utils.h @@ -0,0 +1,331 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +template +void fill_tensor_host_const_impl(Dtype* dio, Dtype value, int64_t size) { + for (int64_t i = 0; i < size; ++i) { + dio[i] = value; + } +} +/** + * \brief Fill the host tensor buffer with rand value. + * \param tensor The reference of input tensor. + */ +void fill_tensor_const(Tensor& tensor, float value) { // NOLINT + int64_t size = tensor.numel(); + PrecisionType type = tensor.precision(); + switch (type) { + case PRECISION(kInt8): + fill_tensor_host_const_impl( + tensor.mutable_data(), static_cast(value), size); + break; + case PRECISION(kInt32): + fill_tensor_host_const_impl( + tensor.mutable_data(), static_cast(value), size); + break; + case PRECISION(kFloat): + fill_tensor_host_const_impl( + tensor.mutable_data(), static_cast(value), size); + break; + default: + LOG(FATAL) << "data type: " << PrecisionRepr(type) + << " is unsupported now"; + } +} + +template +void fill_tensor_host_rand_impl(Dtype* dio, int64_t size) { + for (int64_t i = 0; i < size; ++i) { + Dtype rand_x = static_cast(rand() % 256); // NOLINT + dio[i] = (rand_x - 128) / 128; + } +} +template <> +void fill_tensor_host_rand_impl(signed char* dio, int64_t size) { + for (int64_t i = 0; i < size; ++i) { + dio[i] = rand() % 256 - 128; // NOLINT + } +} +template <> +void fill_tensor_host_rand_impl(unsigned char* dio, + int64_t size) { + for (int64_t i = 0; i < size; ++i) { + dio[i] = rand() % 256; // NOLINT + } +} +/** + * \brief Fill the host tensor buffer with rand value. + * \param The reference of input tensor. + */ +void fill_tensor_rand(Tensor& tensor) { // NOLINT + int64_t size = tensor.numel(); + PrecisionType type = tensor.precision(); + switch (type) { + case PRECISION(kInt8): + fill_tensor_host_rand_impl(tensor.mutable_data(), size); + break; + case PRECISION(kInt32): + fill_tensor_host_rand_impl(tensor.mutable_data(), size); + break; + case PRECISION(kFloat): + fill_tensor_host_rand_impl(tensor.mutable_data(), size); + break; + default: + LOG(FATAL) << "data type: " << PrecisionRepr(type) + << " is unsupported now"; + } +} + +template +void fill_tensor_host_rand_impl2(Dtype* dio, + Dtype vstart, + Dtype vend, + int64_t size) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(0, 1.f); + for (int64_t i = 0; i < size; ++i) { + Dtype random_num = static_cast(vstart + (vend - vstart) * dis(gen)); + dio[i] = random_num; + } +} + +/** + * \brief Fill the host tensor buffer with rand value from vstart to vend. + * \param tensor The reference of input tensor. + */ +void fill_tensor_rand(Tensor& tensor, float vstart, float vend) { // NOLINT + int64_t size = tensor.numel(); + PrecisionType type = tensor.precision(); + switch (type) { + case PRECISION(kInt8): + fill_tensor_host_rand_impl2(tensor.mutable_data(), + static_cast(vstart), + static_cast(vend), + size); + break; + case PRECISION(kInt32): + fill_tensor_host_rand_impl2(tensor.mutable_data(), + static_cast(vstart), + static_cast(vend), + size); + break; + case PRECISION(kFloat): + fill_tensor_host_rand_impl2( + tensor.mutable_data(), vstart, vend, size); + break; + default: + LOG(FATAL) << "data type: " << PrecisionRepr(type) + << " is unsupported now"; + } +} + +template +void print_tensor_host_impl(const Dtype* din, int64_t size, int64_t width); + +template <> +void print_tensor_host_impl(const float* din, int64_t size, int64_t width) { + for (int i = 0; i < size; ++i) { + printf("%.6f ", din[i]); + if ((i + 1) % width == 0) { + printf("\n"); + } + } + printf("\n"); +} + +template <> +void print_tensor_host_impl(const int* din, int64_t size, int64_t width) { + for (int i = 0; i < size; ++i) { + printf("%d ", din[i]); + if ((i + 1) % width == 0) { + printf("\n"); + } + } + printf("\n"); +} + +template <> +void print_tensor_host_impl(const signed char* din, + int64_t size, + int64_t width) { + for (int i = 0; i < size; ++i) { + printf("%d ", din[i]); + if ((i + 1) % width == 0) { + printf("\n"); + } + } + printf("\n"); +} +/** + * \brief Print the data in host tensor. + * \param tensor The reference of input tensor. + */ +void print_tensor(const Tensor& tensor) { + printf("host tensor data size: %ld\n", tensor.numel()); + int64_t size = tensor.numel(); + int64_t width = tensor.dims()[tensor.dims().size() - 1]; + PrecisionType type = tensor.precision(); + switch (type) { + case PRECISION(kInt8): + print_tensor_host_impl(tensor.data(), size, width); + break; + case PRECISION(kInt32): + print_tensor_host_impl(tensor.data(), size, width); + break; + case PRECISION(kFloat): + print_tensor_host_impl(tensor.data(), size, width); + break; + default: + LOG(FATAL) << "data type: " << PrecisionRepr(type) + << " is unsupported now"; + } +} + +template +double tensor_mean_value_host_impl(const Dtype* din, int64_t size) { + double sum = 0.0; + for (int64_t i = 0; i < size; ++i) { + sum += din[i]; + } + return sum / size; +} + +double tensor_mean(const Tensor& tensor) { + int64_t size = tensor.numel(); + PrecisionType type = tensor.precision(); + switch (type) { + case PRECISION(kInt8): + return tensor_mean_value_host_impl(tensor.data(), size); + case PRECISION(kInt32): + return tensor_mean_value_host_impl(tensor.data(), size); + case PRECISION(kFloat): + return tensor_mean_value_host_impl(tensor.data(), size); + default: + LOG(FATAL) << "data type: " << PrecisionRepr(type) + << " is unsupported now"; + } + return 0.0; +} + +template +void data_diff_kernel(const dtype* src1_truth, + const dtype* src2, + int size, + double& max_ratio, // NOLINT + double& max_diff) { // NOLINT + const double eps = 1e-6f; + max_diff = fabs(src1_truth[0] - src2[0]); + max_ratio = fabs(max_diff) / (std::abs(src1_truth[0]) + eps); + for (int i = 1; i < size; ++i) { + double diff = fabs(src1_truth[i] - src2[i]); + double ratio = fabs(diff) / (std::abs(src1_truth[i]) + eps); + if (max_ratio < ratio) { + max_diff = diff; + max_ratio = ratio; + } + } +} + +void tensor_cmp_host(const Tensor& src1_basic, + const Tensor& src2, + double& max_ratio, // NOLINT + double& max_diff) { // NOLINT + max_ratio = 0.; + max_diff = 0.; + int64_t size = src1_basic.numel(); + CHECK_EQ(size, src2.numel()) << "ERROR: tensor_cmp_host: wrong shape"; + auto ptype1 = PrecisionRepr(src1_basic.precision()); + auto ptype2 = PrecisionRepr(src2.precision()); + CHECK_EQ(ptype1, ptype2) << "ERROR: tensor_cmp_host: wrong data type"; + if (size == 0) return; + switch (src1_basic.precision()) { + case PRECISION(kFloat): + data_diff_kernel(src1_basic.data(), + src2.data(), + size, + max_ratio, + max_diff); + return; + case PRECISION(kInt32): + data_diff_kernel( + src1_basic.data(), src2.data(), size, max_ratio, max_diff); + return; + case PRECISION(kInt8): + data_diff_kernel(src1_basic.data(), + src2.data(), + size, + max_ratio, + max_diff); + return; + default: + LOG(FATAL) << "data type: " << PrecisionRepr(src1_basic.precision()) + << " is unsupported now"; + } +} + +template +void tensor_diff_kernel(const dtype* src1, + const dtype* src2, + dtype* dst, + int64_t size) { + for (int i = 0; i < size; ++i) { + dst[i] = src1[i] - src2[i]; + } +} +void tensor_diff(const Tensor& t1, const Tensor& t2, Tensor& tdiff) { // NOLINT + int64_t size1 = t1.numel(); + int64_t size2 = t2.numel(); + int64_t size_out = tdiff.numel(); + CHECK_EQ(size1, size2) << "ERROR: tensor_diff: wrong shape"; + CHECK_EQ(size1, size_out) << "ERROR: tensor_diff: wrong shape"; + auto ptype1 = PrecisionRepr(t1.precision()); + auto ptype2 = PrecisionRepr(t2.precision()); + auto ptype3 = PrecisionRepr(tdiff.precision()); + CHECK_EQ(ptype1, ptype2) << "ERROR: tensor_diff: wrong data type"; + CHECK_EQ(ptype1, ptype3) << "ERROR: tensor_diff: wrong data type"; + switch (t1.precision()) { + case PRECISION(kFloat): + tensor_diff_kernel(t1.data(), + t2.data(), + tdiff.mutable_data(), + size1); + return; + case PRECISION(kInt32): + tensor_diff_kernel( + t1.data(), t2.data(), tdiff.mutable_data(), size1); + case PRECISION(kInt8): + tensor_diff_kernel(t1.data(), + t2.data(), + tdiff.mutable_data(), + size1); + return; + default: + LOG(FATAL) << "data type: " << ptype1 << " is unsupported now"; + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/utils/timer.h b/lite/tests/utils/timer.h new file mode 100644 index 0000000000000000000000000000000000000000..095f32046e0dc5b9342163e1f4f13f4e30c10670 --- /dev/null +++ b/lite/tests/utils/timer.h @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include // NOLINT +#include + +namespace paddle { +namespace lite { + +class Timer final { + public: + Timer() {} + + ~Timer() {} + + void clear() { ms_time_.clear(); } + + void start() { tstart_ = std::chrono::system_clock::now(); } + + void end() { + tend_ = std::chrono::system_clock::now(); + auto ts = + std::chrono::duration_cast(tend_ - tstart_); + latest_time_ = 1000.f * static_cast(ts.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den; + ms_time_.push_back(latest_time_); + } + + float latest_time() const { return latest_time_; } + + float get_average_ms() { + if (ms_time_.size() == 0) { + return 0.f; + } + float sum = 0.f; + for (auto i : ms_time_) { + sum += i; + } + return sum / ms_time_.size(); + } + + float get_sum_ms() { + if (ms_time_.size() == 0) { + return 0.f; + } + float sum = 0.f; + for (auto i : ms_time_) { + sum += i; + } + return sum; + } + + // return tile (0-99) time. + float get_tile_time(float tile) { + if (tile < 0 || tile > 100) { + return -1.f; + } + int total_items = static_cast(ms_time_.size()); + if (total_items <= 0) { + return -2.f; + } + ms_time_.sort(); + int pos = static_cast(tile * total_items / 100); + auto it = ms_time_.begin(); + for (int i = 0; i < pos; ++i) { + ++it; + } + return *it; + } + + std::list get_time_stat() { return ms_time_; } + + float get_min_time() { + ms_time_.sort(); + return *ms_time_.begin(); + } + + float get_max_time() { + ms_time_.sort([](int a, int b) { return a > b; }); + return *ms_time_.begin(); + } + + private: + std::chrono::time_point tstart_; + std::chrono::time_point tend_; + std::list ms_time_; + float latest_time_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/tools/benchmark.sh b/lite/tools/benchmark.sh index c3261c6d4409842d6821179eb8b4e404a28d4c6b..683271fa8f5c97a39099429ed003ba7414de1132 100644 --- a/lite/tools/benchmark.sh +++ b/lite/tools/benchmark.sh @@ -8,6 +8,7 @@ then echo "Usage:" echo " sh benchmark.sh " echo " sh benchmark.sh " + echo " sh benchmark.sh " exit fi @@ -20,6 +21,7 @@ RESULT_FILENAME=$3 WARMUP=10 REPEATS=30 IS_RUN_MODEL_OPTIMIZE=false +IS_RUN_QUANTIZED_MODEL=false NUM_THREADS_LIST=(1 2 4) MODELS_LIST=$(ls $MODELS_DIR) @@ -28,6 +30,10 @@ if [ $# -gt 3 ]; then IS_RUN_MODEL_OPTIMIZE=$4 fi +if [ $# -gt 4 ]; +then + IS_RUN_QUANTIZED_MODEL=$5 +fi # Adb push benchmark_bin, models adb push $BENCHMARK_BIN $ANDROID_DIR/benchmark_bin @@ -46,7 +52,8 @@ for threads in ${NUM_THREADS_LIST[@]}; do --repeats=$REPEATS \ --threads=$threads \ --result_filename=$ANDROID_DIR/$RESULT_FILENAME \ - --run_model_optimize=$IS_RUN_MODEL_OPTIMIZE" + --run_model_optimize=$IS_RUN_MODEL_OPTIMIZE \ + --is_quantized_model=$IS_RUN_QUANTIZED_MODEL" done adb shell "echo >> $ANDROID_DIR/$RESULT_FILENAME" done diff --git a/lite/tools/build.sh b/lite/tools/build.sh index 0860f3d00e951d9a14cc484ca292dae19b892d2a..d1f47c149ec7d6c30767d6db19e371e5e32b865d 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -15,6 +15,10 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4} # global variables BUILD_EXTRA=OFF BUILD_JAVA=ON +BUILD_PYTHON=OFF +BUILD_DIR=$(pwd) +OPTMODEL_DIR="" +BUILD_TAILOR=OFF readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz @@ -23,16 +27,18 @@ readonly workspace=$PWD # for code gen, a source file is generated after a test, but is dependended by some targets in cmake. # here we fake an empty file to make cmake works. function prepare_workspace { + local root_dir=$1 + local build_dir=$2 # in build directory # 1. Prepare gen_code file - GEN_CODE_PATH_PREFIX=lite/gen_code - mkdir -p ./${GEN_CODE_PATH_PREFIX} - touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc + GEN_CODE_PATH_PREFIX=$build_dir/lite/gen_code + mkdir -p ${GEN_CODE_PATH_PREFIX} + touch ${GEN_CODE_PATH_PREFIX}/__generated_code__.cc # 2.Prepare debug tool - DEBUG_TOOL_PATH_PREFIX=lite/tools/debug - mkdir -p ./${DEBUG_TOOL_PATH_PREFIX} - cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/ + DEBUG_TOOL_PATH_PREFIX=$build_dir/lite/tools/debug + mkdir -p ${DEBUG_TOOL_PATH_PREFIX} + cp $root_dir/lite/tools/debug/analysis_tool.py ${DEBUG_TOOL_PATH_PREFIX}/ } function prepare_thirdparty { @@ -48,6 +54,19 @@ function prepare_thirdparty { fi } +function build_model_optimize_tool { + cd $workspace + prepare_thirdparty + mkdir -p build.model_optimize_tool + cd build.model_optimize_tool + cmake .. -DWITH_LITE=ON \ + -DLITE_ON_MODEL_OPTIMIZE_TOOL=ON \ + -DWITH_TESTING=OFF \ + -DLITE_BUILD_EXTRA=ON \ + -DWITH_MKL=OFF + make model_optimize_tool -j$NUM_PROC +} + function make_tiny_publish_so { local os=$1 local abi=$2 @@ -68,13 +87,17 @@ function make_tiny_publish_so { fi cmake .. \ + ${PYTHON_FLAGS} \ ${CMAKE_COMMON_OPTIONS} \ -DWITH_TESTING=OFF \ -DLITE_WITH_JAVA=$BUILD_JAVA \ + -DLITE_WITH_PYTHON=$BUILD_PYTHON \ -DLITE_SHUTDOWN_LOG=ON \ -DLITE_ON_TINY_PUBLISH=ON \ -DANDROID_STL_TYPE=$android_stl \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \ + -DLITE_BUILD_TAILOR=$BUILD_TAILOR \ + -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} make publish_inference -j$NUM_PROC @@ -90,27 +113,32 @@ function make_full_publish_so { #git submodule update --init --recursive prepare_thirdparty - cur_dir=$(pwd) - build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} - if [ -d $build_dir ] + root_dir=$(pwd) + build_directory=$BUILD_DIR/build.lite.${os}.${abi}.${lang} + + if [ -d $build_directory ] then - rm -rf $build_dir + rm -rf $build_directory fi - mkdir -p $build_dir - cd $build_dir + mkdir -p $build_directory + cd $build_directory if [ ${os} == "armlinux" ]; then BUILD_JAVA=OFF fi - prepare_workspace - cmake .. \ + prepare_workspace $root_dir $build_directory + cmake $root_dir \ + ${PYTHON_FLAGS} \ ${CMAKE_COMMON_OPTIONS} \ -DWITH_TESTING=OFF \ -DLITE_WITH_JAVA=$BUILD_JAVA \ + -DLITE_WITH_PYTHON=$BUILD_PYTHON \ -DLITE_SHUTDOWN_LOG=ON \ -DANDROID_STL_TYPE=$android_stl \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \ + -DLITE_BUILD_TAILOR=$BUILD_TAILOR \ + -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} make publish_inference -j4 @@ -124,23 +152,23 @@ function make_all_tests { #git submodule update --init --recursive prepare_thirdparty - cur_dir=$(pwd) - build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} + root_dir=$(pwd) + build_directory=$BUILD_DIR/build.lite.${os}.${abi}.${lang} if [ -d $build_dir ] then rm -rf $build_dir fi - mkdir -p $build_dir - cd $build_dir + mkdir -p $build_directory + cd $build_directory - prepare_workspace - cmake .. \ + prepare_workspace $root_dir $build_directory + cmake $root_dir \ ${CMAKE_COMMON_OPTIONS} \ -DWITH_TESTING=ON \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} - make lite_compile_deps -j4 + make lite_compile_deps -j$NUM_PROC cd - > /dev/null } @@ -179,6 +207,65 @@ function make_ios { cd - } +function make_cuda { + prepare_thirdparty + + root_dir=$(pwd) + build_directory=$BUILD_DIR/build_cuda + + if [ -d $build_directory ] + then + rm -rf $build_directory + fi + mkdir -p $build_directory + cd $build_directory + + prepare_workspace $root_dir $build_directory + + cmake .. -DWITH_MKL=OFF \ + -DLITE_WITH_CUDA=ON \ + -DWITH_MKLDNN=OFF \ + -DLITE_WITH_X86=OFF \ + -DLITE_WITH_PROFILE=OFF \ + -DWITH_LITE=ON \ + -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \ + -DWITH_TESTING=OFF \ + -DLITE_WITH_ARM=OFF \ + -DLITE_WITH_PYTHON=ON \ + -DLITE_BUILD_EXTRA=ON + + make publish_inference_python_lib -j8 + cd - +} + +function make_x86 { + prepare_thirdparty + + root_dir=$(pwd) + build_directory=$BUILD_DIR/build.lite.x86 + + if [ -d $build_directory ] + then + rm -rf $build_directory + fi + mkdir -p $build_directory + cd $build_directory + + prepare_workspace $root_dir $build_directory + + cmake .. -DWITH_MKL=ON \ + -DWITH_MKLDNN=OFF \ + -DLITE_WITH_X86=ON \ + -DLITE_WITH_PROFILE=OFF \ + -DWITH_LITE=ON \ + -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \ + -DLITE_WITH_ARM=OFF \ + -DWITH_GPU=OFF \ + -DLITE_BUILD_EXTRA=ON + + make publish_inference -j4 + cd - +} function print_usage { set +x @@ -199,11 +286,14 @@ function print_usage { echo echo -e "optional argument:" echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)" + echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)" + echo -e "--build_java: (OFF|ON); controls whether to publish java api lib (Only ANDROID is supported)" + echo -e "--build_dir: directory for building" echo echo -e "argument choices:" echo -e "--arm_os:\t android|ios|ios64" echo -e "--arm_abi:\t armv8|armv7" - echo -e "--arm_lang:\t gcc|clang (for android)" + echo -e "--arm_lang:\t only support gcc now, clang will be supported in future.(for android)" echo -e "--android_stl:\t c++_static|c++_shared (for android)" echo echo -e "tasks:" @@ -234,6 +324,13 @@ function main { ;; --arm_lang=*) ARM_LANG="${i#*=}" + if [ ${ARM_LANG} == "clang" ]; then + set +x + echo + echo -e "error: only support gcc now, clang will be supported in future." + echo + exit 1 + fi shift ;; --android_stl=*) @@ -244,6 +341,26 @@ function main { BUILD_EXTRA="${i#*=}" shift ;; + --build_python=*) + BUILD_PYTHON="${i#*=}" + shift + ;; + --build_java=*) + BUILD_JAVA="${i#*=}" + shift + ;; + --build_dir=*) + BUILD_DIR="${i#*=}" + shift + ;; + --opt_model_dir=*) + OPTMODEL_DIR="${i#*=}" + shift + ;; + --build_tailor=*) + BUILD_TAILOR="${i#*=}" + shift + ;; tiny_publish) make_tiny_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL shift @@ -260,6 +377,18 @@ function main { make_ios $ARM_OS $ARM_ABI shift ;; + build_optimize_tool) + build_model_optimize_tool + shift + ;; + cuda) + make_cuda + shift + ;; + x86) + make_x86 + shift + ;; *) # unknown option print_usage diff --git a/lite/tools/build_fpga.sh b/lite/tools/build_fpga.sh index 75d31bc9bd9169b447824bc31dbd0d355b5c3e7c..f8c186e92fc3ba23e5e09b6a139202d028e58fc6 100755 --- a/lite/tools/build_fpga.sh +++ b/lite/tools/build_fpga.sh @@ -18,7 +18,7 @@ cmake .. \ -DLITE_WITH_FPGA=ON \ -DLITE_WITH_OPENMP=ON \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ - -DWITH_TESTING=ON \ + -DWITH_TESTING=OFF \ -DARM_TARGET_OS=armlinux make -j8 diff --git a/lite/tools/build_npu.sh b/lite/tools/build_npu.sh index 9d9d6aceaedb1cf05911a64680417411e03c2d18..03a74046f17ad03bccc7b6d5050acae9d643686c 100755 --- a/lite/tools/build_npu.sh +++ b/lite/tools/build_npu.sh @@ -1,16 +1,29 @@ #!/bin/bash set -ex +# global variables with default value +ARM_OS="android" # android only yet +ARM_ABI="armv8" # armv8, armv7 +ARM_LANG="gcc" # gcc only yet +ANDROID_STL="c++_static" # c++_shared, c++_static +DDK_ROOT="$(pwd)/ai_ddk_lib/" # HIAI SDK from https://developer.huawei.com/consumer/cn/hiai/ +TARGET_NAME="test_npu_pass" # default target +BUILD_EXTRA=OFF # ON(with sequence ops)/OFF +WITH_JAVA=ON # ON(build jar and jni so)/OFF +WITH_TESTING=ON # ON/OFF +SHUTDOWN_LOG=OFF # ON(disable logging)/OFF +ON_TINY_PUBLISH=OFF # ON(tiny publish)/OFF(full publish) + function print_usage { echo -e "\nUSAGE:" echo echo "----------------------------------------" echo -e "--arm_os= android only yet." echo -e "--arm_abi= armv8, armv7 yet." - echo -e "--android_stl= shared or static" - echo -e "--arm_lang= " - echo -e "--ddk_root= " - echo -e "--test_name=" + echo -e "--android_stl= c++_shared or c++_static" + echo -e "--arm_lang=" + echo -e "--ddk_root=" + echo -e "--target_name=" echo "----------------------------------------" echo } @@ -47,80 +60,56 @@ function prepare_thirdparty { fi } -function cmake_npu { - prepare_workspace - # $1: ARM_TARGET_OS in "android" , "armlinux" - # $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf" - # $3: ARM_TARGET_LANG in "gcc" "clang" - # $4: ANDROID_STL_TYPE in "c++_shared" "c++_static" - # $5: DDK_ROOT path +function build_npu { + cur_dir=$(pwd) + + prepare_thirdparty + + local stl_dir + local publish_dir + # the c++ symbol is not recognized by the bundled script + if [[ "${ANDROID_STL}" == "c++_shared" ]]; then + stl_dir="cxx_shared" + fi + if [[ "${ANDROID_STL}" == "c++_static" ]]; then + stl_dir="cxx_static" + fi + if [[ "${ON_TINY_PUBLISH}" == "ON" ]]; then + WITH_TESTING=OFF + SHUTDOWN_LOG=ON + publish_dir="tiny_publish" + else + publish_dir="full_publish" + fi + build_dir=$cur_dir/build.lite.npu.${ARM_OS}.${ARM_ABI}.${ARM_LANG}.${stl_dir}.${publish_dir} + mkdir -p $build_dir + cd $build_dir # NPU libs need API LEVEL 24 above + prepare_workspace cmake .. \ -DWITH_GPU=OFF \ -DWITH_MKL=OFF \ -DWITH_LITE=ON \ -DLITE_WITH_CUDA=OFF \ -DLITE_WITH_X86=OFF \ - -DLITE_BUILD_EXTRA=ON \ + -DLITE_BUILD_EXTRA=${BUILD_EXTRA} \ -DLITE_WITH_ARM=ON \ -DWITH_ARM_DOTPROD=ON \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ - -DWITH_TESTING=ON \ - -DLITE_WITH_JAVA=ON \ + -DWITH_TESTING=${WITH_TESTING} \ + -DLITE_WITH_JAVA=${WITH_JAVA} \ + -DLITE_SHUTDOWN_LOG=${SHUTDOWN_LOG} \ -DLITE_WITH_NPU=ON \ + -DLITE_ON_TINY_PUBLISH=${ON_TINY_PUBLISH} \ -DANDROID_API_LEVEL=24 \ - -DARM_TARGET_OS=$1 \ - -DARM_TARGET_ARCH_ABI=$2 \ - -DARM_TARGET_LANG=$3 \ - -DANDROID_STL_TYPE=$4 \ - -DNPU_DDK_ROOT=$5 -} - -function build_npu { - # os, abi, lang, stl, ddk_root, test_name - cur_dir=$(pwd) - - local os=android - local abi=armv8 - local lang=gcc - local stl="c++_shared" - local ddk_root="${cur_dir}/ai_ddk_lib/" - local test_name=test_npu_pass - prepare_thirdparty - - if [ "x${ARM_OS}" != "x" ]; then - os=$ARM_OS - fi - if [[ "x${ARM_ABI}" != "x" ]]; then - abi=$ARM_ABI - fi - if [[ "x${ARM_LANG}" != "x" ]]; then - lang=$ARM_LANG - fi - if [[ "x${ANDROID_STL}" != "x" ]]; then - stl=$ANDROID_STL - fi - if [[ "x${DDK_ROOT}" != "x" ]]; then - ddk_root=$DDK_ROOT - fi - if [[ $# -ge 1 ]]; then - test_name=$1 - fi - - # the c++ symbol is not recognized by the bundled script - if [[ "${stl}" == "c++_shared" ]]; then - stl_dir="cxx_shared" - fi - if [[ "${stl}" == "c++_static" ]]; then - stl_dir="cxx_static" - fi - build_dir=$cur_dir/build.lite.npu.${os}.${abi}.${lang}.${stl_dir} - mkdir -p $build_dir - cd $build_dir + -DARM_TARGET_OS=${ARM_OS} \ + -DARM_TARGET_ARCH_ABI=${ARM_ABI} \ + -DARM_TARGET_LANG=${ARM_LANG} \ + -DANDROID_STL_TYPE=${ANDROID_STL} \ + -DNPU_DDK_ROOT=${DDK_ROOT} - cmake_npu ${os} ${abi} ${lang} ${stl} ${ddk_root} - make $test_name -j8 + make $TARGET_NAME -j2 cd - echo "Done" @@ -130,12 +119,8 @@ function main { # Parse command line. for i in "$@"; do case $i in - --tests=*) - TESTS_FILE="${i#*=}" - shift - ;; - --test_name=*) - TEST_NAME="${i#*=}" + --target_name=*) + TARGET_NAME="${i#*=}" shift ;; --arm_os=*) @@ -154,16 +139,27 @@ function main { ANDROID_STL="${i#*=}" shift ;; + --build_extra=*) + BUILD_EXTRA="${i#*=}" + shift + ;; --ddk_root=*) DDK_ROOT="${i#*=}" shift ;; build) - build_npu $TEST_NAME + build_npu shift ;; full_publish) - build_npu publish_inference + TARGET_NAME=publish_inference + build_npu + shift + ;; + tiny_publish) + ON_TINY_PUBLISH=ON + TARGET_NAME=publish_inference + build_npu shift ;; *) diff --git a/lite/tools/build_xpu.sh b/lite/tools/build_xpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..62a123c82b2945147fa8616ad8faf0af33a32302 --- /dev/null +++ b/lite/tools/build_xpu.sh @@ -0,0 +1,116 @@ +#!/bin/bash +set -ex + +# global variables with default value +XPU_SDK_ROOT="$(pwd)/../XPU_SDK" # XPU SDK +TARGET_NAME="lite_compile_deps" # default target +BUILD_EXTRA=ON # ON(with sequence ops)/OFF +WITH_TESTING=ON # ON/OFF + +function print_usage { + echo -e "\nUSAGE:" + echo + echo "----------------------------------------" + echo -e "--xpu_sdk_root=" + echo -e "--target_name=" + echo "----------------------------------------" + echo +} + +# readonly variables with default value +readonly CMAKE_COMMON_OPTIONS="-DWITH_LITE=ON \ + -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \ + -DWITH_PYTHON=OFF \ + -DLITE_WITH_ARM=OFF" + +readonly NUM_CORES_FOR_COMPILE=${LITE_BUILD_THREADS:-1} + +readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz +readonly workspace=$(pwd) + +function prepare_thirdparty { + if [ ! -d $workspace/third-party -o -f $workspace/third-party-05b862.tar.gz ]; then + rm -rf $workspace/third-party + + if [ ! -f $workspace/third-party-05b862.tar.gz ]; then + wget $THIRDPARTY_TAR + fi + tar xzf third-party-05b862.tar.gz + else + git submodule update --init --recursive + fi +} + +# for code gen, a source file is generated after a test, but is dependended by some targets in cmake. +# here we fake an empty file to make cmake works. +function prepare_workspace { + # in build directory + # 1. Prepare gen_code file + GEN_CODE_PATH_PREFIX=lite/gen_code + mkdir -p ./${GEN_CODE_PATH_PREFIX} + touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc + + # 2.Prepare debug tool + DEBUG_TOOL_PATH_PREFIX=lite/tools/debug + mkdir -p ./${DEBUG_TOOL_PATH_PREFIX} + cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/ + + # clone submodule + # git submodule update --init --recursive + prepare_thirdparty +} + +function build_xpu { + build_dir=${workspace}/build.lite.xpu + mkdir -p $build_dir + cd $build_dir + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/third_party/install/mklml/lib" + prepare_workspace + cmake .. \ + ${CMAKE_COMMON_OPTIONS} \ + -DWITH_GPU=OFF \ + -DWITH_MKLDNN=OFF \ + -DLITE_WITH_X86=ON \ + -DWITH_MKL=ON \ + -DLITE_BUILD_EXTRA=ON \ + -DLITE_WITH_XPU=ON \ + -DWITH_TESTING=${WITH_TESTING} \ + -DXPU_SDK_ROOT=${XPU_SDK_ROOT} + + make $TARGET_NAME -j$NUM_CORES_FOR_COMPILE + + cd - + echo "Done" +} + +function main { + # Parse command line. + for i in "$@"; do + case $i in + --target_name=*) + TARGET_NAME="${i#*=}" + shift + ;; + --build_extra=*) + BUILD_EXTRA="${i#*=}" + shift + ;; + --xpu_sdk_root=*) + XPU_SDK_ROOT="${i#*=}" + shift + ;; + build) + build_xpu + shift + ;; + *) + # unknown option + print_usage + exit 1 + ;; + esac + done +} + +main $@ diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index eb91e15a6f2e6b7d4e32e434cb87b109bca33db7..8be8e6e6b6da1e2aa38b6fcbcf95b23a8543a5be 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -42,7 +42,7 @@ function prepare_workspace { cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/ # clone submodule - #git submodule update --init --recursive + # git submodule update --init --recursive prepare_thirdparty } @@ -151,7 +151,7 @@ function build_opencl { # This method is only called in CI. function cmake_x86_for_CI { prepare_workspace # fake an empty __generated_code__.cc to pass cmake. - cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON -DWITH_MKL=OFF \ + cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON -DWITH_MKL=ON \ -DLITE_BUILD_EXTRA=ON \ # Compile and execute the gen_code related test, so it will generate some code, and make the compilation reasonable. @@ -194,9 +194,9 @@ function build { function test_server { # Due to the missing of x86 kernels, we skip the following tests temporarily. # TODO(xxx) clear the skip list latter - local skip_list=("test_paddle_api" "test_cxx_api" "test_googlenet" + local skip_list=("test_paddle_api" "test_cxx_api" "test_mobilenetv1_lite_x86" "test_mobilenetv2_lite_x86" - "test_inceptionv4_lite_x86" "test_light_api" + "test_light_api" "test_apis" "test_model_bin" ) local to_skip=0 @@ -219,11 +219,12 @@ function test_server { function build_test_server { mkdir -p ./build cd ./build - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/paddle/build/third_party/install/mklml/lib" + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/third_party/install/mklml/lib" cmake_x86_for_CI build test_server + test_model_optimize_tool_compile } function build_test_train { @@ -247,6 +248,63 @@ function build_test_train { } +function cmake_xpu { + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/third_party/install/mklml/lib" + prepare_workspace + cmake .. \ + ${common_flags} \ + -DWITH_GPU=OFF \ + -DWITH_MKLDNN=OFF \ + -DLITE_WITH_X86=ON \ + -DWITH_MKL=ON \ + -DLITE_BUILD_EXTRA=ON \ + -DLITE_WITH_XPU=ON \ + -DXPU_SDK_ROOT="$(pwd)/../../XPU_SDK" +} + +function build_xpu { + make lite_compile_deps -j$NUM_CORES_FOR_COMPILE +} + +# It will eagerly test all lite related unittests. +function test_xpu { + # Due to the missing of xpu kernels, we skip the following tests temporarily. + # TODO(xxx) clear the skip list latter + local skip_list=("test_paddle_api" "test_cxx_api" "test_googlenet" + "test_mobilenetv1_lite_x86" "test_mobilenetv2_lite_x86" + "test_inceptionv4_lite_x86" "test_light_api" + "test_apis" "test_model_bin" + ) + local to_skip=0 + for _test in $(cat $TESTS_FILE); do + to_skip=0 + for skip_name in ${skip_list[@]}; do + if [ $skip_name = $_test ]; then + echo "to skip " $skip_name + to_skip=1 + fi + done + + if [ $to_skip -eq 0 ]; then + ctest -R $_test -V + fi + done +} + +# Build the code and run lite server tests. This is executed in the CI system. +function build_test_xpu { + cur_dir=$(pwd) + + build_dir=$cur_dir/build.lite.xpu + mkdir -p $build_dir + cd $build_dir + + cmake_xpu + build_xpu + + test_xpu +} + # test_arm_android function test_arm_android { local test_name=$1 @@ -393,20 +451,27 @@ function test_arm_model { adb -s emulator-${port} shell "${adb_work_dir}/${test_name} --model_dir=$adb_model_path" } -function _test_model_optimize_tool { - local port=$1 - local remote_model_path=$ADB_WORK_DIR/lite_naive_model - local remote_test=$ADB_WORK_DIR/model_optimize_tool - local adb="adb -s emulator-${port}" - +# function _test_model_optimize_tool { +# local port=$1 +# local remote_model_path=$ADB_WORK_DIR/lite_naive_model +# local remote_test=$ADB_WORK_DIR/model_optimize_tool +# local adb="adb -s emulator-${port}" + +# make model_optimize_tool -j$NUM_CORES_FOR_COMPILE +# local test_path=$(find . -name model_optimize_tool | head -n1) +# local model_path=$(find . -name lite_naive_model | head -n1) +# $adb push ${test_path} ${ADB_WORK_DIR} +# $adb shell mkdir -p $remote_model_path +# $adb push $model_path/* $remote_model_path +# $adb shell $remote_test --model_dir $remote_model_path --optimize_out ${remote_model_path}.opt \ +# --valid_targets "arm" +# } + +function test_model_optimize_tool_compile { + cd $workspace + cd build + cmake .. -DWITH_LITE=ON -DLITE_ON_MODEL_OPTIMIZE_TOOL=ON -DWITH_TESTING=OFF -DLITE_BUILD_EXTRA=ON make model_optimize_tool -j$NUM_CORES_FOR_COMPILE - local test_path=$(find . -name model_optimize_tool | head -n1) - local model_path=$(find . -name lite_naive_model | head -n1) - $adb push ${test_path} ${ADB_WORK_DIR} - $adb shell mkdir -p $remote_model_path - $adb push $model_path/* $remote_model_path - $adb shell $remote_test --model_dir $remote_model_path --optimize_out ${remote_model_path}.opt \ - --valid_targets "arm" } function _test_paddle_code_generator { @@ -558,8 +623,8 @@ function test_arm { # test finally test_arm_api $port - _test_model_optimize_tool $port - _test_paddle_code_generator $port + # _test_model_optimize_tool $port + # _test_paddle_code_generator $port } function prepare_emulator { @@ -842,6 +907,10 @@ function main { cmake_x86 shift ;; + cmake_xpu) + cmake_xpu + shift + ;; cmake_opencl) cmake_opencl $ARM_OS $ARM_ABI $ARM_LANG shift @@ -866,6 +935,10 @@ function main { test_server shift ;; + test_xpu) + test_xpu + shift + ;; test_arm) test_arm $ARM_OS $ARM_ABI $ARM_LANG $ARM_PORT shift @@ -882,6 +955,10 @@ function main { build_test_server shift ;; + build_test_xpu) + build_test_xpu + shift + ;; build_test_train) build_test_train shift diff --git a/lite/tools/cmake_tools/ast.py b/lite/tools/cmake_tools/ast.py new file mode 100644 index 0000000000000000000000000000000000000000..86b8a58b06665ebbc8f3131a037b37a240c103c6 --- /dev/null +++ b/lite/tools/cmake_tools/ast.py @@ -0,0 +1,364 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +class SyntaxParser(object): + def __init__(self, str): + self.str = str + self.cur_pos = 0 + self.N = len(self.str) + self.token = '' + + def eat_char(self): + self.cur_pos += 1 + + def eat_str(self): + ''' + "xx" + ''' + self.token = '' + assert self.cur == '"'; + self.cur_pos += 1; + + assert self.cur_pos < self.N + while self.cur != '"': + self.token += self.cur + self.cur_pos += 1 + assert self.cur_pos < self.N + assert self.cur == '"' + self.cur_pos += 1 + #logging.warning('get: %s' % self.token) + + def eat_word(self): + self.token = '' + str = '' + while self.cur.isalnum() or self.cur in ('_', ':',): + self.token += self.cur + self.forward() + + #logging.warning('get: %s' % self.token) + + def eat_left_parentheses(self): + ''' + ( + ''' + self.assert_is('(') + self.token = '(' + self.forward() + #logging.warning('get: %s' % self.token) + + def eat_right_parentheses(self): + ''' + ) + ''' + self.assert_is(')') + self.token = ')' + self.forward() + #logging.warning('get: %s' % self.token) + + def eat_left_brace(self): + ''' + { + ''' + self.assert_is('{') + self.token = '{' + self.forward() + #logging.warning('get: %s' % self.token) + + def eat_right_brace(self): + ''' + } + ''' + self.assert_is('}') + self.token = '}' + self.forward() + #logging.warning('get: %s' % self.token) + + def eat_comma(self): + ''' + , + ''' + self.assert_is(',') + self.token = ',' + self.forward() + #logging.warning('get: %s' % self.token) + + def eat_spaces(self): + ''' + eat space like string. + ''' + while self.cur_pos < len(self.str): + if self.cur in (' ', '\t', '\n'): + self.forward() + else: + break + + def eat_point(self): + ''' + . + ''' + self.assert_is('.') + self.token = '.' + self.forward() + #logging.warning('get: %s' % self.token) + + def eat_any_but_brace(self): + ''' + anything but {} + ''' + start = self.cur_pos + while self.cur not in ('{', '}'): + self.cur_pos += 1 + + self.token = self.str[start:self.cur_pos] + #logging.warning('get: %s' % self.token) + + def eat_semicolon(self): + ''' + ; + ''' + self.assert_is(';') + self.token = ';' + self.forward() + #logging.warning('get: %s' % self.token) + + def assert_is(self, w): + assert self.cur == w, "token should be %s, but get %s" % (w, self.cur) + + @property + def cur(self): + assert self.cur_pos < self.N + return self.str[self.cur_pos] + #logging.warning('get: %s' % self.token) + + def forward(self): + self.cur_pos += 1 + + +class IO: + def __init__(self): + self.name = '' + self.type = '' + + def __repr__(self): + return "- %s: %s" % (self.name, self.type) + + +class KernelRegistry: + def __init__(self): + self.op_type = '' + self.target = '' + self.precision = '' + self.data_layout = '' + self.class_ = '' + self.alias = '' + self.inputs = [] + self.outputs = [] + + def __repr__(self): + str = "Kernel({op_type}, {target}, {precision}, {data_layout}, {alias}):".format( + op_type = self.op_type, + target = self.target, + precision = self.precision, + data_layout = self.data_layout, + alias = self.alias, + ) + + str += '\n' + '\n'.join(repr(io) for io in self.inputs) + str += '\n' + '\n'.join(repr(io) for io in self.outputs) + str += '\n' + return str + + +class RegisterLiteKernelParser(SyntaxParser): + + KEYWORD = 'REGISTER_LITE_KERNEL' + + def __init__(self, str): + super(RegisterLiteKernelParser, self).__init__(str) + + self.kernels = [] + + def parse(self): + find_registry_command = False + + while self.cur_pos < len(self.str): + start = self.str.find(self.KEYWORD, self.cur_pos) + if start != -1: + #print 'str ', start, self.str[start-2: start] + if start != 0 and '/' in self.str[start-2: start]: + ''' + skip commented code + ''' + self.cur_pos = start + 1 + continue + self.cur_pos = start + k = KernelRegistry() + self.kernels.append(self.parse_register(k)) + else: + break + + def eat_class(self): + start = self.cur_pos + self.eat_word() + stack = '' + if self.cur == '<': + stack = stack + '<' + self.forward() + while stack: + if self.cur == '<': + stack = stack + '<' + elif self.cur == '>': + stack = stack[1:] + else: + pass + self.forward() + self.token = self.str[start:self.cur_pos] + + + def parse_register(self, k): + + self.eat_word() + assert self.token == self.KEYWORD + self.eat_spaces() + + self.eat_left_parentheses() + self.eat_spaces() + + self.eat_word() + k.op_type = self.token + self.eat_comma() + self.eat_spaces() + + + self.eat_word() + k.target = self.token + self.eat_comma() + self.eat_spaces() + + self.eat_word() + k.precision = self.token + self.eat_comma() + self.eat_spaces() + + self.eat_word() + k.data_layout = self.token + self.eat_comma() + self.eat_spaces() + + self.eat_class() + k.class_ = self.token + self.eat_comma() + self.eat_spaces() + + self.eat_word() + k.alias = self.token + self.eat_spaces() + + self.eat_right_parentheses() + self.eat_spaces() + + + def eat_io(is_input, io): + self.eat_left_parentheses() + self.eat_str() + io.name = self.token + self.eat_comma() + self.eat_spaces() + + self.eat_left_brace() + self.eat_any_but_brace() + io.type = self.token + self.eat_right_brace() + self.eat_spaces() + self.eat_right_parentheses() + self.eat_spaces() + + + # eat input and output + while self.cur_pos < len(self.str): + self.eat_point() + self.eat_spaces() + self.eat_word() + assert self.token in ('BindInput', 'BindOutput', 'SetVersion', 'Finalize') + io = IO() + + if self.token == 'BindInput': + eat_io(True, io) + k.inputs.append(io) + elif self.token == 'BindOutput': + eat_io(False, io) + k.outputs.append(io) + elif self.token == 'SetVersion': + self.eat_left_parentheses() + self.eat_str() + self.version = self.token + self.eat_right_parentheses() + self.eat_spaces() + else: + self.eat_left_parentheses() + self.eat_right_parentheses() + self.eat_semicolon() + self.eat_spaces() + return k + break + + +class RegisterLiteOpParser(SyntaxParser): + + KEYWORD = 'REGISTER_LITE_OP' + + def __init__(self, str): + super(RegisterLiteOpParser, self).__init__(str) + self.ops = [] + + def parse(self): + while self.cur_pos < len(self.str): + start = self.str.find(self.KEYWORD, self.cur_pos) + if start != -1: + #print 'str ', start, self.str[start-2: start] + if start != 0 and '/' in self.str[start-2: start]: + ''' + skip commented code + ''' + self.cur_pos = start + 1 + continue + self.cur_pos = start + self.ops.append(self.__parse_register()) + else: + break + return self.ops + + def __parse_register(self): + self.eat_word() + assert self.token == self.KEYWORD + self.eat_spaces() + + self.eat_left_parentheses() + self.eat_spaces() + + self.eat_word() + return self.token + + +if __name__ == '__main__': + with open('/home/chunwei/project2/Paddle-Lite/lite/kernels/arm/activation_compute.cc') as f: + c = f.read() + kernel_parser = RegisterLiteKernelParser(c) + + kernel_parser.parse() + +# for k in kernel_parser.kernels: +# print k diff --git a/lite/tools/cmake_tools/create_fake_kernel_registry.py b/lite/tools/cmake_tools/create_fake_kernel_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..140d77320704f62dfb2492eec3ad7238fe3868ff --- /dev/null +++ b/lite/tools/cmake_tools/create_fake_kernel_registry.py @@ -0,0 +1,147 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +import logging +from ast import RegisterLiteKernelParser +from utils import * + +ops_list_path = sys.argv[1] +dest_path = sys.argv[2] +kernelmap_path = sys.argv[3] + +out_lines = [ + '#pragma once', + '#include "lite/core/op_registry.h"', + '#include "lite/core/kernel.h"', + '#include "lite/core/type_system.h"', + '', +] + +fake_kernel = ''' + +namespace paddle { +namespace lite { + +class %s : public KernelLite { + public: + void PrepareForRun() override {} + + void Run() override {} + + virtual ~%s() = default; +}; + +} // namespace lite +} // namespace paddle +''' + +# create .h file to store kernel&source relationship +kernel_src_map_lines = [ +''' +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +// ATTENTION This can only include in a .cc file. + +const std::map kernel2path_map{ + +''' +] + + +with open(ops_list_path) as f: + paths = set([path for path in f]) + for path in paths: + print('path', path) + with open(path.strip()) as g: + c = g.read() + kernel_parser = RegisterLiteKernelParser(c) + kernel_parser.parse() + + for k in kernel_parser.kernels: + kernel_name = "{op_type}_{target}_{precision}_{data_layout}_{alias}_class".format( + op_type = k.op_type, + target = k.target, + precision = k.precision, + data_layout = k.data_layout, + alias = k.alias, + ) + + kernel_define = fake_kernel % ( + kernel_name, + k.target, + k.precision, + k.data_layout, + kernel_name, + ) + + out_lines.append(kernel_define) + out_lines.append("") + + + key = "REGISTER_LITE_KERNEL(%s, %s, %s, %s, %s, %s)" % ( + k.op_type, + k.target, + k.precision, + k.data_layout, + '::paddle::lite::' + kernel_name, + k.alias, + ) + out_lines.append(key) + + for input in k.inputs: + io = ' .BindInput("%s", {%s})' % (input.name, input.type) + out_lines.append(io) + for output in k.outputs: + io = ' .BindOutput("%s", {%s})' % (output.name, output.type) + out_lines.append(io) + out_lines.append(" .Finalize();") + out_lines.append("") + out_lines.append(gen_use_kernel_statement(k.op_type, k.target, k.precision, k.data_layout, k.alias)) + + index = path.rindex('/') + filename = path[index + 1:] + map_element = ' {"%s,%s,%s,%s,%s", "%s"},' % ( + k.op_type, + k.target, + k.precision, + k.data_layout, + k.alias, + filename.strip() + ) + kernel_src_map_lines.append(map_element) +with open(dest_path, 'w') as f: + logging.info("write kernel list to %s" % dest_path) + f.write('\n'.join(out_lines)) + +with open(kernelmap_path, 'w') as fd: + logging.info("write kernel map to %s" % dest_path) + kernel_src_map_lines.append(' {" ", " "}') + kernel_src_map_lines.append('};') + fd.write('\n'.join(kernel_src_map_lines)) diff --git a/lite/tools/cmake_tools/parse_kernel_registry.py b/lite/tools/cmake_tools/parse_kernel_registry.py index a0a123898bec18594ae12bfd1584cdd526cb1a33..f4f0b95483687d3785168c132d30ac8a4fa87c8e 100644 --- a/lite/tools/cmake_tools/parse_kernel_registry.py +++ b/lite/tools/cmake_tools/parse_kernel_registry.py @@ -14,65 +14,49 @@ import sys import logging +from ast import RegisterLiteKernelParser ops_list_path = sys.argv[1] dest_path = sys.argv[2] +minkernels_list_path = sys.argv[3] +tailored = sys.argv[4] out_lines = [ '#pragma once', '#include "paddle_lite_factory_helper.h"', '', ] - -left_pattern = 'REGISTER_LITE_KERNEL(' -right_pattern = ')' - -def find_right_pattern(context, start): - if start >= len(context): return -1 - fake_left_num = 0 - while start < len(context): - if context[start] == right_pattern: - if fake_left_num == 0: - return start - else: - fake_left_num -= 1 - elif context[start] == '(': - fake_left_num += 1 - start += 1 - return -1 - -lines = set() +minlines = set() +if tailored == "ON": + with open(minkernels_list_path) as fd: + for line in fd: + minlines.add(line.strip()) with open(ops_list_path) as f: - for line in f: - lines.add(line.strip()) - -for line in lines: - path = line.strip() - - status = '' - with open(path) as g: - context = ''.join([item.strip() for item in g]) - index = 0 - cxt_len = len(context) - while index < cxt_len and index >= 0: - left_index = context.find(left_pattern, index) - if left_index < 0: break - right_index = find_right_pattern(context, left_index+len(left_pattern)) - if right_index < 0: - raise ValueError("Left Pattern and Right Pattern does not match") - tmp = context[left_index+len(left_pattern) : right_index] - index = right_index + 1 - if tmp.startswith('/'): continue - fields = [item.strip() for item in tmp.split(',')] - if len(fields) < 6: - raise ValueError("Invalid REGISTER_LITE_KERNEL format") - - op, target, precision, layout = fields[:4] - alias = fields[-1] - key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( - op, target, precision, layout, alias) - out_lines.append(key) - + paths = set([path for path in f]) + for path in paths: + with open(path.strip()) as g: + c = g.read() + kernel_parser = RegisterLiteKernelParser(c) + kernel_parser.parse() + + for k in kernel_parser.kernels: + kernel = "%s, %s, %s, %s, %s" % ( + k.op_type, + k.target, + k.precision, + k.data_layout, + k.alias, + ) + if tailored == "ON": + if kernel not in minlines: continue + key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( + k.op_type, + k.target, + k.precision, + k.data_layout, + k.alias, + ) + out_lines.append(key) with open(dest_path, 'w') as f: logging.info("write kernel list to %s" % dest_path) diff --git a/lite/tools/cmake_tools/parse_op_registry.py b/lite/tools/cmake_tools/parse_op_registry.py index 6c936c899d1bd030cc7bf2c35bc8b1247608bfed..db58c455a9d5863ec0c66d7783871831c73c120f 100644 --- a/lite/tools/cmake_tools/parse_op_registry.py +++ b/lite/tools/cmake_tools/parse_op_registry.py @@ -15,34 +15,38 @@ import sys import logging +from ast import RegisterLiteOpParser ops_list_path = sys.argv[1] dest_path = sys.argv[2] - +minops_list_path = sys.argv[3] +tailored = sys.argv[4] out_lines = [ '#pragma once', '#include "paddle_lite_factory_helper.h"', '', ] -lines = set() -with open(ops_list_path) as f: - for line in f: - lines.add(line.strip()) - -for line in lines: - path = line.strip() +paths = set() +for line in open(ops_list_path): + paths.add(line.strip()) - with open(path) as g: - for line in g: - key = 'REGISTER_LITE_OP' - if line.startswith(key): - end = line.find(',') - op = line[len(key) + 1:end] - if not op: continue - if "_grad" in op: continue - out = "USE_LITE_OP(%s);" % op - out_lines.append(out) +if tailored == "ON": + minlines = set() + with open(minops_list_path) as fd: + for line in fd: + minlines.add(line.strip()) +for path in paths: + str_info = open(path.strip()).read() + op_parser = RegisterLiteOpParser(str_info) + ops = op_parser.parse() + for op in ops: + if "_grad" in op: + continue + if tailored == "ON": + if op not in minlines: continue + out = "USE_LITE_OP(%s);" % op + out_lines.append(out) with open(dest_path, 'w') as f: logging.info("write op list to %s" % dest_path) diff --git a/lite/tools/cmake_tools/utils.py b/lite/tools/cmake_tools/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..832ead301b6bb8d2d260e3867031582fd9b5330d --- /dev/null +++ b/lite/tools/cmake_tools/utils.py @@ -0,0 +1,18 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def gen_use_kernel_statement(op_type, target, precision, layout, alias): + return 'USE_LITE_KERNEL(%s, %s, %s, %s, %s);' %( + op_type, target, precision, layout, alias + ) diff --git a/lite/tools/convert_arm_sdot_to_machine_code.py b/lite/tools/convert_arm_sdot_to_machine_code.py new file mode 100644 index 0000000000000000000000000000000000000000..66dc387118b018a481287dc34dfe8e3292f4467f --- /dev/null +++ b/lite/tools/convert_arm_sdot_to_machine_code.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys +import os +import re + +def compute_sdot_vec_vec(vd, vn, vm): + i = 0x4e809400 | int(vd) | (int(vn) << 5) | (int(vm) << 16) + return '".word 0x{:08x}\\n"'.format(i) + \ + ' /* sdot v{vd}.4s, v{vn}.16b, v{vm}.16b */'.format( + vd=vd, vn=vn, vm=vm) + +def compute_sdot_vec_elem(vd, vn, vm, idx): + i = 0x4f80e000 | int(vd) | (int(vn) << 5) | (int(vm) << 16) | (int(idx % 2) << 21) | (int(idx / 2) << 11) + return '".word 0x{:08x}\\n"'.format(i) + \ + ' /* sdot v{vd}.4s, v{vn}.16b, v{vm}.4b[{idx}] */\\\r\n'.format( + vd=vd, vn=vn, vm=vm, idx=idx) + +def match_sdot_patten(line): + matched = re.search(r'sdot\s+v(.*?).4s\s*,\s*v(.*?).16b\s*,\s*v(.*?).4b\[(.*?)\].*', line, re.M|re.I) + if matched: + # print('matched:', matched.group(1), matched.group(2), matched.group(3), matched.group(4)) + vd = int(matched.group(1)) + vn = int(matched.group(2)) + vm = int(matched.group(3)) + idx = int(matched.group(4)) + return compute_sdot_vec_elem(vd, vn, vm, idx) + else: + return line + +def parser_file(file_in, file_out): + out = open(file_out, 'w') + if os.path.exists(file_in): + for line in open(file_in): + new_line = match_sdot_patten(line) + # print(new_line) + out.write(new_line) + else: + print('input file {} not exist'.format(file_in)) + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser('convert arm sdot to machine code') + arg_parser.add_argument('--input_file', type=str, required=True) + arg_parser.add_argument('--output_file', type=str, required=True) + args = arg_parser.parse_args() + + print('input file: ', args.input_file) + print('output file: ', args.output_file) + parser_file(args.input_file, args.output_file) diff --git a/lite/tools/debug/CMakeLists.txt b/lite/tools/debug/CMakeLists.txt index b26fd1545a439cd8faf7c4f6700b35bccb918e03..43c0812ab91f6ddcba02f93d2eea60f5a5268341 100644 --- a/lite/tools/debug/CMakeLists.txt +++ b/lite/tools/debug/CMakeLists.txt @@ -1,15 +1,19 @@ lite_cc_library(debug_utils SRCS debug_utils.cc DEPS op_params model_parser) -lite_cc_binary(lite_model_debug_tool SRCS model_debug_tool.cc +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK OR LITE_ON_MODEL_OPTIMIZE_TOOL) + lite_cc_binary(lite_model_debug_tool SRCS model_debug_tool.cc DEPS cxx_api debug_utils target_wrapper_host mir_passes gflags + logging ${ops} ${host_kernels} X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} FPGA_DEPS ${fpga_kernels} CL_DEPS ${opencl_kernels}) +endif() diff --git a/lite/tools/debug/model_debug_tool.cc b/lite/tools/debug/model_debug_tool.cc index a2ff37895cc16766e05dbcfeb71645a11564ec00..4b27db7a8d3a2dcf8237660b50631c71dcd4f4af 100644 --- a/lite/tools/debug/model_debug_tool.cc +++ b/lite/tools/debug/model_debug_tool.cc @@ -16,9 +16,6 @@ #include #include #include "lite/api/cxx_api.h" -#include "lite/api/paddle_use_kernels.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/api/paddle_use_passes.h" #include "lite/core/op_registry.h" #include "lite/model_parser/model_parser.h" #include "lite/model_parser/pb/program_desc.h" @@ -38,7 +35,6 @@ void Run(DebugConfig* conf) { #endif lite::Predictor predictor; std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, #ifdef LITE_WITH_ARM Place{TARGET(kARM), PRECISION(kFloat)}, #endif @@ -47,6 +43,9 @@ void Run(DebugConfig* conf) { #endif #ifdef LITE_WITH_FPGA Place{TARGET(kFPGA), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_CUDA + Place{TARGET(kCUDA), PRECISION(kFloat)}, #endif }); @@ -60,17 +59,7 @@ void Run(DebugConfig* conf) { "runtime_context_assign_pass", }}; - predictor.Build(conf->model_dir, - "", - "", -#ifdef LITE_WITH_ARM - Place{TARGET(kARM), PRECISION(kFloat)}, -#endif -#ifdef LITE_WITH_X86 - Place{TARGET(kX86), PRECISION(kFloat)}, -#endif - valid_places, - passes); + predictor.Build(conf->model_dir, "", "", valid_places, passes); predictor.GenRuntimeProgram(); auto& instructions = predictor.runtime_program().instructions(); diff --git a/lite/utils/CMakeLists.txt b/lite/utils/CMakeLists.txt index 7ab0c61b8f022e0f2a3c91a01dbe0d5730b51c62..6337085d829b115dc6d2553473ddcef8ac5115f8 100644 --- a/lite/utils/CMakeLists.txt +++ b/lite/utils/CMakeLists.txt @@ -3,23 +3,23 @@ # else() # endif() -if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK OR LITE_ON_MODEL_OPTIMIZE_TOOL) lite_cc_library(logging SRCS logging.cc) set(utils_DEPS logging) lite_cc_test(test_logging SRCS logging_test.cc DEPS ${utils_DEPS}) else() - set(utils_DEPS glog) -endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + set(utils_DEPS glog) +endif() lite_cc_test(test_varient SRCS varient_test.cc DEPS utils) lite_cc_library(any SRCS any.cc) -if(LITE_ON_TINY_PUBLISH) -lite_cc_library(stream SRCS replace_stl/stream.cc) +if(LITE_ON_TINY_PUBLISH OR LITE_ON_MODEL_OPTIMIZE_TOOL) + lite_cc_library(stream SRCS replace_stl/stream.cc) endif() #lite_cc_library(utils SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any) -if(LITE_ON_TINY_PUBLISH) +if(LITE_ON_TINY_PUBLISH OR LITE_ON_MODEL_OPTIMIZE_TOOL) lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any stream) else() lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any) diff --git a/lite/utils/all.h b/lite/utils/all.h index b8cffb9fd11e03e4e8a3ec09dd4724ed3452dcee..a0d323aa24b36dac7858f484eb1cf1d5a7bcba50 100644 --- a/lite/utils/all.h +++ b/lite/utils/all.h @@ -21,6 +21,7 @@ #include "lite/utils/hash.h" #include "lite/utils/io.h" #include "lite/utils/macros.h" +#include "lite/utils/string.h" #include "lite/utils/varient.h" #ifdef LITE_ON_TINY_PUBLISH diff --git a/lite/utils/any.h b/lite/utils/any.h index 00c652613d995b12f7efae8c0a6971e412d6a9a1..3f7029e98c161a7c47b6db4aeec9cb18490366f0 100644 --- a/lite/utils/any.h +++ b/lite/utils/any.h @@ -22,6 +22,14 @@ namespace lite { class Any { public: + Any() = default; + explicit Any(const Any& other) { + type_ = other.type_; + data_ = other.clone_data_(other.data_); + deleter_ = other.deleter_; + clone_data_ = other.clone_data_; + } + template void set(const T& v) { set(); @@ -34,7 +42,16 @@ class Any { CHECK(type_ == typeid(T).hash_code()); } else { type_ = typeid(T).hash_code(); - deleter_ = [&] { delete static_cast(data_); }; + deleter_ = [&](void** data) { + delete static_cast(*data); + *data = nullptr; + }; + clone_data_ = [&](void* data) { + T* res = new T; + CHECK(data) << "data pointer is nullptr"; + *res = *static_cast(data); + return res; + }; } data_ = new T; } @@ -52,19 +69,20 @@ class Any { return static_cast(data_); } - bool valid() const { return data_; } + bool valid() const { return (data_ != nullptr); } - // ~Any() { - // if (valid()) { - // deleter_(); - // } - // } + ~Any() { + if (valid()) { + deleter_(&data_); + } + } private: static size_t kInvalidType; size_t type_{kInvalidType}; void* data_{nullptr}; - std::function deleter_; + std::function deleter_; + std::function clone_data_; }; } // namespace lite diff --git a/lite/utils/cp_logging.h b/lite/utils/cp_logging.h index c756832a873ef4051fdf68ff903be6316315bd14..cc10bece471af7a99f3b271990dd13731c08b9f8 100644 --- a/lite/utils/cp_logging.h +++ b/lite/utils/cp_logging.h @@ -13,7 +13,8 @@ // limitations under the License. #pragma once -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#if defined(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || \ + defined(LITE_ON_MODEL_OPTIMIZE_TOOL) #include "lite/utils/logging.h" #else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include diff --git a/lite/utils/logging.cc b/lite/utils/logging.cc index 9a4cad34f74a6346b293cad6948f237bc1d09c75..6351be95acdb7311f7d5604d9af3cfe8945bc424 100644 --- a/lite/utils/logging.cc +++ b/lite/utils/logging.cc @@ -18,8 +18,10 @@ */ #include "lite/utils/logging.h" +#include -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#if defined(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || \ + defined(LITE_ON_MODEL_OPTIMIZE_TOOL) #ifndef LITE_SHUTDOWN_LOG namespace paddle { @@ -48,7 +50,7 @@ void gen_log(STL::ostream& log_stream_, << tv.tv_usec / 1000 << " "; if (len > kMaxLen) { - log_stream_ << "..." << file + len - kMaxLen << " " << func << ":" << lineno + log_stream_ << "..." << file + len - kMaxLen << ":" << lineno << " " << func << "] "; } else { log_stream_ << file << " " << func << ":" << lineno << "] "; diff --git a/lite/utils/logging.h b/lite/utils/logging.h index 8dbb7a9752fb5905168b9c6eb2280f6f025a7309..e85753ec301c62152ce484105d6c42ac1b69ab16 100644 --- a/lite/utils/logging.h +++ b/lite/utils/logging.h @@ -18,6 +18,9 @@ */ #pragma once +#ifndef _LOGGING_H_ +#define _LOGGING_H_ + #include #include #include @@ -81,7 +84,7 @@ void gen_log(STL::ostream& log_stream_, const char* func, int lineno, const char* level, - const int kMaxLen = 20); + const int kMaxLen = 40); // LogMessage class LogMessage { @@ -183,3 +186,4 @@ class VoidifyFatal : public Voidify { } // namespace lite } // namespace paddle +#endif diff --git a/lite/utils/paddle_enforce.h b/lite/utils/paddle_enforce.h index 8317f45a0c522e13d83ae76057918010c434438d..82534af996919ac69a8624e442f1af6a9abb2c07 100644 --- a/lite/utils/paddle_enforce.h +++ b/lite/utils/paddle_enforce.h @@ -35,5 +35,5 @@ CHECK_GT((a), (b)) << paddle::lite::string_format("" __VA_ARGS__); #ifndef PADDLE_THROW -#define PADDLE_THROW +#define PADDLE_THROW(...) printf("" __VA_ARGS__); #endif diff --git a/lite/utils/replace_stl/stream.cc b/lite/utils/replace_stl/stream.cc index e4867d16c09bfd8533cfe290bafffb948ceebcab..61999a79e3d9e997b23943e46a419577ee2de44c 100644 --- a/lite/utils/replace_stl/stream.cc +++ b/lite/utils/replace_stl/stream.cc @@ -32,12 +32,24 @@ ostream& ostream::operator<<(const char* obj) { return *this; } +template <> +ostream& ostream::operator<<(const char& obj) { + _data = _data + obj; + return *this; +} + template <> ostream& ostream::operator<<(const std::string& obj) { _data = _data + obj; return *this; } +template <> +ostream& ostream::operator<<(const int16_t& obj) { + ADD_DATA_AS_STRING(_data, obj); + return *this; +} + template <> ostream& ostream::operator<<(const int& obj) { ADD_DATA_AS_STRING(_data, obj); diff --git a/lite/utils/variant.h b/lite/utils/variant.h new file mode 100644 index 0000000000000000000000000000000000000000..146ea586e46db0f1145f7ca7a9c5b7bd7bfb432e --- /dev/null +++ b/lite/utils/variant.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// Because Boost 1.41.0's variadic templates has bug on nvcc, boost +// will disable variadic template support in NVCC mode. Define +// BOOST_NO_CXX11_VARIADIC_TEMPLATES on gcc/clang to generate same +// function symbols. For details, +// https://github.com/PaddlePaddle/Paddle/issues/3386 + +// some platform-independent defintion +#if defined(_WIN32) +#define UNUSED +#define __builtin_expect(EXP, C) (EXP) +#else +#define UNUSED __attribute__((unused)) +#endif + +#if !defined(_WIN32) +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition) +#endif + +#if !defined(_WIN32) +#define LIKELY(condition) __builtin_expect(static_cast(condition), 1) +#else +// there is no equivalent intrinsics in msvc. +#define LIKELY(condition) (condition) +#endif diff --git a/mobile/.gitignore b/mobile/.gitignore index 70d0b40927d434c6108a5845faf393b84aa40d34..336f08fa8a83780b790a4114182472caa62bbc53 100644 --- a/mobile/.gitignore +++ b/mobile/.gitignore @@ -101,3 +101,4 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources metal/paddle-mobile-demo/paddle-mobile-demo/Resources/images metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models metal/MobileNetDemo/MobileNetDemo/Resources +third_party/opencl/OpenCL-Headers diff --git a/mobile/CMakeLists.txt b/mobile/CMakeLists.txt index 00a53035a118c0ee76025ab50790672c02391ae3..1883da85739f15ada96fead77a02b72b3bcceb6a 100644 --- a/mobile/CMakeLists.txt +++ b/mobile/CMakeLists.txt @@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 3.0.0) if(IS_IOS) option(USE_OPENMP "build with openmp support" OFF) else() - option(USE_OPENMP "build with openmp support" ON) + option(USE_OPENMP "build with openmp support" OFF) endif() option(USE_EXCEPTION "build with exception" ON) option(WITH_LOGGING "print logging for debug" OFF) diff --git a/mobile/src/common/types.cpp b/mobile/src/common/types.cpp index b74d053f5a259078e3c6b96638b61f600392ace7..42a98450a3220bfee9bea4811a9b153ce8ac5b2f 100755 --- a/mobile/src/common/types.cpp +++ b/mobile/src/common/types.cpp @@ -132,6 +132,8 @@ const char *G_OP_TYPE_WHILE = "while"; const char *G_OP_TYPE_BEAM_SEARCH_DECODE = "beam_search_decode"; const char *G_OP_TYPE_FILL_CONSTAN_BATCH_SIZE_LIKE = "fill_constant_batch_size_like"; +const char *G_OP_TYPE_FUSION_INSTANCENORM_RELU = "fusion_instancenorm_relu"; +const char *G_OP_TYPE_PIXEL_SHUFFLE = "pixel_shuffle"; std::unordered_map< std::string, std::pair, std::vector>> @@ -155,6 +157,7 @@ std::unordered_map< {G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}}, {G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}}, {G_OP_TYPE_INSTANCENORM, {{"X"}, {"Out"}}}, + {G_OP_TYPE_FUSION_INSTANCENORM_RELU, {{"X"}, {"Out"}}}, {G_OP_TYPE_LRN, {{"X"}, {"Out"}}}, {G_OP_TYPE_CONCAT, {{"X"}, {"Out"}}}, {G_OP_TYPE_SPLIT, {{"X"}, {"Out"}}}, @@ -254,5 +257,6 @@ std::unordered_map< {G_OP_TYPE_BEAM_SEARCH_DECODE, {{"Ids", "Scores"}, {"SentenceIds", "SentenceScores"}}}, {G_OP_TYPE_FILL_CONSTAN_BATCH_SIZE_LIKE, {{"Input"}, {"Out"}}}, - {G_OP_TYPE_PAD2D, {{"X"}, {"Out"}}}}; + {G_OP_TYPE_PAD2D, {{"X"}, {"Out"}}}, + {G_OP_TYPE_PIXEL_SHUFFLE, {{"X"}, {"Out"}}}}; } // namespace paddle_mobile diff --git a/mobile/src/common/types.h b/mobile/src/common/types.h index f636cd6c0cc2f61f302d6ca84dccdde661916bd1..d876f3b116cbb397ffa8019b1a8d9a637606ec10 100644 --- a/mobile/src/common/types.h +++ b/mobile/src/common/types.h @@ -87,6 +87,11 @@ enum PMStatus { PMException = 0x09 /*!< throw exception. */ }; +enum PrePostType { + NONE_PRE_POST = 0, + UINT8_255 = 1, +}; + enum RoundType { ROUND_NEAREST_AWAY_ZERO = 0, ROUND_NEAREST_TOWARDS_ZERO = 1, @@ -143,6 +148,7 @@ struct PaddleMobileConfigInternal { MemoryOptimizationLevel memory_optimization_level = MemoryOptimizationWithoutFeeds; std::string model_obfuscate_key = ""; + PrePostType pre_post_type = NONE_PRE_POST; }; enum ARMArch { @@ -257,8 +263,8 @@ extern const char *G_OP_TYPE_PAD2D; extern const char *G_OP_TYPE_FUSION_DECONV_ADD_BN_RELU; extern const char *G_OP_TYPE_FUSION_DECONV_ADD_BN; extern const char *G_OP_TYPE_FUSION_DECONV_BN_RELU; - -extern const char *G_OP_TYPE_PAD2D; +extern const char *G_OP_TYPE_FUSION_INSTANCENORM_RELU; +extern const char *G_OP_TYPE_PIXEL_SHUFFLE; extern std::unordered_map< std::string, std::pair, std::vector>> diff --git a/mobile/src/fpga/V2/api.cpp b/mobile/src/fpga/V2/api.cpp index f1d19364f89cfa7118397ab7f33db66c3a78785d..f39d012e08c124feacbd72fa2879e60b352c2785 100644 --- a/mobile/src/fpga/V2/api.cpp +++ b/mobile/src/fpga/V2/api.cpp @@ -359,7 +359,7 @@ void expand_conv_arg(ConvArgs *arg) { if (((res_win % 2) != 0) && (res_win != 1)) { res_win = res_win - 1; } - PADDLE_MOBILE_ENFORCE(res_win >= 2, "window too bigger than fpga volume"); + // PADDLE_MOBILE_ENFORCE(res_win >= 2, "window too bigger than fpga volume"); res_fit = res_win; auto block_num = (output_width + res_fit - 1) / res_fit; @@ -885,7 +885,7 @@ void fill_dwconv_arg(struct DWconvArgs *arg, framework::Tensor *input, int padding_h, int padding_w, float *bias_ptr) { auto filter_ptr = filter->data(); auto input_ptr = input->data(); - auto output_ptr = out->mutable_data(); + auto output_ptr = out->data(); arg->sub_conv_num = 1; arg->relu_enabled = relu_enabled; // arg->output.activation.activation_type = activation_enable; diff --git a/mobile/src/fpga/V2/bias_scale.cpp b/mobile/src/fpga/V2/bias_scale.cpp index ca93fe17ca61e13022e7c86137f8ab3e4b6ed6d9..e04604c587dc6ba1ea29c310a502aba8f8b7b153 100644 --- a/mobile/src/fpga/V2/bias_scale.cpp +++ b/mobile/src/fpga/V2/bias_scale.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "fpga/V2/bias_scale.h" #include +#include #include "fpga/common/fpga_common.h" namespace paddle_mobile { @@ -55,10 +56,22 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) { *data_in = ptr_aligned; } +void fixed_scale_bias_new(void*data_in, int data_len) { + int* data_tmp = static_cast(data_in); + for (int idx = 0; idx < data_len/2; ++idx) { + float tmp = (static_cast(data_in))[idx]; + data_tmp[idx] = static_cast(round(tmp*pow(2.0, 23.0))); + tmp = (static_cast(data_in))[idx+data_len/2]; + data_tmp[idx+data_len/2] = static_cast(round(tmp*pow(2.0, 30.0))); + } + return; +} + void interleave(float **data_in, int num_after_alignment) { // num_after_alignment: number of bias after alignment float *ptr_uninterleaved = *data_in; + // fixed_scale_bias_new(ptr_uninterleaved, 2 * num_after_alignment); float *ptr_interleaved = (float *)fpga_malloc(2 * num_after_alignment * sizeof(float)); // NOLINT int num = num_after_alignment / 4; diff --git a/mobile/src/fpga/V2/pe.cpp b/mobile/src/fpga/V2/pe.cpp index cc9d8d20cd0d68bb16c851aea195b943a9dd18a9..aa150e0c6cecbdf278f3d776ebba4ec81ed003a1 100644 --- a/mobile/src/fpga/V2/pe.cpp +++ b/mobile/src/fpga/V2/pe.cpp @@ -79,7 +79,8 @@ using namespace std; // NOLINT #define REG_CONVERT_CMD 0x400 #define REG_CONVERT_SRC_ADDR 0x408 #define REG_CONVERT_DST_ADDR 0x410 -#define REG_CONVERT_LENGTH 0x418 +#define REG_CONVERT_RD_LENGTH 0x418 +#define REG_CONVERT_WR_LENGTH 0x420 /*resize*/ #define REG_RESIZE_CMD 0x600 @@ -693,7 +694,8 @@ int PerformBypass(const struct BypassArgs &args) { reg_writeq(output_scale, REG_SCALE_PARAMETER); reg_writeq(input_address_phy, REG_CONVERT_SRC_ADDR); reg_writeq(output_address_phy, REG_CONVERT_DST_ADDR); - reg_writeq(datalen, REG_CONVERT_LENGTH); + reg_writeq(datalen, REG_CONVERT_RD_LENGTH); + reg_writeq(datalen, REG_CONVERT_WR_LENGTH); reg_writeq(cmd, REG_CONVERT_CMD); DLOG << "before reg poll"; if (0 != fpga_regpoll(REG_INTERRUPT, INTERRUPT_BYPASS, PE_IRQ_TIMEOUT)) { diff --git a/mobile/src/fpga/common/driver.cpp b/mobile/src/fpga/common/driver.cpp index 71e3bf9746e600703fa43f2c8616f73806f6fd60..911704965aac3b6897b70dc60cb23fb4f3e59979 100644 --- a/mobile/src/fpga/common/driver.cpp +++ b/mobile/src/fpga/common/driver.cpp @@ -134,6 +134,7 @@ int fpga_regpoll(uint64_t reg, uint64_t val, int time) { uint64_t i = 0; /*timeout精确性待确认*/ int64_t timeout = time * 6; + usleep(1); for (i = 0; i < timeout; i++) { if (val == reg_readq(reg)) { diff --git a/mobile/src/framework/cl/cl_engine.h b/mobile/src/framework/cl/cl_engine.h index f5b1e3c2d22bc224ab0a0bc738b96f8c7ef28420..2e21dd9e395354d2bd5e35a648687a6116347caf 100644 --- a/mobile/src/framework/cl/cl_engine.h +++ b/mobile/src/framework/cl/cl_engine.h @@ -133,6 +133,18 @@ class CLEngine { free(max_work_item_sizes); return localWorkSizeInfo_; } + size_t GetKernelWorkSize(cl_kernel kernel) { + cl_int status; + size_t kernel_work_size = 0; + status = + clGetKernelWorkGroupInfo(kernel, devices_[0], CL_KERNEL_WORK_GROUP_SIZE, + sizeof(size_t), &kernel_work_size, NULL); + if (status != CL_SUCCESS) { + return 0; + } + DLOG << "kernel_work_size: " << kernel_work_size; + return kernel_work_size; + } std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith( cl_context context, std::string file_name) { @@ -188,8 +200,7 @@ class CLEngine { bool BuildProgram(cl_program program, const std::string &options = "") { cl_int status; - std::string path = options + " -cl-fast-relaxed-math -I " + - CLEngine::Instance()->GetCLPath() + "/cl_kernel"; + std::string path = options + " -cl-fast-relaxed-math"; status = clBuildProgram(program, 0, 0, path.c_str(), 0, 0); diff --git a/mobile/src/framework/cl/cl_helper.h b/mobile/src/framework/cl/cl_helper.h index f072edd82b38883bd7a8e48b538a8f50a848cec6..893456211d0429701b49d0f0be654beaad16e0e2 100644 --- a/mobile/src/framework/cl/cl_helper.h +++ b/mobile/src/framework/cl/cl_helper.h @@ -54,6 +54,9 @@ class CLHelper { CLLocalWorkSizeInfo LocalWorkSizeInfo() { return scope_->LocalWorkSizeInfo(); } + size_t KernelWorkSize(cl_kernel kernel) { + return scope_->KernelWorkSize(kernel); + } std::vector DefaultWorkSize(const CLImage &image) { // n c h w @@ -63,9 +66,9 @@ class CLHelper { auto h = image_dim[2]; auto w = image_dim[3]; auto image_width = image.ImageWidth(); - auto work_size_0 = image_width / w; - auto work_size_1 = w; - auto work_size_2 = n * h; + size_t work_size_0 = image_width / w; + size_t work_size_1 = w; + size_t work_size_2 = n * h; return {work_size_0, work_size_1, work_size_2}; } else if (image_dim.size() == 2) { auto h = image_dim[0]; @@ -74,9 +77,9 @@ class CLHelper { } else if (image_dim.size() == 1) { return {1, image.ImageWidth(), 1}; } else if (image_dim.size() == 3) { - int c = image_dim[0]; - int h = image_dim[1]; - int w = image_dim[2]; + size_t c = image_dim[0]; + size_t h = image_dim[1]; + size_t w = image_dim[2]; return {(c + 3) / 4, w, h}; } PADDLE_MOBILE_THROW_EXCEPTION(" not support this dim, need imp "); diff --git a/mobile/src/framework/cl/cl_image.cpp b/mobile/src/framework/cl/cl_image.cpp index 4f4b0d8883586e221b9178a104a7f295fab06f83..0d4cf87db0d34953936d107b6bb6c9adbd985560 100644 --- a/mobile/src/framework/cl/cl_image.cpp +++ b/mobile/src/framework/cl/cl_image.cpp @@ -119,8 +119,8 @@ void TensorToCLImage(Tensor *tensor, CLImage *cl_image, cl_context context, #ifdef PADDLE_MOBILE_DEBUG Print &operator<<(Print &printer, const CLImage &cl_image) { - int width = cl_image.ImageDims()[0]; - int height = cl_image.ImageDims()[1]; + size_t width = cl_image.ImageDims()[0]; + size_t height = cl_image.ImageDims()[1]; half_t *image_data = new half_t[height * width * 4]; cl_int err; diff --git a/mobile/src/framework/cl/cl_image.h b/mobile/src/framework/cl/cl_image.h index d92800b170e68e6a33d680e6ae8d197a09daa2d5..6e885adca886b62099946590d52941d8de2550f0 100644 --- a/mobile/src/framework/cl/cl_image.h +++ b/mobile/src/framework/cl/cl_image.h @@ -126,13 +126,16 @@ class CLImage { void InitEmptyImage(cl_context context, cl_command_queue command_queue, const DDim &dim) { + if (image_converter_ != nullptr) { + delete image_converter_; + } PADDLE_MOBILE_ENFORCE(tensor_data_ == nullptr, " empty image tensor data shouldn't have value"); // CLImageConverterFolder *folder_converter = new // CLImageConverterFolder(); CLImageConverterNormal *normal_converter = new CLImageConverterNormal(); - + PADDLE_MOBILE_ENFORCE(!shared_mem_, "do not init mem after shared .") DLOG << " to get image dims "; image_dims_ = normal_converter->InitImageDimInfoWith(dim); DLOG << " end get image dims " << image_dims_; @@ -146,39 +149,65 @@ class CLImage { initialized_ = true; DLOG << " end init cl image"; } - // create fake size cl_mem for mem share + /** + * create fake size cl_mem for mem share + */ void InitFakeSizeImage(cl_context context, cl_command_queue command_queue, - const DDim &need_dims, const DDim &real_dims) { + const DDim &need_dims, const DDim &real_image_dims) { PADDLE_MOBILE_ENFORCE(tensor_data_ == nullptr, " empty image tensor data shouldn't have value"); - + if (image_converter_ != nullptr) { + delete image_converter_; + } CLImageConverterNormal *normal_converter = new CLImageConverterNormal(); - - real_image_dims = normal_converter->InitImageDimInfoWith(real_dims); - real_tensor_dims = real_dims; - + // use real image dims to create mem + real_image_dims_ = real_image_dims; + InitCLImage(context, real_image_dims_[0], real_image_dims_[1], nullptr); + // cheat cl_image they got what they wanted image_dims_ = normal_converter->InitImageDimInfoWith(need_dims); - InitCLImage(context, image_dims_[0], image_dims_[1], nullptr); - + DLOG << "InitFakeSizeImage ... "; + DLOG << "real_image_dims: " << real_image_dims_; + DLOG << "image_dims_: " << image_dims_; + PADDLE_MOBILE_ENFORCE(real_image_dims_[0] >= image_dims_[0] && + real_image_dims_[1] >= image_dims_[1], + "real image is not enough"); tensor_dims_ = need_dims; command_queue_ = command_queue; image_converter_ = normal_converter; cl_event_ = CLEngine::Instance()->CreateEvent(context); initialized_ = true; - DLOG << " end init cl image"; - } + shared_mem_ = true; - void InitWithExitedMem(cl_context context, cl_command_queue command_queue, - DDim need_dims, const CLImage &src) { + DLOG << " end init FakeSizeImage"; + } + /** + * init cl mem with a exist cl mem + */ + void InitWithExistMem(cl_context context, cl_command_queue command_queue, + DDim need_dims, const CLImage &src) { + if (image_converter_ != nullptr) { + delete image_converter_; + } CLImageConverterNormal *normal_converter = new CLImageConverterNormal(); - real_image_dims = normal_converter->InitImageDimInfoWith(src.dims()); - real_tensor_dims = src.dims(); - + real_image_dims_ = src.real_image_dims_; image_dims_ = normal_converter->InitImageDimInfoWith(need_dims); - // InitCLImage(context, image_dims_[0], image_dims_[1], nullptr); + + DLOG << "InitWithExistMem ... "; + DLOG << "real_image_dims: " << real_image_dims_; + DLOG << "image_dims_: " << image_dims_; + + if (real_image_dims_[0] < image_dims_[0] || + real_image_dims_[1] < image_dims_[1]) { + DLOG << "real image is not enough!"; + DLOG << "real_image_dims: " << real_image_dims_; + DLOG << "image_dims_: " << image_dims_; + } + PADDLE_MOBILE_ENFORCE(real_image_dims_[0] >= image_dims_[0] && + real_image_dims_[1] >= image_dims_[1], + "real image is not enough!"); if (cl_image_ != src.cl_image_) { - cl_image_.reset(src.cl_image_.get()); + cl_image_.reset(src.cl_image_.get(), CLMemDeleter()); } tensor_dims_ = need_dims; @@ -186,7 +215,9 @@ class CLImage { image_converter_ = normal_converter; cl_event_ = CLEngine::Instance()->CreateEvent(context); initialized_ = true; - DLOG << " end init cl image"; + shared_mem_ = true; + + DLOG << " end init WithExistMem"; } void InitConv2dTransposeFilterCLImage(cl_context context, @@ -205,7 +236,7 @@ class CLImage { "Tensor holds no memory. Call Tensor::mutable_data first.") if (cl_image_ != src.cl_image_) { - cl_image_.reset(src.cl_image_.get()); + cl_image_.reset(src.cl_image_.get(), CLMemDeleter()); } return *this; } @@ -253,7 +284,10 @@ class CLImage { CLImageConverterBase *Converter() const { return image_converter_; } private: - void InitCLImage(cl_context context, int width, int height, void *data) { + void InitCLImage(cl_context context, size_t width, size_t height, + void *data) { + PADDLE_MOBILE_ENFORCE(!shared_mem_, "do not init mem after shared .") + cl_image_format cf = {.image_channel_order = CL_RGBA, .image_channel_data_type = CL_HALF_FLOAT}; cl_image_desc cid = { @@ -276,7 +310,7 @@ class CLImage { &cid, // const cl_image_desc *image_desc data, // void *host_ptr &err); - cl_image_.reset(cl_image); + cl_image_.reset(cl_image, CLMemDeleter()); if (err != CL_SUCCESS) { CL_CHECK_ERRORS(err); PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error "); @@ -284,18 +318,17 @@ class CLImage { } bool initialized_ = false; - std::unique_ptr<_cl_mem, CLMemDeleter> cl_image_; + std::shared_ptr<_cl_mem> cl_image_; std::unique_ptr<_cl_event, CLEventDeleter> cl_event_; DDim tensor_dims_; DDim image_dims_; // real image dims usually it is same as image_dims - DDim real_image_dims; - // real tensor dims usually it is same as tensor dims - DDim real_tensor_dims; + DDim real_image_dims_; float *tensor_data_ = nullptr; cl_context context_; cl_command_queue command_queue_; CLImageConverterBase *image_converter_ = nullptr; + bool shared_mem_ = false; }; void TensorToCLImage(Tensor *tensor, CLImage *image, cl_context context, diff --git a/mobile/src/framework/cl/cl_scope.h b/mobile/src/framework/cl/cl_scope.h index ebe16b553ab8442986ab02d53328d6c7cda233ef..643ce32b57616305da0c581d6d50dfcbbc4f1b1d 100644 --- a/mobile/src/framework/cl/cl_scope.h +++ b/mobile/src/framework/cl/cl_scope.h @@ -110,6 +110,10 @@ class CLScope { } CLLocalWorkSizeInfo LocalWorkSizeInfo() { return localWorkSizeInfo_; } + size_t KernelWorkSize(cl_kernel kernel) { + size_t kernel_work_size = CLEngine::Instance()->GetKernelWorkSize(kernel); + return kernel_work_size; + } private: cl_int status_; diff --git a/mobile/src/framework/context.cpp b/mobile/src/framework/context.cpp index 36538ef50eacd6edebcab241a2dd22604bf04ae3..10f1572d030c50a2efaaf58654573ee1a3c40b3a 100644 --- a/mobile/src/framework/context.cpp +++ b/mobile/src/framework/context.cpp @@ -63,12 +63,19 @@ void fill_cpu_cache_size(std::vector *cpu_cache_sizes, int value, int num = cpu_ids.size(); if (num > 0) { for (int i = 0; i < num; i++) { - (*cpu_cache_sizes)[cpu_ids[i]] = value; + if (cpu_ids.size() > i) { + int idx = cpu_ids[i]; + if (cpu_cache_sizes->size() > idx) { + (*cpu_cache_sizes)[idx] = value; + } + } } } else { num = cpu_cache_sizes->size(); for (int i = 0; i < num; i++) { - (*cpu_cache_sizes)[i] = value; + if (cpu_cache_sizes->size() > i) { + (*cpu_cache_sizes)[i] = value; + } } } } @@ -248,9 +255,9 @@ int set_sched_affinity(const std::vector &cpu_ids) { // cpu_set_t definition // ref http://stackoverflow.com/questions/16319725/android-set-thread-affinity #define CPU_SETSIZE 1024 -#define __NCPUBITS (8 * sizeof(unsigned long)) +#define __NCPUBITS (8 * sizeof(unsigned long)) // NOLINT typedef struct { - unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; + unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; // NOLINT } cpu_set_t; #define CPU_SET(cpu, cpusetp) \ @@ -477,6 +484,10 @@ CPUContext::CPUContext() { } LOG(kLOG_INFO) << "CPU num: " << _cpu_num; for (int i = 0; i < _cpu_num; i++) { + if (!(_l1_cache_sizes.size() > i && _l2_cache_sizes.size() > i && + _l3_cache_sizes.size() > i)) { + break; + } LOG(kLOG_INFO) << i << " L1 Cache: " << _l1_cache_sizes[i] << "KB" << " L2 Cache: " << _l2_cache_sizes[i] << "KB" << " L3 Cache: " << _l3_cache_sizes[i] << "KB"; @@ -563,12 +574,25 @@ int CPUContext::get_cache_size(int level) { return 0; } if (_power_mode == PERFORMANCE_PRIORITY || _power_mode == PERFORMANCE_ONLY) { - return (*ptr)[_big_core_ids[0]]; + if (_big_core_ids.size() > 0) { + int idx = _big_core_ids[0]; + if (ptr->size() > idx) { + return (*ptr)[idx]; + } + } } else if (_power_mode == EFFICIENCY_PRIORITY || _power_mode == EFFICIENCY_ONLY) { - return (*ptr)[_little_core_ids[0]]; + if (_little_core_ids.size() > 0) { + int idx = _little_core_ids[0]; + if (ptr->size() > idx) { + return (*ptr)[idx]; + } + } } else { // AUTO - return (*ptr)[0]; + int idx = 0; + if (ptr->size() > idx) { + return (*ptr)[idx]; + } } } diff --git a/mobile/src/framework/ddim.cpp b/mobile/src/framework/ddim.cpp index 6da08bf88ea9ed04b21213b921f66002b7a78b66..4f68caad77c60e8f4a2312291e6600290860b102 100644 --- a/mobile/src/framework/ddim.cpp +++ b/mobile/src/framework/ddim.cpp @@ -27,7 +27,7 @@ Dim make_dim(const int64_t *d) { template <> Dim<0> make_dim<0>(const int64_t *d) { - return Dim<0>(*d); + return Dim<0>(0); } void make_ddim(DDim &ddim, const int64_t *dims, int n) { diff --git a/mobile/src/framework/executor.cpp b/mobile/src/framework/executor.cpp index e35c7ebe40993ba8a9babe265d720f652015478a..743dea76aef58a582810a50c3f646a8875d1cacc 100644 --- a/mobile/src/framework/executor.cpp +++ b/mobile/src/framework/executor.cpp @@ -33,7 +33,7 @@ limitations under the License. */ #include "pass/model_obfuscate.h" #ifdef PADDLE_MOBILE_CL #include "framework/cl/cl_image.h" -#include "pass/memory_optimize_super.h" +#include "pass/memory_optimize_cl.h" #endif namespace paddle_mobile { @@ -102,9 +102,9 @@ Executor::Executor(const Program &program, } int count = 0; #ifdef PADDLE_MOBILE_PROFILE - std::vector profile(ops_of_block0_.size()); - struct timespec ts; - int op_index = 0; + std::vector profile(ops_of_block0_.size()); + struct timespec ts; + int op_index = 0; #endif for (auto &op_handler : ops_of_block0_) { #ifdef PADDLE_MOBILE_PROFILE @@ -112,18 +112,38 @@ Executor::Executor(const Program &program, profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; #endif DLOG << "Initialize op[" << count++ << "]: " << op_handler->Type(); + if (op_handler->Type() == "feed" || op_handler->Type() == "fetch") { + op_handler->setPrePostType(config_.pre_post_type); + } op_handler->Init(); #ifdef PADDLE_MOBILE_PROFILE clock_gettime(CLOCK_MONOTONIC, &ts); - profile[op_index].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; - ++op_index; + profile[op_index].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; + ++op_index; #endif } #ifdef PADDLE_MOBILE_PROFILE printf("================[ op init profile ]==================\n"); PrintProfile(profile); #endif + ApplyMemoryOptimise(config, lod_mode); +} + +template +void Executor::ApplyMemoryOptimise( + const PaddleMobileConfigInternal &config, const bool lod_mode) const {} + +#ifdef PADDLE_MOBILE_CL +template <> +void Executor::ApplyMemoryOptimise( + const PaddleMobileConfigInternal &config, const bool lod_mode) const { + if (!config.load_when_predict && !lod_mode && + config_.memory_optimization_level != NoMemoryOptimization) { + pass::MemoryOptPassCl()(program_desc_.get(), program_.scope.get(), + config_.memory_optimization_level); + } } +#endif template void Executor::InitFeedFetchList() { @@ -153,24 +173,35 @@ void Executor::InitFeedFetchList() { } template -static void LoadMemInternal(void **data, LoDTensor *tensor, - bool quant_uint8 = false) { - char **data_buf = reinterpret_cast(data); - int64_t size = tensor->numel(); - T *tensor_data = tensor->mutable_data(); +static void LoadMemInternal(void **in_data, void *out_data, int64_t size, + bool quant_uint8 = false, int quant_fold = 1) { + char **data_buf = reinterpret_cast(in_data); + T *tensor_data = reinterpret_cast(out_data); if (quant_uint8) { - // should be moved into operator init function - float min_value; - float max_value; - memory::Copy(&min_value, *data_buf, sizeof(float)); - memory::Copy(&max_value, *data_buf + sizeof(float), sizeof(float)); - *data_buf += 2 * sizeof(float); - const float factor = (max_value - min_value) / 255.0; - const uint8_t *uint8_data = reinterpret_cast(*data_buf); - for (int k = 0; k < size; ++k) { - tensor_data[k] = uint8_data[k] * factor + min_value; + const int minimal_fold_size = 2; + quant_fold = fmin(fmax(1, size / minimal_fold_size), quant_fold); + int step = fmax(size / quant_fold, 1); + int visited_fold = 0; + while (visited_fold * step < size) { + // should be moved into operator init function + float min_value; + float max_value; + memory::Copy(&min_value, *data_buf, sizeof(float)); + memory::Copy(&max_value, *data_buf + sizeof(float), sizeof(float)); + *data_buf += 2 * sizeof(float); + const float factor = (max_value - min_value) / 255.0; + const uint8_t *uint8_data = reinterpret_cast(*data_buf); + int k = 0; + for (; k < step; ++k) { + int tensor_data_idx = visited_fold * step + k; + if (tensor_data_idx >= size) { + break; + } + tensor_data[tensor_data_idx] = uint8_data[k] * factor + min_value; + } + *data_buf += k * sizeof(uint8_t); + visited_fold++; } - *data_buf += size * sizeof(uint8_t); } else { memory::Copy(tensor_data, *data_buf, size * sizeof(T)); *data_buf += size * sizeof(T); @@ -215,14 +246,20 @@ void Executor::LoadMemory(void **data, // parse tensor from stream switch (tensor_desc.DataType()) { case VARTYPE_TYPE_FP32: - LoadMemInternal(reinterpret_cast(data_buf), tensor, - program_.quantification); + LoadMemInternal( + reinterpret_cast(data_buf), + reinterpret_cast(tensor->mutable_data()), tensor->numel(), + program_.quantification, program_.quantification_fold); break; case VARTYPE_TYPE_INT8: - LoadMemInternal(reinterpret_cast(data_buf), tensor); + LoadMemInternal( + reinterpret_cast(data_buf), + reinterpret_cast(tensor->mutable_data()), tensor->numel()); break; case VARTYPE_TYPE_INT32: - LoadMemInternal(reinterpret_cast(data_buf), tensor); + LoadMemInternal(reinterpret_cast(data_buf), + reinterpret_cast(tensor->mutable_data()), + tensor->numel()); break; default: LOG(kLOG_ERROR) << "data type is not supported"; @@ -850,10 +887,13 @@ void Executor::SetInput(const Tensor &input, DLOG << "SetInput ---- > resize1"; input_tensor->Resize(input.dims()); input_tensor->mutable_data(); - // InitNoPersistableMemory(*input_tensor); - pass::MemoryOptPassSuper()(program_desc_.get(), program_.scope.get(), - config_.memory_optimization_level, - input.dims()); + if (config_.memory_optimization_level == NoMemoryOptimization) { + InitNoPersistableMemory(*input_tensor); + } else { + pass::MemoryOptPassCl()(program_desc_.get(), program_.scope.get(), + config_.memory_optimization_level, + input.dims()); + } } } else { DLOG << "SetInput ---- > resize2"; @@ -921,31 +961,10 @@ void Executor::LoadMemory(const VarDesc var_desc, void *memory = nullptr; int type_size = 4; memory = tensorInput; - if (program_.quantification) { - float min_value; - float max_value; - - memcpy(&min_value, *data, sizeof(float)); - memcpy(&max_value, *data + sizeof(float), sizeof(float)); - *data += 2 * sizeof(float); - const float factor = (max_value - min_value) / 255.0; - uint8_t *uint8_data = reinterpret_cast(*data); - for (int k = 0; k < memory_size; ++k) { - static_cast(memory)[k] = uint8_data[k] * factor + min_value; - } - *data += (memory_size * sizeof(uint8_t)); - } else { - for (int n = 0; n < memory_size; n++) { - float value; - memcpy(&value, *data + n * type_size, type_size); - if (value < 1e-30 && value > -1e-30) { - static_cast(memory)[n] = 0.0; - } else { - static_cast(memory)[n] = value; - } - } - (*data) += (sizeof(char) * memory_size * type_size); - } + + LoadMemInternal(reinterpret_cast(data), + reinterpret_cast(memory), memory_size, + program_.quantification, program_.quantification_fold); } template <> diff --git a/mobile/src/framework/executor.h b/mobile/src/framework/executor.h index 4f108c993c0ff9bda94b11cdebc3cb13af41be03..ebb16f697b39391cd5f405c565285c1bd37dfad5 100644 --- a/mobile/src/framework/executor.h +++ b/mobile/src/framework/executor.h @@ -118,6 +118,8 @@ class Executor { void PrintProfile(const vector::ProfInfo> &profile) const; #endif + void ApplyMemoryOptimise(const PaddleMobileConfigInternal &config, + const bool lod_mode) const; }; } // namespace framework diff --git a/mobile/src/framework/load_ops.h b/mobile/src/framework/load_ops.h index ed30a45114a484f30cab9f70472e1b0aa7082e29..b871d2af140730850dfac0fd43383e48012c9ef0 100755 --- a/mobile/src/framework/load_ops.h +++ b/mobile/src/framework/load_ops.h @@ -354,7 +354,7 @@ LOAD_OP1(pad2d, CPU); LOAD_OP1(one_hot, CPU); #endif #ifdef ASSIGN_VALUE_OP -LOAD_OP1(assign_value, CPU); +LOAD_OP2(assign_value, CPU, GPU_CL); #endif #ifdef EXP_OP LOAD_OP1(exp, CPU); @@ -377,3 +377,6 @@ LOAD_OP1(range, CPU); #ifdef REDUCE_PROD_OP LOAD_OP1(reduce_prod, CPU); #endif +#ifdef PIXEL_SHUFFLE_OP +LOAD_OP1(pixel_shuffle, GPU_CL); +#endif diff --git a/mobile/src/framework/loader.cpp b/mobile/src/framework/loader.cpp index 4350fda969a01f7d672f5aebdc9a77390e175b9b..34cf6253cb4571c3b52fe61161cba3e140eb0110 100644 --- a/mobile/src/framework/loader.cpp +++ b/mobile/src/framework/loader.cpp @@ -87,7 +87,8 @@ void Loader::InitMemoryFromProgram( template <> const Program Loader::LoadCombinedMemory( size_t read_size, const uint8_t *buf, size_t combined_params_len, - uint8_t *combined_params_buf, bool optimize, bool quantification) { + uint8_t *combined_params_buf, bool optimize, bool quantification, + int quantification_fold) { bool can_add_split = false; PaddleMobile__Framework__Proto__ProgramDesc *c_program; @@ -109,6 +110,7 @@ const Program Loader::LoadCombinedMemory( program.quantification = quantification; program.combined_params_len = combined_params_len; program.combined_params_buf = combined_params_buf; + program.quantification_fold = quantification_fold; auto scope = std::make_shared(); program.scope = scope; @@ -187,9 +189,11 @@ template const Program Loader::Load(const std::string &dirname, bool optimize, bool quantification, - bool can_add_split) { - auto program = this->LoadProgram(dirname + "/__model__", optimize, - quantification, can_add_split); + bool can_add_split, + int quantification_fold) { + auto program = + this->LoadProgram(dirname + "/__model__", optimize, quantification, + can_add_split, quantification_fold); program.model_path = dirname; return program; } @@ -198,8 +202,10 @@ template const Program Loader::Load(const std::string &model_path, const std::string ¶_path, bool optimize, - bool quantification) { - auto program = this->LoadProgram(model_path, optimize, quantification); + bool quantification, + int quantification_fold) { + auto program = this->LoadProgram(model_path, optimize, quantification, false, + quantification_fold); program.para_path = para_path; program.combined = true; @@ -210,7 +216,7 @@ const Program Loader::Load(const std::string &model_path, template const Program Loader::LoadProgram( const std::string &model_path, bool optimize, bool quantification, - bool can_add_split) { + bool can_add_split, int quantification_fold) { std::string model_filename = model_path; PaddleMobile__Framework__Proto__ProgramDesc *c_program; uint8_t *buf = NULL; @@ -232,6 +238,7 @@ const Program Loader::LoadProgram( program.quantification = quantification; program.combined_params_len = 0; program.combined_params_buf = nullptr; + program.quantification_fold = quantification_fold; auto scope = std::make_shared(); program.scope = scope; @@ -248,7 +255,8 @@ const Program Loader::LoadProgram( template const Program Loader::LoadCombinedMemory( size_t read_size, const uint8_t *buf, size_t combined_params_len, - uint8_t *combined_params_buf, bool optimize, bool quantification) { + uint8_t *combined_params_buf, bool optimize, bool quantification, + int quantification_fold) { bool can_add_split = false; PaddleMobile__Framework__Proto__ProgramDesc *c_program; @@ -270,6 +278,7 @@ const Program Loader::LoadCombinedMemory( program.quantification = quantification; program.combined_params_len = combined_params_len; program.combined_params_buf = combined_params_buf; + program.quantification_fold = quantification_fold; auto scope = std::make_shared(); program.scope = scope; diff --git a/mobile/src/framework/loader.h b/mobile/src/framework/loader.h index bd4dfa15565dbb8e9afce769b12fe23eb7a1a970..40ded643d53396d1ba4f7964629b1580550b1895 100644 --- a/mobile/src/framework/loader.h +++ b/mobile/src/framework/loader.h @@ -32,7 +32,8 @@ class Loader { const Program Load(const std::string &dirname, bool optimize = false, bool quantification = false, - bool can_add_split = false); + bool can_add_split = false, + int quantification_fold = 1); /* * @b load combine format fluid mode @@ -41,20 +42,20 @@ class Loader { const Program Load(const std::string &model_path, const std::string ¶_path, bool optimize = false, - bool quantification = false); + bool quantification = false, + int quantification_fold = 1); - const Program LoadCombinedMemory(size_t model_len, - const uint8_t *model_buf, - size_t combined_params_len, - uint8_t *combined_params_buf, - bool optimize = false, - bool quantification = false); + const Program LoadCombinedMemory( + size_t model_len, const uint8_t *model_buf, size_t combined_params_len, + uint8_t *combined_params_buf, bool optimize = false, + bool quantification = false, int quantification_fold = 1); private: const Program LoadProgram(const std::string &model_path, bool optimize = false, bool quantification = false, - bool can_add_split = false); + bool can_add_split = false, + int quantification_fold = 1); void InitMemoryFromProgram( const std::shared_ptr &originProgramDesc, diff --git a/mobile/src/framework/operator.h b/mobile/src/framework/operator.h index c8b3a5ccf796c5cce33efb697cd6d6aef8b9d21c..baffba97c25be306970785e83bfa2d0c911dfe52 100644 --- a/mobile/src/framework/operator.h +++ b/mobile/src/framework/operator.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -73,6 +74,7 @@ class OperatorBase { const VariableNameMap &Outputs() const { return outputs_; } const std::string &Type() const { return type_; } const AttributeMap &Attrs() const { return attrs_; } + void setPrePostType(int prePostType) { pre_post_type_ = prePostType; } void ClearVariables(const std::vector &var_names) const { if (this->scope_) { @@ -89,6 +91,7 @@ class OperatorBase { VariableNameMap inputs_; VariableNameMap outputs_; AttributeMap attrs_; + int pre_post_type_ = 0; private: void CheckAllInputOutputSet() const; @@ -111,6 +114,9 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape() const = 0; void Init() { + if (this->pre_post_type_ != NONE_PRE_POST) { + kernel_.setPrePostType(this->pre_post_type_); + } PADDLE_MOBILE_ENFORCE(kernel_.Init(¶m_), " %s kernel init failed", this->type_.c_str()); } @@ -134,11 +140,13 @@ class OpKernelBase { virtual void Compute(const P ¶) = 0; virtual bool Init(P *para) { return true; } virtual ~OpKernelBase() = default; + virtual void setPrePostType(int prePostType) { pre_post_type_ = prePostType; } protected: #ifdef PADDLE_MOBILE_CL CLHelper cl_helper_; #endif + int pre_post_type_ = 0; private: }; diff --git a/mobile/src/framework/program/program.h b/mobile/src/framework/program/program.h index f05aba8565202557f7a26f4640fbbaa622e37f7f..b6d1d96279a517056ccfda1b358625aa7c4987f5 100644 --- a/mobile/src/framework/program/program.h +++ b/mobile/src/framework/program/program.h @@ -34,6 +34,7 @@ class Program { bool quantification = false; size_t combined_params_len; uint8_t *combined_params_buf; + int quantification_fold = 1; }; } // namespace framework diff --git a/mobile/src/framework/tensor_base.h b/mobile/src/framework/tensor_base.h index a7f4aa1b8acadb8cd15676b3584b431b00d383a3..97135bda3960a6a9714141c359980156ffd5d968 100644 --- a/mobile/src/framework/tensor_base.h +++ b/mobile/src/framework/tensor_base.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "common/enforce.h" #include "common/type_define.h" #include "common/types.h" @@ -55,8 +56,8 @@ struct SizeOfTypeFunctor { }; static inline size_t SizeOfType(const kTypeId_t type) { - SizeOfTypeFunctor + SizeOfTypeFunctor functor; size_t size = functor(type); diff --git a/mobile/src/io/api.cc b/mobile/src/io/api.cc index 0e254aa15ac06083038773d89c23d40242847782..b9e7421b54bc4f0e092a6c743d39a81def48b09c 100644 --- a/mobile/src/io/api.cc +++ b/mobile/src/io/api.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "common/type_define.h" #include "cstring" #include "io/paddle_inference_api.h" diff --git a/mobile/src/io/api_paddle_mobile.cc b/mobile/src/io/api_paddle_mobile.cc index fd77941823d55347b6a86a545cc703fd2dfaf787..8bfc91998f600726c1bcf8fe932372928928e334 100644 --- a/mobile/src/io/api_paddle_mobile.cc +++ b/mobile/src/io/api_paddle_mobile.cc @@ -18,6 +18,7 @@ #include #include #include "common/enforce.h" +#include "common/type_define.h" #include "framework/tensor.h" #ifdef PADDLE_MOBILE_FPGA #include @@ -35,7 +36,16 @@ PaddleMobilePredictor::PaddleMobilePredictor( template bool PaddleMobilePredictor::Init(const PaddleMobileConfig &config) { - paddle_mobile_.reset(new PaddleMobile()); + PaddleMobileConfigInternal configInternal; + configInternal.load_when_predict = config.load_when_predict; + if (config.pre_post_type == PaddleMobileConfig::UINT8_255) { + configInternal.pre_post_type = PrePostType::UINT8_255; + } + + configInternal.memory_optimization_level = + config.mem_opt ? MemoryOptimizationWithoutFeeds : NoMemoryOptimization; + + paddle_mobile_.reset(new PaddleMobile(configInternal)); #ifdef PADDLE_MOBILE_CL paddle_mobile_->SetCLPath(config.cl_path); #endif @@ -83,26 +93,37 @@ bool PaddleMobilePredictor::Run( // use tensor framework::DDim ddim = framework::make_ddim(dims); - - framework::Tensor input_tensor; - framework::LoDTensor input_lod_tensor; - paddle_mobile::framework::LoD lod{{}}; - for (int i = 0; i < input.lod.size(); ++i) { - lod[0].push_back(input.lod[i]); - } - input_lod_tensor.set_lod(lod); - int input_length = framework::product(ddim); if (input.lod.size() > 0) { + framework::LoDTensor input_lod_tensor; + paddle_mobile::framework::LoD lod{{}}; + for (int i = 0; i < input.lod.size(); ++i) { + lod[0].push_back(input.lod[i]); + } + input_lod_tensor.set_lod(lod); input_lod_tensor.Resize(ddim); - memcpy(input_lod_tensor.mutable_data(), - static_cast(input.data.data()), input_length * sizeof(T)); + if (input.dtype == UINT8) { + memcpy(input_lod_tensor.mutable_data(), + static_cast(input.data.data()), + input_length * sizeof(uint8_t)); + } else { + memcpy(input_lod_tensor.mutable_data(), + static_cast(input.data.data()), input_length * sizeof(T)); + } paddle_mobile_->Predict(input_lod_tensor); } else { - input_tensor.Resize(ddim); - memcpy(input_tensor.mutable_data(), static_cast(input.data.data()), - input_length * sizeof(T)); - paddle_mobile_->Predict(input_tensor); + if (input.dtype == UINT8) { + framework::Tensor input_tensor(static_cast(input.data.data()), + ddim); + if (paddle_mobile_->Predict(input_tensor) != PMStatus::PMSuccess) { + return false; + } + } else { + framework::Tensor input_tensor(static_cast(input.data.data()), ddim); + if (paddle_mobile_->Predict(input_tensor) != PMStatus::PMSuccess) { + return false; + } + } } auto output_tensor = paddle_mobile_->Fetch(); @@ -121,28 +142,42 @@ bool PaddleMobilePredictor::Run( output.shape.push_back(static_cast(d)); } - if (output.data.length() < output_length * sizeof(T)) { - output.data.Resize(output_length * sizeof(T)); - } + if (output.dtype == UINT8) { + if (output.data.length() < output_length * sizeof(uint8_t)) { + output.data.Resize(output_length * sizeof(uint8_t)); + } + + memcpy(output.data.data(), output_tensor->template data(), + output_length * sizeof(uint8_t)); + } else { + if (output.data.length() < output_length * sizeof(T)) { + output.data.Resize(output_length * sizeof(T)); + } - memcpy(output.data.data(), output_tensor->template data(), - output_length * sizeof(T)); + memcpy(output.data.data(), output_tensor->template data(), + output_length * sizeof(T)); + } return true; } +template +std::string PaddleMobilePredictor::GetExceptionMsg() { + return paddle_mobile_->GetExceptionMsg(); +} + #ifdef PADDLE_MOBILE_FPGA void ConvertPaddleTensors(const PaddleTensor &src, framework::Tensor *des) { des->Resize(framework::make_ddim(src.shape)); des->external_data = src.data.data(); - des->set_type(src.dtypeid); + des->set_type(static_cast(static_cast(src.dtypeid))); des->layout = src.layout == LAYOUT_HWC ? framework::LAYOUT_HWC : framework::LAYOUT_CHW; } void ConvertTensors(const framework::Tensor &src, PaddleTensor *des) { des->shape = framework::vectorize2int(src.dims()); - des->dtypeid = src.type(); + des->dtypeid = static_cast(static_cast(src.type())); des->layout = src.layout == framework::LAYOUT_HWC ? LAYOUT_HWC : LAYOUT_CHW; auto num = src.numel(); @@ -164,7 +199,8 @@ void PaddleMobilePredictor::FeedPaddleTensors( auto num = inputs.size(); std::vector tensors(num, framework::Tensor()); for (int i = 0; i < num; i++) { - if (inputs[i].dtypeid == type_id().hash_code()) { + if (static_cast(static_cast(inputs[i].dtypeid)) == + type_id().hash_code()) { tensors[i].init(type_id().hash_code()); } else { tensors[i].init(type_id().hash_code()); diff --git a/mobile/src/io/api_paddle_mobile.h b/mobile/src/io/api_paddle_mobile.h index 11c993b3f879455eb1ae5268e3d9c2fcbcfc0bc1..63718acd990de664bc06f1af973755aa4336a184 100644 --- a/mobile/src/io/api_paddle_mobile.h +++ b/mobile/src/io/api_paddle_mobile.h @@ -32,6 +32,7 @@ class PaddleMobilePredictor : public PaddlePredictor { bool Run(const std::vector& inputs, std::vector* output_data, int batch_size = -1) override; + std::string GetExceptionMsg(); #ifdef PADDLE_MOBILE_FPGA void Predict_From_To(int start, int end) override; void FeedPaddleTensors(const std::vector& inputs) override; diff --git a/mobile/src/io/ios_io/PaddleMobileCPU.h b/mobile/src/io/ios_io/PaddleMobileCPU.h index 0536f513aa00a26478c16820b20af5100f3ebc62..07e10c0671bbcf8136ccadf8b019d3f2a10ca22f 100644 --- a/mobile/src/io/ios_io/PaddleMobileCPU.h +++ b/mobile/src/io/ios_io/PaddleMobileCPU.h @@ -139,6 +139,18 @@ */ - (PaddleMobileCPUResult *)predict:(CGImageRef)image dim:(NSArray *)dim means:(NSArray *)means scale:(float)scale; +/** + @b 进行预测, means stds和 scale 为训练模型时的预处理参数, 如训练时没有做这些预处理则直接使用 predict, 每一个像素经过这样的预处理 (x + means) * scale, 其中 x 为像素值 + + @param image 输入图像 + @param dim 输入维度 + @param means 预处理中 means + @param stds 预处理中 stds + @param scale 预处理中 scale + @return 预测结果 + */ +- (PaddleMobileCPUResult *)predict:(CGImageRef)image dim:(NSArray *)dim means:(NSArray *)means stds:(NSArray *)stds scale:(float)scale; + /** @b 进行预测, 预处理 means 值为 0, scale 值为 1 diff --git a/mobile/src/io/ios_io/PaddleMobileCPU.mm b/mobile/src/io/ios_io/PaddleMobileCPU.mm index f3a804e713c1e3caa5d806ceeca5b3b2d52ebce3..b952ad8e601fd4e00eb98ac398a8fad40045b7fd 100644 --- a/mobile/src/io/ios_io/PaddleMobileCPU.mm +++ b/mobile/src/io/ios_io/PaddleMobileCPU.mm @@ -181,25 +181,22 @@ static std::mutex shared_mutex; for (int x = 0; x < wanted_input_width; ++x) { int in_row = (y * imageHeight) / wanted_input_height; int in_col = (x * imageWidth) / wanted_input_width; - const UInt8 *in_pixel = input + (in_row * imageWidth * imageChannels) + (in_col * imageChannels); + const UInt8 *in_pixel = input + (in_row * sourceRowBytes) + (in_col * imageChannels); float *out_pos = out_row + x; - if (c == 0) { - *out_pos = (in_pixel[c] - means[c].floatValue) * scale; - }else if (c == 1){ - *out_pos = (in_pixel[c] - means[c].floatValue) * scale; - }else if (c == 2){ - *out_pos = (in_pixel[c] - means[c].floatValue) * scale; - } + *out_pos = (in_pixel[2 - c] - means[c].floatValue) * scale; } } } } --(void)preprocess:(const UInt8 *)input output:(float *)output imageWidth:(int)imageWidth imageHeight:(int)imageHeight imageChannels:(int)imageChannels means:(NSArray *)means scale:(float)scale dim:(std::vector)dim{ +-(void)preprocess:(const UInt8 *)input output:(float *)output bytesPerRow:(int)bytesPerRow imageWidth:(int)imageWidth imageHeight:(int)imageHeight imageChannels:(int)imageChannels means:(NSArray *)means stds:(NSArray *)stds scale:(float)scale dim:(std::vector)dim { if (means == nil) { means = @[@0, @0, @0]; } + if (stds == nil) { + stds = @[@1, @1, @1]; + } int wanted_input_width = dim[3]; int wanted_input_height = dim[2]; @@ -212,15 +209,9 @@ static std::mutex shared_mutex; for (int x = 0; x < wanted_input_width; ++x) { int in_row = (y * imageHeight) / wanted_input_height; int in_col = (x * imageWidth) / wanted_input_width; - const UInt8 *in_pixel = input + (in_row * imageWidth * imageChannels) + (in_col * imageChannels); + const UInt8 *in_pixel = input + (in_row * bytesPerRow) + (in_col * imageChannels); float *out_pos = out_row + x; - if (c == 0) { - *out_pos = (in_pixel[c] - means[c].floatValue) * scale; - }else if (c == 1){ - *out_pos = (in_pixel[c] - means[c].floatValue) * scale; - }else if (c == 2){ - *out_pos = (in_pixel[c] - means[c].floatValue) * scale; - } + *out_pos = (in_pixel[2 - c] - means[c].floatValue) / stds[c].floatValue * scale; } } } @@ -278,8 +269,7 @@ static std::mutex shared_mutex; return cpuResult; } -- (PaddleMobileCPUResult *)predict:(CGImageRef)image dim:(NSArray *)dim means:(NSArray *)means scale:(float)scale{ -// printf(" predict one "); +- (PaddleMobileCPUResult *)predict:(CGImageRef)image dim:(NSArray *)dim means:(NSArray *)means stds:(NSArray *)stds scale:(float)scale { std::lock_guard lock(shared_mutex); if (!loaded_) { printf("PaddleMobile doesn't be loaded yet"); @@ -310,7 +300,7 @@ static std::mutex shared_mutex; // sample image float *output = (float *)malloc(numel*sizeof(float)); - [self preprocess:input output:output imageWidth:image_width imageHeight:image_height imageChannels:image_channels means:means scale:scale dim:dim_vec]; + [self preprocess:input output:output bytesPerRow:sourceRowBytes imageWidth:image_width imageHeight:image_height imageChannels:image_channels means:means stds:stds scale:scale dim:dim_vec]; float *dataPointer = nullptr; if (nullptr != output) { dataPointer = output; @@ -351,7 +341,11 @@ static std::mutex shared_mutex; } - (PaddleMobileCPUResult *)predict:(CGImageRef)image dim:(NSArray *)dim { - return [self predict:image dim:dim means:nil scale:1]; + return [self predict:image dim:dim means:nil stds:nil scale:1]; +} + +- (PaddleMobileCPUResult *)predict:(CGImageRef)image dim:(NSArray *)dim means:(NSArray *)means scale:(float)scale { + return [self predict:image dim:dim means:means stds:nil scale:scale]; } - (PaddleMobileCPUResult *)fetchOutput{ diff --git a/mobile/src/io/paddle_inference_api.h b/mobile/src/io/paddle_inference_api.h index ae7d34bd51dd59de9359a471964647c020e18649..c89b998144badcf7b88dbbfcaa631a25df7892d5 100644 --- a/mobile/src/io/paddle_inference_api.h +++ b/mobile/src/io/paddle_inference_api.h @@ -25,7 +25,6 @@ limitations under the License. */ #include #include #include -#include "common/type_define.h" namespace paddle_mobile { @@ -49,6 +48,7 @@ enum PaddleDType { FLOAT16, INT64, INT8, + UINT8, }; enum LayoutType { @@ -86,6 +86,56 @@ class PaddleBuf { bool memory_owned_{true}; }; +typedef enum { + paddle_void = 0, + paddle_float, + paddle_int, + paddle_uint16_t, + paddle_double, + paddle_int64_t, + paddle_size_t, + paddle_int16_t, + paddle_int8_t, + paddle_uint8_t, + paddle_bool, + paddle_string, + paddle_floats = 100, + paddle_ints, + paddle_int64_ts, + paddle_size_ts, + paddle_bools, + paddle_strings, + paddle_const_float = 200, + paddle_const_int, + paddle_block = 300, + paddle_tensor, + paddle_lod_tensor, + paddle_blocks, + paddle_tensors, + paddle_lod_tensors, + paddle_p_block = 400, + paddle_p_tensor, + paddle_p_lod_tensor, + paddle_p_blocks, + paddle_p_tensors, + paddle_p_lod_tensors, + paddle_scopes = 500, + paddle_selected_rows, + paddle_dim0 = 600, + paddle_dim1, + paddle_dim2, + paddle_dim3, + paddle_dim4, + paddle_dim5, + paddle_dim6, + paddle_dim7, + paddle_dim8, + paddle_dim9, +#ifdef PADDLE_MOBILE_CL + paddle_cl_image, +#endif +} PaddlekTypeId_t; + struct PaddleTensor { PaddleTensor() = default; std::string name; // variable name. @@ -93,7 +143,7 @@ struct PaddleTensor { std::vector lod; PaddleBuf data; // blob of data. PaddleDType dtype; - kTypeId_t dtypeid; + PaddlekTypeId_t dtypeid; LayoutType layout; }; @@ -124,6 +174,7 @@ class PaddlePredictor { virtual bool Run(const std::vector& inputs, std::vector* output_data, int batch_size = -1) = 0; + virtual std::string GetExceptionMsg() { return ""; } // Destroy the Predictor. virtual ~PaddlePredictor() = default; @@ -157,15 +208,20 @@ struct PaddleModelMemoryPack { struct PaddleMobileConfig : public PaddlePredictor::Config { enum Precision { FP32 = 0 }; enum Device { kCPU = 0, kFPGA = 1, kGPU_MALI = 2, kGPU_CL = 3 }; + enum PrePostType { NONE_PRE_POST = 0, UINT8_255 = 1 }; enum Precision precision; enum Device device; + enum PrePostType pre_post_type; int batch_size = 1; bool optimize = true; bool quantification = false; + int quantification_fold = 1; bool lod_mode = false; int thread_num = 1; + bool load_when_predict = false; + bool mem_opt = true; std::string cl_path; struct PaddleModelMemoryPack memory_pack; }; diff --git a/mobile/src/io/paddle_mobile.cpp b/mobile/src/io/paddle_mobile.cpp index 95ae3763a2a6415e5fc355ec648633260a5fe411..be69ce0f63803d714b77fc6e81805cec7339f9dd 100644 --- a/mobile/src/io/paddle_mobile.cpp +++ b/mobile/src/io/paddle_mobile.cpp @@ -37,7 +37,8 @@ void PaddleMobile::SetThreadNum(int thread_num, template PMStatus PaddleMobile::Load(const std::string &dirname, bool optimize, bool quantification, - int batch_size, bool lod_mode) { + int batch_size, bool lod_mode, + int quantification_fold) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -46,8 +47,9 @@ PMStatus PaddleMobile::Load(const std::string &dirname, if (executor_.get() == nullptr) { executor_ = std::make_shared>( - loader_->Load(dirname, optimize, quantification), config_, batch_size, - optimize, lod_mode); + loader_->Load(dirname, optimize, quantification, false, + quantification_fold), + config_, batch_size, optimize, lod_mode); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -59,7 +61,8 @@ template PMStatus PaddleMobile::Load(const std::string &model_path, const std::string ¶_path, bool optimize, bool quantification, - int batch_size, bool lod_mode) { + int batch_size, bool lod_mode, + int quantification_fold) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -69,8 +72,9 @@ PMStatus PaddleMobile::Load(const std::string &model_path, if (executor_.get() == nullptr) { executor_ = std::make_shared>( - loader_->Load(model_path, para_path, optimize, quantification), config_, - batch_size, optimize, lod_mode); + loader_->Load(model_path, para_path, optimize, quantification, + quantification_fold), + config_, batch_size, optimize, lod_mode); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -82,11 +86,12 @@ template PMStatus PaddleMobile::Load(const PaddleMobileConfig &config) { if (!config.model_dir.empty()) { return this->Load(config.model_dir, config.optimize, config.quantification, - config.batch_size, config.lod_mode); + config.batch_size, config.lod_mode, + config.quantification_fold); } else if (!config.prog_file.empty() && !config.param_file.empty()) { return this->Load(config.prog_file, config.param_file, config.optimize, - config.quantification, config.batch_size, - config.lod_mode); + config.quantification, config.batch_size, config.lod_mode, + config.quantification_fold); } else { LOG(kLOG_ERROR) << "Failed to load inference model"; return PMNotInitialized; @@ -97,7 +102,7 @@ template bool PaddleMobile::LoadCombinedMemory( size_t model_len, const uint8_t *model_buf, size_t combined_params_len, uint8_t *combined_params_buf, bool optimize, bool quantification, - int batch_size, bool lod_mode) { + int batch_size, bool lod_mode, int quantification_fold) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -107,7 +112,7 @@ bool PaddleMobile::LoadCombinedMemory( executor_ = std::make_shared>( loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len, combined_params_buf, optimize, - quantification), + quantification, quantification_fold), config_, batch_size, optimize, lod_mode); } else { LOG(kLOG_INFO) << "executor inited"; diff --git a/mobile/src/io/paddle_mobile.h b/mobile/src/io/paddle_mobile.h index e39d712447c1bc1c2fb2dd681e25bd4c130ba0e0..8b8f0683abd12d9516e2a2cb09078241c2b7944e 100644 --- a/mobile/src/io/paddle_mobile.h +++ b/mobile/src/io/paddle_mobile.h @@ -50,10 +50,11 @@ class PaddleMobile { PMStatus Load(const std::string &dirname, const bool optimize = false, const bool quantification = false, const int batch_size = 1, - const bool lod_mode = false); + const bool lod_mode = false, const int quantification_fold = 1); PMStatus Load(const std::string &model_path, const std::string ¶_path, const bool optimize = false, const bool quantification = false, - const int batch_size = 1, const bool lod_mode = false); + const int batch_size = 1, const bool lod_mode = false, + const int quantification_fold = 1); PMStatus Load(const PaddleMobileConfig &config); @@ -84,7 +85,7 @@ class PaddleMobile { size_t combined_params_len, uint8_t *combined_params_buf, bool optimize = false, bool quantification = false, int batch_size = 1, - bool lod_mode = false); + bool lod_mode = false, int quantification_fold = 1); void SetThreadNum(int thread_num, PowerMode power_mode = PERFORMANCE_PRIORITY); diff --git a/mobile/src/operators/assign_value_op.cpp b/mobile/src/operators/assign_value_op.cpp index 49494929de9b146a1a91586c8ca10302d54eedb4..5100c2246bd5a2840d503914e5f4057827e162dd 100644 --- a/mobile/src/operators/assign_value_op.cpp +++ b/mobile/src/operators/assign_value_op.cpp @@ -34,4 +34,8 @@ namespace ops = paddle_mobile::operators; REGISTER_OPERATOR_CPU(assign_value, ops::AssignValueOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(assign_value, ops::AssignValueOp); +#endif + #endif // ASSIGN_VALUE_OP diff --git a/mobile/src/operators/elementwise_mul_op.cpp b/mobile/src/operators/elementwise_mul_op.cpp index 61001ff4ec6be5bc76e5e6dd12093b2e56c12b96..48b2a4c282c3527460baa4b321badcae89783b5d 100644 --- a/mobile/src/operators/elementwise_mul_op.cpp +++ b/mobile/src/operators/elementwise_mul_op.cpp @@ -32,6 +32,9 @@ namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(elementwise_mul, ops::ElementwiseMulOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(elementwise_mul, ops::ElementwiseMulOp); +#endif #ifdef PADDLE_MOBILE_FPGA REGISTER_OPERATOR_FPGA(elementwise_mul, ops::ElementwiseMulOp); #endif diff --git a/mobile/src/operators/fusion_instancenorm_relu_op.cpp b/mobile/src/operators/fusion_instancenorm_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f6299fa72db06e03f54d382cfb761580294042df --- /dev/null +++ b/mobile/src/operators/fusion_instancenorm_relu_op.cpp @@ -0,0 +1,39 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_INSTANCENORM_RELU_OP + +#include "operators/fusion_instancenorm_relu_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionInstanceNormReluOp::InferShape() const { + auto x_dims = this->param_.InputX()->dims(); + this->param_.Out()->Resize(x_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_instancenorm_relu, + ops::FusionInstanceNormReluMatcher); + +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(fusion_instancenorm_relu, ops::FusionInstanceNormReluOp); +#endif + +#endif diff --git a/mobile/src/operators/fusion_instancenorm_relu_op.h b/mobile/src/operators/fusion_instancenorm_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ce2623e4dda46a0952fede3e1a25012ed5da4394 --- /dev/null +++ b/mobile/src/operators/fusion_instancenorm_relu_op.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_INSTANCENORM_RELU_OP + +#pragma once + +#include +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/instancenorm_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionInstanceNormReluMatcher : public framework::FusionOpMatcher { + public: + FusionInstanceNormReluMatcher() { + node_ = framework::Node(G_OP_TYPE_INSTANCENORM); + node_ > std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), {}, removed_nodes); + } + std::string Type() { return G_OP_TYPE_FUSION_INSTANCENORM_RELU; } +}; + +template +class FusionInstanceNormReluOp + : public framework::OperatorWithKernel< + DeviceType, InstanceNormParam, + operators::InstanceNormReluKernel> { + public: + FusionInstanceNormReluOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + framework::Scope *scope) + : framework::OperatorWithKernel< + DeviceType, InstanceNormParam, + operators::InstanceNormReluKernel>( + type, inputs, outputs, attrs, scope) {} + + void InferShape() const override; + + protected: +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/arm/assign_value_kernel.cpp b/mobile/src/operators/kernel/arm/assign_value_kernel.cpp index 7390f77ed1428425ecbdaaa9cbd494f847af5de3..2e98b9f77712936bb8601065fa401c5e41df18ce 100644 --- a/mobile/src/operators/kernel/arm/assign_value_kernel.cpp +++ b/mobile/src/operators/kernel/arm/assign_value_kernel.cpp @@ -67,6 +67,20 @@ void AssignValueKernel::Compute( param.int32_values_)); } +template <> +bool AssignValueKernel::Init(AssignValueParam* param) { + return true; +} + +template <> +void AssignValueKernel::Compute( + const AssignValueParam& param) { + framework::VisitDataType( + framework::ToDataType(param.dtype_), + AssignValueOpFunctor(param.output_, param.shape_, param.fp32_values_, + param.int32_values_)); +} + } // namespace operators } // namespace paddle_mobile diff --git a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp index 08cae42762a148b13dc16e4be11b520da9ffb82b..5c92cdbfd0001cc277d59fc9a6d5c526a43b61ed 100644 --- a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp +++ b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp @@ -20,6 +20,8 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { bool use_lws = true; +int preferred_lws = 0; +int preferred_lws_divisor = 2; template <> void winograd_transform_weight<4, 3>(framework::CLHelper *cl_helper, @@ -32,6 +34,165 @@ void WinogradConv3x3<4, 3>(framework::CLHelper *cl_helper, const framework::CLImage *new_scale, const framework::CLImage *new_bias) {} +void ConvAddBnReluPt1x2(framework::CLHelper *cl_helper, + const ConvParam ¶m, bool ifRelu, + const framework::CLImage *biase, + const framework::CLImage *new_scale, + const framework::CLImage *new_bias) { + auto kernel = cl_helper->KernelAt(0); + auto default_work_size = cl_helper->DefaultWorkSize(*param.Output()); + default_work_size[1] = (default_work_size[1] + 1) / 2; + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + auto input = param.Input()->GetCLImage(); + auto filter = param.Filter()->GetCLImage(); + + auto output = param.Output()->GetCLImage(); + int stride = param.Strides()[0]; + int offset = param.Offset(); + int input_c = reinterpret_cast( + param.Input()->Converter()) + ->GetCBlock(); + int dilation = param.Dilations()[0]; + int input_width = param.Input()->dims()[3]; + int input_height = param.Input()->dims()[2]; + int output_width = param.Output()->dims()[3]; + int output_height = param.Output()->dims()[2]; + int output_c = param.Output()->dims()[1]; + int filter_channel = param.Filter()->dims()[1]; + int input_channel = param.Input()->dims()[1]; + // + // DLOG << " c block " << c_block; + // DLOG << " w " << w; + // DLOG << " nh " << nh; + // DLOG << " stride " << stride; + // DLOG << " offset " << offset; + // DLOG << " input_c " << input_c; + // DLOG << " dilation " << dilation; + // DLOG << " input width " << input_width; + // DLOG << " input height " << input_height; + // DLOG << " output width " << output_width; + // DLOG << " output height " << output_height; + // DLOG << " input dim " << param.Input()->dims(); + // DLOG << " output dim " << param.Output()->dims(); + // DLOG << " filter dim " << param.Filter()->dims(); + + cl_int status; + int index = 0; + + status = clSetKernelArg(kernel, index++, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &w); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &nh); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &input); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &filter); + CL_CHECK_ERRORS(status); + + if (biase) { + auto bias_mem = biase->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &bias_mem); + CL_CHECK_ERRORS(status); + } + + if (new_scale && new_bias) { + auto new_scale_mem = new_scale->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &new_scale_mem); + CL_CHECK_ERRORS(status); + + auto new_bias_mem = new_bias->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &new_bias_mem); + CL_CHECK_ERRORS(status); + } + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &output); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &stride); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &offset); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_c); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &dilation); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &output_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); + CL_CHECK_ERRORS(status); + + if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) { + if (filter_channel != input_channel) { + if (filter_channel != 1) { + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); + CL_CHECK_ERRORS(status); + int has_group = 1; + status = clSetKernelArg(kernel, index++, sizeof(int), &has_group); + CL_CHECK_ERRORS(status); + } + } else { + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); + CL_CHECK_ERRORS(status); + int has_group = 0; + status = clSetKernelArg(kernel, index++, sizeof(int), &has_group); + CL_CHECK_ERRORS(status); + } + } + // DLOG<<"default_work_size"<KernelWorkSize(kernel); + auto tmp0 = default_work_size.data()[0]; + auto tmp1 = default_work_size.data()[1]; + auto tmp2 = default_work_size.data()[2]; + int max_work_size = static_cast(kernel_work_size); + if (preferred_lws_divisor > 1) { + max_work_size /= preferred_lws_divisor; + } + if (preferred_lws > 0 && preferred_lws <= max_work_size) { + max_work_size = preferred_lws; + } + while (tmp1 > max_work_size && max_work_size > 0) { + tmp1 = tmp1 % 2 == 0 ? tmp1 / 2 : 1; + } + while (tmp2 * tmp1 > max_work_size && max_work_size > 0) { + tmp2 = tmp2 % 2 == 0 ? tmp2 / 2 : 1; + } + while (tmp0 * tmp1 * tmp2 > max_work_size && max_work_size > 0) { + tmp0 = tmp0 % 2 == 0 ? tmp0 / 2 : 1; + } + const size_t local_work_size[3] = {static_cast(tmp0), + static_cast(tmp1), + static_cast(tmp2)}; + if (max_work_size > 0 && use_lws) { + status = clEnqueueNDRangeKernel( + cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), local_work_size, 0, NULL, NULL); + } else { + status = clEnqueueNDRangeKernel( + cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + } + CL_CHECK_ERRORS(status); +} + void ConvAddBnRelu(framework::CLHelper *cl_helper, const ConvParam ¶m, bool ifRelu, const framework::CLImage *biase, @@ -51,11 +212,13 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, int input_c = reinterpret_cast( param.Input()->Converter()) ->GetCBlock(); + int input_c_origin = param.Input()->dims()[1]; int dilation = param.Dilations()[0]; int input_width = param.Input()->dims()[3]; int input_height = param.Input()->dims()[2]; int output_width = param.Output()->dims()[3]; int output_height = param.Output()->dims()[2]; + int output_c = param.Output()->dims()[1]; int filter_channel = param.Filter()->dims()[1]; int input_channel = param.Input()->dims()[1]; @@ -122,6 +285,9 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, status = clSetKernelArg(kernel, index++, sizeof(int), &input_c); CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, index++, sizeof(int), &input_c_origin); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, index++, sizeof(int), &dilation); CL_CHECK_ERRORS(status); @@ -145,10 +311,30 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, static_cast(maped_w), static_cast(default_work_size.data()[2])}; - if (work_size[1] % 60 == 0 && use_lws) { - const size_t local_work_size[3] = {static_cast(1), - static_cast(60), - static_cast(1)}; + auto kernel_work_size = cl_helper->KernelWorkSize(kernel); + auto tmp0 = work_size[0]; + auto tmp1 = work_size[1]; + auto tmp2 = work_size[2]; + int max_work_size = static_cast(kernel_work_size); + if (preferred_lws_divisor > 1) { + max_work_size /= preferred_lws_divisor; + } + if (preferred_lws > 0 && preferred_lws <= max_work_size) { + max_work_size = preferred_lws; + } + while (tmp1 > max_work_size && max_work_size > 0) { + tmp1 = tmp1 % 2 == 0 ? tmp1 / 2 : 1; + } + while (tmp2 * tmp1 > max_work_size && max_work_size > 0) { + tmp2 = tmp2 % 2 == 0 ? tmp2 / 2 : 1; + } + while (tmp0 * tmp1 * tmp2 > max_work_size && max_work_size > 0) { + tmp0 = tmp0 % 2 == 0 ? tmp0 / 2 : 1; + } + const size_t local_work_size[3] = {static_cast(tmp0), + static_cast(tmp1), + static_cast(tmp2)}; + if (max_work_size > 0 && use_lws) { status = clEnqueueNDRangeKernel(cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, work_size, local_work_size, 0, NULL, NULL); @@ -218,20 +404,24 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, CL_CHECK_ERRORS(status); if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) { - if (filter_channel != input_channel) { - if (filter_channel != 1) { - status = - clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); - CL_CHECK_ERRORS(status); - int has_group = 1; - status = clSetKernelArg(kernel, index++, sizeof(int), &has_group); - CL_CHECK_ERRORS(status); - } - } else { + // normal conv + if (param.Filter()->dims()[0] == param.Output()->dims()[1] && + param.Filter()->dims()[1] == param.Input()->dims()[1]) { + status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); CL_CHECK_ERRORS(status); - int has_group = 0; - status = clSetKernelArg(kernel, index++, sizeof(int), &has_group); + int group = 1; + status = clSetKernelArg(kernel, index++, sizeof(int), &group); + CL_CHECK_ERRORS(status); + } else if (!(param.Filter()->dims()[0] == param.Input()->dims()[1] && + param.Filter()->dims()[1] == 1)) { // not depwise + status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); + CL_CHECK_ERRORS(status); + int group = input_channel / filter_channel; + status = clSetKernelArg(kernel, index++, sizeof(int), &group); CL_CHECK_ERRORS(status); } } @@ -345,10 +535,30 @@ void DWConvAddBnRelu(framework::CLHelper *cl_helper, status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); CL_CHECK_ERRORS(status); - if (default_work_size.data()[1] % 60 == 0 && use_lws) { - const size_t local_work_size[3] = {static_cast(1), - static_cast(60), - static_cast(1)}; + auto kernel_work_size = cl_helper->KernelWorkSize(kernel); + auto tmp0 = default_work_size.data()[0]; + auto tmp1 = default_work_size.data()[1]; + auto tmp2 = default_work_size.data()[2]; + int max_work_size = static_cast(kernel_work_size); + if (preferred_lws_divisor > 1) { + max_work_size /= preferred_lws_divisor; + } + if (preferred_lws > 0 && preferred_lws <= max_work_size) { + max_work_size = preferred_lws; + } + while (tmp1 > max_work_size && max_work_size > 0) { + tmp1 = tmp1 % 2 == 0 ? tmp1 / 2 : 1; + } + while (tmp2 * tmp1 > max_work_size && max_work_size > 0) { + tmp2 = tmp2 % 2 == 0 ? tmp2 / 2 : 1; + } + while (tmp0 * tmp1 * tmp2 > max_work_size && max_work_size > 0) { + tmp0 = tmp0 % 2 == 0 ? tmp0 / 2 : 1; + } + const size_t local_work_size[3] = {static_cast(tmp0), + static_cast(tmp1), + static_cast(tmp2)}; + if (max_work_size > 0 && use_lws) { status = clEnqueueNDRangeKernel( cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), local_work_size, 0, NULL, NULL); @@ -391,7 +601,6 @@ void SWConvAddBnRelu(framework::CLHelper *cl_helper, int input_channel = param.Input()->dims()[1]; int input_height = param.Input()->dims()[2]; int input_width = param.Input()->dims()[3]; - int output_height = param.Output()->dims()[2]; int output_width = param.Output()->dims()[3]; @@ -454,10 +663,30 @@ void SWConvAddBnRelu(framework::CLHelper *cl_helper, status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); CL_CHECK_ERRORS(status); - if (default_work_size.data()[1] % 60 == 0 && use_lws) { - const size_t local_work_size[3] = {static_cast(1), - static_cast(60), - static_cast(1)}; + auto kernel_work_size = cl_helper->KernelWorkSize(kernel); + auto tmp0 = default_work_size.data()[0]; + auto tmp1 = default_work_size.data()[1]; + auto tmp2 = default_work_size.data()[2]; + int max_work_size = static_cast(kernel_work_size); + if (preferred_lws_divisor > 1) { + max_work_size /= preferred_lws_divisor; + } + if (preferred_lws > 0 && preferred_lws <= max_work_size) { + max_work_size = preferred_lws; + } + while (tmp1 > max_work_size && max_work_size > 0) { + tmp1 = tmp1 % 2 == 0 ? tmp1 / 2 : 1; + } + while (tmp2 * tmp1 > max_work_size && max_work_size > 0) { + tmp2 = tmp2 % 2 == 0 ? tmp2 / 2 : 1; + } + while (tmp0 * tmp1 * tmp2 > max_work_size && max_work_size > 0) { + tmp0 = tmp0 % 2 == 0 ? tmp0 / 2 : 1; + } + const size_t local_work_size[3] = {static_cast(tmp0), + static_cast(tmp1), + static_cast(tmp2)}; + if (max_work_size > 0 && use_lws) { status = clEnqueueNDRangeKernel( cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), local_work_size, 0, NULL, NULL); @@ -587,11 +816,11 @@ void DWConvTransposeAddBnRelu(framework::CLHelper *cl_helper, CL_CHECK_ERRORS(status); } -void ConvTransposeAddBnRelu(framework::CLHelper *cl_helper, - const ConvTransposeParam ¶m, - bool ifRelu, const framework::CLImage *biase, - const framework::CLImage *new_scale, - const framework::CLImage *new_bias) { +void ConvTransposeAddBnRelu_b(framework::CLHelper *cl_helper, + const ConvTransposeParam ¶m, + bool ifRelu, const framework::CLImage *biase, + const framework::CLImage *new_scale, + const framework::CLImage *new_bias) { auto kernel = cl_helper->KernelAt(0); const auto *input = param.Input(); auto *output = param.Output(); @@ -638,5 +867,259 @@ void ConvTransposeAddBnRelu(framework::CLHelper *cl_helper, clEnqueueNDRangeKernel(cl_helper->CLCommandQueue(), kernel, 3, NULL, work_size, NULL, 0, NULL, NULL); } +void ConvTransposeAddBnRelu(framework::CLHelper *cl_helper, + const ConvTransposeParam ¶m, + bool ifRelu, const framework::CLImage *biase, + const framework::CLImage *new_scale, + const framework::CLImage *new_bias) { + auto kernel = cl_helper->KernelAt(0); + auto default_work_size = cl_helper->DefaultWorkSize(*param.Output()); + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + + int w_blk_size = 1; + int w_blk = (w + w_blk_size - 1) / w_blk_size; + default_work_size[1] = w_blk; + + int h_blk_size = 1; + int h_blk = (nh + h_blk_size - 1) / h_blk_size; + default_work_size[2] = h_blk; + + auto input = param.Input()->GetCLImage(); + auto filter = param.Filter()->GetCLImage(); + + auto output = param.Output()->GetCLImage(); + int stride = param.Strides()[0]; + int pad = param.Paddings()[0]; + int dilation = param.Dilations()[0]; + + int input_channel = param.Input()->dims()[1]; + int input_height = param.Input()->dims()[2]; + int input_width = param.Input()->dims()[3]; + + int output_height = param.Output()->dims()[2]; + int output_width = param.Output()->dims()[3]; + + int filter_height = param.Filter()->dims()[2]; + int filter_width = param.Filter()->dims()[3]; + + cl_int status; + int index = 0; + + status = clSetKernelArg(kernel, index++, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &w_blk); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &h_blk); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &input); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &filter); + CL_CHECK_ERRORS(status); + + if (biase) { + auto bias_mem = biase->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &bias_mem); + CL_CHECK_ERRORS(status); + } + + if (new_scale && new_bias) { + auto new_scale_mem = new_scale->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &new_scale_mem); + CL_CHECK_ERRORS(status); + + auto new_bias_mem = new_bias->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &new_bias_mem); + CL_CHECK_ERRORS(status); + } + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &output); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &stride); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &pad); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, index++, sizeof(int), &dilation); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_channel); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &output_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_height); + CL_CHECK_ERRORS(status); + + if (default_work_size.data()[1] % 60 == 0 && use_lws) { + const size_t local_work_size[3] = {static_cast(1), + static_cast(60), + static_cast(1)}; + status = clEnqueueNDRangeKernel( + cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), local_work_size, 0, NULL, NULL); + } else { + status = clEnqueueNDRangeKernel( + cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + } + CL_CHECK_ERRORS(status); +} +void ConvTranspose3x3s2AddBnRelu(framework::CLHelper *cl_helper, + const ConvTransposeParam ¶m, + bool ifRelu, const framework::CLImage *biase, + const framework::CLImage *new_scale, + const framework::CLImage *new_bias) { + auto kernel = cl_helper->KernelAt(0); + auto default_work_size = cl_helper->DefaultWorkSize(*param.Output()); + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + + int w_blk_size = 5; + int w_blk = (w + w_blk_size - 1 + 5) / w_blk_size / 2 * 2; + default_work_size[1] = w_blk; + + int h_blk_size = 1; + int h_blk = (nh + h_blk_size - 1) / h_blk_size; + default_work_size[2] = h_blk; + + auto input = param.Input()->GetCLImage(); + auto filter = param.Filter()->GetCLImage(); + + auto output = param.Output()->GetCLImage(); + int stride = param.Strides()[0]; + int pad = param.Paddings()[0]; + int dilation = param.Dilations()[0]; + + int input_channel = param.Input()->dims()[1]; + int input_height = param.Input()->dims()[2]; + int input_width = param.Input()->dims()[3]; + + int output_height = param.Output()->dims()[2]; + int output_width = param.Output()->dims()[3]; + + int filter_height = param.Filter()->dims()[2]; + int filter_width = param.Filter()->dims()[3]; + + cl_int status; + int index = 0; + + status = clSetKernelArg(kernel, index++, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &w_blk); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &h_blk); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &input); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &filter); + CL_CHECK_ERRORS(status); + + if (biase) { + auto bias_mem = biase->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &bias_mem); + CL_CHECK_ERRORS(status); + } + + if (new_scale && new_bias) { + auto new_scale_mem = new_scale->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &new_scale_mem); + CL_CHECK_ERRORS(status); + + auto new_bias_mem = new_bias->GetCLImage(); + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &new_bias_mem); + CL_CHECK_ERRORS(status); + } + + status = clSetKernelArg(kernel, index++, sizeof(cl_mem), &output); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &stride); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &pad); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, index++, sizeof(int), &dilation); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_channel); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &input_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &output_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_height); + CL_CHECK_ERRORS(status); + + auto kernel_work_size = cl_helper->KernelWorkSize(kernel); + auto tmp0 = default_work_size.data()[0]; + auto tmp1 = default_work_size.data()[1]; + auto tmp2 = default_work_size.data()[2]; + int max_work_size = static_cast(kernel_work_size); + if (preferred_lws_divisor > 1) { + max_work_size /= preferred_lws_divisor; + } + if (preferred_lws > 0 && preferred_lws <= max_work_size) { + max_work_size = preferred_lws; + } + while (tmp1 > max_work_size && max_work_size > 0) { + tmp1 = tmp1 % 2 == 0 ? tmp1 / 2 : 1; + } + while (tmp2 * tmp1 > max_work_size && max_work_size > 0) { + tmp2 = tmp2 % 2 == 0 ? tmp2 / 2 : 1; + } + while (tmp0 * tmp1 * tmp2 > max_work_size && max_work_size > 0) { + tmp0 = tmp0 % 2 == 0 ? tmp0 / 2 : 1; + } + const size_t local_work_size[3] = {static_cast(tmp0), + static_cast(tmp1), + static_cast(tmp2)}; + if (max_work_size > 0 && use_lws) { + status = clEnqueueNDRangeKernel( + cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), local_work_size, 0, NULL, NULL); + } else { + status = clEnqueueNDRangeKernel( + cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + } + CL_CHECK_ERRORS(status); +} } // namespace operators } // namespace paddle_mobile diff --git a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.h b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.h index a0dcd99d9e10a22ce040317ef242d69722c6b9ca..a2488aaa2def03eb3d2165c8720177651ff5e5e5 100644 --- a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.h +++ b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.h @@ -41,6 +41,12 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, const framework::CLImage *new_scale = nullptr, const framework::CLImage *new_bias = nullptr); +void ConvAddBnReluPt1x2(framework::CLHelper *cl_helper, + const ConvParam ¶m, bool ifRelu = false, + const framework::CLImage *biase = nullptr, + const framework::CLImage *new_scale = nullptr, + const framework::CLImage *new_bias = nullptr); + void DWConvAddBnRelu(framework::CLHelper *cl_helper, const ConvParam ¶m, bool ifRelu = false, const framework::CLImage *biase = nullptr, @@ -64,6 +70,18 @@ void ConvTransposeAddBnRelu(framework::CLHelper *cl_helper, const framework::CLImage *biase = nullptr, const framework::CLImage *new_scale = nullptr, const framework::CLImage *new_bias = nullptr); +void ConvTransposeAddBnRelu_b(framework::CLHelper *cl_helper, + const ConvTransposeParam ¶m, + bool ifRelu = false, + const framework::CLImage *biase = nullptr, + const framework::CLImage *new_scale = nullptr, + const framework::CLImage *new_bias = nullptr); +void ConvTranspose3x3s2AddBnRelu(framework::CLHelper *cl_helper, + const ConvTransposeParam ¶m, + bool ifRelu = false, + const framework::CLImage *biase = nullptr, + const framework::CLImage *new_scale = nullptr, + const framework::CLImage *new_bias = nullptr); } // namespace operators } // namespace paddle_mobile diff --git a/mobile/src/operators/kernel/cl/cl-kernel-func/instancenorm_func.cpp b/mobile/src/operators/kernel/cl/cl-kernel-func/instancenorm_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84c3230d82bd2bfb54210e3e57ecf95bb43b7ff9 --- /dev/null +++ b/mobile/src/operators/kernel/cl/cl-kernel-func/instancenorm_func.cpp @@ -0,0 +1,77 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/cl/cl-kernel-func/instancenorm_func.h" +#include +namespace paddle_mobile { +namespace operators { +void InstanceNorm(framework::CLHelper *cl_helper, + const InstanceNormParam ¶m) { + auto kernel = cl_helper->KernelAt(0); + + auto &dims = param.Out()->dims(); + const int n = dims[0]; + const int c_group = (dims[1] + 3) / 4; + const int h = dims[2]; + const int w = dims[3]; + auto epsilon = param.Epsilon(); + auto input = param.InputX()->GetCLImage(); + auto out = param.Out()->GetCLImage(); + + // DLOG << "Epsilon: " << epsilon; + + auto local_work_size_info = cl_helper->LocalWorkSizeInfo(); + // + // DLOG << local_work_size_info.max_work_group_size; + // DLOG << local_work_size_info.max_work_item_size0; + // DLOG << local_work_size_info.max_work_item_size1; + // DLOG << local_work_size_info.max_work_item_size2; + int maxTotal = + std::min(static_cast(local_work_size_info.max_work_group_size), 256); + int local_work_size1 = + std::min(static_cast(local_work_size_info.max_work_item_size1), + std::min(256, w)); + int local_work_size2 = 1; + const size_t work_size[3] = {(size_t)(n * c_group), (size_t)local_work_size1, + (size_t)local_work_size2}; + const size_t local_work_size[3] = {(size_t)1, (size_t)local_work_size1, + (size_t)local_work_size2}; + + // DLOG << "work_size" << work_size[0] << " " << work_size[1] << " " + // << work_size[2]; + // DLOG << "local_work_size" << local_work_size[0] << " " << + // local_work_size[1] + // << " " << local_work_size[2]; + cl_int status; + clSetKernelArg(kernel, 0, sizeof(cl_int), &w); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 1, sizeof(cl_int), &h); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 2, sizeof(cl_int), &c_group); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 3, sizeof(cl_int), &local_work_size1); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 4, sizeof(cl_int), &local_work_size2); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 5, sizeof(cl_float), &epsilon); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 6, sizeof(cl_mem), &input); + CL_CHECK_ERRORS(status); + clSetKernelArg(kernel, 7, sizeof(cl_mem), &out); + CL_CHECK_ERRORS(status); + clEnqueueNDRangeKernel(cl_helper->CLCommandQueue(), kernel, 3, NULL, + work_size, local_work_size, 0, NULL, NULL); +} +} // namespace operators +} // namespace paddle_mobile diff --git a/mobile/src/operators/kernel/cl/cl-kernel-func/instancenorm_func.h b/mobile/src/operators/kernel/cl/cl-kernel-func/instancenorm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..45c0bcd4e8e8ea0d6c24904b4fa7fc763d3e9bc1 --- /dev/null +++ b/mobile/src/operators/kernel/cl/cl-kernel-func/instancenorm_func.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(INSTANCENORM_OP) || defined(FUSION_INSTANCENORM_RELU_OP) + +#pragma once + +#include "framework/cl/cl_helper.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { +void InstanceNorm(framework::CLHelper *cl_helper, + const InstanceNormParam ¶m); +} +} // namespace paddle_mobile +#endif diff --git a/mobile/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl b/mobile/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl index b91be321a6c85c045c5b26f4a99ae70a756368f9..2232cdc0a43d1bcda093007f7ef167a34b246e21 100755 --- a/mobile/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl +++ b/mobile/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl @@ -48,13 +48,14 @@ __kernel void conv_3x3(__private const int global_size_dim0, __private const int input_height,/* of one block */ __private const int output_width, __private const int output_height, + __private const int output_c, __private const int filter_channel, - __private const int has_group) { + __private const int group) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); - + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); if (out_c >= global_size_dim0 || @@ -90,7 +91,7 @@ __kernel void conv_3x3(__private const int global_size_dim0, #endif half4 input[9]; - if (has_group == 0) { + if (group == 1) { for (int i = 0; i < input_c; ++i) { int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); input[0] = select(read_imageh(input_image, sampler, @@ -326,7 +327,7 @@ __kernel void conv_3x3(__private const int global_size_dim0, } } else { for (int i = 0; i < 4; i++) { - int used_input_channel_num = (out_c * 4 + i) * filter_channel; + int used_input_channel_num = (out_c * 4 + i) / (output_c / group) * filter_channel; for (int f_c = 0; f_c < filter_channel; ++f_c) { int input_c = used_input_channel_num + f_c; int input_block = input_c / 4; @@ -424,8 +425,8 @@ __kernel void conv_3x3(__private const int global_size_dim0, write_imageh(output_image, output_pos, output); } - // dilation == 1 && stride == 1 && ou_nh == ou_h -__kernel void conv_3x3s1(__private const int item_ch, + // dilation == 1 +__kernel void conv_3x3spl(__private const int item_ch, __private const int item_w, __private const int item_h, __read_only image2d_t input_image, @@ -456,14 +457,8 @@ __read_only image2d_t new_scale, const int item_w_id = get_global_id(1); const int item_h_id = get_global_id(2); - // in_width_id_per_blk - int in_w_id0 = item_w_id - pad; - int in_w_id1 = in_w_id0 + item_w; - int in_w_id2 = in_w_id1 + item_w; - int in_w_id3 = in_w_id2 + item_w; - int in_w_id4 = in_w_id3 + item_w; - - // out_width_id_per_blk + // out_width_id_per_blk and out_batch_id + int out_batch_id = item_h_id / in_h; int out_w_base_id = item_ch_id * out_w; int out_w_id0 = item_w_id; int out_w_id1 = out_w_id0 + item_w; @@ -471,6 +466,14 @@ __read_only image2d_t new_scale, int out_w_id3 = out_w_id2 + item_w; int out_w_id4 = out_w_id3 + item_w; + // in_width_id_per_blk and in_height_id_per_batch + int in_h_id = (item_h_id % out_h) * stride - pad; + int in_w_id0 = item_w_id * stride - pad; + int in_w_id1 = in_w_id0 + item_w * stride; + int in_w_id2 = in_w_id1 + item_w * stride; + int in_w_id3 = in_w_id2 + item_w * stride; + int in_w_id4 = in_w_id3 + item_w * stride; + #ifdef BIASE_CH half4 output[5]; @@ -518,8 +521,8 @@ __read_only image2d_t new_scale, for (int h = 0; h < 3; h++) { - int in_h_val = select(item_h_id + h - pad, -1, - (item_h_id + h - pad < 0 || item_h_id + h - pad >= in_h)); + int in_h_val = select(out_batch_id * in_h + in_h_id + h, -1, + (out_batch_id * in_h + in_h_id + h < 0 || out_batch_id * in_h + in_h_id + h >= in_h)); for (int w = 0; w < 3; w++) { @@ -539,7 +542,6 @@ __read_only image2d_t new_scale, filter[2] = read_imageh(filter_image, sampler,(int2)(filter_w_val + w,filter_h_val2 + h)); // in_ch:0-3,out_ch:2 filter[3] = read_imageh(filter_image, sampler,(int2)(filter_w_val + w,filter_h_val3 + h)); // in_ch:0-3,out_ch:3 - filter_trans[0] = (half4)(filter[0].x, filter[1].x, filter[2].x, filter[3].x); // in_ch:0,out_ch:0-3 filter_trans[1] = (half4)(filter[0].y, filter[1].y, filter[2].y, filter[3].y); // in_ch:1,out_ch:0-3 filter_trans[2] = (half4)(filter[0].z, filter[1].z, filter[2].z, filter[3].z); // in_ch:2,out_ch:0-3 @@ -954,7 +956,7 @@ __kernel void conv_1x1(__private const int global_size_dim0, const int out_nh = get_global_id(2); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); - + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; @@ -1016,7 +1018,7 @@ __kernel void conv_1x1_spl( __read_only image2d_t new_scale, __read_only image2d_t new_biase, #endif __write_only image2d_t output_image, __private const int stride, - __private const int offset, __private const int input_c, + __private const int offset, __private const int input_c,__private const int input_c_origin, __private const int dilation, __private const int input_width, /* of one block */ __private const int input_height, /* of one block */ @@ -1034,10 +1036,6 @@ __kernel void conv_1x1_spl( int out_w2 = out_w + global_size_dim1 * 2; int out_w3 = out_w + global_size_dim1 * 3; -// int out_w1 = out_w + global_size_dim1; -// int out_w2 = out_w + global_size_dim1 * 2; -// int out_w3 = out_w + global_size_dim1 * 3; - int outpos_main = mul24(out_c , old_w); int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); @@ -1082,6 +1080,9 @@ __kernel void conv_1x1_spl( half4 output2 = 0.0f; half4 output3 = 0.0f; #endif + + int max_w_bound = input_c * input_width; + int burndary_index = input_c * 4 - input_c_origin; for (int i = 0; i < input_c; ++i) { // ------------0--------------- int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, in_pos_in_one_block0.y); @@ -1092,21 +1093,44 @@ __kernel void conv_1x1_spl( half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3)); + int bound_gap = max_w_bound - pos_in.x - 1; + if (bound_gap < input_width && bound_gap >= 0){ + if (burndary_index==0){ + // do nothing + } else if (burndary_index==1){ + input0.w = 0.0f; + } else if (burndary_index==2){ + input0.z = 0.0f; + input0.w = 0.0f; + } else if (burndary_index==3){ + input0.y = 0.0f; + input0.z = 0.0f; + input0.w = 0.0f; + } + } output0 = mad(input0.x, weight0, output0); output0 = mad(input0.y, weight1, output0); output0 = mad(input0.z, weight2, output0); output0 = mad(input0.w, weight3, output0); - // -------------1-------------- pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, in_pos_in_one_block1.y); half4 input1 = read_imageh(input_image, sampler, pos_in); - // - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); + bound_gap = max_w_bound - pos_in.x - 1; + if (bound_gap < input_width && bound_gap >= 0){ + if (burndary_index==0){ + // do nothing + } else if (burndary_index==1){ + input1.w = 0.0f; + } else if (burndary_index==2){ + input1.z = 0.0f; + input1.w = 0.0f; + } else if (burndary_index==3){ + input1.y = 0.0f; + input1.z = 0.0f; + input1.w = 0.0f; + } + } output1 = mad(input1.x, weight0, output1); output1 = mad(input1.y, weight1, output1); output1 = mad(input1.z, weight2, output1); @@ -1116,12 +1140,21 @@ __kernel void conv_1x1_spl( pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, in_pos_in_one_block2.y); half4 input2 = read_imageh(input_image, sampler, pos_in); - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); - + bound_gap = max_w_bound - pos_in.x - 1; + if (bound_gap < input_width && bound_gap >= 0){ + if (burndary_index==0){ + // do nothing + } else if (burndary_index==1){ + input2.w = 0.0f; + } else if (burndary_index==2){ + input2.z = 0.0f; + input2.w = 0.0f; + } else if (burndary_index==3){ + input2.y = 0.0f; + input2.z = 0.0f; + input2.w = 0.0f; + } + } output2 = mad(input2.x, weight0, output2); output2 = mad(input2.y, weight1, output2); output2 = mad(input2.z, weight2, output2); @@ -1130,12 +1163,21 @@ __kernel void conv_1x1_spl( // -------------3-------------- pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, in_pos_in_one_block3.y); half4 input3 = read_imageh(input_image, sampler, pos_in); - - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); + bound_gap = max_w_bound - pos_in.x - 1; + if (bound_gap < input_width && bound_gap >= 0){ + if (burndary_index==0){ + // do nothing + } else if (burndary_index==1){ + input3.w = 0.0f; + } else if (burndary_index==2){ + input3.z = 0.0f; + input3.w = 0.0f; + } else if (burndary_index==3){ + input3.y = 0.0f; + input3.z = 0.0f; + input3.w = 0.0f; + } + } output3 = mad(input3.x, weight0, output3); output3 = mad(input3.y, weight1, output3); @@ -1181,933 +1223,476 @@ __kernel void conv_1x1_spl( } } -__kernel void conv_1x1_spl2( - __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2, __read_only image2d_t input_image, - __read_only image2d_t filter, -#ifdef BIASE - __read_only image2d_t bias, +__kernel void conv_7x7(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, + +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, #endif + #ifdef BATCH_NORM - __read_only image2d_t new_scale, __read_only image2d_t new_biase, + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, #endif - __write_only image2d_t output_image, __private const int stride, - __private const int offset, __private const int input_c, - __private const int dilation, - __private const int input_width, /* of one block */ - __private const int input_height, /* of one block */ - __private const int output_width, - __private const int output_height, - __private const int old_w -) { - - const int out_c = get_global_id(0); - const int out_w = get_global_id(1); - const int out_nh = get_global_id(2); - int out_w0 = out_w; - int out_w1 = out_w + global_size_dim1; - int out_w2 = out_w + global_size_dim1 * 2; - int out_w3 = out_w + global_size_dim1 * 3; - int out_w4 = out_w + global_size_dim1 * 4; - int out_w5 = out_w + global_size_dim1 * 5; - int out_w6 = out_w + global_size_dim1 * 6; - int out_w7 = out_w + global_size_dim1 * 7; + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width,/* of one block */ + __private const int input_height,/* of one block */ + __private const int output_width, + __private const int output_height) { -// int out_w1 = out_w + global_size_dim1; -// int out_w2 = out_w + global_size_dim1 * 2; -// int out_w3 = out_w + global_size_dim1 * 3; + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); - const sampler_t sampler = - CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); - int2 stride_xy = (int2)(stride, stride); + if (out_c >= global_size_dim0 || + out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + const int filter_n0 = 4 * out_c + 0; + const int filter_n1 = 4 * out_c + 1; + const int filter_n2 = 4 * out_c + 2; + const int filter_n3 = 4 * out_c + 3; - int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); - int2 in_pos_in_one_block0 = - ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; - int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); - int2 in_pos_in_one_block1 = - ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh; - int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); - int2 in_pos_in_one_block2 = - ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); - int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); - int2 in_pos_in_one_block3 = - ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; - int2 ouput_pos_in_one_block4 = (int2)(out_w4, out_nh); - int2 in_pos_in_one_block4 = - ouput_pos_in_one_block4 * stride_xy + (int2)(offset, offset); - - int2 ouput_pos_in_one_block5 = (int2)(out_w5, out_nh); - int2 in_pos_in_one_block5 = - ouput_pos_in_one_block5 * stride_xy + (int2)(offset, offset); - - int2 ouput_pos_in_one_block6 = (int2)(out_w6, out_nh); - int2 in_pos_in_one_block6 = - ouput_pos_in_one_block6 * stride_xy + (int2)(offset, offset); - - int2 ouput_pos_in_one_block7 = (int2)(out_w7, out_nh); - int2 in_pos_in_one_block7 = - ouput_pos_in_one_block7 * stride_xy + (int2)(offset, offset); - -#ifdef BIASE - half4 output0 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output1 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output2 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output3 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output4 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output5 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output6 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output7 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output0 = 0.0f; -// half4 output1 = 0.0f; -// half4 output2 = 0.0f; -// half4 output3 = 0.0f; + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; +#ifdef BIASE_CH + half4 output = read_imageh(bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + half4 output = read_imageh(bias, sampler, output_pos); #else - half4 output0 = 0.0f; - half4 output1 = 0.0f; - half4 output2 = 0.0f; - half4 output3 = 0.0f; - half4 output4 = 0.0f; - half4 output5 = 0.0f; - half4 output6 = 0.0f; - half4 output7 = 0.0f; + half4 output = 0.0f; #endif - for (int i = 0; i < input_c; ++i) { - // ------------0--------------- - int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, in_pos_in_one_block0.y); - half4 input0 = read_imageh(input_image, sampler, pos_in); - - half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 0)); - half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 1)); - half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2)); - half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3)); - output0 = mad(input0.x, weight0, output0); - output0 = mad(input0.y, weight1, output0); - output0 = mad(input0.z, weight2, output0); - output0 = mad(input0.w, weight3, output0); + half4 input; + half4 filter[4]; + int2 filter_pos0; + int2 filter_pos1; + int2 filter_pos2; + int2 filter_pos3; + for (int i = 0; i < input_c; ++i) { + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + for(int j = 0; j < 7; j++){ + for(int k = 0; k < 7; k++){ + input = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + (j - 3) * dilation, pos_in.y + (k - 3) * dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + (j - 3) * dilation < 0 || in_pos_in_one_block.y + (k - 3) * dilation < 0 || in_pos_in_one_block.x + (j - 3) * dilation >= input_width || in_pos_in_one_block.y + (k - 3) * dilation >= input_height) << 15)); + int filter_h = k; + int filter_w = j; + int filter_c = i; - // -------------1-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, in_pos_in_one_block1.y); - half4 input1 = read_imageh(input_image, sampler, pos_in); - // - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); + filter_pos0.x = filter_c * 7 + filter_w; + filter_pos0.y = filter_n0 * 7 + filter_h; - output1 = mad(input1.x, weight0, output1); - output1 = mad(input1.y, weight1, output1); - output1 = mad(input1.z, weight2, output1); - output1 = mad(input1.w, weight3, output1); + filter_pos1.x = filter_c * 7 + filter_w; + filter_pos1.y = filter_n1 * 7 + filter_h; - // -------------2-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, in_pos_in_one_block2.y); - half4 input2 = read_imageh(input_image, sampler, pos_in); + filter_pos2.x = filter_c * 7 + filter_w; + filter_pos2.y = filter_n2 * 7 + filter_h; - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); + filter_pos3.x = filter_c * 7 + filter_w; + filter_pos3.y = filter_n3 * 7 + filter_h; - output2 = mad(input2.x, weight0, output2); - output2 = mad(input2.y, weight1, output2); - output2 = mad(input2.z, weight2, output2); - output2 = mad(input2.w, weight3, output2); + filter[0] = read_imageh(filter_image, sampler, filter_pos0); + filter[1] = read_imageh(filter_image, sampler, filter_pos1); + filter[2] = read_imageh(filter_image, sampler, filter_pos2); + filter[3] = read_imageh(filter_image, sampler, filter_pos3); - // -------------3-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, in_pos_in_one_block3.y); - half4 input3 = read_imageh(input_image, sampler, pos_in); + output.x += dot(input, filter[0]); + output.y += dot(input, filter[1]); + output.z += dot(input, filter[2]); + output.w += dot(input, filter[3]); + } + } + } - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); +#ifdef BATCH_NORM + output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0)); +#endif - output3 = mad(input3.x, weight0, output3); - output3 = mad(input3.y, weight1, output3); - output3 = mad(input3.z, weight2, output3); - output3 = mad(input3.w, weight3, output3); +#ifdef RELU + output = activation(output); +#endif + write_imageh(output_image, output_pos, output); +} - // -------------4-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block4.x, in_pos_in_one_block4.y); - half4 input4 = read_imageh(input_image, sampler, pos_in); +__kernel void conv_7x7Pt1x2(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif - output4 = mad(input4.x, weight0, output4); - output4 = mad(input4.y, weight1, output4); - output4 = mad(input4.z, weight2, output4); - output4 = mad(input4.w, weight3, output4); +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width,/* of one block */ + __private const int input_height,/* of one block */ + __private const int output_width, + __private const int output_height) { + const int out_c = get_global_id(0); + const int out_w1 = get_global_id(1); + const int out_nh = get_global_id(2); - // -------------5-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block5.x, in_pos_in_one_block5.y); - half4 input5 = read_imageh(input_image, sampler, pos_in); + if (out_c >= global_size_dim0 || + out_w1 >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + const int out_w = out_w1 * 2; - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); + int2 output_pos = (int2)(out_c * output_width + out_w, out_nh); - output5= mad(input5.x, weight0, output5); - output5 = mad(input5.y, weight1, output5); - output5 = mad(input5.z, weight2, output5); - output5 = mad(input5.w, weight3, output5); + const int filter_n0 = 4 * out_c + 0; + const int filter_n1 = 4 * out_c + 1; + const int filter_n2 = 4 * out_c + 2; + const int filter_n3 = 4 * out_c + 3; + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; - // -------------6-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block6.x, in_pos_in_one_block6.y); - half4 input6 = read_imageh(input_image, sampler, pos_in); - - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); - - output6 = mad(input6.x, weight0, output6); - output6 = mad(input6.y, weight1, output6); - output6 = mad(input6.z, weight2, output6); - output6 = mad(input6.w, weight3, output6); - - - // -------------7-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block7.x, in_pos_in_one_block7.y); - half4 input7 = read_imageh(input_image, sampler, pos_in); - - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); - - output7 = mad(input7.x, weight0, output7); - output7 = mad(input7.y, weight1, output7); - output7 = mad(input7.z, weight2, output7); - output7 = mad(input7.w, weight3, output7); - } - -#ifdef BATCH_NORM - output0 = output0 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output1 = output1 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output2 = output2 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output3 = output3 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output4 = output4 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output5 = output5 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output6 = output6 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output7 = output7 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - -#endif - -#ifdef RELU - output0 = activation(output0); - output1 = activation(output1); - output2 = activation(output2); - output3 = activation(output3); - output4 = activation(output4); - output5 = activation(output5); - output6 = activation(output6); - output7 = activation(output7); -#endif - int outpos_main = mul24(out_c , old_w); - int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); - - if (out_w0 < old_w) { - write_imageh(output_image, output_pos0, output0); - } - int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); - if (out_w1 < old_w){ - write_imageh(output_image, output_pos1, output1); - } - - int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); - if (out_w2 < old_w){ - write_imageh(output_image, output_pos2, output2); - } - - int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); - if (out_w3 < old_w){ - write_imageh(output_image, output_pos3, output3); - } - - int2 output_pos4 = (int2)(outpos_main + out_w4, out_nh); - if (out_w4 < old_w){ - write_imageh(output_image, output_pos4, output4); - } - - int2 output_pos5 = (int2)(outpos_main + out_w5, out_nh); - if (out_w5 < old_w){ - write_imageh(output_image, output_pos5, output5); - - } - int2 output_pos6 = (int2)(outpos_main + out_w6, out_nh); - if (out_w6 < old_w){ - write_imageh(output_image, output_pos6, output6); - } - - int2 output_pos7 = (int2)(outpos_main + out_w7, out_nh); - if (out_w7 < old_w){ - write_imageh(output_image, output_pos7, output7); - } - -} -__kernel void conv_1x1_spl3( - __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2, __read_only image2d_t input_image, - __read_only image2d_t filter, -#ifdef BIASE - __read_only image2d_t bias, -#endif -#ifdef BATCH_NORM - __read_only image2d_t new_scale, __read_only image2d_t new_biase, -#endif - __write_only image2d_t output_image, __private const int stride, - __private const int offset, __private const int input_c, - __private const int dilation, - __private const int input_width, /* of one block */ - __private const int input_height, /* of one block */ - __private const int output_width, - __private const int output_height, - __private const int old_w -) { - - const int out_c = get_global_id(0); - const int out_w = get_global_id(1); - const int out_nh = get_global_id(2); - - int out_w0 = out_w; - int out_w1 = out_w + global_size_dim1; - int out_w2 = out_w + global_size_dim1 * 2; -// int out_w3 = out_w + global_size_dim1 * 3; -// int out_w4 = out_w + global_size_dim1 * 4; -// int out_w5 = out_w + global_size_dim1 * 5; -// int out_w6 = out_w + global_size_dim1 * 6; -// int out_w7 = out_w + global_size_dim1 * 7; - -// int out_w1 = out_w + global_size_dim1; -// int out_w2 = out_w + global_size_dim1 * 2; -// int out_w3 = out_w + global_size_dim1 * 3; - - const sampler_t sampler = - CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - - int2 stride_xy = (int2)(stride, stride); - - int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); - int2 in_pos_in_one_block0 = - ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); - - int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); - int2 in_pos_in_one_block1 = - ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); - -// int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); -// int2 in_pos_in_one_block2 = -// ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); -// -// int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); -// int2 in_pos_in_one_block3 = -// ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); -// -// int2 ouput_pos_in_one_block4 = (int2)(out_w4, out_nh); -// int2 in_pos_in_one_block4 = -// ouput_pos_in_one_block4 * stride_xy + (int2)(offset, offset); -// -// int2 ouput_pos_in_one_block5 = (int2)(out_w5, out_nh); -// int2 in_pos_in_one_block5 = -// ouput_pos_in_one_block5 * stride_xy + (int2)(offset, offset); -// -// int2 ouput_pos_in_one_block6 = (int2)(out_w6, out_nh); -// int2 in_pos_in_one_block6 = -// ouput_pos_in_one_block6 * stride_xy + (int2)(offset, offset); -// -// int2 ouput_pos_in_one_block7 = (int2)(out_w7, out_nh); -// int2 in_pos_in_one_block7 = -// ouput_pos_in_one_block7 * stride_xy + (int2)(offset, offset); - -#ifdef BIASE - half4 output0 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output1 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output2 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output3 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output4 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output5 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output6 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output7 = read_imageh(bias, sampler, (int2)(out_c, 0)); -// half4 output0 = 0.0f; -// half4 output1 = 0.0f; -// half4 output2 = 0.0f; -// half4 output3 = 0.0f; - -#else - half4 output0 = 0.0f; - half4 output1 = 0.0f; -// half4 output2 = 0.0f; -// half4 output3 = 0.0f; -// half4 output4 = 0.0f; -// half4 output5 = 0.0f; -// half4 output6 = 0.0f; -// half4 output7 = 0.0f; -#endif - for (int i = 0; i < input_c; ++i) { - // ------------0--------------- - int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, in_pos_in_one_block0.y); - half4 input0 = read_imageh(input_image, sampler, pos_in); - - half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 0)); - half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 1)); - half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2)); - half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3)); - - output0 = mad(input0.x, weight0, output0); - output0 = mad(input0.y, weight1, output0); - output0 = mad(input0.z, weight2, output0); - output0 = mad(input0.w, weight3, output0); - - // -------------1-------------- - pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, in_pos_in_one_block1.y); - half4 input1 = read_imageh(input_image, sampler, pos_in); - // - // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + - // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 - // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * - // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i - // * 4 + 3)); - - output1 = mad(input1.x, weight0, output1); - output1 = mad(input1.y, weight1, output1); - output1 = mad(input1.z, weight2, output1); - output1 = mad(input1.w, weight3, output1); -// -// // -------------2-------------- -// pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, in_pos_in_one_block2.y); -// half4 input2 = read_imageh(input_image, sampler, pos_in); -// -// // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + -// // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 -// // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * -// // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i -// // * 4 + 3)); -// -// output2 = mad(input2.x, weight0, output2); -// output2 = mad(input2.y, weight1, output2); -// output2 = mad(input2.z, weight2, output2); -// output2 = mad(input2.w, weight3, output2); -// -// // -------------3-------------- -// pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, in_pos_in_one_block3.y); -// half4 input3 = read_imageh(input_image, sampler, pos_in); -// -// // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + -// // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 -// // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * -// // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i -// // * 4 + 3)); -// -// output3 = mad(input3.x, weight0, output3); -// output3 = mad(input3.y, weight1, output3); -// output3 = mad(input3.z, weight2, output3); -// output3 = mad(input3.w, weight3, output3); -// -// -// // -------------4-------------- -// pos_in = (int2)(i * input_width + in_pos_in_one_block4.x, in_pos_in_one_block4.y); -// half4 input4 = read_imageh(input_image, sampler, pos_in); -// -// // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + -// // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 -// // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * -// // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i -// // * 4 + 3)); -// -// output4 = mad(input4.x, weight0, output4); -// output4 = mad(input4.y, weight1, output4); -// output4 = mad(input4.z, weight2, output4); -// output4 = mad(input4.w, weight3, output4); -// -// -// -// // -------------5-------------- -// pos_in = (int2)(i * input_width + in_pos_in_one_block5.x, in_pos_in_one_block5.y); -// half4 input5 = read_imageh(input_image, sampler, pos_in); -// -// // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + -// // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 -// // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * -// // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i -// // * 4 + 3)); -// -// output5= mad(input5.x, weight0, output5); -// output5 = mad(input5.y, weight1, output5); -// output5 = mad(input5.z, weight2, output5); -// output5 = mad(input5.w, weight3, output5); -// -// -// // -------------6-------------- -// pos_in = (int2)(i * input_width + in_pos_in_one_block6.x, in_pos_in_one_block6.y); -// half4 input6 = read_imageh(input_image, sampler, pos_in); -// -// // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + -// // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 -// // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * -// // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i -// // * 4 + 3)); -// -// output6 = mad(input6.x, weight0, output6); -// output6 = mad(input6.y, weight1, output6); -// output6 = mad(input6.z, weight2, output6); -// output6 = mad(input6.w, weight3, output6); -// -// -// // -------------7-------------- -// pos_in = (int2)(i * input_width + in_pos_in_one_block7.x, in_pos_in_one_block7.y); -// half4 input7 = read_imageh(input_image, sampler, pos_in); -// -// // half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + -// // 0)); half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 -// // + 1)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * -// // 4 + 2)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i -// // * 4 + 3)); -// -// output7 = mad(input7.x, weight0, output7); -// output7 = mad(input7.y, weight1, output7); -// output7 = mad(input7.z, weight2, output7); -// output7 = mad(input7.w, weight3, output7); - } - -#ifdef BATCH_NORM - output0 = output0 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); - - output1 = output1 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + - read_imageh(new_biase, sampler, (int2)(out_c, 0)); -// -// output2 = output2 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + -// read_imageh(new_biase, sampler, (int2)(out_c, 0)); -// -// output3 = output3 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + -// read_imageh(new_biase, sampler, (int2)(out_c, 0)); -// -// output4 = output4 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + -// read_imageh(new_biase, sampler, (int2)(out_c, 0)); -// -// output5 = output5 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + -// read_imageh(new_biase, sampler, (int2)(out_c, 0)); -// -// output6 = output6 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + -// read_imageh(new_biase, sampler, (int2)(out_c, 0)); -// -// output7 = output7 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + -// read_imageh(new_biase, sampler, (int2)(out_c, 0)); - -#endif - -#ifdef RELU - output0 = activation(output0); - output1 = activation(output1); -// output2 = activation(output2); -// output3 = activation(output3); -// output4 = activation(output4); -// output5 = activation(output5); -// output6 = activation(output6); -// output7 = activation(output7); -#endif - int outpos_main = mul24(out_c , old_w); - int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); - - if (out_w0 < old_w) { - write_imageh(output_image, output_pos0, output0); - } - int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); - if (out_w1 < old_w){ - write_imageh(output_image, output_pos1, output1); - } -// -// int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); -// if (out_w2 < old_w){ -// write_imageh(output_image, output_pos2, output2); -// } -// -// int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); -// if (out_w3 < old_w){ -// write_imageh(output_image, output_pos3, output3); -// } -// -// int2 output_pos4 = (int2)(outpos_main + out_w4, out_nh); -// if (out_w4 < old_w){ -// write_imageh(output_image, output_pos4, output4); -// } -// -// int2 output_pos5 = (int2)(outpos_main + out_w5, out_nh); -// if (out_w5 < old_w){ -// write_imageh(output_image, output_pos5, output5); -// -// } -// int2 output_pos6 = (int2)(outpos_main + out_w6, out_nh); -// if (out_w6 < old_w){ -// write_imageh(output_image, output_pos6, output6); -// } -// -// int2 output_pos7 = (int2)(outpos_main + out_w7, out_nh); -// if (out_w7 < old_w){ -// write_imageh(output_image, output_pos7, output7); -// } - -} -//__kernel void conv_1x1_c( -// __private const int global_size_dim0, -// __private const int global_size_dim1, -// __private const int global_size_dim2, -// __read_only image2d_t input_image, -// __read_only image2d_t filter, -//#ifdef BIASE -// __read_only image2d_t bias, -//#endif -//#ifdef BATCH_NORM -// __read_only image2d_t new_scale, -// __read_only image2d_t new_biase, -//#endif -// __write_only image2d_t output_image, -// __private const int stride, -// __private const int offset, -// __private const int input_c, -// __private const int dilation, -// __private const int input_width, /* of one block */ -// __private const int input_height, /* of one block */ -// __private const int output_width, -// __private const int output_height, -// __private const int old_w) { -// -// const int out_c = get_global_id(0); -// const int out_w = get_global_id(1); -// const int out_nh = get_global_id(2); -// -// const sampler_t sampler = -// CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -// const int2 stride_xy = (int2)(stride, stride); -// -// for (int i = 0; i < input_c; ++i) { -// half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 0)); -// half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 1)); -// half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2)); -// half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3)); -// -//#pragma unroll -// for (int j = 0; j < 4; ++j) { -// int out_w0 = out_w + global_size_dim1 * j; -// int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); -// int2 in_pos_in_one_block0 = ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); -// -//#ifdef BIASE -// half4 output0 = read_imageh(bias, sampler, (int2)(out_c, 0)); -//#else -// half4 output0 = 0.0f; -//#endif -// int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, in_pos_in_one_block0.y); -// half4 input0 = read_imageh(input_image, sampler, pos_in); -// -// output0 = mad(input0.x, weight0, output0); -// output0 = mad(input0.y, weight1, output0); -// output0 = mad(input0.z, weight2, output0); -// output0 = mad(input0.w, weight3, output0); -// -//#ifdef BATCH_NORM -// output0 = output0 * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0)); -//#endif -// -//#ifdef RELU -// output0 = activation(output0); -//#endif -// int outpos_main = mul24(out_c, old_w); -// int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); -// -// if (out_w0 < old_w) { -// write_imageh(output_image, output_pos0, output0); -// } -// } -// } -//} + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh; -/* -__kernel void conv_1x1_4(__private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2, - __read_only image2d_t input_image, - __read_only image2d_t filter, -#ifdef BIASE - __read_only image2d_t bias, -#endif -#ifdef BATCH_NORM - __read_only image2d_t new_scale, - __read_only image2d_t new_biase, -#endif - __write_only image2d_t output_image, - __private const int stride, - __private const int offset, - __private const int input_c, - __private const int dilation, - __private const int input_width, - __private const int input_height, - __private const int output_width, - __private const int output_height) { - const int out_c = get_global_id(0) * 4; - const int out_w = get_global_id(1); - const int out_nh = get_global_id(2); - - const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | - CLK_ADDRESS_CLAMP | - CLK_FILTER_NEAREST; + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; - int2 stride_xy = (int2)(stride, stride); - int2 ouput_pos_in_one_block = (int2)(out_w, out_nh); - int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset); + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; -#ifdef BIASE - half4 output0 = read_imageh(bias, sampler, (int2)(out_c, 0)); - half4 output1 = read_imageh(bias, sampler, (int2)(out_c + 1, 0)); - half4 output2 = read_imageh(bias, sampler, (int2)(out_c + 2, 0)); - half4 output3 = read_imageh(bias, sampler, (int2)(out_c + 3, 0)); -#else half4 output0 = 0.0f; half4 output1 = 0.0f; - half4 output2 = 0.0f; - half4 output3 = 0.0f; +#ifdef BIASE_CH + output0 = read_imageh(bias, sampler, (int2)(out_c, 0)); + output1 = output0; +#elif defined(BIASE_ELE) + output0 = read_imageh(bias, sampler, output_pos); + output1 = read_imageh(bias, sampler, (int2)(output_pos.x + 1, output_pos.y)); +#else + output0 = 0.0f; + output1 = 0.0f; #endif + half4 input[8]; + half4 filter0[4]; + half4 filter1[4]; + half4 filter2[4]; + half4 filter3[4]; + int2 filter_pos0; + int2 filter_pos1; + int2 filter_pos2; + int2 filter_pos3; for (int i = 0; i < input_c; ++i) { int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); - half4 input = read_imageh(input_image, sampler, pos_in); - - half4 weight0_0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 0)); - half4 weight0_1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 1)); - half4 weight0_2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2)); - half4 weight0_3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3)); - - output0 = mad(input.x, weight0_0, output0); - output0 = mad(input.y, weight0_1, output0); - output0 = mad(input.z, weight0_2, output0); - output0 = mad(input.w, weight0_3, output0); - - half4 weight1_0 = read_imageh(filter, sampler, (int2)(out_c + 1, i * 4 + 0)); - half4 weight1_1 = read_imageh(filter, sampler, (int2)(out_c + 1, i * 4 + 1)); - half4 weight1_2 = read_imageh(filter, sampler, (int2)(out_c + 1, i * 4 + 2)); - half4 weight1_3 = read_imageh(filter, sampler, (int2)(out_c + 1, i * 4 + 3)); - - output1 = mad(input.x, weight1_0, output1); - output1 = mad(input.y, weight1_1, output1); - output1 = mad(input.z, weight1_2, output1); - output1 = mad(input.w, weight1_3, output1); - - half4 weight2_0 = read_imageh(filter, sampler, (int2)(out_c + 2, i * 4 + 0)); - half4 weight2_1 = read_imageh(filter, sampler, (int2)(out_c + 2, i * 4 + 1)); - half4 weight2_2 = read_imageh(filter, sampler, (int2)(out_c + 2, i * 4 + 2)); - half4 weight2_3 = read_imageh(filter, sampler, (int2)(out_c + 2, i * 4 + 3)); - - output2 = mad(input.x, weight2_0, output2); - output2 = mad(input.y, weight2_1, output2); - output2 = mad(input.z, weight2_2, output2); - output2 = mad(input.w, weight2_3, output2); - - half4 weight3_0 = read_imageh(filter, sampler, (int2)(out_c + 3, i * 4 + 0)); - half4 weight3_1 = read_imageh(filter, sampler, (int2)(out_c + 3, i * 4 + 1)); - half4 weight3_2 = read_imageh(filter, sampler, (int2)(out_c + 3, i * 4 + 2)); - half4 weight3_3 = read_imageh(filter, sampler, (int2)(out_c + 3, i * 4 + 3)); - - output3 = mad(input.x, weight3_0, output3); - output3 = mad(input.y, weight3_1, output3); - output3 = mad(input.z, weight3_2, output3); - output3 = mad(input.w, weight3_3, output3); - + for(int k = 0; k < 7; k++){ + for (int j = 0; j < 8; j++) { + input[j] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + (j - 3) * dilation, pos_in.y + (k - 3) * dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + (j - 3) * dilation < 0 || in_pos_in_one_block.y + (k - 3) * dilation < 0 || in_pos_in_one_block.x + (j - 3) * dilation >= input_width || in_pos_in_one_block.y + (k - 3) * dilation >= input_height) << 15)); + + int filter_h = k; + int filter_w = j; + int filter_c = i; + + if (j < 7) { + filter_pos0.x = filter_c * 7 + filter_w; + filter_pos0.y = filter_n0 * 7 + filter_h; + + filter_pos1.x = filter_c * 7 + filter_w; + filter_pos1.y = filter_n1 * 7 + filter_h; + + filter_pos2.x = filter_c * 7 + filter_w; + filter_pos2.y = filter_n2 * 7 + filter_h; + + filter_pos3.x = filter_c * 7 + filter_w; + filter_pos3.y = filter_n3 * 7 + filter_h; + + filter0[0] = read_imageh(filter_image, sampler, filter_pos0); + filter0[1] = read_imageh(filter_image, sampler, filter_pos1); + filter0[2] = read_imageh(filter_image, sampler, filter_pos2); + filter0[3] = read_imageh(filter_image, sampler, filter_pos3); + + output0.x += dot(input[j], filter0[0]); + output0.y += dot(input[j], filter0[1]); + output0.z += dot(input[j], filter0[2]); + output0.w += dot(input[j], filter0[3]); + } + + if (j > 0) { + output1.x += dot(input[j], filter1[0]); + output1.y += dot(input[j], filter1[1]); + output1.z += dot(input[j], filter1[2]); + output1.w += dot(input[j], filter1[3]); + } + + filter1[0] = filter0[0]; + filter1[1] = filter0[1]; + filter1[2] = filter0[2]; + filter1[3] = filter0[3]; + } + } } #ifdef BATCH_NORM - output0 = output0 * read_imageh(new_scale, sampler, (int2)(out_c + 0, 0)) + read_imageh(new_biase, sampler, (int2)(out_c + 0, 0)); - - output1 = output1 * read_imageh(new_scale, sampler, (int2)(out_c + 1, 0)) + read_imageh(new_biase, sampler, (int2)(out_c + 1, 0)); - - output2 = output2 * read_imageh(new_scale, sampler, (int2)(out_c + 2, 0)) + read_imageh(new_biase, sampler, (int2)(out_c + 2, 0)); - - output3 = output3 * read_imageh(new_scale, sampler, (int2)(out_c + 3, 0)) + read_imageh(new_biase, sampler, (int2)(out_c + 3, 0)); - + half4 s = read_imageh(new_scale, sampler, (int2)(out_c, 0)); + half4 b = read_imageh(new_biase, sampler, (int2)(out_c, 0)); + output0 = output0 * s + b; + output1 = output1 * s + b; #endif #ifdef RELU - output0 = activation(output0); - output1 = activation(output1); - output2 = activation(output2); - output3 = activation(output3); + output0 = activation(output0); + output1 = activation(output1); #endif - - int2 output_pos0 = (int2)(out_c * global_size_dim1 + out_w, out_nh); - write_imageh(output_image, output_pos0, output0); - - - int2 output_pos1 = (int2)((out_c + 1) * global_size_dim1 + out_w, out_nh); - write_imageh(output_image, output_pos1, output1); - - - int2 output_pos2 = (int2)((out_c + 2) * global_size_dim1 + out_w, out_nh); - write_imageh(output_image, output_pos2, output2); - - - int2 output_pos3 = (int2)((out_c + 3) * global_size_dim1 + out_w, out_nh); - write_imageh(output_image, output_pos3, output3); + write_imageh(output_image, output_pos, output0); + if ((output_pos.x + 1) % output_width != 0) { + write_imageh(output_image, (int2)(output_pos.x + 1, output_pos.y), output1); + } } -*/ - -__kernel void conv_7x7(__private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2, - __read_only image2d_t input_image, - __read_only image2d_t filter_image, - +// dilation == 1 +__kernel void conv_7x7spl(__private const int item_ch, + __private const int item_w, + __private const int item_h, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, #if defined(BIASE_CH) || defined(BIASE_ELE) - __read_only image2d_t bias, + __read_only image2d_t bias, #endif - #ifdef BATCH_NORM - __read_only image2d_t new_scale, +__read_only image2d_t new_scale, __read_only image2d_t new_biase, #endif - - __write_only image2d_t output_image, - __private const int stride, - __private const int offset, - __private const int input_c, - __private const int dilation, - __private const int input_width,/* of one block */ - __private const int input_height,/* of one block */ - __private const int output_width, - __private const int output_height) { - - const int out_c = get_global_id(0); - const int out_w = get_global_id(1); - const int out_nh = get_global_id(2); - - int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); - - if (out_c >= global_size_dim0 || - out_w >= global_size_dim1 || - out_nh >= global_size_dim2) { - return; - } - const int filter_n0 = 4 * out_c + 0; - const int filter_n1 = 4 * out_c + 1; - const int filter_n2 = 4 * out_c + 2; - const int filter_n3 = 4 * out_c + 3; - - int2 stride_xy; - stride_xy.x = stride; - stride_xy.y = stride; - - int2 ouput_pos_in_one_block; - ouput_pos_in_one_block.x = out_w; - ouput_pos_in_one_block.y = out_nh; - + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int in_ch, + __private const int in_w, + __private const int in_h, + __private const int out_w, + __private const int out_h) { const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + // filter + const int filter_w = 7; + const int filter_h = 7; - int2 in_pos_in_one_block; - in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; - in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + // item_id + const int item_ch_id = get_global_id(0); + const int item_w_id = get_global_id(1); + const int item_h_id = get_global_id(2); + + // out_width_id_per_blk and out_batch_id + int out_batch_id = item_h_id / in_h; + int out_w_base_id = item_ch_id * out_w; + int out_w_id0 = item_w_id; + int out_w_id1 = out_w_id0 + item_w; + int out_w_id2 = out_w_id1 + item_w; + int out_w_id3 = out_w_id2 + item_w; + int out_w_id4 = out_w_id3 + item_w; + + // in_width_id_per_blk and in_height_id_per_batch + int in_h_id = (item_h_id % out_h) * stride - pad; + int in_w_id0 = item_w_id * stride - pad; + int in_w_id1 = in_w_id0 + item_w * stride; + int in_w_id2 = in_w_id1 + item_w * stride; + int in_w_id3 = in_w_id2 + item_w * stride; + int in_w_id4 = in_w_id3 + item_w * stride; #ifdef BIASE_CH - half4 output = read_imageh(bias, sampler, (int2)(out_c, 0)); + + half4 output[5]; + output[0] = read_imageh(bias, sampler, (int2)(item_ch_id, 0)); + output[1] = output[0]; + output[2] = output[0]; + output[3] = output[0]; + output[4] = output[0]; + #elif defined(BIASE_ELE) - half4 output = read_imageh(bias, sampler, output_pos); + + half4 output[5]; + output[0] = read_imageh(bias, sampler, (int2)(out_w_base_id + out_w_id0, item_h_id)); + if (out_w_id1 < out_w) { + output[1] = read_imageh(bias, sampler, (int2)(out_w_base_id + out_w_id1, item_h_id)); + } + if (out_w_id2 < out_w) { + output[2] = read_imageh(bias, sampler, (int2)(out_w_base_id + out_w_id2, item_h_id)); + } + if (out_w_id3 < out_w) { + output[3] = read_imageh(bias, sampler, (int2)(out_w_base_id + out_w_id3, item_h_id)); + } + if (out_w_id4 < out_w) { + output[4] = read_imageh(bias, sampler, (int2)(out_w_base_id + out_w_id4, item_h_id)); + } #else - half4 output = 0.0f; + half4 output[5] = {0.0f}; #endif - half4 input; - half4 filter[4]; - int2 filter_pos0; - int2 filter_pos1; - int2 filter_pos2; - int2 filter_pos3; - for (int i = 0; i < input_c; ++i) { - int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); - for(int j = 0; j < 7; j++){ - for(int k = 0; k < 7; k++){ - input = select(read_imageh(input_image, sampler, - (int2)(pos_in.x + (j - 3) * dilation, pos_in.y + (k - 3) * dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x + (j - 3) * dilation < 0 || in_pos_in_one_block.y + (k - 3) * dilation < 0 || in_pos_in_one_block.x + (j - 3) * dilation >= input_width || in_pos_in_one_block.y + (k - 3) * dilation >= input_height) << 15)); - int filter_h = k; - int filter_w = j; - int filter_c = i; + half4 filter[4] = {0.0f}; + half4 filter_trans[4] = {0.0f}; + half4 input[5] = {0.0f}; - filter_pos0.x = filter_c * 7 + filter_w; - filter_pos0.y = filter_n0 * 7 + filter_h; + int filter_h_val0 = item_ch_id * 4 * filter_h; + int filter_h_val1 = filter_h_val0 + filter_h; + int filter_h_val2 = filter_h_val1 + filter_h; + int filter_h_val3 = filter_h_val2 + filter_h; - filter_pos1.x = filter_c * 7 + filter_w; - filter_pos1.y = filter_n1 * 7 + filter_h; + for (int ch = 0; ch < (in_ch + 3) / 4; ch++) { + int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0; - filter_pos2.x = filter_c * 7 + filter_w; - filter_pos2.y = filter_n2 * 7 + filter_h; + const int in_w_base_id = mul24(ch, in_w); - filter_pos3.x = filter_c * 7 + filter_w; - filter_pos3.y = filter_n3 * 7 + filter_h; + int filter_w_val = ch * filter_w; - filter[0] = read_imageh(filter_image, sampler, filter_pos0); - filter[1] = read_imageh(filter_image, sampler, filter_pos1); - filter[2] = read_imageh(filter_image, sampler, filter_pos2); - filter[3] = read_imageh(filter_image, sampler, filter_pos3); + for (int h = 0; h < filter_h; h++) { - output.x += dot(input, filter[0]); - output.y += dot(input, filter[1]); - output.z += dot(input, filter[2]); - output.w += dot(input, filter[3]); - } + int in_h_val = select(out_batch_id * in_h + in_h_id + h, -1, + (out_batch_id * in_h + in_h_id + h < 0 || out_batch_id * in_h + in_h_id + h >= in_h)); + + for (int w = 0; w < filter_w; w++) { + + int in_w_val0 = select(in_w_base_id + in_w_id0 + w, -1, + (in_w_id0 + w < 0 || in_w_id0 + w >= in_w)); + int in_w_val1 = select(in_w_base_id + in_w_id1 + w, -1, + (in_w_id1 + w < 0 || in_w_id1 + w >= in_w)); + int in_w_val2 = select(in_w_base_id + in_w_id2 + w, -1, + (in_w_id2 + w < 0 || in_w_id2 + w >= in_w)); + int in_w_val3 = select(in_w_base_id + in_w_id3 + w, -1, + (in_w_id3 + w < 0 || in_w_id3 + w >= in_w)); + int in_w_val4 = select(in_w_base_id + in_w_id4 + w, -1, + (in_w_id4 + w < 0 || in_w_id4 + w >= in_w)); + + filter[0] = read_imageh(filter_image, sampler,(int2)(filter_w_val + w,filter_h_val0 + h)); // in_ch:0-3,out_ch:0 + filter[1] = read_imageh(filter_image, sampler,(int2)(filter_w_val + w,filter_h_val1 + h)); // in_ch:0-3,out_ch:1 + filter[2] = read_imageh(filter_image, sampler,(int2)(filter_w_val + w,filter_h_val2 + h)); // in_ch:0-3,out_ch:2 + filter[3] = read_imageh(filter_image, sampler,(int2)(filter_w_val + w,filter_h_val3 + h)); // in_ch:0-3,out_ch:3 + + filter_trans[0] = (half4)(filter[0].x, filter[1].x, filter[2].x, filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (half4)(filter[0].y, filter[1].y, filter[2].y, filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (half4)(filter[0].z, filter[1].z, filter[2].z, filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (half4)(filter[0].w, filter[1].w, filter[2].w, filter[3].w); // in_ch:3,out_ch:0-3 + + input[0] = read_imageh(input_image, sampler, (int2)(in_w_val0, in_h_val)); + input[1] = read_imageh(input_image, sampler, (int2)(in_w_val1, in_h_val)); + input[2] = read_imageh(input_image, sampler, (int2)(in_w_val2, in_h_val)); + input[3] = read_imageh(input_image, sampler, (int2)(in_w_val3, in_h_val)); + input[4] = read_imageh(input_image, sampler, (int2)(in_w_val4, in_h_val)); + + output[0] = mad(input[0].x, filter_trans[0], output[0]); + output[1] = mad(input[1].x, filter_trans[0], output[1]); + output[2] = mad(input[2].x, filter_trans[0], output[2]); + output[3] = mad(input[3].x, filter_trans[0], output[3]); + output[4] = mad(input[4].x, filter_trans[0], output[4]); + + if (ch_surplus < 3) { + output[0] = mad(input[0].y, filter_trans[1], output[0]); + output[1] = mad(input[1].y, filter_trans[1], output[1]); + output[2] = mad(input[2].y, filter_trans[1], output[2]); + output[3] = mad(input[3].y, filter_trans[1], output[3]); + output[4] = mad(input[4].y, filter_trans[1], output[4]); + } + if (ch_surplus < 2) { + output[0] = mad(input[0].z, filter_trans[2], output[0]); + output[1] = mad(input[1].z, filter_trans[2], output[1]); + output[2] = mad(input[2].z, filter_trans[2], output[2]); + output[3] = mad(input[3].z, filter_trans[2], output[3]); + output[4] = mad(input[4].z, filter_trans[2], output[4]); + } + if (ch_surplus < 1) { + output[0] = mad(input[0].w, filter_trans[3], output[0]); + output[1] = mad(input[1].w, filter_trans[3], output[1]); + output[2] = mad(input[2].w, filter_trans[3], output[2]); + output[3] = mad(input[3].w, filter_trans[3], output[3]); + output[4] = mad(input[4].w, filter_trans[3], output[4]); + } + } } } - #ifdef BATCH_NORM - output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0)); + half4 scale = read_imageh(new_scale, sampler, (int2)(item_ch_id, 0)); + half4 biase = read_imageh(new_biase, sampler, (int2)(item_ch_id, 0)); + output[0] = mad(scale, output[0], biase); + if (out_w_id1 < out_w) { + output[1] = mad(scale, output[1], biase); + } + if (out_w_id2 < out_w) { + output[2] = mad(scale, output[2], biase); + } + if (out_w_id3 < out_w) { + output[3] = mad(scale, output[3], biase); + } + if (out_w_id4 < out_w) { + output[4] = mad(scale, output[4], biase); + } #endif #ifdef RELU - output = activation(output); + output[0] = activation(output[0]); + output[1] = activation(output[1]); + output[2] = activation(output[2]); + output[3] = activation(output[3]); + output[4] = activation(output[4]); #endif - - write_imageh(output_image, output_pos, output); + write_imageh(output_image, (int2)(out_w_base_id + out_w_id0, item_h_id), output[0]); + if (out_w_id1 < out_w) { + write_imageh(output_image, (int2)(out_w_base_id + out_w_id1, item_h_id), output[1]); + } + if (out_w_id2 < out_w) { + write_imageh(output_image, (int2)(out_w_base_id + out_w_id2, item_h_id), output[2]); + } + if (out_w_id3 < out_w) { + write_imageh(output_image, (int2)(out_w_base_id + out_w_id3, item_h_id), output[3]); + } + if (out_w_id4 < out_w) { + write_imageh(output_image, (int2)(out_w_base_id + out_w_id4, item_h_id), output[4]); + } } __kernel void conv_5x5(__private const int global_size_dim0, @@ -2138,7 +1723,7 @@ __kernel void conv_5x5(__private const int global_size_dim0, const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); - + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); if (out_c >= global_size_dim0 || @@ -2258,7 +1843,7 @@ __kernel void convBNAdd_3x3(__private const int global_size_dim0, const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); - + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); if (out_c >= global_size_dim0 || @@ -2567,7 +2152,7 @@ __kernel void convBNAdd_1x1(__private const int global_size_dim0, const int out_nh = get_global_id(2); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); - + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; @@ -2675,7 +2260,7 @@ __kernel void convBNAdd_1x1_spl( int2 in_pos_in_one_block3 = ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); - + half4 output0 = 0.0f; half4 output1 = 0.0f; half4 output2 = 0.0f; @@ -2767,7 +2352,7 @@ __kernel void convBNAdd_1x1_spl( output2 += read_imageh(bias, sampler, output_pos2); output3 += read_imageh(bias, sampler, output_pos3); #endif - + #ifdef RELU output0 = activation(output0); output1 = activation(output1); diff --git a/mobile/src/operators/kernel/cl/cl_kernel/conv_transpose_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/conv_transpose_kernel.cl index e13f5debba48bbcfbfcdf0b77e14b8932929b509..96044b575e980cd1fcb4d2785c8adc4e83712196 100644 --- a/mobile/src/operators/kernel/cl/cl_kernel/conv_transpose_kernel.cl +++ b/mobile/src/operators/kernel/cl/cl_kernel/conv_transpose_kernel.cl @@ -14,7 +14,7 @@ limitations under the License. */ #include "cl_common.h" -__kernel void conv_transpose(__private const int input_c_block, +__kernel void conv_transpose_b(__private const int input_c_block, __private const int input_width,/* of one block */ __private const int input_height,/* of one block */ __private const int output_width, @@ -242,7 +242,312 @@ __read_only image2d_t new_scale, } +/* batch == 1 pad(output) == 1 out_w % 2 == 0 */ +__kernel void conv_transpose3x3s2(__private const int item_ch, + __private const int item_w, + __private const int item_h, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM +__read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int in_ch, + __private const int in_w, + __private const int in_h, + __private const int out_w, + __private const int out_h, + __private const int filter_w, + __private const int filter_h) { + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + // item_id + const int item_ch_id = get_global_id(0); + const int item_w_id = get_global_id(1); + const int item_h_id = get_global_id(2); + + // out_id + int out_w_id_per_ch_blk = item_w_id / 2 * 10 + item_w_id % 2; + int out_h_id = item_h_id; + int out_w_id0 = item_ch_id * out_w + out_w_id_per_ch_blk; + int out_w_id1 = out_w_id0 + 2; + int out_w_id2 = out_w_id1 + 2; + int out_w_id3 = out_w_id2 + 2; + int out_w_id4 = out_w_id3 + 2; + + // in_id + int in_w_id_per_ch_blk = (out_w_id_per_ch_blk) / 2; + in_w_id_per_ch_blk = in_w_id_per_ch_blk > 0 ? in_w_id_per_ch_blk : 0; + int in_h_id_per_batch = (out_h_id) / 2; + in_h_id_per_batch = in_h_id_per_batch > 0 ? in_h_id_per_batch : 0; + + // filter_id + int align_w_i = out_w_id_per_ch_blk - 1; + int align_w = align_w_i % 2 > 0 ? + align_w_i % 2 - 2 : align_w_i % 2; + int filter_w_id_per_ch_blk = out_w_id_per_ch_blk + 1 < 3 ? out_w_id_per_ch_blk + 1 : 2 + align_w; + + int align_h_i = out_h_id - 1; + int align_h = align_h_i % 2 > 0 ? + align_h_i % 2 - 2 : align_h_i % 2; + int filter_h_id_per_out_ch = out_h_id + 1 < 3 ? out_h_id + 1 : 2 + align_h; + +#ifdef BIASE_CH + half4 output[5]; + output[0] = read_imageh(bias, sampler, (int2)(item_ch_id, 0)); + output[1] = output[0]; + output[2] = output[0]; + output[3] = output[0]; + output[4] = output[0]; + +#elif defined(BIASE_ELE) + half4 output[5]; + output[0] = read_imageh(bias, sampler, (int2)(out_w_id0, item_h_id)); + if (out_w_id_per_ch_blk + 2 < out_w) { + output[1] = read_imageh(bias, sampler, (int2)(out_w_id1, item_h_id)); + } + if (out_w_id_per_ch_blk + 4 < out_w) { + output[2] = read_imageh(bias, sampler, (int2)(out_w_id2, item_h_id)); + } + if (out_w_id_per_ch_blk + 6 < out_w) { + output[3] = read_imageh(bias, sampler, (int2)(out_w_id3, item_h_id)); + } + if (out_w_id_per_ch_blk + 8 < out_w) { + output[4] = read_imageh(bias, sampler, (int2)(out_w_id4, item_h_id)); + } + +#else + half4 output[5] = {0.0f}; +#endif + half4 filter[4] = {0.0f}; + half4 filter_trans[4] = {0.0f}; + + half4 input[5] = {0.0f}; + for (int ch = 0; ch < (in_ch + 3) / 4; ch++) { + int filter_w_id = ch * 3; + int h_idx = 0; + for (int h = filter_h_id_per_out_ch; h >= 0; h -= 2) { + int in_h_id = select(in_h_id_per_batch + h_idx, -1, + in_h_id_per_batch + h_idx < 0 || in_h_id_per_batch + h_idx >= in_h); + int filter_h_id = item_ch_id * 12 + h; + int w_idx = 0; + for (int w = filter_w_id_per_ch_blk; w >= 0; w -= 2) { + int in_w_id0 = select(ch * in_w + in_w_id_per_ch_blk + w_idx, -1, + in_w_id_per_ch_blk + w_idx < 0 || in_w_id_per_ch_blk + w_idx >= in_w); + int in_w_id1 = select(ch * in_w + in_w_id_per_ch_blk + 1 + w_idx, -1, + in_w_id_per_ch_blk + 1 + w_idx < 0 || in_w_id_per_ch_blk + 1 + w_idx >= in_w); + int in_w_id2 = select(ch * in_w + in_w_id_per_ch_blk + 2 + w_idx, -1, + in_w_id_per_ch_blk + 2 + w_idx < 0 || in_w_id_per_ch_blk + 2 + w_idx >= in_w); + int in_w_id3 = select(ch * in_w + in_w_id_per_ch_blk + 3 + w_idx, -1, + in_w_id_per_ch_blk + 3 + w_idx < 0 || in_w_id_per_ch_blk + 3 + w_idx >= in_w); + int in_w_id4 = select(ch * in_w + in_w_id_per_ch_blk + 4 + w_idx, -1, + in_w_id_per_ch_blk + 4 + w_idx < 0 || in_w_id_per_ch_blk + 4 + w_idx >= in_w); + + input[0] = read_imageh(input_image, sampler, (int2)(in_w_id0, in_h_id)); + input[1] = read_imageh(input_image, sampler, (int2)(in_w_id1, in_h_id)); + input[2] = read_imageh(input_image, sampler, (int2)(in_w_id2, in_h_id)); + input[3] = read_imageh(input_image, sampler, (int2)(in_w_id3, in_h_id)); + input[4] = read_imageh(input_image, sampler, (int2)(in_w_id4, in_h_id)); + + filter[0] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id)); // in_ch:0-3,out_ch:0 + filter[1] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id + 3)); // in_ch:0-3,out_ch:1 + filter[2] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id + 6)); // in_ch:0-3,out_ch:2 + filter[3] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id + 9)); // in_ch:0-3,out_ch:3 + + filter_trans[0] = (half4)(filter[0].x, filter[1].x, filter[2].x, filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (half4)(filter[0].y, filter[1].y, filter[2].y, filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (half4)(filter[0].z, filter[1].z, filter[2].z, filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (half4)(filter[0].w, filter[1].w, filter[2].w, filter[3].w); // in_ch:3,out_ch:0-3 + + output[0] = mad(input[0].x, filter_trans[0], output[0]); + output[0] = mad(input[0].y, filter_trans[1], output[0]); + output[0] = mad(input[0].z, filter_trans[2], output[0]); + output[0] = mad(input[0].w, filter_trans[3], output[0]); + + output[1] = mad(input[1].x, filter_trans[0], output[1]); + output[1] = mad(input[1].y, filter_trans[1], output[1]); + output[1] = mad(input[1].z, filter_trans[2], output[1]); + output[1] = mad(input[1].w, filter_trans[3], output[1]); + + output[2] = mad(input[2].x, filter_trans[0], output[2]); + output[2] = mad(input[2].y, filter_trans[1], output[2]); + output[2] = mad(input[2].z, filter_trans[2], output[2]); + output[2] = mad(input[2].w, filter_trans[3], output[2]); + + output[3] = mad(input[3].x, filter_trans[0], output[3]); + output[3] = mad(input[3].y, filter_trans[1], output[3]); + output[3] = mad(input[3].z, filter_trans[2], output[3]); + output[3] = mad(input[3].w, filter_trans[3], output[3]); + + output[4] = mad(input[4].x, filter_trans[0], output[4]); + output[4] = mad(input[4].y, filter_trans[1], output[4]); + output[4] = mad(input[4].z, filter_trans[2], output[4]); + output[4] = mad(input[4].w, filter_trans[3], output[4]); + w_idx++; + } + h_idx++; + } + } +#ifdef BATCH_NORM + half4 scale = read_imageh(new_scale, sampler, (int2)(item_ch_id, 0)); + half4 biase = read_imageh(new_biase, sampler, (int2)(item_ch_id, 0)); + output[0] = mad(scale, output[0], biase); + if (out_w_id_per_ch_blk + 2 < out_w) { + output[1] = mad(scale, output[1], biase); + } + if (out_w_id_per_ch_blk + 4 < out_w) { + output[2] = mad(scale, output[2], biase); + } + if (out_w_id_per_ch_blk + 6 < out_w) { + output[3] = mad(scale, output[3], biase); + } + if (out_w_id_per_ch_blk + 8 < out_w) { + output[4] = mad(scale, output[4], biase); + } +#endif + +#ifdef RELU + output[0] = activation(output[0]); + output[1] = activation(output[1]); + output[2] = activation(output[2]); + output[3] = activation(output[3]); + output[4] = activation(output[4]); + +#endif + + write_imageh(output_image, (int2)(out_w_id0, item_h_id), output[0]); + + if (out_w_id_per_ch_blk + 2 < out_w) { + write_imageh(output_image, (int2)(out_w_id1, item_h_id), output[1]); + } + if (out_w_id_per_ch_blk + 4 < out_w) { + write_imageh(output_image, (int2)(out_w_id2, item_h_id), output[2]); + } + if (out_w_id_per_ch_blk + 6 < out_w) { + write_imageh(output_image, (int2)(out_w_id3, item_h_id), output[3]); + } + if (out_w_id_per_ch_blk + 8 < out_w) { + write_imageh(output_image, (int2)(out_w_id4, item_h_id), output[4]); + } +} + +__kernel void conv_transpose(__private const int item_ch, + __private const int item_w, + __private const int item_h, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int in_ch, + __private const int in_w, + __private const int in_h, + __private const int out_w, + __private const int out_h, + __private const int filter_w, + __private const int filter_h) { + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + // item_id + const int item_ch_id = get_global_id(0); + const int item_w_id = get_global_id(1); + const int item_h_id = get_global_id(2); + + // out_id + int out_b_id = item_h_id / out_h; + int out_w_id_per_ch_blk = item_w_id; + int out_h_id_per_batch = item_h_id % out_h; + int out_w_id = item_ch_id * out_w + out_w_id_per_ch_blk; + + // in_id + int in_w_id_per_ch_blk = (out_w_id_per_ch_blk + pad - filter_w + stride) / stride; + in_w_id_per_ch_blk = in_w_id_per_ch_blk > 0 ? in_w_id_per_ch_blk : 0; + int in_h_id_per_batch = (out_h_id_per_batch + pad - filter_h + stride) / stride; + in_h_id_per_batch = in_h_id_per_batch > 0 ? in_h_id_per_batch : 0; + + // filter_id + int align_w_i = out_w_id_per_ch_blk + pad - filter_w + 1; + int align_w = align_w_i % stride > 0 ? + align_w_i % stride - stride : align_w_i % stride; + int filter_w_id_per_ch_blk = out_w_id_per_ch_blk + pad < filter_w ? out_w_id_per_ch_blk + pad : filter_w + align_w - 1; + int align_h_i = out_h_id_per_batch + pad - filter_h + 1; + int align_h = align_h_i % stride > 0 ? + align_h_i % stride - stride : align_h_i % stride; + int filter_h_id_per_out_ch = out_h_id_per_batch + pad < filter_h ? out_h_id_per_batch + pad : filter_h + align_h - 1; +#ifdef BIASE_CH + half4 output; + output = read_imageh(bias, sampler, (int2)(item_ch_id, 0)); +#elif defined(BIASE_ELE) + half4 output; + output = read_imageh(bias, sampler, (int2)(out_w_id, item_h_id)); +#else + half4 output = 0.0f; +#endif + half4 filter[4] = {0.0f}; + half4 filter_trans[4] = {0.0f}; + half4 input = 0.0f; + for (int ch = 0; ch < (in_ch + 3) / 4; ch++) { + int filter_w_id = ch * filter_w; + int h_idx = 0; + for (int h = filter_h_id_per_out_ch; h >= 0; h -= stride) { + int in_h_id = select(in_h_id_per_batch + h_idx, -1, + in_h_id_per_batch + h_idx < 0 || in_h_id_per_batch + h_idx >= in_h); + int filter_h_id = item_ch_id * filter_h * 4 + h; + int w_idx = 0; + for (int w = filter_w_id_per_ch_blk; w >= 0; w -= stride) { + int in_w_id = select(ch * in_w + in_w_id_per_ch_blk + w_idx, -1, + in_w_id_per_ch_blk + w_idx < 0 || in_w_id_per_ch_blk + w_idx >= in_w); + input = read_imageh(input_image, sampler, (int2)(in_w_id, in_h_id)); + filter[0] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id)); // in_ch:0-3,out_ch:0 + filter[1] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id + filter_h)); // in_ch:0-3,out_ch:1 + filter[2] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id + 2 * filter_h)); // in_ch:0-3,out_ch:2 + filter[3] = read_imageh(filter_image, sampler, (int2)(filter_w_id + w, filter_h_id + 3 * filter_h)); // in_ch:0-3,out_ch:3 + + filter_trans[0] = (half4)(filter[0].x, filter[1].x, filter[2].x, filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (half4)(filter[0].y, filter[1].y, filter[2].y, filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (half4)(filter[0].z, filter[1].z, filter[2].z, filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (half4)(filter[0].w, filter[1].w, filter[2].w, filter[3].w); // in_ch:3,out_ch:0-3 + + output = mad(input.x, filter_trans[0], output); + output = mad(input.y, filter_trans[1], output); + output = mad(input.z, filter_trans[2], output); + output = mad(input.w, filter_trans[3], output); + w_idx++; + } + h_idx++; + } + } +#ifdef BATCH_NORM + half4 scale = read_imageh(new_scale, sampler, (int2)(item_ch_id, 0)); + half4 biase = read_imageh(new_biase, sampler, (int2)(item_ch_id, 0)); + output = mad(scale, output, biase); +#endif + +#ifdef RELU + output = activation(output); +#endif + write_imageh(output_image, (int2)(out_w_id, item_h_id), output); +} diff --git a/mobile/src/operators/kernel/cl/cl_kernel/elementwise_mul_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/elementwise_mul_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..b975eb405633b3d7252aea30671818066459b3ea --- /dev/null +++ b/mobile/src/operators/kernel/cl/cl_kernel/elementwise_mul_kernel.cl @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias,__write_only image2d_t outputImage) { + int x = get_global_id(0); + int y = get_global_id(1); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + half4 in = read_imageh(input, sampler, coords); + half4 biase = read_imageh(bias, sampler, coords); + half4 output = in * biase; + write_imageh(outputImage,coords,output); + } + + +__kernel void channel_mul(__global image2d_t input, __global image2d_t bias,__write_only +image2d_t outputImage, int w) { + int x = get_global_id(0); + int y = get_global_id(1); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + int2 coords_bias; + coords_bias.x = x/w; + coords_bias.y = 0; + half4 in = read_imageh(input, sampler, coords); + half4 biase = read_imageh(bias, sampler, coords_bias); + half4 output = in * biase; + write_imageh(outputImage,coords,output); +} diff --git a/mobile/src/operators/kernel/cl/cl_kernel/feed_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/feed_kernel.cl index bb661f3cf7102d5ef35b57f2167face0957129bc..27ca4d296e786715345540c44b6691dc8a69cefe 100644 --- a/mobile/src/operators/kernel/cl/cl_kernel/feed_kernel.cl +++ b/mobile/src/operators/kernel/cl/cl_kernel/feed_kernel.cl @@ -60,3 +60,51 @@ __kernel void feed(__global float *in, write_imageh(output_image, output_pos, output); } + +__kernel void feed_with_pre(__global uchar *in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_C, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2){ + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh/out_H; + const int out_h = out_nh%out_H; + + const int in_n = out_n; + const int in_c0 = out_c * 4 + 0; + const int in_c1 = out_c * 4 + 1; + const int in_c2 = out_c * 4 + 2; + const int in_c3 = out_c * 4 + 3; + const int in_h = out_h; + const int in_w = out_w; + + + int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n * Stride2 + in_c3 * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + half4 output = (half4)0.0f; + output.x = convert_half(in[input_pos0]) / 255; + if(out_C - 4 * out_c>=2){ + output.y = convert_half(in[input_pos1]) / 255; + } + if(out_C - 4 * out_c>=3){ + output.z = convert_half(in[input_pos2]) / 255; + } + if(out_C - 4 * out_c>=4){ + output.w = convert_half(in[input_pos3]) / 255; + } + write_imageh(output_image, output_pos, output); + +} diff --git a/mobile/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl index f6014b732398cccd025a39cfb4a824b3154fcd66..f6b8e23cc43f512d5112d6bc80c6e1199d7c8c5e 100644 --- a/mobile/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl +++ b/mobile/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl @@ -67,3 +67,38 @@ __kernel void fetch_2d(__private const int in_height, out[index + 2] = convert_float(in.z); out[index + 3] = convert_float(in.w); } + +__kernel void fetch_with_post(__private const int in_height, + __private const int in_width, + __read_only image2d_t input, + __global uchar* out, + __private const int size_ch, + __private const int size_block, + __private const int size_batch, + __private const int C) { + const int in_c = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + const int in_n = in_nh / in_height; + const int in_h = in_nh % in_height; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + const int pos_x = mad24(in_c, in_width, in_w); + half4 in = read_imageh(input, sampler, (int2)(pos_x, in_nh)); + + const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; + out[index] = convert_uchar_sat(in.x * 255); + if(C - 4 * in_c>=2){ + out[index + size_ch] = convert_uchar_sat(in.y * 255); + } + if(C - 4 * in_c>=3){ + out[index + size_ch * 2] = convert_uchar_sat(in.z * 255); + } + + if(C - 4 * in_c>=4){ + out[index + size_ch * 3] = convert_uchar_sat(in.w * 255); + } + +} diff --git a/mobile/src/operators/kernel/cl/cl_kernel/instancenorm_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/instancenorm_kernel.cl index 30e248f7f6241899ffd7d8aee45d06d1c61dbf4a..f78de05f766e12a54e35cd8cd59102435e1d950a 100644 --- a/mobile/src/operators/kernel/cl/cl_kernel/instancenorm_kernel.cl +++ b/mobile/src/operators/kernel/cl/cl_kernel/instancenorm_kernel.cl @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#include "cl_common.h" __kernel void instancenorm(__private const int in_width, __private const int in_height, @@ -32,13 +32,19 @@ __kernel void instancenorm(__private const int in_width, const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - +#ifdef LOCAL_MEM_128 + __local float4 shared_mem[128]; +#elif defined(LOCAL_MEM_64) + __local float4 shared_mem[64]; +#else __local float4 shared_mem[256]; - +#endif + int xOffset = c * in_width; + int yOffset = n * in_height; float4 sum = 0.0f; for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) { for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) { - sum += read_imagef(input, sampler, (int2)(mad24(c, in_width, xIndex), mad24(n, in_height, yIndex))); + sum += read_imagef(input, sampler, (int2)(xOffset + xIndex, yOffset + yIndex)); } } shared_mem[local_id] = sum; @@ -73,7 +79,8 @@ __kernel void instancenorm(__private const int in_width, sum = 0.0f; for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) { for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) { - sum += pow(read_imagef(input, sampler, (int2)(mad24(c, in_width, xIndex), mad24(n, in_height, yIndex))) - mean_val, 2); + float4 temp = read_imagef(input, sampler, (int2)(xOffset + xIndex, yOffset + yIndex)) - mean_val; + sum += temp * temp; } } shared_mem[local_id] = sum; @@ -107,9 +114,13 @@ __kernel void instancenorm(__private const int in_width, for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) { for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) { - int2 intout_pos = (int2)(mad24(c, in_width, xIndex), mad24(n, in_height, yIndex)); + int2 intout_pos = (int2)(xOffset + xIndex, yOffset + yIndex); float4 in_val = read_imagef(input, sampler, intout_pos); - write_imageh(output, intout_pos, convert_half4((in_val - mean_val) * s)); + half4 out_val = convert_half4((in_val - mean_val) * s); +#ifdef RELU + out_val = activation(out_val); +#endif + write_imageh(output, intout_pos, out_val); } } } diff --git a/mobile/src/operators/kernel/cl/cl_kernel/pixel_shuffle_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/pixel_shuffle_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..a38c1ceae0a0dd502bd4c133c1ce229006e6eba3 --- /dev/null +++ b/mobile/src/operators/kernel/cl/cl_kernel/pixel_shuffle_kernel.cl @@ -0,0 +1,114 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void pixel_shuffle(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int in_N, + __private const int in_C, + __private const int in_H, + __private const int in_W, + __private const int out_N, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int upscale_factor) { + + const int out_c4 = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int out_h = out_nh % out_H; + int out_n = out_nh / out_H; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + int in_h = out_h / upscale_factor; + int in_w = out_w / upscale_factor; + int in_nh = out_n * in_H + in_h; + + half4 res; + int out_c; + int in_c; + half4 in; + int2 in_pos; + + out_c = out_c4 * 4 + 0; + in_c = out_c * upscale_factor * upscale_factor + (out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor); + in_pos.x = (in_c / 4) * in_W + in_w; + in_pos.y = in_nh; + in = read_imageh(input_image, sampler, in_pos); + if (in_c % 4 == 0) { + res.x = in.x; + } else if (in_c % 4 == 1) { + res.x = in.y; + } else if (in_c % 4 == 2) { + res.x = in.z; + } else if (in_c % 4 == 3) { + res.x = in.w; + } + + out_c = out_c4 * 4 + 1; + in_c = out_c * upscale_factor * upscale_factor + (out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor); + in_pos.x = (in_c / 4) * in_W + in_w; + in_pos.y = in_nh; + in = read_imageh(input_image, sampler, in_pos); + if (in_c % 4 == 0) { + res.y = in.x; + } else if (in_c % 4 == 1) { + res.y = in.y; + } else if (in_c % 4 == 2) { + res.y = in.z; + } else if (in_c % 4 == 3) { + res.y = in.w; + } + + out_c = out_c4 * 4 + 2; + in_c = out_c * upscale_factor * upscale_factor + (out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor); + in_pos.x = (in_c / 4) * in_W + in_w; + in_pos.y = in_nh; + in = read_imageh(input_image, sampler, in_pos); + if (in_c % 4 == 0) { + res.z = in.x; + } else if (in_c % 4 == 1) { + res.z = in.y; + } else if (in_c % 4 == 2) { + res.z = in.z; + } else if (in_c % 4 == 3) { + res.z = in.w; + } + + out_c = out_c4 * 4 + 3; + in_c = out_c * upscale_factor * upscale_factor + (out_h % upscale_factor) * upscale_factor + (out_w % upscale_factor); + in_pos.x = (in_c / 4) * in_W + in_w; + in_pos.y = in_nh; + in = read_imageh(input_image, sampler, in_pos); + if (in_c % 4 == 0) { + res.w = in.x; + } else if (in_c % 4 == 1) { + res.w = in.y; + } else if (in_c % 4 == 2) { + res.w = in.z; + } else if (in_c % 4 == 3) { + res.w = in.w; + } + + int2 out_pos; + out_pos.x = out_c4 * out_W + out_w; + out_pos.y = out_nh; + write_imageh(output_image, out_pos, res); +} diff --git a/mobile/src/operators/kernel/cl/cl_kernel/pre_post_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/pre_post_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..edb6138919d3025c176f6dda540f683912633fce --- /dev/null +++ b/mobile/src/operators/kernel/cl/cl_kernel/pre_post_kernel.cl @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void pre(__global const uchar *input, + __global float *output){ + + int index = get_global_id(0); + output[index] = convert_float(input[index]) / 255; + + } diff --git a/mobile/src/operators/kernel/cl/conv_add_kernel.cpp b/mobile/src/operators/kernel/cl/conv_add_kernel.cpp index 8e21480b412affe2910142ae746b42682871859e..74225142283ae842237e2f1e84644a02addfcb40 100644 --- a/mobile/src/operators/kernel/cl/conv_add_kernel.cpp +++ b/mobile/src/operators/kernel/cl/conv_add_kernel.cpp @@ -82,17 +82,11 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { // winograd_transform_weight<4, 3>(&this->cl_helper_, param->Filter()); // // } else { - if (param->Strides()[0] == 1 && param->Dilations()[0] == 1) { - param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT; - param->Filter()->InitCLImage(cl_helper_.CLContext(), - cl_helper_.CLCommandQueue()); - this->cl_helper_.AddKernel("conv_3x3s1", conv_kernel_file, build_options); - } else { - param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3_FLOAT; - param->Filter()->InitCLImage(cl_helper_.CLContext(), - cl_helper_.CLCommandQueue()); - this->cl_helper_.AddKernel("conv_3x3", conv_kernel_file, build_options); - } + + param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3_FLOAT; + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_3x3spl", conv_kernel_file, build_options); // } } else if (param->Filter()->dims()[2] == 7 && @@ -101,7 +95,7 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { param->Filter()->InitCLImage(cl_helper_.CLContext(), cl_helper_.CLCommandQueue()); - this->cl_helper_.AddKernel("conv_7x7", conv_kernel_file, build_options); + this->cl_helper_.AddKernel("conv_7x7spl", conv_kernel_file, build_options); } else if (param->Filter()->dims()[2] == 5 && param->Filter()->dims()[3] == 5) { @@ -123,16 +117,17 @@ void ConvAddKernel::Compute( WinogradConv3x3<4, 3>(&this->cl_helper_, param, false, param.Bias()); break; case ConvParam::EXEC_SLIDINGWINDOW1x1_FLOAT: - case ConvParam::EXEC_SLIDINGWINDOW3x3_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW5x5_FLOAT: - case ConvParam::EXEC_SLIDINGWINDOW7x7_FLOAT: case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: ConvAddBnRelu(&this->cl_helper_, param, false, param.Bias()); break; + case ConvParam::EXEC_SLIDINGWINDOW7x7_FLOAT: + SWConvAddBnRelu(&this->cl_helper_, param, false, param.Bias()); + break; case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: DWConvAddBnRelu(&this->cl_helper_, param, false, param.Bias()); break; - case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: + case ConvParam::EXEC_SLIDINGWINDOW3x3_FLOAT: SWConvAddBnRelu(&this->cl_helper_, param, false, param.Bias()); break; default: diff --git a/mobile/src/operators/kernel/cl/conv_relu_kernel.cpp b/mobile/src/operators/kernel/cl/conv_relu_kernel.cpp index 585b68f5326ec7990279287d33edbad38cb3d4fb..1aedbeec7acdb5eafa4ffe088cee1f8a7adf8230 100644 --- a/mobile/src/operators/kernel/cl/conv_relu_kernel.cpp +++ b/mobile/src/operators/kernel/cl/conv_relu_kernel.cpp @@ -86,7 +86,8 @@ bool ConvReluKernel::Init(FusionConvReluParam *param) { param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT; param->Filter()->InitCLImage(cl_helper_.CLContext(), cl_helper_.CLCommandQueue()); - this->cl_helper_.AddKernel("conv_3x3s1", conv_kernel_file, build_options); + this->cl_helper_.AddKernel("conv_3x3spl", conv_kernel_file, + build_options); } else { param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3_FLOAT; param->Filter()->InitCLImage(cl_helper_.CLContext(), diff --git a/mobile/src/operators/kernel/cl/conv_transpose_kernel.cpp b/mobile/src/operators/kernel/cl/conv_transpose_kernel.cpp index f5be81eefd8284741819cadee8f0ea5e30100772..8d66b50a99a6cd07de8dcf32867f1cb3c28d2232 100644 --- a/mobile/src/operators/kernel/cl/conv_transpose_kernel.cpp +++ b/mobile/src/operators/kernel/cl/conv_transpose_kernel.cpp @@ -40,7 +40,8 @@ bool ConvTransposeKernel::Init( param->ExecMode() = ConvTransposeParam::EXEC_CONVTRANS3x3s2_FLOAT; param->Filter()->InitConv2dTransposeFilterCLImage( cl_helper_.CLContext(), cl_helper_.CLCommandQueue()); - this->cl_helper_.AddKernel("conv_transpose", "conv_transpose_kernel.cl"); + this->cl_helper_.AddKernel("conv_transpose3x3s2", + "conv_transpose_kernel.cl"); } else { PADDLE_MOBILE_THROW_EXCEPTION(" not support "); } @@ -55,7 +56,7 @@ void ConvTransposeKernel::Compute( DWConvTransposeAddBnRelu(&this->cl_helper_, param); break; case ConvTransposeParam::EXEC_CONVTRANS3x3s2_FLOAT: - ConvTransposeAddBnRelu(&this->cl_helper_, param); + ConvTranspose3x3s2AddBnRelu(&this->cl_helper_, param); break; default: PADDLE_MOBILE_THROW_EXCEPTION( diff --git a/mobile/src/operators/kernel/cl/density_prior_box_kernel.cpp b/mobile/src/operators/kernel/cl/density_prior_box_kernel.cpp index 0a281ed1039232722b624fb42ef575ec48f835aa..1a5cf0f061606d82076ce0f231e03ba3b36753a0 100644 --- a/mobile/src/operators/kernel/cl/density_prior_box_kernel.cpp +++ b/mobile/src/operators/kernel/cl/density_prior_box_kernel.cpp @@ -25,6 +25,35 @@ bool DensityPriorBoxKernel::Init( *param) { this->cl_helper_.AddKernel("density_prior_box", "density_prior_box_kernel.cl"); + vector fixed_sizes = param->FixedSizes(); + vector fixed_ratios = param->FixedRatios(); + vector densities = param->Densities(); + vector variances = param->Variances(); + int fix_ratio_size = fixed_ratios.size(); + int total_size = densities.size() + fixed_sizes.size() + fix_ratio_size; + float *densities_data = new float[total_size]; + for (int i = 0; i < densities.size(); ++i) { + float density = densities[i]; + densities_data[i] = density; + } + + for (int k = 0; k < fixed_sizes.size(); ++k) { + densities_data[k + densities.size()] = fixed_sizes[k]; + } + + for (int j = 0; j < fixed_ratios.size(); ++j) { + float sqrt_ratios = sqrt(fixed_ratios[j]); + densities_data[j + densities.size() + fixed_sizes.size()] = sqrt_ratios; + } + + framework::CLImage *new_density = new framework::CLImage(); + new_density->SetTensorData(densities_data, {1, 1, 1, total_size}); + new_density->InitCLImage(this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue()); + param->setNewDensity(new_density); + + delete[](densities_data); + return true; } @@ -39,6 +68,7 @@ void DensityPriorBoxKernel::Compute( auto output_boxes = param.OutputBoxes()->GetCLImage(); auto output_var = param.OutputVariances()->GetCLImage(); + auto new_density = param.getNewDensity()->GetCLImage(); float step_w = param.StepW(); float step_h = param.StepH(); @@ -73,43 +103,17 @@ void DensityPriorBoxKernel::Compute( auto default_work = this->cl_helper_.DefaultWorkSize(*param.OutputBoxes()); - float *densities_data[densities.size() + fixed_sizes.size() + fix_ratio_size]; - - int status; - - for (int i = 0; i < densities.size(); ++i) { - float density = densities[i]; - densities_data[i] = &density; - } - - for (int k = 0; k < fixed_sizes.size(); ++k) { - densities_data[k + densities.size()] = &fixed_sizes[k]; - } - - for (int j = 0; j < fixed_ratios.size(); ++j) { - float sqrt_ratios = sqrt(fixed_ratios[j]); - densities_data[j + densities.size() + fixed_sizes.size()] = &sqrt_ratios; - } - - cl_mem densities_memobj = clCreateBuffer( - this->cl_helper_.CLContext(), CL_MEM_READ_WRITE, - sizeof(float) * (densities.size() * 2 + fix_ratio_size), NULL, &status); - status = clEnqueueWriteBuffer( - this->cl_helper_.CLCommandQueue(), densities_memobj, CL_FALSE, 0, - (densities.size() * 2 + fix_ratio_size) * sizeof(float), densities_data, - 0, NULL, NULL); - CL_CHECK_ERRORS(status); - float variances0 = variances[0]; float variances1 = variances[1]; float variances2 = variances[2]; float variances3 = variances[3]; + cl_int status; status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &output_boxes); CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_var); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &densities_memobj); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &new_density); CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, 3, sizeof(float), &step_h); CL_CHECK_ERRORS(status); diff --git a/mobile/src/operators/kernel/cl/elementwise_add_kernel.cpp b/mobile/src/operators/kernel/cl/elementwise_add_kernel.cpp index 1506956280a489ebf9631099101c1ac0f0bf03ec..06d718601cc885ac100dc29a4879b88ce9384736 100644 --- a/mobile/src/operators/kernel/cl/elementwise_add_kernel.cpp +++ b/mobile/src/operators/kernel/cl/elementwise_add_kernel.cpp @@ -77,8 +77,8 @@ void ElementwiseAddKernel::Compute( status = clSetKernelArg(kernel, 2, sizeof(cl_mem), reinterpret_cast(&output_image)); CL_CHECK_ERRORS(status); - int width = input->ImageWidth(); - int height = input->ImageHeight(); + auto width = input->ImageWidth(); + auto height = input->ImageHeight(); size_t global_work_size[2] = {width, height}; status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, @@ -103,8 +103,8 @@ void ElementwiseAddKernel::Compute( status = clSetKernelArg(kernel, 3, sizeof(cl_int), reinterpret_cast(&tensor_w)); CL_CHECK_ERRORS(status); - int width = input->ImageWidth(); - int height = input->ImageHeight(); + auto width = input->ImageWidth(); + auto height = input->ImageHeight(); DLOG << "dede:" << width << "," << height; size_t global_work_size[2] = {width, height}; cl_event out_event = param.Out()->GetClEvent(); diff --git a/mobile/src/operators/kernel/cl/elementwise_mul_kernel.cpp b/mobile/src/operators/kernel/cl/elementwise_mul_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f2aca78509ea45525f1dcd39a7a8154ca75060e --- /dev/null +++ b/mobile/src/operators/kernel/cl/elementwise_mul_kernel.cpp @@ -0,0 +1,103 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISEMUL_OP + +#include "operators/kernel/elementwise_mul_kernel.h" +#include "framework/cl/cl_image.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ElementwiseMulKernel::Init( + ElementwiseMulParam *param) { + DLOG << "-----init add-----"; + framework::CLImage *bias = reinterpret_cast( + const_cast(param->InputY())); + if (bias->dims() == param->InputX()->dims()) { + this->cl_helper_.AddKernel("elementwise_mul", "elementwise_mul_kernel.cl"); + } else if (bias->dims().size() == 4) { + this->cl_helper_.AddKernel("channel_mul", "elementwise_mul_kernel.cl"); + } else { + DLOG << "error:bias dims is error"; + } + return true; +} + +template <> +void ElementwiseMulKernel::Compute( + const ElementwiseMulParam ¶m) { + auto input = param.InputX(); + auto bias = param.InputY(); + auto output = param.Out(); + cl_int status; + auto kernel = this->cl_helper_.KernelAt(0); + if (bias->dims() == input->dims()) { + cl_mem input_image = input->GetCLImage(); + cl_mem bias_image = bias->GetCLImage(); + cl_mem output_image = output->GetCLImage(); + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), + reinterpret_cast(&input_image)); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), + reinterpret_cast(&bias_image)); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), + reinterpret_cast(&output_image)); + CL_CHECK_ERRORS(status); + auto width = input->ImageWidth(); + auto height = input->ImageHeight(); + size_t global_work_size[2] = {width, height}; + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, + NULL, global_work_size, NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + } else if (bias->dims().size() == 4) { + DLOG << "zp7 444"; + cl_mem input_image = input->GetCLImage(); + cl_mem bias_image = bias->GetCLImage(); + cl_mem output_image = output->GetCLImage(); + int tensor_w = input->dims()[input->dims().size() - 1]; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), + reinterpret_cast(&input_image)); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), + reinterpret_cast(&bias_image)); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), + reinterpret_cast(&output_image)); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), + reinterpret_cast(&tensor_w)); + CL_CHECK_ERRORS(status); + auto width = input->ImageWidth(); + auto height = input->ImageHeight(); + DLOG << "dede:" << width << "," << height; + size_t global_work_size[2] = {width, height}; + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, + NULL, global_work_size, NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + } else { + DLOG << "error:bias dims is error"; + } +} + +template class ElementwiseMulKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/cl/feed_kernel.cpp b/mobile/src/operators/kernel/cl/feed_kernel.cpp index 0522905fee91fd466b2c334677acce0d25cfac7e..f96059593459d7fd95e236473b3ca3c5cd1420fc 100644 --- a/mobile/src/operators/kernel/cl/feed_kernel.cpp +++ b/mobile/src/operators/kernel/cl/feed_kernel.cpp @@ -21,7 +21,11 @@ namespace operators { template <> bool FeedKernel::Init(FeedParam *param) { DLOG << "Init feed"; - this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); + if (this->pre_post_type_ == UINT8_255) { + this->cl_helper_.AddKernel("feed_with_pre", "feed_kernel.cl"); + } else { + this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); + } return true; } @@ -34,7 +38,7 @@ void FeedKernel::Compute(const FeedParam ¶m) { auto output = param.Out(); const Tensor *input = ¶m.InputX()->at(col); // DLOG << *input; - const float *input_data = input->data(); + int numel = input->numel(); cl_mem output_image = output->GetCLImage(); const int out_C = output->dims()[1]; @@ -46,7 +50,14 @@ void FeedKernel::Compute(const FeedParam ¶m) { framework::CLTensor input_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); input_cl_tensor.Resize(input->dims()); - cl_mem inputBuffer = input_cl_tensor.mutable_with_data(input_data); + cl_mem inputBuffer; + if (this->pre_post_type_ == UINT8_255) { + inputBuffer = + input_cl_tensor.mutable_with_data(input->data()); + } else { + inputBuffer = + input_cl_tensor.mutable_with_data(input->data()); + } status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer); CL_CHECK_ERRORS(status); diff --git a/mobile/src/operators/kernel/cl/fetch_kernel.cpp b/mobile/src/operators/kernel/cl/fetch_kernel.cpp index e1e1522a449685902dd64369bcc15798d1376a72..df2c2e1f5c2df08897c4d00db1f80d79f4c13c25 100644 --- a/mobile/src/operators/kernel/cl/fetch_kernel.cpp +++ b/mobile/src/operators/kernel/cl/fetch_kernel.cpp @@ -20,7 +20,11 @@ namespace operators { template <> bool FetchKernel::Init(FetchParam *param) { - this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + if (this->pre_post_type_ == UINT8_255) { + this->cl_helper_.AddKernel("fetch_with_post", "fetch_kernel.cl"); + } else { + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + } return true; } @@ -33,7 +37,6 @@ void FetchKernel::Compute(const FetchParam ¶m) { auto input = param.InputX()->GetCLImage(); auto *out = ¶m.Out()->at(col); out->Resize(param.InputX()->dims()); - out->mutable_data(); DLOG << "fetch kernel out dims = " << out->dims(); DLOG << "fetch kernel out memory size = " << out->memory_size(); @@ -57,7 +60,14 @@ void FetchKernel::Compute(const FetchParam ¶m) { framework::CLTensor out_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); out_cl_tensor.Resize(out->dims()); - cl_mem outBuffer = out_cl_tensor.mutable_data(); + cl_mem outBuffer; + if (this->pre_post_type_ == UINT8_255) { + out->mutable_data(); + outBuffer = out_cl_tensor.mutable_data(); + } else { + out->mutable_data(); + outBuffer = out_cl_tensor.mutable_data(); + } cl_int status; status = clSetKernelArg(kernel, 0, sizeof(int), &in_height); @@ -91,8 +101,13 @@ void FetchKernel::Compute(const FetchParam ¶m) { DLOG << "fetch kernel out_cl_tensor dims = " << out_cl_tensor.dims(); DLOG << "fetch kernel out_cl_tensor memery size = " << out_cl_tensor.memory_size(); - memcpy(out->data(), out_cl_tensor.Data(), - sizeof(float) * out->numel()); + if (this->pre_post_type_ == UINT8_255) { + memcpy(out->data(), out_cl_tensor.Data(), + sizeof(uint8_t) * out->numel()); + } else { + memcpy(out->data(), out_cl_tensor.Data(), + sizeof(float) * out->numel()); + } } template class FetchKernel; diff --git a/mobile/src/operators/kernel/cl/fusion_fc_kernel.cpp b/mobile/src/operators/kernel/cl/fusion_fc_kernel.cpp index a9d6b806080e797d4aab9b9f46315a9e34808c8b..de6a0455b9763890ca3fbf00af7bc25a43bb5a42 100644 --- a/mobile/src/operators/kernel/cl/fusion_fc_kernel.cpp +++ b/mobile/src/operators/kernel/cl/fusion_fc_kernel.cpp @@ -98,7 +98,7 @@ void FusionFcCompute(const FusionFcParam ¶m, cl_context context, static_cast(1), out, static_cast(1), false); - out_image->InitEmptyImage(context, commandQueue, out->dims()); + // out_image->InitEmptyImage(context, commandQueue, out->dims()); framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1); delete (input_x); diff --git a/mobile/src/operators/kernel/cl/instancenorm_kernel.cpp b/mobile/src/operators/kernel/cl/instancenorm_kernel.cpp index a8307d05d5b493a983e33cebdb331bdc09c27fd9..f068d36133e826e8caa79d8f4852bbaac4415cdd 100644 --- a/mobile/src/operators/kernel/cl/instancenorm_kernel.cpp +++ b/mobile/src/operators/kernel/cl/instancenorm_kernel.cpp @@ -16,86 +16,32 @@ limitations under the License. */ #include "operators/kernel/instancenorm_kernel.h" #include +#include "operators/kernel/cl/cl-kernel-func/instancenorm_func.h" namespace paddle_mobile { namespace operators { template <> bool InstanceNormKernel::Init(InstanceNormParam *param) { - this->cl_helper_.AddKernel("instancenorm", "instancenorm_kernel.cl"); + auto &dims = param->Out()->dims(); + const int h = dims[2]; + std::string build_options = ""; + if (h == 128) { + build_options = "-DLOCAL_MEM_128"; + } else if (h == 64) { + build_options = "-DLOCAL_MEM_64"; + } else if (h > 256) { + PADDLE_MOBILE_THROW_EXCEPTION("instance norm unsupported input height"); + } + this->cl_helper_.AddKernel("instancenorm", "instancenorm_kernel.cl", + build_options); return true; } template <> void InstanceNormKernel::Compute( const InstanceNormParam ¶m) { - auto kernel = this->cl_helper_.KernelAt(0); - auto &dims = param.Out()->dims(); - - const int n = dims[0]; - const int c_group = (dims[1] + 3) / 4; - const int h = dims[2]; - const int w = dims[3]; - auto epsilon = param.Epsilon(); - auto input = param.InputX()->GetCLImage(); - auto out = param.Out()->GetCLImage(); - - DLOG << "Epsilon: " << epsilon; - - auto local_work_size_info = this->cl_helper_.LocalWorkSizeInfo(); - - DLOG << local_work_size_info.max_work_group_size; - DLOG << local_work_size_info.max_work_item_size0; - DLOG << local_work_size_info.max_work_item_size1; - DLOG << local_work_size_info.max_work_item_size2; - - const int max_work_group_size = - std::min(256, static_cast(local_work_size_info.max_work_group_size)); - int local_work_size1 = 1; - int local_work_size2 = 1; - for (int i = 1; i <= local_work_size_info.max_work_item_size1 && i <= w; - i++) { - for (int j = 1; j <= local_work_size_info.max_work_item_size2 && j <= h; - j++) { - if (i * j <= max_work_group_size) { - if (i * j > local_work_size1 * local_work_size2) { - local_work_size1 = i; - local_work_size2 = j; - } - } - } - } - const size_t work_size[3] = {(size_t)(n * c_group), (size_t)local_work_size1, - (size_t)local_work_size2}; - const size_t local_work_size[3] = {(size_t)1, (size_t)local_work_size1, - (size_t)local_work_size2}; - - DLOG << "work_size" << work_size[0] << " " << work_size[1] << " " - << work_size[2]; - DLOG << "local_work_size" << local_work_size[0] << " " << local_work_size[1] - << " " << local_work_size[2]; - - cl_int status; - status = clSetKernelArg(kernel, 0, sizeof(cl_int), &w); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 1, sizeof(cl_int), &h); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 2, sizeof(cl_int), &c_group); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 3, sizeof(cl_int), &local_work_size1); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 4, sizeof(cl_int), &local_work_size2); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 5, sizeof(cl_float), &epsilon); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &input); - CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &out); - CL_CHECK_ERRORS(status); - status = - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, - work_size, local_work_size, 0, NULL, NULL); - CL_CHECK_ERRORS(status); + InstanceNorm(&this->cl_helper_, param); } template class InstanceNormKernel; diff --git a/mobile/src/operators/kernel/cl/instancenorm_relu_kernel.cpp b/mobile/src/operators/kernel/cl/instancenorm_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c265454d0ea67c7a6aec8f1017bc5455d328a756 --- /dev/null +++ b/mobile/src/operators/kernel/cl/instancenorm_relu_kernel.cpp @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_INSTANCENORM_RELU_OP + +#include "operators/kernel/instancenorm_relu_kernel.h" +#include +#include "operators/kernel/cl/cl-kernel-func/instancenorm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool InstanceNormReluKernel::Init( + InstanceNormParam *param) { + auto &dims = param->Out()->dims(); + const int h = dims[2]; + std::string build_options = "-DRELU"; + if (h == 128) { + build_options += " -DLOCAL_MEM_128"; + } else if (h == 64) { + build_options += " -DLOCAL_MEM_64"; + } else if (h > 256) { + PADDLE_MOBILE_THROW_EXCEPTION("instance norm unsupported input height"); + } + this->cl_helper_.AddKernel("instancenorm", "instancenorm_kernel.cl", + build_options); + return true; +} + +template <> +void InstanceNormReluKernel::Compute( + const InstanceNormParam ¶m) { + InstanceNorm(&this->cl_helper_, param); +} + +template class InstanceNormReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/cl/mul_kernel.cpp b/mobile/src/operators/kernel/cl/mul_kernel.cpp index d021aa6d7458d2b20d7ec51f095ef8aca6caeef2..3a45babee062ac415c1903e901488a73731f2e22 100644 --- a/mobile/src/operators/kernel/cl/mul_kernel.cpp +++ b/mobile/src/operators/kernel/cl/mul_kernel.cpp @@ -63,7 +63,7 @@ void MulCompute(const MulParam ¶m, cl_context context, static_cast(1), output_tensor, static_cast(0)); - output->InitEmptyImage(context, commandQueue, output_tensor->dims()); + // output->InitEmptyImage(context, commandQueue, output_tensor->dims()); framework::TensorToCLImage(output_tensor, output, context, commandQueue, kernel1); diff --git a/mobile/src/operators/kernel/cl/pixel_shuffle_kernel.cpp b/mobile/src/operators/kernel/cl/pixel_shuffle_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..faa90f9c4329d2450e15c220a68e3d675fb2eacc --- /dev/null +++ b/mobile/src/operators/kernel/cl/pixel_shuffle_kernel.cpp @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PIXEL_SHUFFLE_OP + +#include "operators/kernel/pixel_shuffle_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool PixelShuffleKernel::Init(PixelShuffleParam *param) { + this->cl_helper_.AddKernel("pixel_shuffle", "pixel_shuffle_kernel.cl"); + return true; +} + +template <> +void PixelShuffleKernel::Compute( + const PixelShuffleParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out()); + + auto input_image = param.InputX()->GetCLImage(); + auto output_image = param.Out()->GetCLImage(); + auto upscale_factor = param.upscale_factor(); + + int input_n = param.InputX()->dims()[0]; + int input_c = param.InputX()->dims()[1]; + int input_h = param.InputX()->dims()[2]; + int input_w = param.InputX()->dims()[3]; + int output_n = param.Out()->dims()[0]; + int output_c = param.Out()->dims()[1]; + int output_h = param.Out()->dims()[2]; + int output_w = param.Out()->dims()[3]; + + cl_int status; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &input_n); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(int), &input_c); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(int), &input_h); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(int), &input_w); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(int), &output_n); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 8, sizeof(int), &output_h); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 9, sizeof(int), &output_w); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 10, sizeof(int), &upscale_factor); + CL_CHECK_ERRORS(status); + + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/cl/pool_kernel.cpp b/mobile/src/operators/kernel/cl/pool_kernel.cpp index ed0731c31b01728592336abcd1e282fc74f5ca11..990f6ea67572043b4d09332ab0a1c82cdb8765f9 100644 --- a/mobile/src/operators/kernel/cl/pool_kernel.cpp +++ b/mobile/src/operators/kernel/cl/pool_kernel.cpp @@ -50,6 +50,14 @@ void PoolKernel::Compute(const PoolParam ¶m) { std::vector ksize = param.Ksize(); std::vector strides = param.Strides(); std::vector paddings = param.Paddings(); + + if (param.isGlobalPooling()) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(param.Input()->dims()[i + 2]); + } + } + const int pad_top = paddings[0]; const int pad_left = paddings[1]; const int stride_h = strides[0]; diff --git a/mobile/src/operators/kernel/cl/prior_box_kernel.cpp b/mobile/src/operators/kernel/cl/prior_box_kernel.cpp index 92764b379e8dad8070407fcf012b4bad73fd19a1..c10bfed8d1a21d6578258a28259e883422342085 100644 --- a/mobile/src/operators/kernel/cl/prior_box_kernel.cpp +++ b/mobile/src/operators/kernel/cl/prior_box_kernel.cpp @@ -121,9 +121,9 @@ void PriorBoxKernel::Compute( auto kernel = this->cl_helper_.KernelAt(0); auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.OutputBoxes()); - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; + auto c_block = default_work_size[0]; + auto w = default_work_size[1]; + auto nh = default_work_size[2]; std::vector box_shape({num_priors}); framework::DDim ddim = framework::make_ddim(box_shape); diff --git a/mobile/src/operators/kernel/cl/transpose2_kernel.cpp b/mobile/src/operators/kernel/cl/transpose2_kernel.cpp index 371fbee7106fbc0b5a09d17d3c18628ff57c16a9..a40569574af2653f8592ee68f7f9fc2395e969db 100644 --- a/mobile/src/operators/kernel/cl/transpose2_kernel.cpp +++ b/mobile/src/operators/kernel/cl/transpose2_kernel.cpp @@ -184,6 +184,8 @@ void Transpose2Compute(const Transpose2Param ¶m, cl_context context, output->InitEmptyImage(context, commandQueue, output_tensor->dims()); framework::TensorToCLImage(output_tensor, output, context, commandQueue, kernel1); + delete (input_tensor); + delete (output_tensor); } template <> diff --git a/mobile/src/operators/kernel/fpga/V2/anchor_generator_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/anchor_generator_kernel.cpp index 6046b3d2f0a4a1d273d31aac079244ce3ec3703a..951fbb5f3708bf511bfcbbb0669fb7a56a4eb7c4 100644 --- a/mobile/src/operators/kernel/fpga/V2/anchor_generator_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/anchor_generator_kernel.cpp @@ -45,9 +45,9 @@ bool AnchorGeneratorKernel::Init( if (offset > 0.6) { memcpy(anchors_offset, anchors_offset2, sizeof(anchors_offset)); - std::cout << "anchor generator marker" << std::endl; + DLOG << "anchor generator marker"; } else { - std::cout << "anchor generator rfcn" << std::endl; + DLOG << "anchor generator rfcn"; } int num_anchors = sizeof(anchors_offset) / (sizeof(int) * 4); diff --git a/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp index ecc2577bd6ba9f8f21d4cccb94bdc27466b4a5d1..50179b9cd55d772f035e83fdc54ee98f38de1c54 100644 --- a/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/proposal_kernel.cpp @@ -30,16 +30,12 @@ bool ProposalKernel::Init(ProposalParam *param) { int64_t batch = param->scores_->dims()[0]; auto total = post_nms_top_n * batch; param->rpn_rois_->mutable_data({total, 4}); - param->rpn_probs_->mutable_data({total, 1}); + param->rpn_probs_->mutable_data({total, 1}); param->float_bbox = std::make_shared(); param->float_bbox->Resize(param->bbox_deltas_->dims()); param->float_bbox->init(type_id().hash_code()); fpga::format_fp32_ofm(param->float_bbox.get()); - param->float_score = std::make_shared(); - param->float_score->Resize(param->scores_->dims()); - param->float_score->init(type_id().hash_code()); - fpga::format_fp32_ofm(param->float_score.get()); auto input = param->scores_; param->score_index_ = std::make_shared(); @@ -88,7 +84,7 @@ void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) { template static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas, - Tensor *variances, Tensor *proposals) { + Tensor *proposals) { T *proposals_data = proposals->mutable_data(); int64_t row = all_anchors->dims()[0]; @@ -96,10 +92,6 @@ static inline void BoxCoder(Tensor *all_anchors, Tensor *bbox_deltas, auto *bbox_deltas_data = bbox_deltas->data(); auto *anchor_data = all_anchors->data(); - const T *variances_data = nullptr; - if (variances) { - variances_data = variances->data(); - } for (int64_t i = 0; i < row; ++i) { T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0; @@ -244,10 +236,10 @@ static inline Tensor NMS(Tensor *bbox, Tensor *scores, T nms_threshold, // 4: [xmin ymin xmax ymax] int64_t box_size = bbox->dims()[1]; - std::vector scores_data(num_boxes); - std::copy_n(scores->data(), num_boxes, scores_data.begin()); - std::vector> sorted_indices = - GetSortedScoreIndex(scores_data); + std::vector scores_data(num_boxes); + std::copy_n(scores->data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices = + GetSortedScoreIndex(scores_data); std::vector selected_indices; int selected_num = 0; @@ -284,8 +276,7 @@ std::pair ProposalForOneImage( const Tensor &scores_slice, // [N, 1] const Tensor &score_index, int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, float eta) { - auto *scores_data = scores_slice.data(); - + auto *scores_data = scores_slice.data(); // Sort index Tensor index_t; index_t.Resize({scores_slice.numel()}); @@ -306,17 +297,17 @@ std::pair ProposalForOneImage( } Tensor scores_sel, bbox_sel, anchor_sel, var_sel; - scores_sel.mutable_data({index_t.numel(), 1}); + scores_sel.mutable_data({index_t.numel(), 1}); bbox_sel.mutable_data({index_t.numel(), 4}); anchor_sel.mutable_data({index_t.numel(), 4}); var_sel.mutable_data({index_t.numel(), 4}); - CPUGather(scores_slice, index_t, &scores_sel); + CPUGather(scores_slice, index_t, &scores_sel); CPUGather(bbox_deltas_slice, index_t, &bbox_sel); CPUGather(anchors, index_t, &anchor_sel); Tensor proposals; proposals.mutable_data({index_t.numel(), 4}); - BoxCoder(&anchor_sel, &bbox_sel, nullptr, &proposals); + BoxCoder(&anchor_sel, &bbox_sel, &proposals); ClipTiledBoxes(im_info_slice, &proposals); @@ -325,10 +316,10 @@ std::pair ProposalForOneImage( Tensor scores_filter; bbox_sel.mutable_data({keep.numel(), 4}); - scores_filter.mutable_data({keep.numel(), 1}); + scores_filter.mutable_data({keep.numel(), 1}); CPUGather(proposals, keep, &bbox_sel); - CPUGather(scores_sel, keep, &scores_filter); + CPUGather(scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(bbox_sel, scores_filter); } @@ -341,10 +332,10 @@ std::pair ProposalForOneImage( } proposals.mutable_data({keep_nms.numel(), 4}); // original - scores_sel.mutable_data({keep_nms.numel(), 1}); // original + scores_sel.mutable_data({keep_nms.numel(), 1}); // original CPUGather(bbox_sel, keep_nms, &proposals); - CPUGather(scores_filter, keep_nms, &scores_sel); + CPUGather(scores_filter, keep_nms, &scores_sel); return std::make_pair(proposals, scores_sel); } @@ -368,69 +359,43 @@ void ProposalKernel::Compute(const ProposalParam ¶m) { bbox_height = (uint32_t)(input_bbox->dims()[2]); bbox_width = (uint32_t)(input_bbox->dims()[3]); - std::shared_ptr score_tmp = std::make_shared(); - score_tmp->Resize(param.scores_->dims()); - score_tmp->mutable_data(); - - std::shared_ptr bbox_tmp = std::make_shared(); - bbox_tmp->Resize(param.bbox_deltas_->dims()); - bbox_tmp->mutable_data(); - - auto score_tmp_data = score_tmp->data(); - auto bbox_tmp_data = bbox_tmp->data(); int64_t amount_per_side = score_width * score_height; - int idx = 0; + int alignedCW = fpga::align_to_x(score_width * score_channels, IMAGE_ALIGNMENT); int unalignedCW = score_width * score_channels; fpga::fpga_invalidate(input_score_data, score_height * alignedCW * sizeof(int8_t)); + + Tensor score_tensor = *input_score; for (int h = 0; h < score_height; h++) { for (int w = 0; w < score_width; w++) { - for (int c = 0; c < score_channels; c++) { - if (alignedCW == unalignedCW) { - *(score_tmp_data + c * amount_per_side + score_width * h + w) = - (*(input_score_data++)); - } else { - idx = h * alignedCW + w * score_channels + c; - *(score_tmp_data + c * amount_per_side + score_width * h + w) = - input_score_data[idx]; - } + for (int c = 0; c < score_channels; ++c) { + int dstidx = h*unalignedCW + w*score_channels + c; + int srcidx = h*alignedCW + w*score_channels + c; + score_tensor.data()[dstidx] = input_score_data[srcidx]; } } } + amount_per_side = bbox_width * bbox_height; alignedCW = fpga::align_to_x(bbox_width * bbox_channels, IMAGE_ALIGNMENT); unalignedCW = bbox_width * bbox_channels; fpga::fpga_invalidate(input_bbox_data, bbox_height * alignedCW * sizeof(int8_t)); + + auto bbox_tensor = param.float_bbox.get(); for (int h = 0; h < bbox_height; h++) { for (int w = 0; w < bbox_width; w++) { - for (int c = 0; c < bbox_channels; c++) { - if (alignedCW == unalignedCW) { - *(bbox_tmp_data + c * amount_per_side + bbox_width * h + w) = - (*(input_bbox_data++)); - } else { - idx = h * alignedCW + w * bbox_channels + c; - *(bbox_tmp_data + c * amount_per_side + bbox_width * h + w) = - input_bbox_data[idx]; - } + for (int c = 0; c < bbox_channels; ++c) { + int dstidx = h*unalignedCW + w*bbox_channels + c; + int srcidx = h*alignedCW + w*bbox_channels + c; + bbox_tensor->data()[dstidx] = + (static_cast(input_bbox_data[srcidx]))/127.0* + input_bbox->scale[0]; } } } - - auto score_tensor = param.float_score.get(); - for (int i = 0; i < score_height * score_width * score_channels; i++) { - score_tensor->data()[i] = - score_tmp_data[i] / 127.0 * input_score->scale[0]; - } - auto bbox_tensor = param.float_bbox.get(); - for (int i = 0; i < bbox_height * bbox_width * bbox_channels; i++) { - bbox_tensor->data()[i] = - bbox_tmp_data[i] / 127.0 * input_bbox->scale[0]; - } - auto *scores = param.float_score.get(); - auto *bbox_deltas = param.float_bbox.get(); auto *im_info = param.im_info_; auto anchors = *param.anchors_; auto variances = *param.variances_; @@ -447,37 +412,23 @@ void ProposalKernel::Compute(const ProposalParam ¶m) { float min_size = param.min_size_; float eta = param.eta_; - auto &scores_dim = scores->dims(); - int64_t num = scores_dim[0]; - int64_t c_score = scores_dim[1]; - int64_t h_score = scores_dim[2]; - int64_t w_score = scores_dim[3]; - - auto &bbox_dim = bbox_deltas->dims(); - int64_t c_bbox = bbox_dim[1]; - int64_t h_bbox = bbox_dim[2]; - int64_t w_bbox = bbox_dim[3]; - - // - rpn_rois->mutable_data({bbox_deltas->numel(), 4}); - rpn_roi_probs->mutable_data({scores->numel(), 1}); - + rpn_rois->mutable_data({bbox_tensor->numel()/4, 4}); + rpn_roi_probs->mutable_data({input_score->numel()/4, 1}); framework::LoD lod; lod.resize(1); auto &lod0 = lod[0]; lod0.push_back(0); - anchors.Resize({anchors.numel(), 4}); - variances.Resize({variances.numel(), 4}); + anchors.Resize({anchors.numel()/4, 4}); + variances.Resize({variances.numel()/4, 4}); int64_t num_proposals = 0; - for (int64_t i = 0; i < num; ++i) { + for (int64_t i = 0; i < score_n; ++i) { Tensor im_info_slice = im_info->Slice(i, i + 1); Tensor bbox_deltas_slice = (*bbox_tensor).Slice(i, i + 1); - Tensor scores_slice = (*score_tensor).Slice(i, i + 1); - - bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox, 4}); - scores_slice.Resize({h_score * w_score * c_score, 1}); + Tensor scores_slice = score_tensor.Slice(i, i + 1); + bbox_deltas_slice.Resize({bbox_height * bbox_width * bbox_channels / 4, 4}); + scores_slice.Resize({score_height * score_width * score_channels, 1}); std::pair tensor_pair = ProposalForOneImage( im_info_slice, anchors, variances, bbox_deltas_slice, scores_slice, score_index, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, eta); diff --git a/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp index b8b5202e27369a74430aa130db68501ff6891eec..87948f824e353ef3a3c341a0a1ecee5957e871d6 100644 --- a/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/psroi_pool_kernel.cpp @@ -44,14 +44,14 @@ bool PSRoiPoolKernel::Init(PSRoiPoolParam* param) { } template -void PSROIPoolingForward(const Dtype* bottom_data, const int height, +void PSROIPoolingForward(const int8_t* bottom_data, const int height, const int width, const int input_channel, Dtype* top_data, const int pooled_height, const int pooled_width, const int output_channel, const Dtype* bottom_rois, const Dtype Bin_size_h, const Dtype Bin_size_w, const Dtype roi_start_h, const Dtype roi_start_w, const int pw, const int ph, - const int roi_batch_ind) { + float scale, const int roi_batch_ind) { int hstart = floor(static_cast(ph) * Bin_size_h + roi_start_h); int wstart = floor(static_cast(pw) * Bin_size_w + roi_start_w); int hend = ceil(static_cast(ph + 1) * Bin_size_h + roi_start_h); @@ -64,11 +64,12 @@ void PSROIPoolingForward(const Dtype* bottom_data, const int height, wend = std::min(std::max(wend, 0), width); bool is_empty = (hend <= hstart) || (wend <= wstart); - float sum_pixels_c[output_channel] = {0}; - float pixels_c[output_channel] = {0}; + float avg_pixels_c[output_channel] = {0}; + int sum_pixels_c[output_channel] = {0}; + int8_t pixels_c[output_channel] = {0}; if (!is_empty) { Dtype bin_area = (hend - hstart) * (wend - wstart); - float rec_bin_area = 1 / bin_area; + float scale_fuse = scale / bin_area; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -86,27 +87,21 @@ void PSROIPoolingForward(const Dtype* bottom_data, const int height, } } for (int output_c = 0; output_c < output_channel; output_c++) { - sum_pixels_c[output_c] *= rec_bin_area; + avg_pixels_c[output_c] = sum_pixels_c[output_c] * scale_fuse; } } int output_index_base = (ph * pooled_width + pw) * output_channel; top_data += output_index_base; - memcpy(top_data, sum_pixels_c, output_channel * 4); + memcpy(top_data, avg_pixels_c, output_channel * 4); } template <> void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { auto input_tensor = param.input_x_; auto input_data = input_tensor->data(); - auto Si = input_tensor->scale[0]; - auto float_input_tensor = param.float_input.get(); - auto float_input_data = float_input_tensor->data(); - for (int i = 0; i < float_input_tensor->numel(); i++) { - float_input_data[i] = input_data[i] / 127.0 * Si; - } - - auto* in = float_input_tensor; + auto scale = input_tensor->scale[0] / 127.0; + fpga::fpga_invalidate(input_data, input_tensor->numel() * sizeof(int8_t)); auto* rois = param.input_rois_; auto* out = param.output_; @@ -115,22 +110,19 @@ void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { auto spatial_scale = param.spatial_scale_; auto output_channels = param.output_channels_; - auto in_dims = in->dims(); + auto in_dims = input_tensor->dims(); int batch_size = in_dims[0]; int input_channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; - auto data_nhwc = in->mutable_data(); - framework::DDim dims_out_new = framework::make_ddim( {rois_num, (param.output_)->dims()[1], (((param.output_)->dims()[2])), (param.output_)->dims()[3]}); (param.output_)->Resize(dims_out_new); - const float* input_data_tmp = data_nhwc; // in->data(); framework::Tensor rois_batch_id_list; rois_batch_id_list.Resize({rois_num}); auto rois_batch_id_data = rois_batch_id_list.mutable_data(); @@ -151,12 +143,7 @@ void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { "the channels of input X should equal the product of " "output_channels x pooled_height x pooled_width"); - // calculate batch id index for each roi according to LoD - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; - } - } + auto output_data = out->mutable_data(); auto input_rois = rois->data(); @@ -187,10 +174,10 @@ void PSRoiPoolKernel::Compute(const PSRoiPoolParam& param) { for (int ph = 0; ph < pooled_height; ph++) { for (int pw = 0; pw < pooled_width; pw++) { PSROIPoolingForward( - input_data_tmp, height, width, input_channels, offset_output_data, + input_data, height, width, input_channels, offset_output_data, pooled_height, pooled_width, output_channels, input_rois, bin_size_h, bin_size_w, roi_start_h, roi_start_w, pw, ph, - roi_batch_ind); + scale, roi_batch_ind); } } } diff --git a/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp index ebaf3759400c60c9ecf36467d0eeb7adad140f46..fcf0889b4a66919efc677e211a1da453fd761de4 100644 --- a/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/reshape2_kernel.cpp @@ -25,6 +25,7 @@ bool Reshape2Kernel::Init(Reshape2Param *param) { auto input = const_cast(param->InputX()); auto output = param->Out(); auto shape = param->Shape(); + output->scale[0] = input->scale[0]; auto num_in = framework::product(input->dims()); auto num_shape = framework::product(framework::make_ddim(shape)); diff --git a/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp b/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp index b7615a8891b8292dd4d65c15955a0ee640c2f770..843f249c683717789999db733a04b3da0198bdcb 100755 --- a/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp +++ b/mobile/src/operators/kernel/fpga/V2/softmax_kernel.cpp @@ -81,6 +81,7 @@ void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { auto w = 1; auto c = 1; if (dims.size() == 4) { + n = dims[0]; h = dims[1]; w = dims[2]; c = dims[3]; @@ -90,6 +91,7 @@ void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { h = 1; } } else if (dims.size() == 2) { + n = dims[0]; c = dims[1]; } if ((c == 2) && (in_x->type() == type_id())) { diff --git a/mobile/src/operators/kernel/instancenorm_relu_kernel.h b/mobile/src/operators/kernel/instancenorm_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9a4bedb564ea68e252f65372c38f3cfce13f339f --- /dev/null +++ b/mobile/src/operators/kernel/instancenorm_relu_kernel.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_INSTANCENORM_RELU_OP + +#include +#include "framework/operator.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using framework::OpKernelBase; + +template +class InstanceNormReluKernel + : public OpKernelBase> { + public: + void Compute(const InstanceNormParam ¶m); + bool Init(InstanceNormParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/pixel_shuffle_kernel.h b/mobile/src/operators/kernel/pixel_shuffle_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3f95c866f893f625194afe127dc83851dd874ff7 --- /dev/null +++ b/mobile/src/operators/kernel/pixel_shuffle_kernel.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef LRN_OP + +#include +#ifdef _OPENMP +#include +#endif +#ifdef __ARM_NEON +#include +#include "operators/math/math.h" +#endif +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class PixelShuffleKernel + : public framework::OpKernelBase> { + public: + void Compute(const PixelShuffleParam ¶m); + bool Init(PixelShuffleParam *param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/prior_box_kernel.h b/mobile/src/operators/kernel/prior_box_kernel.h index f691ffb83ae4a400aaa89ecc2731daf1f7048051..c5d561083d13f878b6b46ccd03d4ae3c4d1f233f 100644 --- a/mobile/src/operators/kernel/prior_box_kernel.h +++ b/mobile/src/operators/kernel/prior_box_kernel.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "framework/operator.h" #include "operators/math/transform.h" @@ -77,6 +78,8 @@ class DensityPriorBoxParam : public OpParam { densities_ = GetAttr>("densities", attrs); } + ~DensityPriorBoxParam() {} + const GType *Input() const { return input_; } const GType *InputImage() const { return input_image_; } GType *OutputBoxes() const { return output_boxes_; } @@ -90,6 +93,8 @@ class DensityPriorBoxParam : public OpParam { const vector &FixedRatios() const { return fixed_ratios_; } const vector &Densities() const { return densities_; } const vector &Variances() const { return variances_; } + GType *getNewDensity() const { return new_density.get(); } + void setNewDensity(GType *newDensity) { new_density.reset(newDensity); } public: GType *input_; @@ -105,6 +110,7 @@ class DensityPriorBoxParam : public OpParam { vector fixed_ratios_; vector densities_; vector variances_; + std::shared_ptr new_density; }; DECLARE_KERNEL(DensityPriorBox, DensityPriorBoxParam); diff --git a/mobile/src/operators/math/gemm/gemm1x1s1.cpp b/mobile/src/operators/math/gemm/gemm1x1s1.cpp index fd997dc48d830e6612d6a3b2f55b0f929c8a7e77..2fd78fa18923248a9a8b12d3ea8bef444b664733 100644 --- a/mobile/src/operators/math/gemm/gemm1x1s1.cpp +++ b/mobile/src/operators/math/gemm/gemm1x1s1.cpp @@ -1518,10 +1518,12 @@ void sgemm_conv_6x8(const float* A_packed, const float* B, const float* bias, (l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH)); x_block /= NBLOCK; x_block *= NBLOCK; - int x_num = (N + (x_block - 1)) / x_block; - x_block = (N + x_num - 1) / x_num; - x_block = (x_block + NBLOCK - 1) / NBLOCK; - x_block *= NBLOCK; + if (x_block != 0) { + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + } x_block = x_block < NBLOCK ? NBLOCK : x_block; int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; int tail_pre = (K & (KBLOCK - 1)); diff --git a/mobile/src/operators/op_param.h b/mobile/src/operators/op_param.h index 07413fb422759766957751e509ebf0b77d7e2364..2651a0f69766544a0ec09250248682c5b559ef01 100644 --- a/mobile/src/operators/op_param.h +++ b/mobile/src/operators/op_param.h @@ -57,6 +57,21 @@ using std::vector; using framework::DtypeTensorTrait; +template +class CLImageDeleter { + typedef typename DtypeTensorTrait::gtype GType; + + public: + void operator()(GType *ptr) { +#ifdef PADDLE_MOBILE_CL + framework::CLImage *image = dynamic_cast(ptr); + if (image) { + delete image; + } +#endif + } +}; + class OpParam { public: OpParam(const VariableNameMap &inputs, const VariableNameMap &outputs, @@ -850,6 +865,8 @@ class BatchNormParam : public OpParam { // is_test_ = GetAttr("is_test", attrs); } + ~BatchNormParam() {} + const GType *InputX() const { return input_x_; } GType *OutputY() const { return output_y_; } @@ -870,13 +887,17 @@ class BatchNormParam : public OpParam { const string &DataFormat() const { return data_format_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } private: GType *input_x_; @@ -889,8 +910,8 @@ class BatchNormParam : public OpParam { float momentum_; bool is_test_; string data_format_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -2076,6 +2097,9 @@ class FusionConvAddBNReluParam : public ConvParam { momentum_ = OpParam::GetAttr("momentum", attrs); this->output_ = OpParam::OutFrom(outputs, *scope); } + + ~FusionConvAddBNReluParam() {} + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } @@ -2092,13 +2116,17 @@ class FusionConvAddBNReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } protected: GType *bias_; @@ -2109,8 +2137,8 @@ class FusionConvAddBNReluParam : public ConvParam { GType *input_variance_; float epsilon_; float momentum_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -2143,6 +2171,8 @@ class FusionConvBNAddReluParam : public ConvParam { } this->output_ = OpParam::OutFrom(outputs, *scope); } + + ~FusionConvBNAddReluParam() {} GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } @@ -2159,13 +2189,17 @@ class FusionConvBNAddReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } protected: GType *bias_; @@ -2176,8 +2210,8 @@ class FusionConvBNAddReluParam : public ConvParam { GType *input_variance_; float epsilon_; float momentum_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; std::string keyBNY_; std::string keyX_; std::string keyY_; @@ -2216,13 +2250,17 @@ class FusionConvBNParam : public ConvParam { const float &Momentum() const { return momentum_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } protected: GType *input_bias_; @@ -2231,8 +2269,8 @@ class FusionConvBNParam : public ConvParam { GType *input_variance_; float epsilon_; float momentum_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -2273,13 +2311,17 @@ class FusionConvAddBNParam : public ConvParam { const float &Momentum() const { return momentum_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } protected: GType *bias_; @@ -2290,8 +2332,8 @@ class FusionConvAddBNParam : public ConvParam { GType *input_variance_; float epsilon_; float momentum_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -2315,6 +2357,8 @@ class FusionDWConvBNReluParam : public ConvParam { this->output_ = OpParam::OutFrom(outputs, *scope); } + ~FusionDWConvBNReluParam() {} + const GType *InputBias() const { return input_bias_; } const GType *InputMean() const { return input_mean_; } @@ -2327,13 +2371,17 @@ class FusionDWConvBNReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } protected: GType *input_bias_; @@ -2342,8 +2390,8 @@ class FusionDWConvBNReluParam : public ConvParam { GType *input_variance_; float epsilon_; float momentum_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -2384,6 +2432,8 @@ class FusionConvBNReluParam : public ConvParam { this->output_ = OpParam::OutFrom(outputs, *scope); } + ~FusionConvBNReluParam() {} + const GType *InputBias() const { return input_bias_; } const GType *InputMean() const { return input_mean_; } @@ -2396,13 +2446,17 @@ class FusionConvBNReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const GType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_.get(); } - const GType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_.get(); } protected: GType *input_bias_; @@ -2411,8 +2465,8 @@ class FusionConvBNReluParam : public ConvParam { GType *input_variance_; float epsilon_; float momentum_; - GType *new_bias_; - GType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -2637,13 +2691,17 @@ class FusionDeconvAddBNParam : public ConvTransposeParam { const bool &IsTest() const { return is_test_; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(RType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(RType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const RType *NewScale() const { return new_scale_; } + const RType *NewScale() const { return new_scale_.get(); } - const RType *NewBias() const { return new_bias_; } + const RType *NewBias() const { return new_bias_.get(); } protected: RType *output_; @@ -2654,8 +2712,8 @@ class FusionDeconvAddBNParam : public ConvTransposeParam { float epsilon_; float momentum_; bool is_test_; - RType *new_bias_; - RType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif #ifdef FUSION_DECONVBNRELU_OP @@ -2693,13 +2751,17 @@ class FusionDeconvBNReluParam : public ConvTransposeParam { const bool &IsTest() const { return is_test_; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(RType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(RType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const RType *NewScale() const { return new_scale_; } + const RType *NewScale() const { return new_scale_.get(); } - const RType *NewBias() const { return new_bias_; } + const RType *NewBias() const { return new_bias_.get(); } protected: RType *output_; @@ -2710,8 +2772,8 @@ class FusionDeconvBNReluParam : public ConvTransposeParam { float epsilon_; float momentum_; bool is_test_; - RType *new_bias_; - RType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif #ifdef FUSION_DECONVADDBNRELU_OP @@ -2750,13 +2812,17 @@ class FusionDeconvAddBNReluParam : public ConvTransposeParam { const bool &IsTest() const { return is_test_; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(RType *new_scale) { + new_scale_.reset(new_scale, CLImageDeleter()); + } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(RType *new_bias) { + new_bias_.reset(new_bias, CLImageDeleter()); + } - const RType *NewScale() const { return new_scale_; } + const RType *NewScale() const { return new_scale_.get(); } - const RType *NewBias() const { return new_bias_; } + const RType *NewBias() const { return new_bias_.get(); } protected: RType *output_; @@ -2767,8 +2833,8 @@ class FusionDeconvAddBNReluParam : public ConvTransposeParam { float epsilon_; float momentum_; bool is_test_; - RType *new_bias_; - RType *new_scale_; + std::shared_ptr new_bias_; + std::shared_ptr new_scale_; }; #endif @@ -3562,5 +3628,35 @@ class EXPParam : public OpParam { GType *out_; }; #endif + +#ifdef PIXEL_SHUFFLE_OP +template +class PixelShuffleParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + PixelShuffleParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + input_x_ = InputXFrom(inputs, *scope); + out_ = OutFrom(outputs, *scope); + upscale_factor_ = GetAttr("upscale_factor", attrs); + } + + const GType *InputX() const { return input_x_; } + + GType *Out() const { return out_; } + + const int &upscale_factor() const { return upscale_factor_; } + + private: + GType *input_x_; + GType *out_; + int upscale_factor_; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/mobile/src/operators/pixel_shuffle_op.cpp b/mobile/src/operators/pixel_shuffle_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9105a72cfbddddbe39ecbbe2f35da204ba118f18 --- /dev/null +++ b/mobile/src/operators/pixel_shuffle_op.cpp @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PIXEL_SHUFFLE_OP + +#include "operators/pixel_shuffle_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void PixelShuffleOp::InferShape() const { + auto x_dims = this->param_.InputX()->dims(); + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; + int upscale_factor = this->param_.upscale_factor(); + this->param_.Out()->Resize( + framework::make_ddim({n, c / (upscale_factor * upscale_factor), + h * upscale_factor, w * upscale_factor})); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(pixel_shuffle, ops::PixelShuffleOp); +#endif + +#endif diff --git a/mobile/src/operators/pixel_shuffle_op.h b/mobile/src/operators/pixel_shuffle_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a1c6f8e1adb0c4f52e54974080aaa80e6ebe295f --- /dev/null +++ b/mobile/src/operators/pixel_shuffle_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PIXEL_SHUFFLE_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/pixel_shuffle_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +template +class PixelShuffleOp : public framework::OperatorWithKernel< + DeviceType, PixelShuffleParam, + operators::PixelShuffleKernel> { + public: + PixelShuffleOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, framework::Scope *scope) + : framework::OperatorWithKernel< + DeviceType, PixelShuffleParam, + operators::PixelShuffleKernel>(type, inputs, outputs, + attrs, scope) {} + void InferShape() const override; + + protected: +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/pass/memory_optimize_super.cpp b/mobile/src/pass/memory_optimize_cl.cpp similarity index 57% rename from mobile/src/pass/memory_optimize_super.cpp rename to mobile/src/pass/memory_optimize_cl.cpp index 344b88b02ed915570f50a4f0eebdc9949c338ddb..355123349d645075fd2ccc37144144da7d332a8f 100644 --- a/mobile/src/pass/memory_optimize_super.cpp +++ b/mobile/src/pass/memory_optimize_cl.cpp @@ -12,21 +12,21 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_MOBILE_CL -#include "pass/memory_optimize_super.h" +#include "pass/memory_optimize_cl.h" #include #include "framework/cl/cl_image.h" #include "framework/lod_tensor.h" namespace paddle_mobile { namespace pass { -void MemoryOptPassSuper::AppendBlockVars(const framework::BlockDesc *block) { +void MemoryOptPassCl::AppendBlockVars(const framework::BlockDesc *block) { // block_vars_.clear(); for (const auto var : block->Vars()) { block_vars_[var->Name()] = var.get(); } } -bool MemoryOptPassSuper::IsPersistable(const std::string name) { +bool MemoryOptPassCl::IsPersistable(const std::string name) { const auto it = block_vars_.find(name); if (it != block_vars_.end()) { return it->second->Persistable(); @@ -34,7 +34,7 @@ bool MemoryOptPassSuper::IsPersistable(const std::string name) { return false; } -ClVarNode *MemoryOptPassSuper::CreateNode(const std::string name) { +ClVarNode *MemoryOptPassCl::CreateNode(const std::string name) { auto it = created_nodes_.find(name); if (it != created_nodes_.end()) { ++(it->second->count); @@ -48,7 +48,7 @@ ClVarNode *MemoryOptPassSuper::CreateNode(const std::string name) { return var; } -void MemoryOptPassSuper::operator()( +void MemoryOptPassCl::operator()( const framework::ProgramDesc *program, framework::Scope *scope, MemoryOptimizationLevel memory_optimization_level, framework::DDim target_dims) { @@ -82,6 +82,8 @@ void MemoryOptPassSuper::operator()( DLOG << "op_desc->Type(): " << op->Type(); for (const auto &outputs : op->GetOutputs()) { for (const auto &output : outputs.second) { + // not a persistable and not a exclude one ,then add it to + // analysis_nodes if (!IsPersistable(output) && std::find(exclude_var_names.begin(), exclude_var_names.end(), output) == exclude_var_names.end()) { @@ -93,6 +95,8 @@ void MemoryOptPassSuper::operator()( } for (const auto &inputs : op->GetInputs()) { for (const auto &input : inputs.second) { + // not a persistable and not a exclude one ,then add it to + // analysis_nodes if (!IsPersistable(input) && std::find(exclude_var_names.begin(), exclude_var_names.end(), input) == exclude_var_names.end()) { @@ -128,6 +132,7 @@ void MemoryOptPassSuper::operator()( bool reused = false; // find out a possable reuse list for (auto &list : reused_nodes_) { + // reference count = 0 and not in fetch list if (list.back()->count == 0 && std::find(fetch_var_nodes.begin(), fetch_var_nodes.end(), list.back()) == fetch_var_nodes.end()) { @@ -146,60 +151,115 @@ void MemoryOptPassSuper::operator()( node->visited = true; node->count -= 1; } - // shared data within all variables in the same reused list ShareData(scope, memory_optimization_level, target_dims); } } -void MemoryOptPassSuper::ShareData( +void MemoryOptPassCl::ShareData( framework::Scope *scope, MemoryOptimizationLevel memory_optimization_level, framework::DDim target_dims) const { // shared data within all variables in the same reused list + cl_context context = scope->GetCLScpoe()->Context(); + cl_command_queue command_queue = scope->GetCLScpoe()->CommandQueue(); + for (const auto &list : reused_nodes_) { DLOG << "\n"; DLOG << "gpu . share memory within these variables"; - // find max dims - int64_t max_numl = -1; + int64_t x_based_max_numl = -1; + int64_t y_based_max_numl = -1; + int64_t x_based_max_x = -1; + int64_t x_based_max_y = -1; + int64_t y_based_max_x = -1; + int64_t y_based_max_y = -1; - framework::CLImage *reuse_tensor = nullptr; - DLOG << "resused nodes group ----------"; + framework::CLImage *x_based_reuse_tensor = nullptr; + framework::CLImage *y_based_reuse_tensor = nullptr; for (const auto &node : list) { auto *var = scope->Var(node->name); auto *tensor = var->template GetMutable(); const int64_t numl = tensor->numel(); - if (max_numl < numl) { - max_numl = numl; - reuse_tensor = tensor; + auto origin_tensor_dims = tensor->dims(); + + // for super ,hack origin dims + if (target_dims.size() == 4) { + PADDLE_MOBILE_ENFORCE(origin_tensor_dims.size() == 4, + "tensor dims must be equal to 4"); + origin_tensor_dims = {origin_tensor_dims[0], origin_tensor_dims[1], + target_dims[2], target_dims[3]}; + tensor->Resize(origin_tensor_dims); } - DLOG << node->name << " ----dims: " << tensor->dims() - << "----numl----: " << numl; - } - if (reuse_tensor == nullptr) { - return; + const framework::DDim &image_dims = + normal_converter->InitImageDimInfoWith(origin_tensor_dims); + int64_t image_dims_x = image_dims[0]; + int64_t image_dims_y = image_dims[1]; + // classify memory into two parts + if (image_dims_x > image_dims_y) { + // choose a biggest tensor for reuse + if (x_based_max_numl < numl) { + x_based_max_numl = numl; + x_based_reuse_tensor = tensor; + } + x_based_max_x = std::max(x_based_max_x, image_dims_x); + x_based_max_y = std::max(x_based_max_y, image_dims_y); + } else { + // choose a biggest tensor for reuse + if (y_based_max_numl < numl) { + y_based_max_numl = numl; + y_based_reuse_tensor = tensor; + } + y_based_max_x = std::max(y_based_max_x, image_dims_x); + y_based_max_y = std::max(y_based_max_y, image_dims_y); + } } - const framework::DDim &dims = reuse_tensor->dims(); - cl_context context = scope->GetCLScpoe()->Context(); - cl_command_queue command_queue = scope->GetCLScpoe()->CommandQueue(); - - framework::DDim reshaped_dim = framework::make_ddim( - {dims[0], dims[1], target_dims[2], target_dims[3]}); + PADDLE_MOBILE_ENFORCE( + x_based_reuse_tensor != nullptr || y_based_reuse_tensor != nullptr, + "x_based_reuse_tensor and y_based_reuse_tensor can not be null at same " + "time"); - DLOG << "target dims : " << target_dims; - DLOG << "reshaped_dim : " << reshaped_dim; - reuse_tensor->InitFakeSizeImage(context, command_queue, reshaped_dim, - reshaped_dim); + // init x based shared cl mem + if (x_based_reuse_tensor != nullptr) { + const framework::DDim &x_reuse_dims = x_based_reuse_tensor->dims(); + x_based_reuse_tensor->InitFakeSizeImage( + context, command_queue, x_reuse_dims, {x_based_max_x, x_based_max_y}); + } + // init y based shared cl mem + if (y_based_reuse_tensor != nullptr) { + const framework::DDim &y_reuse_dims = y_based_reuse_tensor->dims(); + y_based_reuse_tensor->InitFakeSizeImage( + context, command_queue, y_reuse_dims, {y_based_max_x, y_based_max_y}); + } + // share mem for (const auto &node : list) { auto *var = scope->Var(node->name); auto *tensor = var->template GetMutable(); - const framework::DDim &temp_dim = tensor->dims(); - framework::DDim need_dims = framework::make_ddim( - {temp_dim[0], temp_dim[1], target_dims[2], target_dims[3]}); - tensor->InitWithExitedMem(context, command_queue, need_dims, - *reuse_tensor); + auto need_dims = tensor->dims(); + + // for super ,hack origin dims + if (target_dims.size() == 4) { + need_dims = {need_dims[0], need_dims[1], target_dims[2], + target_dims[3]}; + } + + const framework::DDim &need_image_dims = + normal_converter->InitImageDimInfoWith(need_dims); + int64_t image_dims_x = need_image_dims[0]; + int64_t image_dims_y = need_image_dims[1]; + + if (image_dims_x > image_dims_y) { + PADDLE_MOBILE_ENFORCE(x_based_reuse_tensor != nullptr, + "x_based_reuse_tensor not null here"); + tensor->InitWithExistMem(context, command_queue, need_dims, + *x_based_reuse_tensor); + } else { + PADDLE_MOBILE_ENFORCE(y_based_reuse_tensor != nullptr, + "y_based_reuse_tensor not null here"); + tensor->InitWithExistMem(context, command_queue, need_dims, + *y_based_reuse_tensor); + } } } } diff --git a/mobile/src/pass/memory_optimize_super.h b/mobile/src/pass/memory_optimize_cl.h similarity index 84% rename from mobile/src/pass/memory_optimize_super.h rename to mobile/src/pass/memory_optimize_cl.h index 08af29919f99253765412a2dae81fc95d9f5e62c..aafdda4b34cce4db7be1e0bc836b83401bdedde1 100644 --- a/mobile/src/pass/memory_optimize_super.h +++ b/mobile/src/pass/memory_optimize_cl.h @@ -19,10 +19,12 @@ limitations under the License. */ #include #include #include +#include "framework/cl/cl_image_converter.h" #include "framework/lod_tensor.h" #include "framework/program/program.h" #include "pass/pass_base.h" -// use for super resulotion to be extend for all opencl + +// use for opencl namespace paddle_mobile { namespace pass { @@ -34,19 +36,20 @@ typedef struct { // MemoryOptPass will analyze the program, and reuse memory between // variables as much as possible -class MemoryOptPassSuper : public PassBase { +class MemoryOptPassCl : public PassBase { public: - MemoryOptPassSuper() {} - virtual ~MemoryOptPassSuper() { + MemoryOptPassCl() {} + virtual ~MemoryOptPassCl() { for (auto &it : created_nodes_) { delete it.second; } + delete normal_converter; } void operator()(const framework::ProgramDesc *program, framework::Scope *scope, MemoryOptimizationLevel memory_optimization_level, - framework::DDim dims); + framework::DDim dims = {}); void AppendBlockVars(const framework::BlockDesc *block); @@ -63,6 +66,8 @@ class MemoryOptPassSuper : public PassBase { std::vector> reused_nodes_; std::unordered_map created_nodes_; std::unordered_map block_vars_; + paddle_mobile::framework::CLImageConverterNormal *normal_converter = + new paddle_mobile::framework::CLImageConverterNormal(); }; } // namespace pass diff --git a/mobile/test/CMakeLists.txt b/mobile/test/CMakeLists.txt index 056ede3fb9a2113ca932daee8afa0affbbe184db..ccc609ff8300b9220285d73527145379edb30b5a 100644 --- a/mobile/test/CMakeLists.txt +++ b/mobile/test/CMakeLists.txt @@ -534,9 +534,16 @@ if (ENABLE_ALL_TEST) # gen test ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-net paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-net-performance net/test_net_performance.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-net-performance paddle-mobile) endif () else() # gen test ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-net paddle-mobile) + + ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h) + target_link_libraries(test-net-benchmark paddle-mobile) endif() diff --git a/mobile/test/fpga/test_marker_api.cpp b/mobile/test/fpga/test_marker_api.cpp index 29cf6561df1afae1fae5494bcb8fa6606d3999df..19e051a38d2b853dda8a9364ac1ea64ae3f38acb 100644 --- a/mobile/test/fpga/test_marker_api.cpp +++ b/mobile/test/fpga/test_marker_api.cpp @@ -104,7 +104,7 @@ void dump_stride_float(std::string filename, void dump_stride(std::string filename, paddle_mobile::PaddleTensor input_tensor) { - if (input_tensor.dtypeid == type_id().hash_code()) { + if (input_tensor.dtypeid == PaddlekTypeId_t::paddle_float) { dump_stride_float(filename, input_tensor); } else { std::cout << "only support dumping float data" << std::endl; @@ -156,13 +156,13 @@ int main() { std::cout << "Finishing initializing data" << std::endl; struct PaddleTensor t_img_info, t_img; - t_img_info.dtypeid = type_id().hash_code(); + t_img_info.dtypeid = PaddlekTypeId_t::paddle_float; t_img_info.layout = LAYOUT_HWC; t_img_info.shape = std::vector({1, 3}); t_img_info.name = "Image information"; t_img_info.data.Reset(img_info, 3 * sizeof(float)); - t_img.dtypeid = type_id().hash_code(); + t_img.dtypeid = PaddlekTypeId_t::paddle_float; // quantize(&img, img_length); // t_img.dtypeid = typeid(int8_t); t_img.layout = LAYOUT_HWC; @@ -209,7 +209,7 @@ int main() { std::cout << "Finishing initializing data" << std::endl; struct PaddleTensor t_img1; - t_img1.dtypeid = type_id().hash_code(); + t_img1.dtypeid = PaddlekTypeId_t::paddle_float; t_img1.layout = LAYOUT_HWC; t_img1.shape = std::vector({1, 14, 14, 144}); t_img1.name = "Image information"; diff --git a/mobile/test/fpga/test_mobilenet_api.cpp b/mobile/test/fpga/test_mobilenet_api.cpp index 09392e9d38f8a0312715da2aae74e1e039c2452e..5c0a594ca8c4692b7c0a07afb72bf260b3c6086d 100644 --- a/mobile/test/fpga/test_mobilenet_api.cpp +++ b/mobile/test/fpga/test_mobilenet_api.cpp @@ -96,7 +96,7 @@ void dump_stride_float(std::string filename, PaddleTensor input_tensor) { } void dump_stride(std::string filename, PaddleTensor input_tensor) { - if (input_tensor.dtypeid == type_id().hash_code()) { + if (input_tensor.dtypeid == PaddlekTypeId_t::paddle_float) { dump_stride_float(filename, input_tensor); } else { std::cout << "only support dumping float data" << std::endl; @@ -131,7 +131,7 @@ int main() { std::cout << "Finishing initializing data" << std::endl; struct PaddleTensor t_img; t_img.dtype = FLOAT32; - t_img.dtypeid = type_id().hash_code(); + t_img.dtypeid = PaddlekTypeId_t::paddle_float; // quantize(&img, img_length); // t_img.dtype = INT8; // t_img.dtypeid = typeid(int8_t); diff --git a/mobile/test/fpga/test_rfcn_api.cpp b/mobile/test/fpga/test_rfcn_api.cpp index e86743cc7e7b8fdb48a034aebd5a53224674b6f5..b8b031bf59a47b7ac9a0f71b828596aaad15e3f1 100644 --- a/mobile/test/fpga/test_rfcn_api.cpp +++ b/mobile/test/fpga/test_rfcn_api.cpp @@ -117,13 +117,13 @@ int main() { std::cout << "Finishing initializing data" << std::endl; struct PaddleTensor t_img_info, t_img; - t_img.dtypeid = type_id().hash_code(); + t_img.dtypeid = PaddlekTypeId_t::paddle_float; t_img_info.layout = LAYOUT_HWC; t_img_info.shape = std::vector({1, 3}); t_img_info.name = "Image information"; t_img_info.data.Reset(img_info, 3 * sizeof(float)); - t_img.dtypeid = type_id().hash_code(); + t_img.dtypeid = PaddlekTypeId_t::paddle_float; t_img.layout = LAYOUT_HWC; t_img.shape = std::vector({1, 432, 1280, 3}); t_img.name = "Image information"; diff --git a/mobile/test/fpga/test_yolo_api.cpp b/mobile/test/fpga/test_yolo_api.cpp index f8f1a48abca74b76f77e7f8d59371b7672a8b453..161d695418654b2198a59889ee44583901d25c2b 100644 --- a/mobile/test/fpga/test_yolo_api.cpp +++ b/mobile/test/fpga/test_yolo_api.cpp @@ -95,7 +95,7 @@ void dump_stride_float(std::string filename, PaddleTensor input_tensor) { } void dump_stride(std::string filename, PaddleTensor input_tensor) { - if (input_tensor.dtypeid == type_id().hash_code()) { + if (input_tensor.dtypeid == PaddlekTypeId_t::paddle_float) { dump_stride_float(filename, input_tensor); } else { std::cout << "only support dumping float data" << std::endl; @@ -134,7 +134,7 @@ int main() { // t_img.dtypeid = type_id().hash_code(); quantize(&img, img_length); t_img.dtype = INT8; - t_img.dtypeid = type_id().hash_code(); + t_img.dtypeid = PaddlekTypeId_t::paddle_int8_t; t_img.layout = LAYOUT_HWC; t_img.shape = std::vector({1, 256, 416, 3}); t_img.name = "Image information"; diff --git a/mobile/test/net/test_inference_pre_post.cpp b/mobile/test/net/test_inference_pre_post.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39dc9429208e260b0ac1fe1edeb6dcfa1c9a4112 --- /dev/null +++ b/mobile/test/net/test_inference_pre_post.cpp @@ -0,0 +1,84 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "io/paddle_inference_api.h" + +using namespace paddle_mobile; // NOLINT + +PaddleMobileConfig GetConfig() { + PaddleMobileConfig config; + config.precision = PaddleMobileConfig::FP32; + config.device = PaddleMobileConfig::kGPU_CL; + config.pre_post_type = PaddleMobileConfig::UINT8_255; + + config.prog_file = "../models/superv2/model"; + config.param_file = "../models/superv2/params"; + config.lod_mode = false; + config.load_when_predict = true; + config.cl_path = "/data/local/tmp/bin"; + return config; +} + +int main() { + PaddleMobileConfig config = GetConfig(); + auto predictor = + CreatePaddlePredictor(config); + + int input_length = 1 * 1 * 300 * 300; + int output_length = input_length; + + uint8_t data_ui[300 * 300]; + for (int i = 0; i < input_length; ++i) { + data_ui[i] = i % 256; + } + + PaddleTensor input; + input.shape = std::vector({1, 1, 300, 300}); + input.data = PaddleBuf(data_ui, sizeof(data_ui)); + input.dtype = PaddleDType::UINT8; + input.layout = LayoutType::LAYOUT_CHW; + std::vector inputs(1, input); + + PaddleTensor output; + output.shape = std::vector({}); + output.data = PaddleBuf(); + output.dtype = PaddleDType::UINT8; + output.layout = LayoutType::LAYOUT_CHW; + std::vector outputs(1, output); + + std::cout << " print input : " << std::endl; + int stride = input_length / 20; + stride = stride > 0 ? stride : 1; + for (size_t j = 0; j < input_length; j += stride) { + std::cout << (unsigned)data_ui[j] << " "; + } + std::cout << std::endl; + + predictor->Run(inputs, &outputs); + + std::cout << " print output : " << std::endl; + uint8_t *data_o = static_cast(outputs[0].data.data()); + int numel = outputs[0].data.length() / sizeof(uint8_t); + stride = numel / 20; + stride = stride > 0 ? stride : 1; + for (size_t j = 0; j < numel; j += stride) { + std::cout << (unsigned)data_o[j] << " "; + } + std::cout << std::endl; + + return 0; +} diff --git a/mobile/test/net/test_net.cpp b/mobile/test/net/test_net.cpp index a1c234dbca31d2211138d8e26d1af72d81debdb4..3d5386513be09adc50b153bde6335f7cac00c107 100644 --- a/mobile/test/net/test_net.cpp +++ b/mobile/test/net/test_net.cpp @@ -31,6 +31,10 @@ void test(int argc, char *argv[]) { arg_index++; bool enable_memory_optimization = std::stoi(argv[arg_index]) == 1; arg_index++; + bool quantification = std::stoi(argv[arg_index]) == 1; + arg_index++; + int quantification_fold = std::stoi(argv[arg_index]); + arg_index++; paddle_mobile::PaddleMobileConfigInternal config; config.memory_optimization_level = enable_memory_optimization ? MemoryOptimizationWithoutFeeds @@ -98,7 +102,8 @@ void test(int argc, char *argv[]) { auto time1 = time(); if (paddle_mobile.Load("./checked_model/model", "./checked_model/params", - fuse, false, 1, true)) { + fuse, quantification, 1, is_lod, + quantification_fold)) { auto time2 = time(); std::cout << "auto-test" << " load-time-cost :" << time_diff(time1, time2) << "ms" @@ -181,8 +186,8 @@ void test(int argc, char *argv[]) { if (len == 0) { continue; } - int width = cl_image->ImageDims()[0]; - int height = cl_image->ImageDims()[1]; + size_t width = cl_image->ImageDims()[0]; + size_t height = cl_image->ImageDims()[1]; paddle_mobile::framework::half_t *image_data = new paddle_mobile::framework::half_t[height * width * 4]; cl_int err; diff --git a/mobile/test/net/test_net_benchmark.cpp b/mobile/test/net/test_net_benchmark.cpp index f874683148e95180a5c1376e8d6a3233a2cabe1b..396f293f760a3bd8c134c3e3ab34b9f1e2b34219 100644 --- a/mobile/test/net/test_net_benchmark.cpp +++ b/mobile/test/net/test_net_benchmark.cpp @@ -17,20 +17,26 @@ limitations under the License. */ #include "../test_include.h" int main() { +#ifdef PADDLE_MOBILE_CL + paddle_mobile::PaddleMobileConfigInternal config; + config.load_when_predict = false; + paddle_mobile::PaddleMobile paddle_mobile(config); +#else paddle_mobile::PaddleMobile paddle_mobile; +#endif paddle_mobile.SetThreadNum(1); auto time1 = paddle_mobile::time(); - auto isok = - paddle_mobile.Load(std::string(g_yolo) + "/model", - std::string(g_yolo) + "/params", true, false, 1, true); + auto isok = paddle_mobile.Load(std::string(g_mobilenet_combined) + "/model", + std::string(g_mobilenet_combined) + "/params", + true, false, 1, false); if (isok) { auto time2 = paddle_mobile::time(); - std::cout << "load cost :" << paddle_mobile::time_diff(time1, time1) << "ms" + std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms" << std::endl; std::vector input; - std::vector dims{1, 3, 64, 64}; + std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224_banana, &input, dims); paddle_mobile::framework::DDim ddim = diff --git a/mobile/test/net/test_net_performance.cpp b/mobile/test/net/test_net_performance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95e72ea7a77d38f07abd391326120b136b4cc499 --- /dev/null +++ b/mobile/test/net/test_net_performance.cpp @@ -0,0 +1,197 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "../test_helper.h" +#include "../test_include.h" + +void test(int argc, char *argv[]); + +int main(int argc, char *argv[]) { + test(argc, argv); + return 0; +} + +void test(int argc, char *argv[]) { + int arg_index = 1; + bool fuse = std::stoi(argv[arg_index]) == 1; + arg_index++; + bool enable_memory_optimization = std::stoi(argv[arg_index]) == 1; + arg_index++; + bool quantification = std::stoi(argv[arg_index]) == 1; + arg_index++; + int quantification_fold = std::stoi(argv[arg_index]); + arg_index++; + paddle_mobile::PaddleMobileConfigInternal config; + config.memory_optimization_level = enable_memory_optimization + ? MemoryOptimizationWithoutFeeds + : NoMemoryOptimization; + + // save obfuscated model + // config.model_obfuscate_key = "asdf"; + // std::ofstream out_file("new-params", std::ofstream::binary); + // char *out_data = ReadFileToBuff("./checked_model/params"); + // int len = GetFileLength("./checked_model/params"); + // out_file.write(out_data, len); + // out_file.close(); + +#ifdef PADDLE_MOBILE_CL + // config.load_when_predict = true; + paddle_mobile::PaddleMobile paddle_mobile(config); + paddle_mobile.SetCLPath("/data/local/tmp/bin"); + std::cout << "testing opencl performance " << std::endl; +#else + paddle_mobile::PaddleMobile paddle_mobile(config); + paddle_mobile.SetThreadNum(1); + std::cout << "testing cpu performance " << std::endl; +#endif + + int dim_count = std::stoi(argv[arg_index]); + arg_index++; + int size = 1; + std::vector dims; + for (int i = 0; i < dim_count; i++) { + int64_t dim = std::stoi(argv[arg_index + i]); + size *= dim; + dims.push_back(dim); + } + arg_index += dim_count; + + bool is_lod = std::stoi(argv[arg_index]) == 1; + arg_index++; + paddle_mobile::framework::LoD lod{{}}; + if (is_lod) { + int lod_count = std::stoi(argv[arg_index]); + arg_index++; + for (int i = 0; i < lod_count; i++) { + int dim = std::stoi(argv[arg_index + i]); + lod[0].push_back(dim); + } + arg_index += lod_count; + } + + int var_count = std::stoi(argv[arg_index]); + arg_index++; + bool is_sample_step = std::stoi(argv[arg_index]) == 1; + arg_index++; + int sample_arg = std::stoi(argv[arg_index]); + int sample_step = sample_arg; + int sample_num = sample_arg; + arg_index++; + std::vector var_names; + for (int i = 0; i < var_count; i++) { + std::string var_name = argv[arg_index + i]; + var_names.push_back(var_name); + } + arg_index += var_count; + bool check_shape = std::stoi(argv[arg_index]) == 1; + arg_index++; + + int run_times = std::stoi(argv[arg_index]); + arg_index++; + + bool warm_up = std::stoi(argv[arg_index]) == 1; + arg_index++; + + auto time1 = time(); + if (paddle_mobile.Load("./checked_model/model", "./checked_model/params", + fuse, quantification, 1, is_lod, + quantification_fold)) { + auto time2 = time(); + std::cout << "auto-test" + << " load-time-cost :" << time_diff(time1, time2) << "ms" + << std::endl; + + float *input_data_array = new float[size]; + std::ifstream in("input.txt", std::ios::in); + for (int i = 0; i < size; i++) { + float num; + in >> num; + input_data_array[i] = num; + } + in.close(); + + auto time3 = time(); + + paddle_mobile::framework::Tensor input_tensor( + input_data_array, paddle_mobile::framework::make_ddim(dims)); + auto time4 = time(); + std::cout << "auto-test" + << " preprocess-time-cost :" << time_diff(time3, time4) << "ms" + << std::endl; + + paddle_mobile::framework::LoDTensor input_lod_tensor; + if (is_lod) { + input_lod_tensor.Resize(paddle_mobile::framework::make_ddim(dims)); + input_lod_tensor.set_lod(lod); + auto *tensor_data = input_lod_tensor.mutable_data(); + for (int i = 0; i < size; i++) { + tensor_data[i] = input_data_array[i]; + } + } + + // 预热10次 + if (warm_up) { + for (int i = 0; i < 10; i++) { + if (is_lod) { + auto out = paddle_mobile.Predict(input_lod_tensor); + } else { + paddle_mobile.Feed(var_names[0], input_tensor); + paddle_mobile.Predict(); + } + } + } + + // 测速 + auto max_time = -1; + auto min_time = 100000; + auto all_time = 0; + if (is_lod) { + for (int i = 0; i < run_times; i++) { + auto time7 = time(); + paddle_mobile.Predict(input_lod_tensor); + auto time8 = time(); + const double diff_time_single = time_diff(time7, time8); + max_time = fmax(diff_time_single, max_time); + min_time = fmin(diff_time_single, min_time); + all_time += diff_time_single; + } + } else { + paddle_mobile.Feed(var_names[0], input_tensor); + for (int i = 0; i < run_times; i++) { + auto time7 = time(); + paddle_mobile.Predict(); + auto time8 = time(); + const double diff_time_single = time_diff(time7, time8); + max_time = fmax(diff_time_single, max_time); + min_time = fmin(diff_time_single, min_time); + all_time += diff_time_single; + } + } + + std::cout << "auto-test" + << " predict-time-cost-avg " << all_time * 1.0f / run_times + << "ms" << std::endl; + std::cout << "auto-test" + << " predict-time-cost-max " << double(max_time) << "ms" + << std::endl; + std::cout << "auto-test" + << " predict-time-cost-min " << double(min_time) << "ms" + << std::endl; + + std::cout << std::endl; + } +} diff --git a/mobile/test/net/test_op_in_net.cpp b/mobile/test/net/test_op_in_net.cpp index 4666f4133c7f7c93b1e23be19d0a6d8343d477a7..9425c02762352ff4e1724cb95b4c9fc243a042e1 100644 --- a/mobile/test/net/test_op_in_net.cpp +++ b/mobile/test/net/test_op_in_net.cpp @@ -58,7 +58,7 @@ void test(int argc, char *argv[]) { auto time1 = time(); if (paddle_mobile.Load("./checked_model/model", "./checked_model/params", - fuse, false, 1, true)) { + fuse, false, 1, true, 1)) { auto time2 = time(); std::cout << "auto-test" << " load-time-cost :" << time_diff(time1, time2) << "ms" diff --git a/mobile/third_party/opencl/.gitinore b/mobile/third_party/opencl/.gitinore new file mode 100644 index 0000000000000000000000000000000000000000..0c27d54300a9dc71c8a11fcc1d4a5e82c09c42db --- /dev/null +++ b/mobile/third_party/opencl/.gitinore @@ -0,0 +1 @@ +OpenCL-Headers diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl.h deleted file mode 100644 index a301ac6a003f5946bc40aa36fd99ccba1270dae9..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl.h +++ /dev/null @@ -1,1782 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2008-2018 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - ******************************************************************************/ - -#ifndef __OPENCL_CL_H -#define __OPENCL_CL_H - -#ifdef __APPLE__ -#include -#include -#else -#include -#include -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -/******************************************************************************/ - -typedef struct _cl_platform_id * cl_platform_id; -typedef struct _cl_device_id * cl_device_id; -typedef struct _cl_context * cl_context; -typedef struct _cl_command_queue * cl_command_queue; -typedef struct _cl_mem * cl_mem; -typedef struct _cl_program * cl_program; -typedef struct _cl_kernel * cl_kernel; -typedef struct _cl_event * cl_event; -typedef struct _cl_sampler * cl_sampler; - -typedef cl_uint cl_bool; /* WARNING! Unlike cl_ types in cl_platform.h, cl_bool is not guaranteed to be the same size as the bool in kernels. */ -typedef cl_ulong cl_bitfield; -typedef cl_bitfield cl_device_type; -typedef cl_uint cl_platform_info; -typedef cl_uint cl_device_info; -typedef cl_bitfield cl_device_fp_config; -typedef cl_uint cl_device_mem_cache_type; -typedef cl_uint cl_device_local_mem_type; -typedef cl_bitfield cl_device_exec_capabilities; -#ifdef CL_VERSION_2_0 -typedef cl_bitfield cl_device_svm_capabilities; -#endif -typedef cl_bitfield cl_command_queue_properties; -#ifdef CL_VERSION_1_2 -typedef intptr_t cl_device_partition_property; -typedef cl_bitfield cl_device_affinity_domain; -#endif - -typedef intptr_t cl_context_properties; -typedef cl_uint cl_context_info; -#ifdef CL_VERSION_2_0 -typedef cl_bitfield cl_queue_properties; -#endif -typedef cl_uint cl_command_queue_info; -typedef cl_uint cl_channel_order; -typedef cl_uint cl_channel_type; -typedef cl_bitfield cl_mem_flags; -#ifdef CL_VERSION_2_0 -typedef cl_bitfield cl_svm_mem_flags; -#endif -typedef cl_uint cl_mem_object_type; -typedef cl_uint cl_mem_info; -#ifdef CL_VERSION_1_2 -typedef cl_bitfield cl_mem_migration_flags; -#endif -typedef cl_uint cl_image_info; -#ifdef CL_VERSION_1_1 -typedef cl_uint cl_buffer_create_type; -#endif -typedef cl_uint cl_addressing_mode; -typedef cl_uint cl_filter_mode; -typedef cl_uint cl_sampler_info; -typedef cl_bitfield cl_map_flags; -#ifdef CL_VERSION_2_0 -typedef intptr_t cl_pipe_properties; -typedef cl_uint cl_pipe_info; -#endif -typedef cl_uint cl_program_info; -typedef cl_uint cl_program_build_info; -#ifdef CL_VERSION_1_2 -typedef cl_uint cl_program_binary_type; -#endif -typedef cl_int cl_build_status; -typedef cl_uint cl_kernel_info; -#ifdef CL_VERSION_1_2 -typedef cl_uint cl_kernel_arg_info; -typedef cl_uint cl_kernel_arg_address_qualifier; -typedef cl_uint cl_kernel_arg_access_qualifier; -typedef cl_bitfield cl_kernel_arg_type_qualifier; -#endif -typedef cl_uint cl_kernel_work_group_info; -#ifdef CL_VERSION_2_1 -typedef cl_uint cl_kernel_sub_group_info; -#endif -typedef cl_uint cl_event_info; -typedef cl_uint cl_command_type; -typedef cl_uint cl_profiling_info; -#ifdef CL_VERSION_2_0 -typedef cl_bitfield cl_sampler_properties; -typedef cl_uint cl_kernel_exec_info; -#endif - -typedef struct _cl_image_format { - cl_channel_order image_channel_order; - cl_channel_type image_channel_data_type; -} cl_image_format; - -#ifdef CL_VERSION_1_2 - -typedef struct _cl_image_desc { - cl_mem_object_type image_type; - size_t image_width; - size_t image_height; - size_t image_depth; - size_t image_array_size; - size_t image_row_pitch; - size_t image_slice_pitch; - cl_uint num_mip_levels; - cl_uint num_samples; -#ifdef __GNUC__ - __extension__ /* Prevents warnings about anonymous union in -pedantic builds */ -#endif - union { - cl_mem buffer; - cl_mem mem_object; - }; -} cl_image_desc; - -#endif - -#ifdef CL_VERSION_1_1 - -typedef struct _cl_buffer_region { - size_t origin; - size_t size; -} cl_buffer_region; - -#endif - -/******************************************************************************/ - -/* Error Codes */ -#define CL_SUCCESS 0 -#define CL_DEVICE_NOT_FOUND -1 -#define CL_DEVICE_NOT_AVAILABLE -2 -#define CL_COMPILER_NOT_AVAILABLE -3 -#define CL_MEM_OBJECT_ALLOCATION_FAILURE -4 -#define CL_OUT_OF_RESOURCES -5 -#define CL_OUT_OF_HOST_MEMORY -6 -#define CL_PROFILING_INFO_NOT_AVAILABLE -7 -#define CL_MEM_COPY_OVERLAP -8 -#define CL_IMAGE_FORMAT_MISMATCH -9 -#define CL_IMAGE_FORMAT_NOT_SUPPORTED -10 -#define CL_BUILD_PROGRAM_FAILURE -11 -#define CL_MAP_FAILURE -12 -#ifdef CL_VERSION_1_1 -#define CL_MISALIGNED_SUB_BUFFER_OFFSET -13 -#define CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST -14 -#endif -#ifdef CL_VERSION_1_2 -#define CL_COMPILE_PROGRAM_FAILURE -15 -#define CL_LINKER_NOT_AVAILABLE -16 -#define CL_LINK_PROGRAM_FAILURE -17 -#define CL_DEVICE_PARTITION_FAILED -18 -#define CL_KERNEL_ARG_INFO_NOT_AVAILABLE -19 -#endif - -#define CL_INVALID_VALUE -30 -#define CL_INVALID_DEVICE_TYPE -31 -#define CL_INVALID_PLATFORM -32 -#define CL_INVALID_DEVICE -33 -#define CL_INVALID_CONTEXT -34 -#define CL_INVALID_QUEUE_PROPERTIES -35 -#define CL_INVALID_COMMAND_QUEUE -36 -#define CL_INVALID_HOST_PTR -37 -#define CL_INVALID_MEM_OBJECT -38 -#define CL_INVALID_IMAGE_FORMAT_DESCRIPTOR -39 -#define CL_INVALID_IMAGE_SIZE -40 -#define CL_INVALID_SAMPLER -41 -#define CL_INVALID_BINARY -42 -#define CL_INVALID_BUILD_OPTIONS -43 -#define CL_INVALID_PROGRAM -44 -#define CL_INVALID_PROGRAM_EXECUTABLE -45 -#define CL_INVALID_KERNEL_NAME -46 -#define CL_INVALID_KERNEL_DEFINITION -47 -#define CL_INVALID_KERNEL -48 -#define CL_INVALID_ARG_INDEX -49 -#define CL_INVALID_ARG_VALUE -50 -#define CL_INVALID_ARG_SIZE -51 -#define CL_INVALID_KERNEL_ARGS -52 -#define CL_INVALID_WORK_DIMENSION -53 -#define CL_INVALID_WORK_GROUP_SIZE -54 -#define CL_INVALID_WORK_ITEM_SIZE -55 -#define CL_INVALID_GLOBAL_OFFSET -56 -#define CL_INVALID_EVENT_WAIT_LIST -57 -#define CL_INVALID_EVENT -58 -#define CL_INVALID_OPERATION -59 -#define CL_INVALID_GL_OBJECT -60 -#define CL_INVALID_BUFFER_SIZE -61 -#define CL_INVALID_MIP_LEVEL -62 -#define CL_INVALID_GLOBAL_WORK_SIZE -63 -#ifdef CL_VERSION_1_1 -#define CL_INVALID_PROPERTY -64 -#endif -#ifdef CL_VERSION_1_2 -#define CL_INVALID_IMAGE_DESCRIPTOR -65 -#define CL_INVALID_COMPILER_OPTIONS -66 -#define CL_INVALID_LINKER_OPTIONS -67 -#define CL_INVALID_DEVICE_PARTITION_COUNT -68 -#endif -#ifdef CL_VERSION_2_0 -#define CL_INVALID_PIPE_SIZE -69 -#define CL_INVALID_DEVICE_QUEUE -70 -#endif -#ifdef CL_VERSION_2_2 -#define CL_INVALID_SPEC_ID -71 -#define CL_MAX_SIZE_RESTRICTION_EXCEEDED -72 -#endif - - -/* cl_bool */ -#define CL_FALSE 0 -#define CL_TRUE 1 -#ifdef CL_VERSION_1_2 -#define CL_BLOCKING CL_TRUE -#define CL_NON_BLOCKING CL_FALSE -#endif - -/* cl_platform_info */ -#define CL_PLATFORM_PROFILE 0x0900 -#define CL_PLATFORM_VERSION 0x0901 -#define CL_PLATFORM_NAME 0x0902 -#define CL_PLATFORM_VENDOR 0x0903 -#define CL_PLATFORM_EXTENSIONS 0x0904 -#ifdef CL_VERSION_2_1 -#define CL_PLATFORM_HOST_TIMER_RESOLUTION 0x0905 -#endif - -/* cl_device_type - bitfield */ -#define CL_DEVICE_TYPE_DEFAULT (1 << 0) -#define CL_DEVICE_TYPE_CPU (1 << 1) -#define CL_DEVICE_TYPE_GPU (1 << 2) -#define CL_DEVICE_TYPE_ACCELERATOR (1 << 3) -#ifdef CL_VERSION_1_2 -#define CL_DEVICE_TYPE_CUSTOM (1 << 4) -#endif -#define CL_DEVICE_TYPE_ALL 0xFFFFFFFF - -/* cl_device_info */ -#define CL_DEVICE_TYPE 0x1000 -#define CL_DEVICE_VENDOR_ID 0x1001 -#define CL_DEVICE_MAX_COMPUTE_UNITS 0x1002 -#define CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS 0x1003 -#define CL_DEVICE_MAX_WORK_GROUP_SIZE 0x1004 -#define CL_DEVICE_MAX_WORK_ITEM_SIZES 0x1005 -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_CHAR 0x1006 -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_SHORT 0x1007 -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_INT 0x1008 -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_LONG 0x1009 -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_FLOAT 0x100A -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_DOUBLE 0x100B -#define CL_DEVICE_MAX_CLOCK_FREQUENCY 0x100C -#define CL_DEVICE_ADDRESS_BITS 0x100D -#define CL_DEVICE_MAX_READ_IMAGE_ARGS 0x100E -#define CL_DEVICE_MAX_WRITE_IMAGE_ARGS 0x100F -#define CL_DEVICE_MAX_MEM_ALLOC_SIZE 0x1010 -#define CL_DEVICE_IMAGE2D_MAX_WIDTH 0x1011 -#define CL_DEVICE_IMAGE2D_MAX_HEIGHT 0x1012 -#define CL_DEVICE_IMAGE3D_MAX_WIDTH 0x1013 -#define CL_DEVICE_IMAGE3D_MAX_HEIGHT 0x1014 -#define CL_DEVICE_IMAGE3D_MAX_DEPTH 0x1015 -#define CL_DEVICE_IMAGE_SUPPORT 0x1016 -#define CL_DEVICE_MAX_PARAMETER_SIZE 0x1017 -#define CL_DEVICE_MAX_SAMPLERS 0x1018 -#define CL_DEVICE_MEM_BASE_ADDR_ALIGN 0x1019 -#define CL_DEVICE_MIN_DATA_TYPE_ALIGN_SIZE 0x101A -#define CL_DEVICE_SINGLE_FP_CONFIG 0x101B -#define CL_DEVICE_GLOBAL_MEM_CACHE_TYPE 0x101C -#define CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE 0x101D -#define CL_DEVICE_GLOBAL_MEM_CACHE_SIZE 0x101E -#define CL_DEVICE_GLOBAL_MEM_SIZE 0x101F -#define CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE 0x1020 -#define CL_DEVICE_MAX_CONSTANT_ARGS 0x1021 -#define CL_DEVICE_LOCAL_MEM_TYPE 0x1022 -#define CL_DEVICE_LOCAL_MEM_SIZE 0x1023 -#define CL_DEVICE_ERROR_CORRECTION_SUPPORT 0x1024 -#define CL_DEVICE_PROFILING_TIMER_RESOLUTION 0x1025 -#define CL_DEVICE_ENDIAN_LITTLE 0x1026 -#define CL_DEVICE_AVAILABLE 0x1027 -#define CL_DEVICE_COMPILER_AVAILABLE 0x1028 -#define CL_DEVICE_EXECUTION_CAPABILITIES 0x1029 -#define CL_DEVICE_QUEUE_PROPERTIES 0x102A /* deprecated */ -#ifdef CL_VERSION_2_0 -#define CL_DEVICE_QUEUE_ON_HOST_PROPERTIES 0x102A -#endif -#define CL_DEVICE_NAME 0x102B -#define CL_DEVICE_VENDOR 0x102C -#define CL_DRIVER_VERSION 0x102D -#define CL_DEVICE_PROFILE 0x102E -#define CL_DEVICE_VERSION 0x102F -#define CL_DEVICE_EXTENSIONS 0x1030 -#define CL_DEVICE_PLATFORM 0x1031 -#ifdef CL_VERSION_1_2 -#define CL_DEVICE_DOUBLE_FP_CONFIG 0x1032 -#endif -/* 0x1033 reserved for CL_DEVICE_HALF_FP_CONFIG which is already defined in "cl_ext.h" */ -#ifdef CL_VERSION_1_1 -#define CL_DEVICE_PREFERRED_VECTOR_WIDTH_HALF 0x1034 -#define CL_DEVICE_HOST_UNIFIED_MEMORY 0x1035 /* deprecated */ -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_CHAR 0x1036 -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_SHORT 0x1037 -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_INT 0x1038 -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_LONG 0x1039 -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_FLOAT 0x103A -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_DOUBLE 0x103B -#define CL_DEVICE_NATIVE_VECTOR_WIDTH_HALF 0x103C -#define CL_DEVICE_OPENCL_C_VERSION 0x103D -#endif -#ifdef CL_VERSION_1_2 -#define CL_DEVICE_LINKER_AVAILABLE 0x103E -#define CL_DEVICE_BUILT_IN_KERNELS 0x103F -#define CL_DEVICE_IMAGE_MAX_BUFFER_SIZE 0x1040 -#define CL_DEVICE_IMAGE_MAX_ARRAY_SIZE 0x1041 -#define CL_DEVICE_PARENT_DEVICE 0x1042 -#define CL_DEVICE_PARTITION_MAX_SUB_DEVICES 0x1043 -#define CL_DEVICE_PARTITION_PROPERTIES 0x1044 -#define CL_DEVICE_PARTITION_AFFINITY_DOMAIN 0x1045 -#define CL_DEVICE_PARTITION_TYPE 0x1046 -#define CL_DEVICE_REFERENCE_COUNT 0x1047 -#define CL_DEVICE_PREFERRED_INTEROP_USER_SYNC 0x1048 -#define CL_DEVICE_PRINTF_BUFFER_SIZE 0x1049 -#define CL_DEVICE_IMAGE_PITCH_ALIGNMENT 0x104A -#define CL_DEVICE_IMAGE_BASE_ADDRESS_ALIGNMENT 0x104B -#endif -#ifdef CL_VERSION_2_0 -#define CL_DEVICE_MAX_READ_WRITE_IMAGE_ARGS 0x104C -#define CL_DEVICE_MAX_GLOBAL_VARIABLE_SIZE 0x104D -#define CL_DEVICE_QUEUE_ON_DEVICE_PROPERTIES 0x104E -#define CL_DEVICE_QUEUE_ON_DEVICE_PREFERRED_SIZE 0x104F -#define CL_DEVICE_QUEUE_ON_DEVICE_MAX_SIZE 0x1050 -#define CL_DEVICE_MAX_ON_DEVICE_QUEUES 0x1051 -#define CL_DEVICE_MAX_ON_DEVICE_EVENTS 0x1052 -#define CL_DEVICE_SVM_CAPABILITIES 0x1053 -#define CL_DEVICE_GLOBAL_VARIABLE_PREFERRED_TOTAL_SIZE 0x1054 -#define CL_DEVICE_MAX_PIPE_ARGS 0x1055 -#define CL_DEVICE_PIPE_MAX_ACTIVE_RESERVATIONS 0x1056 -#define CL_DEVICE_PIPE_MAX_PACKET_SIZE 0x1057 -#define CL_DEVICE_PREFERRED_PLATFORM_ATOMIC_ALIGNMENT 0x1058 -#define CL_DEVICE_PREFERRED_GLOBAL_ATOMIC_ALIGNMENT 0x1059 -#define CL_DEVICE_PREFERRED_LOCAL_ATOMIC_ALIGNMENT 0x105A -#endif -#ifdef CL_VERSION_2_1 -#define CL_DEVICE_IL_VERSION 0x105B -#define CL_DEVICE_MAX_NUM_SUB_GROUPS 0x105C -#define CL_DEVICE_SUB_GROUP_INDEPENDENT_FORWARD_PROGRESS 0x105D -#endif - -/* cl_device_fp_config - bitfield */ -#define CL_FP_DENORM (1 << 0) -#define CL_FP_INF_NAN (1 << 1) -#define CL_FP_ROUND_TO_NEAREST (1 << 2) -#define CL_FP_ROUND_TO_ZERO (1 << 3) -#define CL_FP_ROUND_TO_INF (1 << 4) -#define CL_FP_FMA (1 << 5) -#ifdef CL_VERSION_1_1 -#define CL_FP_SOFT_FLOAT (1 << 6) -#endif -#ifdef CL_VERSION_1_2 -#define CL_FP_CORRECTLY_ROUNDED_DIVIDE_SQRT (1 << 7) -#endif - -/* cl_device_mem_cache_type */ -#define CL_NONE 0x0 -#define CL_READ_ONLY_CACHE 0x1 -#define CL_READ_WRITE_CACHE 0x2 - -/* cl_device_local_mem_type */ -#define CL_LOCAL 0x1 -#define CL_GLOBAL 0x2 - -/* cl_device_exec_capabilities - bitfield */ -#define CL_EXEC_KERNEL (1 << 0) -#define CL_EXEC_NATIVE_KERNEL (1 << 1) - -/* cl_command_queue_properties - bitfield */ -#define CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE (1 << 0) -#define CL_QUEUE_PROFILING_ENABLE (1 << 1) -#ifdef CL_VERSION_2_0 -#define CL_QUEUE_ON_DEVICE (1 << 2) -#define CL_QUEUE_ON_DEVICE_DEFAULT (1 << 3) -#endif - -/* cl_context_info */ -#define CL_CONTEXT_REFERENCE_COUNT 0x1080 -#define CL_CONTEXT_DEVICES 0x1081 -#define CL_CONTEXT_PROPERTIES 0x1082 -#ifdef CL_VERSION_1_1 -#define CL_CONTEXT_NUM_DEVICES 0x1083 -#endif - -/* cl_context_properties */ -#define CL_CONTEXT_PLATFORM 0x1084 -#ifdef CL_VERSION_1_2 -#define CL_CONTEXT_INTEROP_USER_SYNC 0x1085 -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_device_partition_property */ -#define CL_DEVICE_PARTITION_EQUALLY 0x1086 -#define CL_DEVICE_PARTITION_BY_COUNTS 0x1087 -#define CL_DEVICE_PARTITION_BY_COUNTS_LIST_END 0x0 -#define CL_DEVICE_PARTITION_BY_AFFINITY_DOMAIN 0x1088 - -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_device_affinity_domain */ -#define CL_DEVICE_AFFINITY_DOMAIN_NUMA (1 << 0) -#define CL_DEVICE_AFFINITY_DOMAIN_L4_CACHE (1 << 1) -#define CL_DEVICE_AFFINITY_DOMAIN_L3_CACHE (1 << 2) -#define CL_DEVICE_AFFINITY_DOMAIN_L2_CACHE (1 << 3) -#define CL_DEVICE_AFFINITY_DOMAIN_L1_CACHE (1 << 4) -#define CL_DEVICE_AFFINITY_DOMAIN_NEXT_PARTITIONABLE (1 << 5) - -#endif - -#ifdef CL_VERSION_2_0 - -/* cl_device_svm_capabilities */ -#define CL_DEVICE_SVM_COARSE_GRAIN_BUFFER (1 << 0) -#define CL_DEVICE_SVM_FINE_GRAIN_BUFFER (1 << 1) -#define CL_DEVICE_SVM_FINE_GRAIN_SYSTEM (1 << 2) -#define CL_DEVICE_SVM_ATOMICS (1 << 3) - -#endif - -/* cl_command_queue_info */ -#define CL_QUEUE_CONTEXT 0x1090 -#define CL_QUEUE_DEVICE 0x1091 -#define CL_QUEUE_REFERENCE_COUNT 0x1092 -#define CL_QUEUE_PROPERTIES 0x1093 -#ifdef CL_VERSION_2_0 -#define CL_QUEUE_SIZE 0x1094 -#endif -#ifdef CL_VERSION_2_1 -#define CL_QUEUE_DEVICE_DEFAULT 0x1095 -#endif - -/* cl_mem_flags and cl_svm_mem_flags - bitfield */ -#define CL_MEM_READ_WRITE (1 << 0) -#define CL_MEM_WRITE_ONLY (1 << 1) -#define CL_MEM_READ_ONLY (1 << 2) -#define CL_MEM_USE_HOST_PTR (1 << 3) -#define CL_MEM_ALLOC_HOST_PTR (1 << 4) -#define CL_MEM_COPY_HOST_PTR (1 << 5) -/* reserved (1 << 6) */ -#ifdef CL_VERSION_1_2 -#define CL_MEM_HOST_WRITE_ONLY (1 << 7) -#define CL_MEM_HOST_READ_ONLY (1 << 8) -#define CL_MEM_HOST_NO_ACCESS (1 << 9) -#endif -#ifdef CL_VERSION_2_0 -#define CL_MEM_SVM_FINE_GRAIN_BUFFER (1 << 10) /* used by cl_svm_mem_flags only */ -#define CL_MEM_SVM_ATOMICS (1 << 11) /* used by cl_svm_mem_flags only */ -#define CL_MEM_KERNEL_READ_AND_WRITE (1 << 12) -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_mem_migration_flags - bitfield */ -#define CL_MIGRATE_MEM_OBJECT_HOST (1 << 0) -#define CL_MIGRATE_MEM_OBJECT_CONTENT_UNDEFINED (1 << 1) - -#endif - -/* cl_channel_order */ -#define CL_R 0x10B0 -#define CL_A 0x10B1 -#define CL_RG 0x10B2 -#define CL_RA 0x10B3 -#define CL_RGB 0x10B4 -#define CL_RGBA 0x10B5 -#define CL_BGRA 0x10B6 -#define CL_ARGB 0x10B7 -#define CL_INTENSITY 0x10B8 -#define CL_LUMINANCE 0x10B9 -#ifdef CL_VERSION_1_1 -#define CL_Rx 0x10BA -#define CL_RGx 0x10BB -#define CL_RGBx 0x10BC -#endif -#ifdef CL_VERSION_1_2 -#define CL_DEPTH 0x10BD -#define CL_DEPTH_STENCIL 0x10BE -#endif -#ifdef CL_VERSION_2_0 -#define CL_sRGB 0x10BF -#define CL_sRGBx 0x10C0 -#define CL_sRGBA 0x10C1 -#define CL_sBGRA 0x10C2 -#define CL_ABGR 0x10C3 -#endif - -/* cl_channel_type */ -#define CL_SNORM_INT8 0x10D0 -#define CL_SNORM_INT16 0x10D1 -#define CL_UNORM_INT8 0x10D2 -#define CL_UNORM_INT16 0x10D3 -#define CL_UNORM_SHORT_565 0x10D4 -#define CL_UNORM_SHORT_555 0x10D5 -#define CL_UNORM_INT_101010 0x10D6 -#define CL_SIGNED_INT8 0x10D7 -#define CL_SIGNED_INT16 0x10D8 -#define CL_SIGNED_INT32 0x10D9 -#define CL_UNSIGNED_INT8 0x10DA -#define CL_UNSIGNED_INT16 0x10DB -#define CL_UNSIGNED_INT32 0x10DC -#define CL_HALF_FLOAT 0x10DD -#define CL_FLOAT 0x10DE -#ifdef CL_VERSION_1_2 -#define CL_UNORM_INT24 0x10DF -#endif -#ifdef CL_VERSION_2_1 -#define CL_UNORM_INT_101010_2 0x10E0 -#endif - -/* cl_mem_object_type */ -#define CL_MEM_OBJECT_BUFFER 0x10F0 -#define CL_MEM_OBJECT_IMAGE2D 0x10F1 -#define CL_MEM_OBJECT_IMAGE3D 0x10F2 -#ifdef CL_VERSION_1_2 -#define CL_MEM_OBJECT_IMAGE2D_ARRAY 0x10F3 -#define CL_MEM_OBJECT_IMAGE1D 0x10F4 -#define CL_MEM_OBJECT_IMAGE1D_ARRAY 0x10F5 -#define CL_MEM_OBJECT_IMAGE1D_BUFFER 0x10F6 -#endif -#ifdef CL_VERSION_2_0 -#define CL_MEM_OBJECT_PIPE 0x10F7 -#endif - -/* cl_mem_info */ -#define CL_MEM_TYPE 0x1100 -#define CL_MEM_FLAGS 0x1101 -#define CL_MEM_SIZE 0x1102 -#define CL_MEM_HOST_PTR 0x1103 -#define CL_MEM_MAP_COUNT 0x1104 -#define CL_MEM_REFERENCE_COUNT 0x1105 -#define CL_MEM_CONTEXT 0x1106 -#ifdef CL_VERSION_1_1 -#define CL_MEM_ASSOCIATED_MEMOBJECT 0x1107 -#define CL_MEM_OFFSET 0x1108 -#endif -#ifdef CL_VERSION_2_0 -#define CL_MEM_USES_SVM_POINTER 0x1109 -#endif - -/* cl_image_info */ -#define CL_IMAGE_FORMAT 0x1110 -#define CL_IMAGE_ELEMENT_SIZE 0x1111 -#define CL_IMAGE_ROW_PITCH 0x1112 -#define CL_IMAGE_SLICE_PITCH 0x1113 -#define CL_IMAGE_WIDTH 0x1114 -#define CL_IMAGE_HEIGHT 0x1115 -#define CL_IMAGE_DEPTH 0x1116 -#ifdef CL_VERSION_1_2 -#define CL_IMAGE_ARRAY_SIZE 0x1117 -#define CL_IMAGE_BUFFER 0x1118 -#define CL_IMAGE_NUM_MIP_LEVELS 0x1119 -#define CL_IMAGE_NUM_SAMPLES 0x111A -#endif - -#ifdef CL_VERSION_2_0 - -/* cl_pipe_info */ -#define CL_PIPE_PACKET_SIZE 0x1120 -#define CL_PIPE_MAX_PACKETS 0x1121 - -#endif - -/* cl_addressing_mode */ -#define CL_ADDRESS_NONE 0x1130 -#define CL_ADDRESS_CLAMP_TO_EDGE 0x1131 -#define CL_ADDRESS_CLAMP 0x1132 -#define CL_ADDRESS_REPEAT 0x1133 -#ifdef CL_VERSION_1_1 -#define CL_ADDRESS_MIRRORED_REPEAT 0x1134 -#endif - -/* cl_filter_mode */ -#define CL_FILTER_NEAREST 0x1140 -#define CL_FILTER_LINEAR 0x1141 - -/* cl_sampler_info */ -#define CL_SAMPLER_REFERENCE_COUNT 0x1150 -#define CL_SAMPLER_CONTEXT 0x1151 -#define CL_SAMPLER_NORMALIZED_COORDS 0x1152 -#define CL_SAMPLER_ADDRESSING_MODE 0x1153 -#define CL_SAMPLER_FILTER_MODE 0x1154 -#ifdef CL_VERSION_2_0 -#define CL_SAMPLER_MIP_FILTER_MODE 0x1155 -#define CL_SAMPLER_LOD_MIN 0x1156 -#define CL_SAMPLER_LOD_MAX 0x1157 -#endif - -/* cl_map_flags - bitfield */ -#define CL_MAP_READ (1 << 0) -#define CL_MAP_WRITE (1 << 1) -#ifdef CL_VERSION_1_2 -#define CL_MAP_WRITE_INVALIDATE_REGION (1 << 2) -#endif - -/* cl_program_info */ -#define CL_PROGRAM_REFERENCE_COUNT 0x1160 -#define CL_PROGRAM_CONTEXT 0x1161 -#define CL_PROGRAM_NUM_DEVICES 0x1162 -#define CL_PROGRAM_DEVICES 0x1163 -#define CL_PROGRAM_SOURCE 0x1164 -#define CL_PROGRAM_BINARY_SIZES 0x1165 -#define CL_PROGRAM_BINARIES 0x1166 -#ifdef CL_VERSION_1_2 -#define CL_PROGRAM_NUM_KERNELS 0x1167 -#define CL_PROGRAM_KERNEL_NAMES 0x1168 -#endif -#ifdef CL_VERSION_2_1 -#define CL_PROGRAM_IL 0x1169 -#endif -#ifdef CL_VERSION_2_2 -#define CL_PROGRAM_SCOPE_GLOBAL_CTORS_PRESENT 0x116A -#define CL_PROGRAM_SCOPE_GLOBAL_DTORS_PRESENT 0x116B -#endif - -/* cl_program_build_info */ -#define CL_PROGRAM_BUILD_STATUS 0x1181 -#define CL_PROGRAM_BUILD_OPTIONS 0x1182 -#define CL_PROGRAM_BUILD_LOG 0x1183 -#ifdef CL_VERSION_1_2 -#define CL_PROGRAM_BINARY_TYPE 0x1184 -#endif -#ifdef CL_VERSION_2_0 -#define CL_PROGRAM_BUILD_GLOBAL_VARIABLE_TOTAL_SIZE 0x1185 -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_program_binary_type */ -#define CL_PROGRAM_BINARY_TYPE_NONE 0x0 -#define CL_PROGRAM_BINARY_TYPE_COMPILED_OBJECT 0x1 -#define CL_PROGRAM_BINARY_TYPE_LIBRARY 0x2 -#define CL_PROGRAM_BINARY_TYPE_EXECUTABLE 0x4 - -#endif - -/* cl_build_status */ -#define CL_BUILD_SUCCESS 0 -#define CL_BUILD_NONE -1 -#define CL_BUILD_ERROR -2 -#define CL_BUILD_IN_PROGRESS -3 - -/* cl_kernel_info */ -#define CL_KERNEL_FUNCTION_NAME 0x1190 -#define CL_KERNEL_NUM_ARGS 0x1191 -#define CL_KERNEL_REFERENCE_COUNT 0x1192 -#define CL_KERNEL_CONTEXT 0x1193 -#define CL_KERNEL_PROGRAM 0x1194 -#ifdef CL_VERSION_1_2 -#define CL_KERNEL_ATTRIBUTES 0x1195 -#endif -#ifdef CL_VERSION_2_1 -#define CL_KERNEL_MAX_NUM_SUB_GROUPS 0x11B9 -#define CL_KERNEL_COMPILE_NUM_SUB_GROUPS 0x11BA -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_kernel_arg_info */ -#define CL_KERNEL_ARG_ADDRESS_QUALIFIER 0x1196 -#define CL_KERNEL_ARG_ACCESS_QUALIFIER 0x1197 -#define CL_KERNEL_ARG_TYPE_NAME 0x1198 -#define CL_KERNEL_ARG_TYPE_QUALIFIER 0x1199 -#define CL_KERNEL_ARG_NAME 0x119A - -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_kernel_arg_address_qualifier */ -#define CL_KERNEL_ARG_ADDRESS_GLOBAL 0x119B -#define CL_KERNEL_ARG_ADDRESS_LOCAL 0x119C -#define CL_KERNEL_ARG_ADDRESS_CONSTANT 0x119D -#define CL_KERNEL_ARG_ADDRESS_PRIVATE 0x119E - -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_kernel_arg_access_qualifier */ -#define CL_KERNEL_ARG_ACCESS_READ_ONLY 0x11A0 -#define CL_KERNEL_ARG_ACCESS_WRITE_ONLY 0x11A1 -#define CL_KERNEL_ARG_ACCESS_READ_WRITE 0x11A2 -#define CL_KERNEL_ARG_ACCESS_NONE 0x11A3 - -#endif - -#ifdef CL_VERSION_1_2 - -/* cl_kernel_arg_type_qualifier */ -#define CL_KERNEL_ARG_TYPE_NONE 0 -#define CL_KERNEL_ARG_TYPE_CONST (1 << 0) -#define CL_KERNEL_ARG_TYPE_RESTRICT (1 << 1) -#define CL_KERNEL_ARG_TYPE_VOLATILE (1 << 2) -#ifdef CL_VERSION_2_0 -#define CL_KERNEL_ARG_TYPE_PIPE (1 << 3) -#endif - -#endif - -/* cl_kernel_work_group_info */ -#define CL_KERNEL_WORK_GROUP_SIZE 0x11B0 -#define CL_KERNEL_COMPILE_WORK_GROUP_SIZE 0x11B1 -#define CL_KERNEL_LOCAL_MEM_SIZE 0x11B2 -#define CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE 0x11B3 -#define CL_KERNEL_PRIVATE_MEM_SIZE 0x11B4 -#ifdef CL_VERSION_1_2 -#define CL_KERNEL_GLOBAL_WORK_SIZE 0x11B5 -#endif - -#ifdef CL_VERSION_2_1 - -/* cl_kernel_sub_group_info */ -#define CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE 0x2033 -#define CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE 0x2034 -#define CL_KERNEL_LOCAL_SIZE_FOR_SUB_GROUP_COUNT 0x11B8 - -#endif - -#ifdef CL_VERSION_2_0 - -/* cl_kernel_exec_info */ -#define CL_KERNEL_EXEC_INFO_SVM_PTRS 0x11B6 -#define CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM 0x11B7 - -#endif - -/* cl_event_info */ -#define CL_EVENT_COMMAND_QUEUE 0x11D0 -#define CL_EVENT_COMMAND_TYPE 0x11D1 -#define CL_EVENT_REFERENCE_COUNT 0x11D2 -#define CL_EVENT_COMMAND_EXECUTION_STATUS 0x11D3 -#ifdef CL_VERSION_1_1 -#define CL_EVENT_CONTEXT 0x11D4 -#endif - -/* cl_command_type */ -#define CL_COMMAND_NDRANGE_KERNEL 0x11F0 -#define CL_COMMAND_TASK 0x11F1 -#define CL_COMMAND_NATIVE_KERNEL 0x11F2 -#define CL_COMMAND_READ_BUFFER 0x11F3 -#define CL_COMMAND_WRITE_BUFFER 0x11F4 -#define CL_COMMAND_COPY_BUFFER 0x11F5 -#define CL_COMMAND_READ_IMAGE 0x11F6 -#define CL_COMMAND_WRITE_IMAGE 0x11F7 -#define CL_COMMAND_COPY_IMAGE 0x11F8 -#define CL_COMMAND_COPY_IMAGE_TO_BUFFER 0x11F9 -#define CL_COMMAND_COPY_BUFFER_TO_IMAGE 0x11FA -#define CL_COMMAND_MAP_BUFFER 0x11FB -#define CL_COMMAND_MAP_IMAGE 0x11FC -#define CL_COMMAND_UNMAP_MEM_OBJECT 0x11FD -#define CL_COMMAND_MARKER 0x11FE -#define CL_COMMAND_ACQUIRE_GL_OBJECTS 0x11FF -#define CL_COMMAND_RELEASE_GL_OBJECTS 0x1200 -#ifdef CL_VERSION_1_1 -#define CL_COMMAND_READ_BUFFER_RECT 0x1201 -#define CL_COMMAND_WRITE_BUFFER_RECT 0x1202 -#define CL_COMMAND_COPY_BUFFER_RECT 0x1203 -#define CL_COMMAND_USER 0x1204 -#endif -#ifdef CL_VERSION_1_2 -#define CL_COMMAND_BARRIER 0x1205 -#define CL_COMMAND_MIGRATE_MEM_OBJECTS 0x1206 -#define CL_COMMAND_FILL_BUFFER 0x1207 -#define CL_COMMAND_FILL_IMAGE 0x1208 -#endif -#ifdef CL_VERSION_2_0 -#define CL_COMMAND_SVM_FREE 0x1209 -#define CL_COMMAND_SVM_MEMCPY 0x120A -#define CL_COMMAND_SVM_MEMFILL 0x120B -#define CL_COMMAND_SVM_MAP 0x120C -#define CL_COMMAND_SVM_UNMAP 0x120D -#endif - -/* command execution status */ -#define CL_COMPLETE 0x0 -#define CL_RUNNING 0x1 -#define CL_SUBMITTED 0x2 -#define CL_QUEUED 0x3 - -#ifdef CL_VERSION_1_1 - -/* cl_buffer_create_type */ -#define CL_BUFFER_CREATE_TYPE_REGION 0x1220 - -#endif - -/* cl_profiling_info */ -#define CL_PROFILING_COMMAND_QUEUED 0x1280 -#define CL_PROFILING_COMMAND_SUBMIT 0x1281 -#define CL_PROFILING_COMMAND_START 0x1282 -#define CL_PROFILING_COMMAND_END 0x1283 -#ifdef CL_VERSION_2_0 -#define CL_PROFILING_COMMAND_COMPLETE 0x1284 -#endif - -/********************************************************************************************************/ - -/* Platform API */ -extern CL_API_ENTRY cl_int CL_API_CALL -clGetPlatformIDs(cl_uint /* num_entries */, - cl_platform_id * /* platforms */, - cl_uint * /* num_platforms */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetPlatformInfo(cl_platform_id /* platform */, - cl_platform_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -/* Device APIs */ -extern CL_API_ENTRY cl_int CL_API_CALL -clGetDeviceIDs(cl_platform_id /* platform */, - cl_device_type /* device_type */, - cl_uint /* num_entries */, - cl_device_id * /* devices */, - cl_uint * /* num_devices */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetDeviceInfo(cl_device_id /* device */, - cl_device_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clCreateSubDevices(cl_device_id /* in_device */, - const cl_device_partition_property * /* properties */, - cl_uint /* num_devices */, - cl_device_id * /* out_devices */, - cl_uint * /* num_devices_ret */) CL_API_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainDevice(cl_device_id /* device */) CL_API_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseDevice(cl_device_id /* device */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -#ifdef CL_VERSION_2_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetDefaultDeviceCommandQueue(cl_context /* context */, - cl_device_id /* device */, - cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_2_1; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetDeviceAndHostTimer(cl_device_id /* device */, - cl_ulong* /* device_timestamp */, - cl_ulong* /* host_timestamp */) CL_API_SUFFIX__VERSION_2_1; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetHostTimer(cl_device_id /* device */, - cl_ulong * /* host_timestamp */) CL_API_SUFFIX__VERSION_2_1; - -#endif - -/* Context APIs */ -extern CL_API_ENTRY cl_context CL_API_CALL -clCreateContext(const cl_context_properties * /* properties */, - cl_uint /* num_devices */, - const cl_device_id * /* devices */, - void (CL_CALLBACK * /* pfn_notify */)(const char *, const void *, size_t, void *), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_context CL_API_CALL -clCreateContextFromType(const cl_context_properties * /* properties */, - cl_device_type /* device_type */, - void (CL_CALLBACK * /* pfn_notify*/ )(const char *, const void *, size_t, void *), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainContext(cl_context /* context */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseContext(cl_context /* context */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetContextInfo(cl_context /* context */, - cl_context_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -/* Command Queue APIs */ - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY cl_command_queue CL_API_CALL -clCreateCommandQueueWithProperties(cl_context /* context */, - cl_device_id /* device */, - const cl_queue_properties * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainCommandQueue(cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseCommandQueue(cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetCommandQueueInfo(cl_command_queue /* command_queue */, - cl_command_queue_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -/* Memory Object APIs */ -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateBuffer(cl_context /* context */, - cl_mem_flags /* flags */, - size_t /* size */, - void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateSubBuffer(cl_mem /* buffer */, - cl_mem_flags /* flags */, - cl_buffer_create_type /* buffer_create_type */, - const void * /* buffer_create_info */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1; - -#endif - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateImage(cl_context /* context */, - cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, - const cl_image_desc * /* image_desc */, - void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreatePipe(cl_context /* context */, - cl_mem_flags /* flags */, - cl_uint /* pipe_packet_size */, - cl_uint /* pipe_max_packets */, - const cl_pipe_properties * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainMemObject(cl_mem /* memobj */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseMemObject(cl_mem /* memobj */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetSupportedImageFormats(cl_context /* context */, - cl_mem_flags /* flags */, - cl_mem_object_type /* image_type */, - cl_uint /* num_entries */, - cl_image_format * /* image_formats */, - cl_uint * /* num_image_formats */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetMemObjectInfo(cl_mem /* memobj */, - cl_mem_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetImageInfo(cl_mem /* image */, - cl_image_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetPipeInfo(cl_mem /* pipe */, - cl_pipe_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetMemObjectDestructorCallback(cl_mem /* memobj */, - void (CL_CALLBACK * /*pfn_notify*/)( cl_mem /* memobj */, void* /*user_data*/), - void * /*user_data */ ) CL_API_SUFFIX__VERSION_1_1; - -#endif - -/* SVM Allocation APIs */ - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY void * CL_API_CALL -clSVMAlloc(cl_context /* context */, - cl_svm_mem_flags /* flags */, - size_t /* size */, - cl_uint /* alignment */) CL_API_SUFFIX__VERSION_2_0; - -extern CL_API_ENTRY void CL_API_CALL -clSVMFree(cl_context /* context */, - void * /* svm_pointer */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -/* Sampler APIs */ - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY cl_sampler CL_API_CALL -clCreateSamplerWithProperties(cl_context /* context */, - const cl_sampler_properties * /* normalized_coords */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainSampler(cl_sampler /* sampler */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseSampler(cl_sampler /* sampler */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetSamplerInfo(cl_sampler /* sampler */, - cl_sampler_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -/* Program Object APIs */ -extern CL_API_ENTRY cl_program CL_API_CALL -clCreateProgramWithSource(cl_context /* context */, - cl_uint /* count */, - const char ** /* strings */, - const size_t * /* lengths */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_program CL_API_CALL -clCreateProgramWithBinary(cl_context /* context */, - cl_uint /* num_devices */, - const cl_device_id * /* device_list */, - const size_t * /* lengths */, - const unsigned char ** /* binaries */, - cl_int * /* binary_status */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_program CL_API_CALL -clCreateProgramWithBuiltInKernels(cl_context /* context */, - cl_uint /* num_devices */, - const cl_device_id * /* device_list */, - const char * /* kernel_names */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -#ifdef CL_VERSION_2_1 - -extern CL_API_ENTRY cl_program CL_API_CALL -clCreateProgramWithIL(cl_context /* context */, - const void* /* il */, - size_t /* length */, - cl_int* /* errcode_ret */) CL_API_SUFFIX__VERSION_2_1; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainProgram(cl_program /* program */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseProgram(cl_program /* program */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clBuildProgram(cl_program /* program */, - cl_uint /* num_devices */, - const cl_device_id * /* device_list */, - const char * /* options */, - void (CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clCompileProgram(cl_program /* program */, - cl_uint /* num_devices */, - const cl_device_id * /* device_list */, - const char * /* options */, - cl_uint /* num_input_headers */, - const cl_program * /* input_headers */, - const char ** /* header_include_names */, - void (CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_program CL_API_CALL -clLinkProgram(cl_context /* context */, - cl_uint /* num_devices */, - const cl_device_id * /* device_list */, - const char * /* options */, - cl_uint /* num_input_programs */, - const cl_program * /* input_programs */, - void (CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */), - void * /* user_data */, - cl_int * /* errcode_ret */ ) CL_API_SUFFIX__VERSION_1_2; - -#endif - -#ifdef CL_VERSION_2_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetProgramReleaseCallback(cl_program /* program */, - void (CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_2_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetProgramSpecializationConstant(cl_program /* program */, - cl_uint /* spec_id */, - size_t /* spec_size */, - const void* /* spec_value */) CL_API_SUFFIX__VERSION_2_2; - -#endif - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clUnloadPlatformCompiler(cl_platform_id /* platform */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetProgramInfo(cl_program /* program */, - cl_program_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetProgramBuildInfo(cl_program /* program */, - cl_device_id /* device */, - cl_program_build_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -/* Kernel Object APIs */ -extern CL_API_ENTRY cl_kernel CL_API_CALL -clCreateKernel(cl_program /* program */, - const char * /* kernel_name */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clCreateKernelsInProgram(cl_program /* program */, - cl_uint /* num_kernels */, - cl_kernel * /* kernels */, - cl_uint * /* num_kernels_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_2_1 - -extern CL_API_ENTRY cl_kernel CL_API_CALL -clCloneKernel(cl_kernel /* source_kernel */, - cl_int* /* errcode_ret */) CL_API_SUFFIX__VERSION_2_1; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainKernel(cl_kernel /* kernel */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseKernel(cl_kernel /* kernel */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetKernelArg(cl_kernel /* kernel */, - cl_uint /* arg_index */, - size_t /* arg_size */, - const void * /* arg_value */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetKernelArgSVMPointer(cl_kernel /* kernel */, - cl_uint /* arg_index */, - const void * /* arg_value */) CL_API_SUFFIX__VERSION_2_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetKernelExecInfo(cl_kernel /* kernel */, - cl_kernel_exec_info /* param_name */, - size_t /* param_value_size */, - const void * /* param_value */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetKernelInfo(cl_kernel /* kernel */, - cl_kernel_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetKernelArgInfo(cl_kernel /* kernel */, - cl_uint /* arg_indx */, - cl_kernel_arg_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetKernelWorkGroupInfo(cl_kernel /* kernel */, - cl_device_id /* device */, - cl_kernel_work_group_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_2_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetKernelSubGroupInfo(cl_kernel /* kernel */, - cl_device_id /* device */, - cl_kernel_sub_group_info /* param_name */, - size_t /* input_value_size */, - const void* /*input_value */, - size_t /* param_value_size */, - void* /* param_value */, - size_t* /* param_value_size_ret */ ) CL_API_SUFFIX__VERSION_2_1; - -#endif - -/* Event Object APIs */ -extern CL_API_ENTRY cl_int CL_API_CALL -clWaitForEvents(cl_uint /* num_events */, - const cl_event * /* event_list */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetEventInfo(cl_event /* event */, - cl_event_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_event CL_API_CALL -clCreateUserEvent(cl_context /* context */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainEvent(cl_event /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseEvent(cl_event /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetUserEventStatus(cl_event /* event */, - cl_int /* execution_status */) CL_API_SUFFIX__VERSION_1_1; - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetEventCallback( cl_event /* event */, - cl_int /* command_exec_callback_type */, - void (CL_CALLBACK * /* pfn_notify */)(cl_event, cl_int, void *), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_1; - -#endif - -/* Profiling APIs */ -extern CL_API_ENTRY cl_int CL_API_CALL -clGetEventProfilingInfo(cl_event /* event */, - cl_profiling_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -/* Flush and Finish APIs */ -extern CL_API_ENTRY cl_int CL_API_CALL -clFlush(cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clFinish(cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; - -/* Enqueued Commands APIs */ -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReadBuffer(cl_command_queue /* command_queue */, - cl_mem /* buffer */, - cl_bool /* blocking_read */, - size_t /* offset */, - size_t /* size */, - void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReadBufferRect(cl_command_queue /* command_queue */, - cl_mem /* buffer */, - cl_bool /* blocking_read */, - const size_t * /* buffer_offset */, - const size_t * /* host_offset */, - const size_t * /* region */, - size_t /* buffer_row_pitch */, - size_t /* buffer_slice_pitch */, - size_t /* host_row_pitch */, - size_t /* host_slice_pitch */, - void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueWriteBuffer(cl_command_queue /* command_queue */, - cl_mem /* buffer */, - cl_bool /* blocking_write */, - size_t /* offset */, - size_t /* size */, - const void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueWriteBufferRect(cl_command_queue /* command_queue */, - cl_mem /* buffer */, - cl_bool /* blocking_write */, - const size_t * /* buffer_offset */, - const size_t * /* host_offset */, - const size_t * /* region */, - size_t /* buffer_row_pitch */, - size_t /* buffer_slice_pitch */, - size_t /* host_row_pitch */, - size_t /* host_slice_pitch */, - const void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; - -#endif - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueFillBuffer(cl_command_queue /* command_queue */, - cl_mem /* buffer */, - const void * /* pattern */, - size_t /* pattern_size */, - size_t /* offset */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueCopyBuffer(cl_command_queue /* command_queue */, - cl_mem /* src_buffer */, - cl_mem /* dst_buffer */, - size_t /* src_offset */, - size_t /* dst_offset */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueCopyBufferRect(cl_command_queue /* command_queue */, - cl_mem /* src_buffer */, - cl_mem /* dst_buffer */, - const size_t * /* src_origin */, - const size_t * /* dst_origin */, - const size_t * /* region */, - size_t /* src_row_pitch */, - size_t /* src_slice_pitch */, - size_t /* dst_row_pitch */, - size_t /* dst_slice_pitch */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReadImage(cl_command_queue /* command_queue */, - cl_mem /* image */, - cl_bool /* blocking_read */, - const size_t * /* origin[3] */, - const size_t * /* region[3] */, - size_t /* row_pitch */, - size_t /* slice_pitch */, - void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueWriteImage(cl_command_queue /* command_queue */, - cl_mem /* image */, - cl_bool /* blocking_write */, - const size_t * /* origin[3] */, - const size_t * /* region[3] */, - size_t /* input_row_pitch */, - size_t /* input_slice_pitch */, - const void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueFillImage(cl_command_queue /* command_queue */, - cl_mem /* image */, - const void * /* fill_color */, - const size_t * /* origin[3] */, - const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueCopyImage(cl_command_queue /* command_queue */, - cl_mem /* src_image */, - cl_mem /* dst_image */, - const size_t * /* src_origin[3] */, - const size_t * /* dst_origin[3] */, - const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueCopyImageToBuffer(cl_command_queue /* command_queue */, - cl_mem /* src_image */, - cl_mem /* dst_buffer */, - const size_t * /* src_origin[3] */, - const size_t * /* region[3] */, - size_t /* dst_offset */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueCopyBufferToImage(cl_command_queue /* command_queue */, - cl_mem /* src_buffer */, - cl_mem /* dst_image */, - size_t /* src_offset */, - const size_t * /* dst_origin[3] */, - const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY void * CL_API_CALL -clEnqueueMapBuffer(cl_command_queue /* command_queue */, - cl_mem /* buffer */, - cl_bool /* blocking_map */, - cl_map_flags /* map_flags */, - size_t /* offset */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY void * CL_API_CALL -clEnqueueMapImage(cl_command_queue /* command_queue */, - cl_mem /* image */, - cl_bool /* blocking_map */, - cl_map_flags /* map_flags */, - const size_t * /* origin[3] */, - const size_t * /* region[3] */, - size_t * /* image_row_pitch */, - size_t * /* image_slice_pitch */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueUnmapMemObject(cl_command_queue /* command_queue */, - cl_mem /* memobj */, - void * /* mapped_ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueMigrateMemObjects(cl_command_queue /* command_queue */, - cl_uint /* num_mem_objects */, - const cl_mem * /* mem_objects */, - cl_mem_migration_flags /* flags */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueNDRangeKernel(cl_command_queue /* command_queue */, - cl_kernel /* kernel */, - cl_uint /* work_dim */, - const size_t * /* global_work_offset */, - const size_t * /* global_work_size */, - const size_t * /* local_work_size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueNativeKernel(cl_command_queue /* command_queue */, - void (CL_CALLBACK * /*user_func*/)(void *), - void * /* args */, - size_t /* cb_args */, - cl_uint /* num_mem_objects */, - const cl_mem * /* mem_list */, - const void ** /* args_mem_loc */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueMarkerWithWaitList(cl_command_queue /* command_queue */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueBarrierWithWaitList(cl_command_queue /* command_queue */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -#ifdef CL_VERSION_2_0 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMFree(cl_command_queue /* command_queue */, - cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void (CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */, - cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void * /* user_data */), - void * /* user_data */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMemcpy(cl_command_queue /* command_queue */, - cl_bool /* blocking_copy */, - void * /* dst_ptr */, - const void * /* src_ptr */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMemFill(cl_command_queue /* command_queue */, - void * /* svm_ptr */, - const void * /* pattern */, - size_t /* pattern_size */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMap(cl_command_queue /* command_queue */, - cl_bool /* blocking_map */, - cl_map_flags /* flags */, - void * /* svm_ptr */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMUnmap(cl_command_queue /* command_queue */, - void * /* svm_ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; - -#endif - -#ifdef CL_VERSION_2_1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMigrateMem(cl_command_queue /* command_queue */, - cl_uint /* num_svm_pointers */, - const void ** /* svm_pointers */, - const size_t * /* sizes */, - cl_mem_migration_flags /* flags */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_1; - -#endif - -#ifdef CL_VERSION_1_2 - -/* Extension function access - * - * Returns the extension function address for the given function name, - * or NULL if a valid function can not be found. The client must - * check to make sure the address is not NULL, before using or - * calling the returned function address. - */ -extern CL_API_ENTRY void * CL_API_CALL -clGetExtensionFunctionAddressForPlatform(cl_platform_id /* platform */, - const char * /* func_name */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -#ifdef CL_USE_DEPRECATED_OPENCL_1_0_APIS - /* - * WARNING: - * This API introduces mutable state into the OpenCL implementation. It has been REMOVED - * to better facilitate thread safety. The 1.0 API is not thread safe. It is not tested by the - * OpenCL 1.1 conformance test, and consequently may not work or may not work dependably. - * It is likely to be non-performant. Use of this API is not advised. Use at your own risk. - * - * Software developers previously relying on this API are instructed to set the command queue - * properties when creating the queue, instead. - */ - extern CL_API_ENTRY cl_int CL_API_CALL - clSetCommandQueueProperty(cl_command_queue /* command_queue */, - cl_command_queue_properties /* properties */, - cl_bool /* enable */, - cl_command_queue_properties * /* old_properties */) CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED; -#endif /* CL_USE_DEPRECATED_OPENCL_1_0_APIS */ - -/* Deprecated OpenCL 1.1 APIs */ -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL -clCreateImage2D(cl_context /* context */, - cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, - size_t /* image_width */, - size_t /* image_height */, - size_t /* image_row_pitch */, - void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL -clCreateImage3D(cl_context /* context */, - cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, - size_t /* image_width */, - size_t /* image_height */, - size_t /* image_depth */, - size_t /* image_row_pitch */, - size_t /* image_slice_pitch */, - void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_int CL_API_CALL -clEnqueueMarker(cl_command_queue /* command_queue */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_int CL_API_CALL -clEnqueueWaitForEvents(cl_command_queue /* command_queue */, - cl_uint /* num_events */, - const cl_event * /* event_list */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_int CL_API_CALL -clEnqueueBarrier(cl_command_queue /* command_queue */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_int CL_API_CALL -clUnloadCompiler(void) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED void * CL_API_CALL -clGetExtensionFunctionAddress(const char * /* func_name */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -/* Deprecated OpenCL 2.0 APIs */ -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_2_DEPRECATED cl_command_queue CL_API_CALL -clCreateCommandQueue(cl_context /* context */, - cl_device_id /* device */, - cl_command_queue_properties /* properties */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_2_DEPRECATED cl_sampler CL_API_CALL -clCreateSampler(cl_context /* context */, - cl_bool /* normalized_coords */, - cl_addressing_mode /* addressing_mode */, - cl_filter_mode /* filter_mode */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_2_DEPRECATED cl_int CL_API_CALL -clEnqueueTask(cl_command_queue /* command_queue */, - cl_kernel /* kernel */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_d3d10.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_d3d10.h deleted file mode 100644 index 6b6bcaf1559e6d5ee6b05973ebd75785c6ef4269..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_d3d10.h +++ /dev/null @@ -1,130 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2015 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ - -/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */ - -#ifndef __OPENCL_CL_D3D10_H -#define __OPENCL_CL_D3D10_H - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/****************************************************************************** - * cl_khr_d3d10_sharing */ -#define cl_khr_d3d10_sharing 1 - -typedef cl_uint cl_d3d10_device_source_khr; -typedef cl_uint cl_d3d10_device_set_khr; - -/******************************************************************************/ - -/* Error Codes */ -#define CL_INVALID_D3D10_DEVICE_KHR -1002 -#define CL_INVALID_D3D10_RESOURCE_KHR -1003 -#define CL_D3D10_RESOURCE_ALREADY_ACQUIRED_KHR -1004 -#define CL_D3D10_RESOURCE_NOT_ACQUIRED_KHR -1005 - -/* cl_d3d10_device_source_nv */ -#define CL_D3D10_DEVICE_KHR 0x4010 -#define CL_D3D10_DXGI_ADAPTER_KHR 0x4011 - -/* cl_d3d10_device_set_nv */ -#define CL_PREFERRED_DEVICES_FOR_D3D10_KHR 0x4012 -#define CL_ALL_DEVICES_FOR_D3D10_KHR 0x4013 - -/* cl_context_info */ -#define CL_CONTEXT_D3D10_DEVICE_KHR 0x4014 -#define CL_CONTEXT_D3D10_PREFER_SHARED_RESOURCES_KHR 0x402C - -/* cl_mem_info */ -#define CL_MEM_D3D10_RESOURCE_KHR 0x4015 - -/* cl_image_info */ -#define CL_IMAGE_D3D10_SUBRESOURCE_KHR 0x4016 - -/* cl_command_type */ -#define CL_COMMAND_ACQUIRE_D3D10_OBJECTS_KHR 0x4017 -#define CL_COMMAND_RELEASE_D3D10_OBJECTS_KHR 0x4018 - -/******************************************************************************/ - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromD3D10KHR_fn)( - cl_platform_id platform, - cl_d3d10_device_source_khr d3d_device_source, - void * d3d_object, - cl_d3d10_device_set_khr d3d_device_set, - cl_uint num_entries, - cl_device_id * devices, - cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10BufferKHR_fn)( - cl_context context, - cl_mem_flags flags, - ID3D10Buffer * resource, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10Texture2DKHR_fn)( - cl_context context, - cl_mem_flags flags, - ID3D10Texture2D * resource, - UINT subresource, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10Texture3DKHR_fn)( - cl_context context, - cl_mem_flags flags, - ID3D10Texture3D * resource, - UINT subresource, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireD3D10ObjectsKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseD3D10ObjectsKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event) CL_API_SUFFIX__VERSION_1_0; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_D3D10_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_d3d11.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_d3d11.h deleted file mode 100644 index 38cc21a2e528a5f8c9f0dd753001c2fb79f9a9d1..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_d3d11.h +++ /dev/null @@ -1,130 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2015 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ - -/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */ - -#ifndef __OPENCL_CL_D3D11_H -#define __OPENCL_CL_D3D11_H - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/****************************************************************************** - * cl_khr_d3d11_sharing */ -#define cl_khr_d3d11_sharing 1 - -typedef cl_uint cl_d3d11_device_source_khr; -typedef cl_uint cl_d3d11_device_set_khr; - -/******************************************************************************/ - -/* Error Codes */ -#define CL_INVALID_D3D11_DEVICE_KHR -1006 -#define CL_INVALID_D3D11_RESOURCE_KHR -1007 -#define CL_D3D11_RESOURCE_ALREADY_ACQUIRED_KHR -1008 -#define CL_D3D11_RESOURCE_NOT_ACQUIRED_KHR -1009 - -/* cl_d3d11_device_source */ -#define CL_D3D11_DEVICE_KHR 0x4019 -#define CL_D3D11_DXGI_ADAPTER_KHR 0x401A - -/* cl_d3d11_device_set */ -#define CL_PREFERRED_DEVICES_FOR_D3D11_KHR 0x401B -#define CL_ALL_DEVICES_FOR_D3D11_KHR 0x401C - -/* cl_context_info */ -#define CL_CONTEXT_D3D11_DEVICE_KHR 0x401D -#define CL_CONTEXT_D3D11_PREFER_SHARED_RESOURCES_KHR 0x402D - -/* cl_mem_info */ -#define CL_MEM_D3D11_RESOURCE_KHR 0x401E - -/* cl_image_info */ -#define CL_IMAGE_D3D11_SUBRESOURCE_KHR 0x401F - -/* cl_command_type */ -#define CL_COMMAND_ACQUIRE_D3D11_OBJECTS_KHR 0x4020 -#define CL_COMMAND_RELEASE_D3D11_OBJECTS_KHR 0x4021 - -/******************************************************************************/ - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromD3D11KHR_fn)( - cl_platform_id platform, - cl_d3d11_device_source_khr d3d_device_source, - void * d3d_object, - cl_d3d11_device_set_khr d3d_device_set, - cl_uint num_entries, - cl_device_id * devices, - cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11BufferKHR_fn)( - cl_context context, - cl_mem_flags flags, - ID3D11Buffer * resource, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11Texture2DKHR_fn)( - cl_context context, - cl_mem_flags flags, - ID3D11Texture2D * resource, - UINT subresource, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11Texture3DKHR_fn)( - cl_context context, - cl_mem_flags flags, - ID3D11Texture3D * resource, - UINT subresource, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireD3D11ObjectsKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseD3D11ObjectsKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event) CL_API_SUFFIX__VERSION_1_2; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_D3D11_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing.h deleted file mode 100644 index 484f8cbc77ddeab30e021bed8d889131e3100940..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing.h +++ /dev/null @@ -1,131 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2015 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ - -/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */ - -#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_H -#define __OPENCL_CL_DX9_MEDIA_SHARING_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/******************************************************************************/ -/* cl_khr_dx9_media_sharing */ -#define cl_khr_dx9_media_sharing 1 - -typedef cl_uint cl_dx9_media_adapter_type_khr; -typedef cl_uint cl_dx9_media_adapter_set_khr; - -#if defined(_WIN32) -#include -typedef struct _cl_dx9_surface_info_khr -{ - IDirect3DSurface9 *resource; - HANDLE shared_handle; -} cl_dx9_surface_info_khr; -#endif - - -/******************************************************************************/ - -/* Error Codes */ -#define CL_INVALID_DX9_MEDIA_ADAPTER_KHR -1010 -#define CL_INVALID_DX9_MEDIA_SURFACE_KHR -1011 -#define CL_DX9_MEDIA_SURFACE_ALREADY_ACQUIRED_KHR -1012 -#define CL_DX9_MEDIA_SURFACE_NOT_ACQUIRED_KHR -1013 - -/* cl_media_adapter_type_khr */ -#define CL_ADAPTER_D3D9_KHR 0x2020 -#define CL_ADAPTER_D3D9EX_KHR 0x2021 -#define CL_ADAPTER_DXVA_KHR 0x2022 - -/* cl_media_adapter_set_khr */ -#define CL_PREFERRED_DEVICES_FOR_DX9_MEDIA_ADAPTER_KHR 0x2023 -#define CL_ALL_DEVICES_FOR_DX9_MEDIA_ADAPTER_KHR 0x2024 - -/* cl_context_info */ -#define CL_CONTEXT_ADAPTER_D3D9_KHR 0x2025 -#define CL_CONTEXT_ADAPTER_D3D9EX_KHR 0x2026 -#define CL_CONTEXT_ADAPTER_DXVA_KHR 0x2027 - -/* cl_mem_info */ -#define CL_MEM_DX9_MEDIA_ADAPTER_TYPE_KHR 0x2028 -#define CL_MEM_DX9_MEDIA_SURFACE_INFO_KHR 0x2029 - -/* cl_image_info */ -#define CL_IMAGE_DX9_MEDIA_PLANE_KHR 0x202A - -/* cl_command_type */ -#define CL_COMMAND_ACQUIRE_DX9_MEDIA_SURFACES_KHR 0x202B -#define CL_COMMAND_RELEASE_DX9_MEDIA_SURFACES_KHR 0x202C - -/******************************************************************************/ - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromDX9MediaAdapterKHR_fn)( - cl_platform_id platform, - cl_uint num_media_adapters, - cl_dx9_media_adapter_type_khr * media_adapter_type, - void * media_adapters, - cl_dx9_media_adapter_set_khr media_adapter_set, - cl_uint num_entries, - cl_device_id * devices, - cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromDX9MediaSurfaceKHR_fn)( - cl_context context, - cl_mem_flags flags, - cl_dx9_media_adapter_type_khr adapter_type, - void * surface_info, - cl_uint plane, - cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireDX9MediaSurfacesKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event) CL_API_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseDX9MediaSurfacesKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event) CL_API_SUFFIX__VERSION_1_2; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_DX9_MEDIA_SHARING_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing_intel.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing_intel.h deleted file mode 100644 index abae0457a8e290977eb4ce0e99cf533ed7d9b094..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing_intel.h +++ /dev/null @@ -1,181 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2016 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ -/*****************************************************************************\ - -Copyright (c) 2013-2016 Intel Corporation All Rights Reserved. - -THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE -MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -File Name: cl_dx9_media_sharing_intel.h - -Abstract: - -Notes: - -\*****************************************************************************/ - -#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H -#define __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H - -#include -#include -#include -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/*************************************** -* cl_intel_dx9_media_sharing extension * -****************************************/ - -#define cl_intel_dx9_media_sharing 1 - -typedef cl_uint cl_dx9_device_source_intel; -typedef cl_uint cl_dx9_device_set_intel; - -/* error codes */ -#define CL_INVALID_DX9_DEVICE_INTEL -1010 -#define CL_INVALID_DX9_RESOURCE_INTEL -1011 -#define CL_DX9_RESOURCE_ALREADY_ACQUIRED_INTEL -1012 -#define CL_DX9_RESOURCE_NOT_ACQUIRED_INTEL -1013 - -/* cl_dx9_device_source_intel */ -#define CL_D3D9_DEVICE_INTEL 0x4022 -#define CL_D3D9EX_DEVICE_INTEL 0x4070 -#define CL_DXVA_DEVICE_INTEL 0x4071 - -/* cl_dx9_device_set_intel */ -#define CL_PREFERRED_DEVICES_FOR_DX9_INTEL 0x4024 -#define CL_ALL_DEVICES_FOR_DX9_INTEL 0x4025 - -/* cl_context_info */ -#define CL_CONTEXT_D3D9_DEVICE_INTEL 0x4026 -#define CL_CONTEXT_D3D9EX_DEVICE_INTEL 0x4072 -#define CL_CONTEXT_DXVA_DEVICE_INTEL 0x4073 - -/* cl_mem_info */ -#define CL_MEM_DX9_RESOURCE_INTEL 0x4027 -#define CL_MEM_DX9_SHARED_HANDLE_INTEL 0x4074 - -/* cl_image_info */ -#define CL_IMAGE_DX9_PLANE_INTEL 0x4075 - -/* cl_command_type */ -#define CL_COMMAND_ACQUIRE_DX9_OBJECTS_INTEL 0x402A -#define CL_COMMAND_RELEASE_DX9_OBJECTS_INTEL 0x402B -/******************************************************************************/ - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetDeviceIDsFromDX9INTEL( - cl_platform_id /* platform */, - cl_dx9_device_source_intel /* dx9_device_source */, - void* /* dx9_object */, - cl_dx9_device_set_intel /* dx9_device_set */, - cl_uint /* num_entries */, - cl_device_id* /* devices */, - cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_int (CL_API_CALL* clGetDeviceIDsFromDX9INTEL_fn)( - cl_platform_id /* platform */, - cl_dx9_device_source_intel /* dx9_device_source */, - void* /* dx9_object */, - cl_dx9_device_set_intel /* dx9_device_set */, - cl_uint /* num_entries */, - cl_device_id* /* devices */, - cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_1; - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateFromDX9MediaSurfaceINTEL( - cl_context /* context */, - cl_mem_flags /* flags */, - IDirect3DSurface9* /* resource */, - HANDLE /* sharedHandle */, - UINT /* plane */, - cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromDX9MediaSurfaceINTEL_fn)( - cl_context /* context */, - cl_mem_flags /* flags */, - IDirect3DSurface9* /* resource */, - HANDLE /* sharedHandle */, - UINT /* plane */, - cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueAcquireDX9ObjectsINTEL( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireDX9ObjectsINTEL_fn)( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReleaseDX9ObjectsINTEL( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseDX9ObjectsINTEL_fn)( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_egl.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_egl.h deleted file mode 100644 index a765bd5266c02fc2fd2892f0257b228996d73c5f..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_egl.h +++ /dev/null @@ -1,136 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2008-2015 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - ******************************************************************************/ - -#ifndef __OPENCL_CL_EGL_H -#define __OPENCL_CL_EGL_H - -#ifdef __APPLE__ - -#else -#include -#endif - -#ifdef __cplusplus -extern "C" { -#endif - - -/* Command type for events created with clEnqueueAcquireEGLObjectsKHR */ -#define CL_COMMAND_EGL_FENCE_SYNC_OBJECT_KHR 0x202F -#define CL_COMMAND_ACQUIRE_EGL_OBJECTS_KHR 0x202D -#define CL_COMMAND_RELEASE_EGL_OBJECTS_KHR 0x202E - -/* Error type for clCreateFromEGLImageKHR */ -#define CL_INVALID_EGL_OBJECT_KHR -1093 -#define CL_EGL_RESOURCE_NOT_ACQUIRED_KHR -1092 - -/* CLeglImageKHR is an opaque handle to an EGLImage */ -typedef void* CLeglImageKHR; - -/* CLeglDisplayKHR is an opaque handle to an EGLDisplay */ -typedef void* CLeglDisplayKHR; - -/* CLeglSyncKHR is an opaque handle to an EGLSync object */ -typedef void* CLeglSyncKHR; - -/* properties passed to clCreateFromEGLImageKHR */ -typedef intptr_t cl_egl_image_properties_khr; - - -#define cl_khr_egl_image 1 - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateFromEGLImageKHR(cl_context /* context */, - CLeglDisplayKHR /* egldisplay */, - CLeglImageKHR /* eglimage */, - cl_mem_flags /* flags */, - const cl_egl_image_properties_khr * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromEGLImageKHR_fn)( - cl_context context, - CLeglDisplayKHR egldisplay, - CLeglImageKHR eglimage, - cl_mem_flags flags, - const cl_egl_image_properties_khr * properties, - cl_int * errcode_ret); - - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueAcquireEGLObjectsKHR(cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireEGLObjectsKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event); - - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReleaseEGLObjectsKHR(cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseEGLObjectsKHR_fn)( - cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem * mem_objects, - cl_uint num_events_in_wait_list, - const cl_event * event_wait_list, - cl_event * event); - - -#define cl_khr_egl_event 1 - -extern CL_API_ENTRY cl_event CL_API_CALL -clCreateEventFromEGLSyncKHR(cl_context /* context */, - CLeglSyncKHR /* sync */, - CLeglDisplayKHR /* display */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_event (CL_API_CALL *clCreateEventFromEGLSyncKHR_fn)( - cl_context context, - CLeglSyncKHR sync, - CLeglDisplayKHR display, - cl_int * errcode_ret); - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_EGL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_ext.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_ext.h deleted file mode 100644 index af3ce461f3a48e7707caca966e704dfe5eb58e30..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_ext.h +++ /dev/null @@ -1,723 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2008-2018 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - ******************************************************************************/ - -/* cl_ext.h contains OpenCL extensions which don't have external */ -/* (OpenGL, D3D) dependencies. */ - -#ifndef __CL_EXT_H -#define __CL_EXT_H - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __APPLE__ - #include - #include -#else - #include -#endif - -/* cl_khr_fp64 extension - no extension #define since it has no functions */ -/* CL_DEVICE_DOUBLE_FP_CONFIG is defined in CL.h for OpenCL >= 120 */ - -#if CL_TARGET_OPENCL_VERSION <= 110 -#define CL_DEVICE_DOUBLE_FP_CONFIG 0x1032 -#endif - -/* cl_khr_fp16 extension - no extension #define since it has no functions */ -#define CL_DEVICE_HALF_FP_CONFIG 0x1033 - -/* Memory object destruction - * - * Apple extension for use to manage externally allocated buffers used with cl_mem objects with CL_MEM_USE_HOST_PTR - * - * Registers a user callback function that will be called when the memory object is deleted and its resources - * freed. Each call to clSetMemObjectCallbackFn registers the specified user callback function on a callback - * stack associated with memobj. The registered user callback functions are called in the reverse order in - * which they were registered. The user callback functions are called and then the memory object is deleted - * and its resources freed. This provides a mechanism for the application (and libraries) using memobj to be - * notified when the memory referenced by host_ptr, specified when the memory object is created and used as - * the storage bits for the memory object, can be reused or freed. - * - * The application may not call CL api's with the cl_mem object passed to the pfn_notify. - * - * Please check for the "cl_APPLE_SetMemObjectDestructor" extension using clGetDeviceInfo(CL_DEVICE_EXTENSIONS) - * before using. - */ -#define cl_APPLE_SetMemObjectDestructor 1 -cl_int CL_API_ENTRY clSetMemObjectDestructorAPPLE( cl_mem /* memobj */, - void (* /*pfn_notify*/)( cl_mem /* memobj */, void* /*user_data*/), - void * /*user_data */ ) CL_EXT_SUFFIX__VERSION_1_0; - - -/* Context Logging Functions - * - * The next three convenience functions are intended to be used as the pfn_notify parameter to clCreateContext(). - * Please check for the "cl_APPLE_ContextLoggingFunctions" extension using clGetDeviceInfo(CL_DEVICE_EXTENSIONS) - * before using. - * - * clLogMessagesToSystemLog fowards on all log messages to the Apple System Logger - */ -#define cl_APPLE_ContextLoggingFunctions 1 -extern void CL_API_ENTRY clLogMessagesToSystemLogAPPLE( const char * /* errstr */, - const void * /* private_info */, - size_t /* cb */, - void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0; - -/* clLogMessagesToStdout sends all log messages to the file descriptor stdout */ -extern void CL_API_ENTRY clLogMessagesToStdoutAPPLE( const char * /* errstr */, - const void * /* private_info */, - size_t /* cb */, - void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0; - -/* clLogMessagesToStderr sends all log messages to the file descriptor stderr */ -extern void CL_API_ENTRY clLogMessagesToStderrAPPLE( const char * /* errstr */, - const void * /* private_info */, - size_t /* cb */, - void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0; - - -/************************ -* cl_khr_icd extension * -************************/ -#define cl_khr_icd 1 - -/* cl_platform_info */ -#define CL_PLATFORM_ICD_SUFFIX_KHR 0x0920 - -/* Additional Error Codes */ -#define CL_PLATFORM_NOT_FOUND_KHR -1001 - -extern CL_API_ENTRY cl_int CL_API_CALL -clIcdGetPlatformIDsKHR(cl_uint /* num_entries */, - cl_platform_id * /* platforms */, - cl_uint * /* num_platforms */); - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clIcdGetPlatformIDsKHR_fn)( - cl_uint /* num_entries */, - cl_platform_id * /* platforms */, - cl_uint * /* num_platforms */); - - - -/******************************* - * cl_khr_il_program extension * - *******************************/ -#define cl_khr_il_program 1 - -/* New property to clGetDeviceInfo for retrieving supported intermediate - * languages - */ -#define CL_DEVICE_IL_VERSION_KHR 0x105B - -/* New property to clGetProgramInfo for retrieving for retrieving the IL of a - * program - */ -#define CL_PROGRAM_IL_KHR 0x1169 - -extern CL_API_ENTRY cl_program - CL_API_CALL clCreateProgramWithILKHR( - cl_context /* context */, - const void * /* il */, - size_t /* length */, - cl_int * /* errcode_ret */); - -typedef CL_API_ENTRY cl_program - (CL_API_CALL *clCreateProgramWithILKHR_fn)( - cl_context /* context */, - const void * /* il */, - size_t /* length */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -/* Extension: cl_khr_image2D_buffer - * - * This extension allows a 2D image to be created from a cl_mem buffer without a copy. - * The type associated with a 2D image created from a buffer in an OpenCL program is image2d_t. - * Both the sampler and sampler-less read_image built-in functions are supported for 2D images - * and 2D images created from a buffer. Similarly, the write_image built-ins are also supported - * for 2D images created from a buffer. - * - * When the 2D image from buffer is created, the client must specify the width, - * height, image format (i.e. channel order and channel data type) and optionally the row pitch - * - * The pitch specified must be a multiple of CL_DEVICE_IMAGE_PITCH_ALIGNMENT pixels. - * The base address of the buffer must be aligned to CL_DEVICE_IMAGE_BASE_ADDRESS_ALIGNMENT pixels. - */ - -/************************************** - * cl_khr_initialize_memory extension * - **************************************/ - -#define CL_CONTEXT_MEMORY_INITIALIZE_KHR 0x2030 - - -/************************************** - * cl_khr_terminate_context extension * - **************************************/ - -#define CL_DEVICE_TERMINATE_CAPABILITY_KHR 0x2031 -#define CL_CONTEXT_TERMINATE_KHR 0x2032 - -#define cl_khr_terminate_context 1 -extern CL_API_ENTRY cl_int CL_API_CALL clTerminateContextKHR(cl_context /* context */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clTerminateContextKHR_fn)(cl_context /* context */) CL_EXT_SUFFIX__VERSION_1_2; - - -/* - * Extension: cl_khr_spir - * - * This extension adds support to create an OpenCL program object from a - * Standard Portable Intermediate Representation (SPIR) instance - */ - -#define CL_DEVICE_SPIR_VERSIONS 0x40E0 -#define CL_PROGRAM_BINARY_TYPE_INTERMEDIATE 0x40E1 - - -/***************************************** - * cl_khr_create_command_queue extension * - *****************************************/ -#define cl_khr_create_command_queue 1 - -typedef cl_bitfield cl_queue_properties_khr; - -extern CL_API_ENTRY cl_command_queue CL_API_CALL -clCreateCommandQueueWithPropertiesKHR( cl_context /* context */, - cl_device_id /* device */, - const cl_queue_properties_khr* /* properties */, - cl_int* /* errcode_ret */ ) CL_EXT_SUFFIX__VERSION_1_2; -typedef CL_API_ENTRY cl_command_queue -(CL_API_CALL *clCreateCommandQueueWithPropertiesKHR_fn)( cl_context /* context */, - cl_device_id /* device */, - const cl_queue_properties_khr* /* properties */, - cl_int* /* errcode_ret */ ) CL_EXT_SUFFIX__VERSION_1_2; - - -/****************************************** -* cl_nv_device_attribute_query extension * -******************************************/ - -/* cl_nv_device_attribute_query extension - no extension #define since it has no functions */ -#define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000 -#define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001 -#define CL_DEVICE_REGISTERS_PER_BLOCK_NV 0x4002 -#define CL_DEVICE_WARP_SIZE_NV 0x4003 -#define CL_DEVICE_GPU_OVERLAP_NV 0x4004 -#define CL_DEVICE_KERNEL_EXEC_TIMEOUT_NV 0x4005 -#define CL_DEVICE_INTEGRATED_MEMORY_NV 0x4006 - - -/********************************* -* cl_amd_device_attribute_query * -*********************************/ - -#define CL_DEVICE_PROFILING_TIMER_OFFSET_AMD 0x4036 - - -/********************************* -* cl_arm_printf extension -*********************************/ - -#define CL_PRINTF_CALLBACK_ARM 0x40B0 -#define CL_PRINTF_BUFFERSIZE_ARM 0x40B1 - - -/*********************************** -* cl_ext_device_fission extension -***********************************/ -#define cl_ext_device_fission 1 - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseDeviceEXT( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_int -(CL_API_CALL *clReleaseDeviceEXT_fn)( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1; - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainDeviceEXT( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_int -(CL_API_CALL *clRetainDeviceEXT_fn)( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1; - -typedef cl_ulong cl_device_partition_property_ext; -extern CL_API_ENTRY cl_int CL_API_CALL -clCreateSubDevicesEXT( cl_device_id /*in_device*/, - const cl_device_partition_property_ext * /* properties */, - cl_uint /*num_entries*/, - cl_device_id * /*out_devices*/, - cl_uint * /*num_devices*/ ) CL_EXT_SUFFIX__VERSION_1_1; - -typedef CL_API_ENTRY cl_int -( CL_API_CALL * clCreateSubDevicesEXT_fn)( cl_device_id /*in_device*/, - const cl_device_partition_property_ext * /* properties */, - cl_uint /*num_entries*/, - cl_device_id * /*out_devices*/, - cl_uint * /*num_devices*/ ) CL_EXT_SUFFIX__VERSION_1_1; - -/* cl_device_partition_property_ext */ -#define CL_DEVICE_PARTITION_EQUALLY_EXT 0x4050 -#define CL_DEVICE_PARTITION_BY_COUNTS_EXT 0x4051 -#define CL_DEVICE_PARTITION_BY_NAMES_EXT 0x4052 -#define CL_DEVICE_PARTITION_BY_AFFINITY_DOMAIN_EXT 0x4053 - -/* clDeviceGetInfo selectors */ -#define CL_DEVICE_PARENT_DEVICE_EXT 0x4054 -#define CL_DEVICE_PARTITION_TYPES_EXT 0x4055 -#define CL_DEVICE_AFFINITY_DOMAINS_EXT 0x4056 -#define CL_DEVICE_REFERENCE_COUNT_EXT 0x4057 -#define CL_DEVICE_PARTITION_STYLE_EXT 0x4058 - -/* error codes */ -#define CL_DEVICE_PARTITION_FAILED_EXT -1057 -#define CL_INVALID_PARTITION_COUNT_EXT -1058 -#define CL_INVALID_PARTITION_NAME_EXT -1059 - -/* CL_AFFINITY_DOMAINs */ -#define CL_AFFINITY_DOMAIN_L1_CACHE_EXT 0x1 -#define CL_AFFINITY_DOMAIN_L2_CACHE_EXT 0x2 -#define CL_AFFINITY_DOMAIN_L3_CACHE_EXT 0x3 -#define CL_AFFINITY_DOMAIN_L4_CACHE_EXT 0x4 -#define CL_AFFINITY_DOMAIN_NUMA_EXT 0x10 -#define CL_AFFINITY_DOMAIN_NEXT_FISSIONABLE_EXT 0x100 - -/* cl_device_partition_property_ext list terminators */ -#define CL_PROPERTIES_LIST_END_EXT ((cl_device_partition_property_ext) 0) -#define CL_PARTITION_BY_COUNTS_LIST_END_EXT ((cl_device_partition_property_ext) 0) -#define CL_PARTITION_BY_NAMES_LIST_END_EXT ((cl_device_partition_property_ext) 0 - 1) - - -/*********************************** - * cl_ext_migrate_memobject extension definitions - ***********************************/ -#define cl_ext_migrate_memobject 1 - -typedef cl_bitfield cl_mem_migration_flags_ext; - -#define CL_MIGRATE_MEM_OBJECT_HOST_EXT 0x1 - -#define CL_COMMAND_MIGRATE_MEM_OBJECT_EXT 0x4040 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueMigrateMemObjectEXT( cl_command_queue /* command_queue */, - cl_uint /* num_mem_objects */, - const cl_mem * /* mem_objects */, - cl_mem_migration_flags_ext /* flags */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */ ); - -typedef CL_API_ENTRY cl_int -(CL_API_CALL *clEnqueueMigrateMemObjectEXT_fn)( cl_command_queue /* command_queue */, - cl_uint /* num_mem_objects */, - const cl_mem * /* mem_objects */, - cl_mem_migration_flags_ext /* flags */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */ ); - - -/********************************* -* cl_qcom_ext_host_ptr extension -*********************************/ -#define cl_qcom_ext_host_ptr 1 - -#define CL_MEM_EXT_HOST_PTR_QCOM (1 << 29) - -#define CL_DEVICE_EXT_MEM_PADDING_IN_BYTES_QCOM 0x40A0 -#define CL_DEVICE_PAGE_SIZE_QCOM 0x40A1 -#define CL_IMAGE_ROW_ALIGNMENT_QCOM 0x40A2 -#define CL_IMAGE_SLICE_ALIGNMENT_QCOM 0x40A3 -#define CL_MEM_HOST_UNCACHED_QCOM 0x40A4 -#define CL_MEM_HOST_WRITEBACK_QCOM 0x40A5 -#define CL_MEM_HOST_WRITETHROUGH_QCOM 0x40A6 -#define CL_MEM_HOST_WRITE_COMBINING_QCOM 0x40A7 - -typedef cl_uint cl_image_pitch_info_qcom; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetDeviceImageInfoQCOM(cl_device_id device, - size_t image_width, - size_t image_height, - const cl_image_format *image_format, - cl_image_pitch_info_qcom param_name, - size_t param_value_size, - void *param_value, - size_t *param_value_size_ret); - -typedef struct _cl_mem_ext_host_ptr -{ - /* Type of external memory allocation. */ - /* Legal values will be defined in layered extensions. */ - cl_uint allocation_type; - - /* Host cache policy for this external memory allocation. */ - cl_uint host_cache_policy; - -} cl_mem_ext_host_ptr; - - -/******************************************* -* cl_qcom_ext_host_ptr_iocoherent extension -********************************************/ - -/* Cache policy specifying io-coherence */ -#define CL_MEM_HOST_IOCOHERENT_QCOM 0x40A9 - - -/********************************* -* cl_qcom_ion_host_ptr extension -*********************************/ - -#define CL_MEM_ION_HOST_PTR_QCOM 0x40A8 - -typedef struct _cl_mem_ion_host_ptr -{ - /* Type of external memory allocation. */ - /* Must be CL_MEM_ION_HOST_PTR_QCOM for ION allocations. */ - cl_mem_ext_host_ptr ext_host_ptr; - - /* ION file descriptor */ - int ion_filedesc; - - /* Host pointer to the ION allocated memory */ - void* ion_hostptr; - -} cl_mem_ion_host_ptr; - - -/********************************* -* cl_qcom_android_native_buffer_host_ptr extension -*********************************/ - -#define CL_MEM_ANDROID_NATIVE_BUFFER_HOST_PTR_QCOM 0x40C6 - -typedef struct _cl_mem_android_native_buffer_host_ptr -{ - /* Type of external memory allocation. */ - /* Must be CL_MEM_ANDROID_NATIVE_BUFFER_HOST_PTR_QCOM for Android native buffers. */ - cl_mem_ext_host_ptr ext_host_ptr; - - /* Virtual pointer to the android native buffer */ - void* anb_ptr; - -} cl_mem_android_native_buffer_host_ptr; - - -/****************************************** - * cl_img_yuv_image extension * - ******************************************/ - -/* Image formats used in clCreateImage */ -#define CL_NV21_IMG 0x40D0 -#define CL_YV12_IMG 0x40D1 - - -/****************************************** - * cl_img_cached_allocations extension * - ******************************************/ - -/* Flag values used by clCreteBuffer */ -#define CL_MEM_USE_UNCACHED_CPU_MEMORY_IMG (1 << 26) -#define CL_MEM_USE_CACHED_CPU_MEMORY_IMG (1 << 27) - - -/****************************************** - * cl_img_use_gralloc_ptr extension * - ******************************************/ -#define cl_img_use_gralloc_ptr 1 - -/* Flag values used by clCreteBuffer */ -#define CL_MEM_USE_GRALLOC_PTR_IMG (1 << 28) - -/* To be used by clGetEventInfo: */ -#define CL_COMMAND_ACQUIRE_GRALLOC_OBJECTS_IMG 0x40D2 -#define CL_COMMAND_RELEASE_GRALLOC_OBJECTS_IMG 0x40D3 - -/* Error code from clEnqueueReleaseGrallocObjectsIMG */ -#define CL_GRALLOC_RESOURCE_NOT_ACQUIRED_IMG 0x40D4 - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueAcquireGrallocObjectsIMG(cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReleaseGrallocObjectsIMG(cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - - -/********************************* -* cl_khr_subgroups extension -*********************************/ -#define cl_khr_subgroups 1 - -#if !defined(CL_VERSION_2_1) -/* For OpenCL 2.1 and newer, cl_kernel_sub_group_info is declared in CL.h. - In hindsight, there should have been a khr suffix on this type for - the extension, but keeping it un-suffixed to maintain backwards - compatibility. */ -typedef cl_uint cl_kernel_sub_group_info; -#endif - -/* cl_kernel_sub_group_info */ -#define CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR 0x2033 -#define CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR 0x2034 - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetKernelSubGroupInfoKHR(cl_kernel /* in_kernel */, - cl_device_id /*in_device*/, - cl_kernel_sub_group_info /* param_name */, - size_t /*input_value_size*/, - const void * /*input_value*/, - size_t /*param_value_size*/, - void* /*param_value*/, - size_t* /*param_value_size_ret*/ ) CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED; - -typedef CL_API_ENTRY cl_int -(CL_API_CALL * clGetKernelSubGroupInfoKHR_fn)(cl_kernel /* in_kernel */, - cl_device_id /*in_device*/, - cl_kernel_sub_group_info /* param_name */, - size_t /*input_value_size*/, - const void * /*input_value*/, - size_t /*param_value_size*/, - void* /*param_value*/, - size_t* /*param_value_size_ret*/ ) CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED; - - -/********************************* -* cl_khr_priority_hints extension -*********************************/ -/* This extension define is for backwards compatibility. - It shouldn't be required since this extension has no new functions. */ -#define cl_khr_priority_hints 1 - -typedef cl_uint cl_queue_priority_khr; - -/* cl_command_queue_properties */ -#define CL_QUEUE_PRIORITY_KHR 0x1096 - -/* cl_queue_priority_khr */ -#define CL_QUEUE_PRIORITY_HIGH_KHR (1<<0) -#define CL_QUEUE_PRIORITY_MED_KHR (1<<1) -#define CL_QUEUE_PRIORITY_LOW_KHR (1<<2) - - -/********************************* -* cl_khr_throttle_hints extension -*********************************/ -/* This extension define is for backwards compatibility. - It shouldn't be required since this extension has no new functions. */ -#define cl_khr_throttle_hints 1 - -typedef cl_uint cl_queue_throttle_khr; - -/* cl_command_queue_properties */ -#define CL_QUEUE_THROTTLE_KHR 0x1097 - -/* cl_queue_throttle_khr */ -#define CL_QUEUE_THROTTLE_HIGH_KHR (1<<0) -#define CL_QUEUE_THROTTLE_MED_KHR (1<<1) -#define CL_QUEUE_THROTTLE_LOW_KHR (1<<2) - - -/********************************* -* cl_khr_subgroup_named_barrier -*********************************/ -/* This extension define is for backwards compatibility. - It shouldn't be required since this extension has no new functions. */ -#define cl_khr_subgroup_named_barrier 1 - -/* cl_device_info */ -#define CL_DEVICE_MAX_NAMED_BARRIER_COUNT_KHR 0x2035 - - -/********************************** - * cl_arm_import_memory extension * - **********************************/ -#define cl_arm_import_memory 1 - -typedef intptr_t cl_import_properties_arm; - -/* Default and valid proporties name for cl_arm_import_memory */ -#define CL_IMPORT_TYPE_ARM 0x40B2 - -/* Host process memory type default value for CL_IMPORT_TYPE_ARM property */ -#define CL_IMPORT_TYPE_HOST_ARM 0x40B3 - -/* DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */ -#define CL_IMPORT_TYPE_DMA_BUF_ARM 0x40B4 - -/* Secure DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */ -#define CL_IMPORT_TYPE_SECURE_ARM 0x40B5 - -/* This extension adds a new function that allows for direct memory import into - * OpenCL via the clImportMemoryARM function. - * - * Memory imported through this interface will be mapped into the device's page - * tables directly, providing zero copy access. It will never fall back to copy - * operations and aliased buffers. - * - * Types of memory supported for import are specified as additional extension - * strings. - * - * This extension produces cl_mem allocations which are compatible with all other - * users of cl_mem in the standard API. - * - * This extension maps pages with the same properties as the normal buffer creation - * function clCreateBuffer. - */ -extern CL_API_ENTRY cl_mem CL_API_CALL -clImportMemoryARM( cl_context context, - cl_mem_flags flags, - const cl_import_properties_arm *properties, - void *memory, - size_t size, - cl_int *errcode_ret) CL_EXT_SUFFIX__VERSION_1_0; - - -/****************************************** - * cl_arm_shared_virtual_memory extension * - ******************************************/ -#define cl_arm_shared_virtual_memory 1 - -/* Used by clGetDeviceInfo */ -#define CL_DEVICE_SVM_CAPABILITIES_ARM 0x40B6 - -/* Used by clGetMemObjectInfo */ -#define CL_MEM_USES_SVM_POINTER_ARM 0x40B7 - -/* Used by clSetKernelExecInfoARM: */ -#define CL_KERNEL_EXEC_INFO_SVM_PTRS_ARM 0x40B8 -#define CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM_ARM 0x40B9 - -/* To be used by clGetEventInfo: */ -#define CL_COMMAND_SVM_FREE_ARM 0x40BA -#define CL_COMMAND_SVM_MEMCPY_ARM 0x40BB -#define CL_COMMAND_SVM_MEMFILL_ARM 0x40BC -#define CL_COMMAND_SVM_MAP_ARM 0x40BD -#define CL_COMMAND_SVM_UNMAP_ARM 0x40BE - -/* Flag values returned by clGetDeviceInfo with CL_DEVICE_SVM_CAPABILITIES_ARM as the param_name. */ -#define CL_DEVICE_SVM_COARSE_GRAIN_BUFFER_ARM (1 << 0) -#define CL_DEVICE_SVM_FINE_GRAIN_BUFFER_ARM (1 << 1) -#define CL_DEVICE_SVM_FINE_GRAIN_SYSTEM_ARM (1 << 2) -#define CL_DEVICE_SVM_ATOMICS_ARM (1 << 3) - -/* Flag values used by clSVMAllocARM: */ -#define CL_MEM_SVM_FINE_GRAIN_BUFFER_ARM (1 << 10) -#define CL_MEM_SVM_ATOMICS_ARM (1 << 11) - -typedef cl_bitfield cl_svm_mem_flags_arm; -typedef cl_uint cl_kernel_exec_info_arm; -typedef cl_bitfield cl_device_svm_capabilities_arm; - -extern CL_API_ENTRY void * CL_API_CALL -clSVMAllocARM(cl_context /* context */, - cl_svm_mem_flags_arm /* flags */, - size_t /* size */, - cl_uint /* alignment */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY void CL_API_CALL -clSVMFreeARM(cl_context /* context */, - void * /* svm_pointer */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMFreeARM(cl_command_queue /* command_queue */, - cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void (CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */, - cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void * /* user_data */), - void * /* user_data */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMemcpyARM(cl_command_queue /* command_queue */, - cl_bool /* blocking_copy */, - void * /* dst_ptr */, - const void * /* src_ptr */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMemFillARM(cl_command_queue /* command_queue */, - void * /* svm_ptr */, - const void * /* pattern */, - size_t /* pattern_size */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMMapARM(cl_command_queue /* command_queue */, - cl_bool /* blocking_map */, - cl_map_flags /* flags */, - void * /* svm_ptr */, - size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueSVMUnmapARM(cl_command_queue /* command_queue */, - void * /* svm_ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clSetKernelArgSVMPointerARM(cl_kernel /* kernel */, - cl_uint /* arg_index */, - const void * /* arg_value */) CL_EXT_SUFFIX__VERSION_1_2; -extern CL_API_ENTRY cl_int CL_API_CALL -clSetKernelExecInfoARM(cl_kernel /* kernel */, - cl_kernel_exec_info_arm /* param_name */, - size_t /* param_value_size */, - const void * /* param_value */) CL_EXT_SUFFIX__VERSION_1_2; - -#ifdef __cplusplus -} -#endif - - -#endif /* __CL_EXT_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_ext_intel.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_ext_intel.h deleted file mode 100644 index 53bd3107c5e7e3c6f01c1127faed59bd5e741def..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_ext_intel.h +++ /dev/null @@ -1,428 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2008-2017 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - ******************************************************************************/ -/*****************************************************************************\ - -Copyright (c) 2013-2017 Intel Corporation All Rights Reserved. - -THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE -MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -File Name: cl_ext_intel.h - -Abstract: - -Notes: - -\*****************************************************************************/ - -#ifndef __CL_EXT_INTEL_H -#define __CL_EXT_INTEL_H - -#ifdef __APPLE__ - #include - #include -#else - #include - #include -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -/*************************************** -* cl_intel_thread_local_exec extension * -****************************************/ - -#define cl_intel_thread_local_exec 1 - -#define CL_QUEUE_THREAD_LOCAL_EXEC_ENABLE_INTEL (((cl_bitfield)1) << 31) - -/*********************************************** -* cl_intel_device_partition_by_names extension * -************************************************/ - -#define cl_intel_device_partition_by_names 1 - -#define CL_DEVICE_PARTITION_BY_NAMES_INTEL 0x4052 -#define CL_PARTITION_BY_NAMES_LIST_END_INTEL -1 - -/************************************************ -* cl_intel_accelerator extension * -* cl_intel_motion_estimation extension * -* cl_intel_advanced_motion_estimation extension * -*************************************************/ - -#define cl_intel_accelerator 1 -#define cl_intel_motion_estimation 1 -#define cl_intel_advanced_motion_estimation 1 - -typedef struct _cl_accelerator_intel* cl_accelerator_intel; -typedef cl_uint cl_accelerator_type_intel; -typedef cl_uint cl_accelerator_info_intel; - -typedef struct _cl_motion_estimation_desc_intel { - cl_uint mb_block_type; - cl_uint subpixel_mode; - cl_uint sad_adjust_mode; - cl_uint search_path_type; -} cl_motion_estimation_desc_intel; - -/* error codes */ -#define CL_INVALID_ACCELERATOR_INTEL -1094 -#define CL_INVALID_ACCELERATOR_TYPE_INTEL -1095 -#define CL_INVALID_ACCELERATOR_DESCRIPTOR_INTEL -1096 -#define CL_ACCELERATOR_TYPE_NOT_SUPPORTED_INTEL -1097 - -/* cl_accelerator_type_intel */ -#define CL_ACCELERATOR_TYPE_MOTION_ESTIMATION_INTEL 0x0 - -/* cl_accelerator_info_intel */ -#define CL_ACCELERATOR_DESCRIPTOR_INTEL 0x4090 -#define CL_ACCELERATOR_REFERENCE_COUNT_INTEL 0x4091 -#define CL_ACCELERATOR_CONTEXT_INTEL 0x4092 -#define CL_ACCELERATOR_TYPE_INTEL 0x4093 - -/* cl_motion_detect_desc_intel flags */ -#define CL_ME_MB_TYPE_16x16_INTEL 0x0 -#define CL_ME_MB_TYPE_8x8_INTEL 0x1 -#define CL_ME_MB_TYPE_4x4_INTEL 0x2 - -#define CL_ME_SUBPIXEL_MODE_INTEGER_INTEL 0x0 -#define CL_ME_SUBPIXEL_MODE_HPEL_INTEL 0x1 -#define CL_ME_SUBPIXEL_MODE_QPEL_INTEL 0x2 - -#define CL_ME_SAD_ADJUST_MODE_NONE_INTEL 0x0 -#define CL_ME_SAD_ADJUST_MODE_HAAR_INTEL 0x1 - -#define CL_ME_SEARCH_PATH_RADIUS_2_2_INTEL 0x0 -#define CL_ME_SEARCH_PATH_RADIUS_4_4_INTEL 0x1 -#define CL_ME_SEARCH_PATH_RADIUS_16_12_INTEL 0x5 - -#define CL_ME_SKIP_BLOCK_TYPE_16x16_INTEL 0x0 -#define CL_ME_CHROMA_INTRA_PREDICT_ENABLED_INTEL 0x1 -#define CL_ME_LUMA_INTRA_PREDICT_ENABLED_INTEL 0x2 -#define CL_ME_SKIP_BLOCK_TYPE_8x8_INTEL 0x4 - -#define CL_ME_FORWARD_INPUT_MODE_INTEL 0x1 -#define CL_ME_BACKWARD_INPUT_MODE_INTEL 0x2 -#define CL_ME_BIDIRECTION_INPUT_MODE_INTEL 0x3 - -#define CL_ME_BIDIR_WEIGHT_QUARTER_INTEL 16 -#define CL_ME_BIDIR_WEIGHT_THIRD_INTEL 21 -#define CL_ME_BIDIR_WEIGHT_HALF_INTEL 32 -#define CL_ME_BIDIR_WEIGHT_TWO_THIRD_INTEL 43 -#define CL_ME_BIDIR_WEIGHT_THREE_QUARTER_INTEL 48 - -#define CL_ME_COST_PENALTY_NONE_INTEL 0x0 -#define CL_ME_COST_PENALTY_LOW_INTEL 0x1 -#define CL_ME_COST_PENALTY_NORMAL_INTEL 0x2 -#define CL_ME_COST_PENALTY_HIGH_INTEL 0x3 - -#define CL_ME_COST_PRECISION_QPEL_INTEL 0x0 -#define CL_ME_COST_PRECISION_HPEL_INTEL 0x1 -#define CL_ME_COST_PRECISION_PEL_INTEL 0x2 -#define CL_ME_COST_PRECISION_DPEL_INTEL 0x3 - -#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_INTEL 0x0 -#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1 -#define CL_ME_LUMA_PREDICTOR_MODE_DC_INTEL 0x2 -#define CL_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_LEFT_INTEL 0x3 - -#define CL_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_RIGHT_INTEL 0x4 -#define CL_ME_LUMA_PREDICTOR_MODE_PLANE_INTEL 0x4 -#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_RIGHT_INTEL 0x5 -#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_DOWN_INTEL 0x6 -#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_LEFT_INTEL 0x7 -#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_UP_INTEL 0x8 - -#define CL_ME_CHROMA_PREDICTOR_MODE_DC_INTEL 0x0 -#define CL_ME_CHROMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1 -#define CL_ME_CHROMA_PREDICTOR_MODE_VERTICAL_INTEL 0x2 -#define CL_ME_CHROMA_PREDICTOR_MODE_PLANE_INTEL 0x3 - -/* cl_device_info */ -#define CL_DEVICE_ME_VERSION_INTEL 0x407E - -#define CL_ME_VERSION_LEGACY_INTEL 0x0 -#define CL_ME_VERSION_ADVANCED_VER_1_INTEL 0x1 -#define CL_ME_VERSION_ADVANCED_VER_2_INTEL 0x2 - -extern CL_API_ENTRY cl_accelerator_intel CL_API_CALL -clCreateAcceleratorINTEL( - cl_context /* context */, - cl_accelerator_type_intel /* accelerator_type */, - size_t /* descriptor_size */, - const void* /* descriptor */, - cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_accelerator_intel (CL_API_CALL *clCreateAcceleratorINTEL_fn)( - cl_context /* context */, - cl_accelerator_type_intel /* accelerator_type */, - size_t /* descriptor_size */, - const void* /* descriptor */, - cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetAcceleratorInfoINTEL( - cl_accelerator_intel /* accelerator */, - cl_accelerator_info_intel /* param_name */, - size_t /* param_value_size */, - void* /* param_value */, - size_t* /* param_value_size_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetAcceleratorInfoINTEL_fn)( - cl_accelerator_intel /* accelerator */, - cl_accelerator_info_intel /* param_name */, - size_t /* param_value_size */, - void* /* param_value */, - size_t* /* param_value_size_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clRetainAcceleratorINTEL( - cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clRetainAcceleratorINTEL_fn)( - cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clReleaseAcceleratorINTEL( - cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clReleaseAcceleratorINTEL_fn)( - cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2; - -/****************************************** -* cl_intel_simultaneous_sharing extension * -*******************************************/ - -#define cl_intel_simultaneous_sharing 1 - -#define CL_DEVICE_SIMULTANEOUS_INTEROPS_INTEL 0x4104 -#define CL_DEVICE_NUM_SIMULTANEOUS_INTEROPS_INTEL 0x4105 - -/*********************************** -* cl_intel_egl_image_yuv extension * -************************************/ - -#define cl_intel_egl_image_yuv 1 - -#define CL_EGL_YUV_PLANE_INTEL 0x4107 - -/******************************** -* cl_intel_packed_yuv extension * -*********************************/ - -#define cl_intel_packed_yuv 1 - -#define CL_YUYV_INTEL 0x4076 -#define CL_UYVY_INTEL 0x4077 -#define CL_YVYU_INTEL 0x4078 -#define CL_VYUY_INTEL 0x4079 - -/******************************************** -* cl_intel_required_subgroup_size extension * -*********************************************/ - -#define cl_intel_required_subgroup_size 1 - -#define CL_DEVICE_SUB_GROUP_SIZES_INTEL 0x4108 -#define CL_KERNEL_SPILL_MEM_SIZE_INTEL 0x4109 -#define CL_KERNEL_COMPILE_SUB_GROUP_SIZE_INTEL 0x410A - -/**************************************** -* cl_intel_driver_diagnostics extension * -*****************************************/ - -#define cl_intel_driver_diagnostics 1 - -typedef cl_uint cl_diagnostics_verbose_level; - -#define CL_CONTEXT_SHOW_DIAGNOSTICS_INTEL 0x4106 - -#define CL_CONTEXT_DIAGNOSTICS_LEVEL_ALL_INTEL ( 0xff ) -#define CL_CONTEXT_DIAGNOSTICS_LEVEL_GOOD_INTEL ( 1 ) -#define CL_CONTEXT_DIAGNOSTICS_LEVEL_BAD_INTEL ( 1 << 1 ) -#define CL_CONTEXT_DIAGNOSTICS_LEVEL_NEUTRAL_INTEL ( 1 << 2 ) - -/******************************** -* cl_intel_planar_yuv extension * -*********************************/ - -#define CL_NV12_INTEL 0x410E - -#define CL_MEM_NO_ACCESS_INTEL ( 1 << 24 ) -#define CL_MEM_ACCESS_FLAGS_UNRESTRICTED_INTEL ( 1 << 25 ) - -#define CL_DEVICE_PLANAR_YUV_MAX_WIDTH_INTEL 0x417E -#define CL_DEVICE_PLANAR_YUV_MAX_HEIGHT_INTEL 0x417F - -/******************************************************* -* cl_intel_device_side_avc_motion_estimation extension * -********************************************************/ - -#define CL_DEVICE_AVC_ME_VERSION_INTEL 0x410B -#define CL_DEVICE_AVC_ME_SUPPORTS_TEXTURE_SAMPLER_USE_INTEL 0x410C -#define CL_DEVICE_AVC_ME_SUPPORTS_PREEMPTION_INTEL 0x410D - -#define CL_AVC_ME_VERSION_0_INTEL 0x0; // No support. -#define CL_AVC_ME_VERSION_1_INTEL 0x1; // First supported version. - -#define CL_AVC_ME_MAJOR_16x16_INTEL 0x0 -#define CL_AVC_ME_MAJOR_16x8_INTEL 0x1 -#define CL_AVC_ME_MAJOR_8x16_INTEL 0x2 -#define CL_AVC_ME_MAJOR_8x8_INTEL 0x3 - -#define CL_AVC_ME_MINOR_8x8_INTEL 0x0 -#define CL_AVC_ME_MINOR_8x4_INTEL 0x1 -#define CL_AVC_ME_MINOR_4x8_INTEL 0x2 -#define CL_AVC_ME_MINOR_4x4_INTEL 0x3 - -#define CL_AVC_ME_MAJOR_FORWARD_INTEL 0x0 -#define CL_AVC_ME_MAJOR_BACKWARD_INTEL 0x1 -#define CL_AVC_ME_MAJOR_BIDIRECTIONAL_INTEL 0x2 - -#define CL_AVC_ME_PARTITION_MASK_ALL_INTEL 0x0 -#define CL_AVC_ME_PARTITION_MASK_16x16_INTEL 0x7E -#define CL_AVC_ME_PARTITION_MASK_16x8_INTEL 0x7D -#define CL_AVC_ME_PARTITION_MASK_8x16_INTEL 0x7B -#define CL_AVC_ME_PARTITION_MASK_8x8_INTEL 0x77 -#define CL_AVC_ME_PARTITION_MASK_8x4_INTEL 0x6F -#define CL_AVC_ME_PARTITION_MASK_4x8_INTEL 0x5F -#define CL_AVC_ME_PARTITION_MASK_4x4_INTEL 0x3F - -#define CL_AVC_ME_SEARCH_WINDOW_EXHAUSTIVE_INTEL 0x0 -#define CL_AVC_ME_SEARCH_WINDOW_SMALL_INTEL 0x1 -#define CL_AVC_ME_SEARCH_WINDOW_TINY_INTEL 0x2 -#define CL_AVC_ME_SEARCH_WINDOW_EXTRA_TINY_INTEL 0x3 -#define CL_AVC_ME_SEARCH_WINDOW_DIAMOND_INTEL 0x4 -#define CL_AVC_ME_SEARCH_WINDOW_LARGE_DIAMOND_INTEL 0x5 -#define CL_AVC_ME_SEARCH_WINDOW_RESERVED0_INTEL 0x6 -#define CL_AVC_ME_SEARCH_WINDOW_RESERVED1_INTEL 0x7 -#define CL_AVC_ME_SEARCH_WINDOW_CUSTOM_INTEL 0x8 -#define CL_AVC_ME_SEARCH_WINDOW_16x12_RADIUS_INTEL 0x9 -#define CL_AVC_ME_SEARCH_WINDOW_4x4_RADIUS_INTEL 0x2 -#define CL_AVC_ME_SEARCH_WINDOW_2x2_RADIUS_INTEL 0xa - -#define CL_AVC_ME_SAD_ADJUST_MODE_NONE_INTEL 0x0 -#define CL_AVC_ME_SAD_ADJUST_MODE_HAAR_INTEL 0x2 - -#define CL_AVC_ME_SUBPIXEL_MODE_INTEGER_INTEL 0x0 -#define CL_AVC_ME_SUBPIXEL_MODE_HPEL_INTEL 0x1 -#define CL_AVC_ME_SUBPIXEL_MODE_QPEL_INTEL 0x3 - -#define CL_AVC_ME_COST_PRECISION_QPEL_INTEL 0x0 -#define CL_AVC_ME_COST_PRECISION_HPEL_INTEL 0x1 -#define CL_AVC_ME_COST_PRECISION_PEL_INTEL 0x2 -#define CL_AVC_ME_COST_PRECISION_DPEL_INTEL 0x3 - -#define CL_AVC_ME_BIDIR_WEIGHT_QUARTER_INTEL 0x10 -#define CL_AVC_ME_BIDIR_WEIGHT_THIRD_INTEL 0x15 -#define CL_AVC_ME_BIDIR_WEIGHT_HALF_INTEL 0x20 -#define CL_AVC_ME_BIDIR_WEIGHT_TWO_THIRD_INTEL 0x2B -#define CL_AVC_ME_BIDIR_WEIGHT_THREE_QUARTER_INTEL 0x30 - -#define CL_AVC_ME_BORDER_REACHED_LEFT_INTEL 0x0 -#define CL_AVC_ME_BORDER_REACHED_RIGHT_INTEL 0x2 -#define CL_AVC_ME_BORDER_REACHED_TOP_INTEL 0x4 -#define CL_AVC_ME_BORDER_REACHED_BOTTOM_INTEL 0x8 - -#define CL_AVC_ME_SKIP_BLOCK_PARTITION_16x16_INTEL 0x0 -#define CL_AVC_ME_SKIP_BLOCK_PARTITION_8x8_INTEL 0x4000 - -#define CL_AVC_ME_SKIP_BLOCK_16x16_FORWARD_ENABLE_INTEL ( 0x1 << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_16x16_BACKWARD_ENABLE_INTEL ( 0x2 << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_16x16_DUAL_ENABLE_INTEL ( 0x3 << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_FORWARD_ENABLE_INTEL ( 0x55 << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_BACKWARD_ENABLE_INTEL ( 0xAA << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_DUAL_ENABLE_INTEL ( 0xFF << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_0_FORWARD_ENABLE_INTEL ( 0x1 << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_0_BACKWARD_ENABLE_INTEL ( 0x2 << 24 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_1_FORWARD_ENABLE_INTEL ( 0x1 << 26 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_1_BACKWARD_ENABLE_INTEL ( 0x2 << 26 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_2_FORWARD_ENABLE_INTEL ( 0x1 << 28 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_2_BACKWARD_ENABLE_INTEL ( 0x2 << 28 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_3_FORWARD_ENABLE_INTEL ( 0x1 << 30 ) -#define CL_AVC_ME_SKIP_BLOCK_8x8_3_BACKWARD_ENABLE_INTEL ( 0x2 << 30 ) - -#define CL_AVC_ME_BLOCK_BASED_SKIP_4x4_INTEL 0x00 -#define CL_AVC_ME_BLOCK_BASED_SKIP_8x8_INTEL 0x80 - -#define CL_AVC_ME_INTRA_16x16_INTEL 0x0 -#define CL_AVC_ME_INTRA_8x8_INTEL 0x1 -#define CL_AVC_ME_INTRA_4x4_INTEL 0x2 - -#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_16x16_INTEL 0x6 -#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_8x8_INTEL 0x5 -#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_4x4_INTEL 0x3 - -#define CL_AVC_ME_INTRA_NEIGHBOR_LEFT_MASK_ENABLE_INTEL 0x60 -#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_MASK_ENABLE_INTEL 0x10 -#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_RIGHT_MASK_ENABLE_INTEL 0x8 -#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_LEFT_MASK_ENABLE_INTEL 0x4 - -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_INTEL 0x0 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DC_INTEL 0x2 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_LEFT_INTEL 0x3 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_RIGHT_INTEL 0x4 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_PLANE_INTEL 0x4 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_RIGHT_INTEL 0x5 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_DOWN_INTEL 0x6 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_LEFT_INTEL 0x7 -#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_UP_INTEL 0x8 -#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_DC_INTEL 0x0 -#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1 -#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_VERTICAL_INTEL 0x2 -#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_PLANE_INTEL 0x3 - -#define CL_AVC_ME_FRAME_FORWARD_INTEL 0x1 -#define CL_AVC_ME_FRAME_BACKWARD_INTEL 0x2 -#define CL_AVC_ME_FRAME_DUAL_INTEL 0x3 - -#define CL_AVC_ME_SLICE_TYPE_PRED_INTEL 0x0 -#define CL_AVC_ME_SLICE_TYPE_BPRED_INTEL 0x1 -#define CL_AVC_ME_SLICE_TYPE_INTRA_INTEL 0x2 - -#define CL_AVC_ME_INTERLACED_SCAN_TOP_FIELD_INTEL 0x0 -#define CL_AVC_ME_INTERLACED_SCAN_BOTTOM_FIELD_INTEL 0x1 - -#ifdef __cplusplus -} -#endif - -#endif /* __CL_EXT_INTEL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_gl.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_gl.h deleted file mode 100644 index 58b6449f9b4e98d561ee9a6f8b3daa6caede9f44..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_gl.h +++ /dev/null @@ -1,175 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2018 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ - -#ifndef __OPENCL_CL_GL_H -#define __OPENCL_CL_GL_H - -#ifdef __APPLE__ -#include -#else -#include -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -typedef cl_uint cl_gl_object_type; -typedef cl_uint cl_gl_texture_info; -typedef cl_uint cl_gl_platform_info; -typedef struct __GLsync *cl_GLsync; - -/* cl_gl_object_type = 0x2000 - 0x200F enum values are currently taken */ -#define CL_GL_OBJECT_BUFFER 0x2000 -#define CL_GL_OBJECT_TEXTURE2D 0x2001 -#define CL_GL_OBJECT_TEXTURE3D 0x2002 -#define CL_GL_OBJECT_RENDERBUFFER 0x2003 -#ifdef CL_VERSION_1_2 -#define CL_GL_OBJECT_TEXTURE2D_ARRAY 0x200E -#define CL_GL_OBJECT_TEXTURE1D 0x200F -#define CL_GL_OBJECT_TEXTURE1D_ARRAY 0x2010 -#define CL_GL_OBJECT_TEXTURE_BUFFER 0x2011 -#endif - -/* cl_gl_texture_info */ -#define CL_GL_TEXTURE_TARGET 0x2004 -#define CL_GL_MIPMAP_LEVEL 0x2005 -#ifdef CL_VERSION_1_2 -#define CL_GL_NUM_SAMPLES 0x2012 -#endif - - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateFromGLBuffer(cl_context /* context */, - cl_mem_flags /* flags */, - cl_GLuint /* bufobj */, - int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -#ifdef CL_VERSION_1_2 - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateFromGLTexture(cl_context /* context */, - cl_mem_flags /* flags */, - cl_GLenum /* target */, - cl_GLint /* miplevel */, - cl_GLuint /* texture */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; - -#endif - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateFromGLRenderbuffer(cl_context /* context */, - cl_mem_flags /* flags */, - cl_GLuint /* renderbuffer */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetGLObjectInfo(cl_mem /* memobj */, - cl_gl_object_type * /* gl_object_type */, - cl_GLuint * /* gl_object_name */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetGLTextureInfo(cl_mem /* memobj */, - cl_gl_texture_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueAcquireGLObjects(cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReleaseGLObjects(cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - - -/* Deprecated OpenCL 1.1 APIs */ -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL -clCreateFromGLTexture2D(cl_context /* context */, - cl_mem_flags /* flags */, - cl_GLenum /* target */, - cl_GLint /* miplevel */, - cl_GLuint /* texture */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL -clCreateFromGLTexture3D(cl_context /* context */, - cl_mem_flags /* flags */, - cl_GLenum /* target */, - cl_GLint /* miplevel */, - cl_GLuint /* texture */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED; - -/* cl_khr_gl_sharing extension */ - -#define cl_khr_gl_sharing 1 - -typedef cl_uint cl_gl_context_info; - -/* Additional Error Codes */ -#define CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR -1000 - -/* cl_gl_context_info */ -#define CL_CURRENT_DEVICE_FOR_GL_CONTEXT_KHR 0x2006 -#define CL_DEVICES_FOR_GL_CONTEXT_KHR 0x2007 - -/* Additional cl_context_properties */ -#define CL_GL_CONTEXT_KHR 0x2008 -#define CL_EGL_DISPLAY_KHR 0x2009 -#define CL_GLX_DISPLAY_KHR 0x200A -#define CL_WGL_HDC_KHR 0x200B -#define CL_CGL_SHAREGROUP_KHR 0x200C - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetGLContextInfoKHR(const cl_context_properties * /* properties */, - cl_gl_context_info /* param_name */, - size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetGLContextInfoKHR_fn)( - const cl_context_properties * properties, - cl_gl_context_info param_name, - size_t param_value_size, - void * param_value, - size_t * param_value_size_ret); - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_GL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_gl_ext.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_gl_ext.h deleted file mode 100644 index e3c14c6408c44160103bcb4c0dcd230a674643a5..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_gl_ext.h +++ /dev/null @@ -1,74 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2015 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ - -/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */ - -/* cl_gl_ext.h contains vendor (non-KHR) OpenCL extensions which have */ -/* OpenGL dependencies. */ - -#ifndef __OPENCL_CL_GL_EXT_H -#define __OPENCL_CL_GL_EXT_H - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __APPLE__ - #include -#else - #include -#endif - -/* - * For each extension, follow this template - * cl_VEN_extname extension */ -/* #define cl_VEN_extname 1 - * ... define new types, if any - * ... define new tokens, if any - * ... define new APIs, if any - * - * If you need GLtypes here, mirror them with a cl_GLtype, rather than including a GL header - * This allows us to avoid having to decide whether to include GL headers or GLES here. - */ - -/* - * cl_khr_gl_event extension - * See section 9.9 in the OpenCL 1.1 spec for more information - */ -#define CL_COMMAND_GL_FENCE_SYNC_OBJECT_KHR 0x200D - -extern CL_API_ENTRY cl_event CL_API_CALL -clCreateEventFromGLsyncKHR(cl_context /* context */, - cl_GLsync /* cl_GLsync */, - cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_GL_EXT_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_platform.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_platform.h deleted file mode 100644 index c2f408fed59fc42f9c2573061704610498890b40..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_platform.h +++ /dev/null @@ -1,1460 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2018 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ - -#ifndef __CL_PLATFORM_H -#define __CL_PLATFORM_H - -#ifdef __APPLE__ - #include - - /* Contains #defines for AVAILABLE_MAC_OS_X_VERSION_10_6_AND_LATER below */ - #include -#else - #include -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#if defined(_WIN32) - #define CL_API_ENTRY - #define CL_API_CALL __stdcall - #define CL_CALLBACK __stdcall -#else - #define CL_API_ENTRY - #define CL_API_CALL - #define CL_CALLBACK -#endif - -/* - * Deprecation flags refer to the last version of the header in which the - * feature was not deprecated. - * - * E.g. VERSION_1_1_DEPRECATED means the feature is present in 1.1 without - * deprecation but is deprecated in versions later than 1.1. - */ - -#ifdef __APPLE__ - #define CL_EXTENSION_WEAK_LINK __attribute__((weak_import)) - #define CL_API_SUFFIX__VERSION_1_0 AVAILABLE_MAC_OS_X_VERSION_10_6_AND_LATER - #define CL_EXT_SUFFIX__VERSION_1_0 CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_6_AND_LATER - #define CL_API_SUFFIX__VERSION_1_1 AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #define GCL_API_SUFFIX__VERSION_1_1 AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #define CL_EXT_SUFFIX__VERSION_1_1 CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #define CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_6_AND_LATER_BUT_DEPRECATED_IN_MAC_OS_X_VERSION_10_7 - - #ifdef AVAILABLE_MAC_OS_X_VERSION_10_8_AND_LATER - #define CL_API_SUFFIX__VERSION_1_2 AVAILABLE_MAC_OS_X_VERSION_10_8_AND_LATER - #define GCL_API_SUFFIX__VERSION_1_2 AVAILABLE_MAC_OS_X_VERSION_10_8_AND_LATER - #define CL_EXT_SUFFIX__VERSION_1_2 CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_8_AND_LATER - #define CL_EXT_PREFIX__VERSION_1_1_DEPRECATED - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER_BUT_DEPRECATED_IN_MAC_OS_X_VERSION_10_8 - #else - #warning This path should never happen outside of internal operating system development. AvailabilityMacros do not function correctly here! - #define CL_API_SUFFIX__VERSION_1_2 AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #define GCL_API_SUFFIX__VERSION_1_2 AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #define CL_EXT_SUFFIX__VERSION_1_2 CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED CL_EXTENSION_WEAK_LINK AVAILABLE_MAC_OS_X_VERSION_10_7_AND_LATER - #endif -#else - #define CL_EXTENSION_WEAK_LINK - #define CL_API_SUFFIX__VERSION_1_0 - #define CL_EXT_SUFFIX__VERSION_1_0 - #define CL_API_SUFFIX__VERSION_1_1 - #define CL_EXT_SUFFIX__VERSION_1_1 - #define CL_API_SUFFIX__VERSION_1_2 - #define CL_EXT_SUFFIX__VERSION_1_2 - #define CL_API_SUFFIX__VERSION_2_0 - #define CL_EXT_SUFFIX__VERSION_2_0 - #define CL_API_SUFFIX__VERSION_2_1 - #define CL_EXT_SUFFIX__VERSION_2_1 - #define CL_API_SUFFIX__VERSION_2_2 - #define CL_EXT_SUFFIX__VERSION_2_2 - - #ifdef __GNUC__ - #ifdef CL_USE_DEPRECATED_OPENCL_1_0_APIS - #define CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_0_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED __attribute__((deprecated)) - #define CL_EXT_PREFIX__VERSION_1_0_DEPRECATED - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_1_1_APIS - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_1_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED __attribute__((deprecated)) - #define CL_EXT_PREFIX__VERSION_1_1_DEPRECATED - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_1_2_APIS - #define CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_2_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED __attribute__((deprecated)) - #define CL_EXT_PREFIX__VERSION_1_2_DEPRECATED - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_2_0_APIS - #define CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_0_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED __attribute__((deprecated)) - #define CL_EXT_PREFIX__VERSION_2_0_DEPRECATED - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_2_1_APIS - #define CL_EXT_SUFFIX__VERSION_2_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_1_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_2_1_DEPRECATED __attribute__((deprecated)) - #define CL_EXT_PREFIX__VERSION_2_1_DEPRECATED - #endif - #elif defined(_WIN32) - #ifdef CL_USE_DEPRECATED_OPENCL_1_0_APIS - #define CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_0_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_0_DEPRECATED __declspec(deprecated) - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_1_1_APIS - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_1_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_1_DEPRECATED __declspec(deprecated) - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_1_2_APIS - #define CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_2_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_2_DEPRECATED __declspec(deprecated) - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_2_0_APIS - #define CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_0_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_0_DEPRECATED __declspec(deprecated) - #endif - - #ifdef CL_USE_DEPRECATED_OPENCL_2_1_APIS - #define CL_EXT_SUFFIX__VERSION_2_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_1_DEPRECATED - #else - #define CL_EXT_SUFFIX__VERSION_2_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_1_DEPRECATED __declspec(deprecated) - #endif - #else - #define CL_EXT_SUFFIX__VERSION_1_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_0_DEPRECATED - - #define CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_1_DEPRECATED - - #define CL_EXT_SUFFIX__VERSION_1_2_DEPRECATED - #define CL_EXT_PREFIX__VERSION_1_2_DEPRECATED - - #define CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_0_DEPRECATED - - #define CL_EXT_SUFFIX__VERSION_2_1_DEPRECATED - #define CL_EXT_PREFIX__VERSION_2_1_DEPRECATED - #endif -#endif - -#if (defined (_WIN32) && defined(_MSC_VER)) - -/* scalar types */ -typedef signed __int8 cl_char; -typedef unsigned __int8 cl_uchar; -typedef signed __int16 cl_short; -typedef unsigned __int16 cl_ushort; -typedef signed __int32 cl_int; -typedef unsigned __int32 cl_uint; -typedef signed __int64 cl_long; -typedef unsigned __int64 cl_ulong; - -typedef unsigned __int16 cl_half; -typedef float cl_float; -typedef double cl_double; - -/* Macro names and corresponding values defined by OpenCL */ -#define CL_CHAR_BIT 8 -#define CL_SCHAR_MAX 127 -#define CL_SCHAR_MIN (-127-1) -#define CL_CHAR_MAX CL_SCHAR_MAX -#define CL_CHAR_MIN CL_SCHAR_MIN -#define CL_UCHAR_MAX 255 -#define CL_SHRT_MAX 32767 -#define CL_SHRT_MIN (-32767-1) -#define CL_USHRT_MAX 65535 -#define CL_INT_MAX 2147483647 -#define CL_INT_MIN (-2147483647-1) -#define CL_UINT_MAX 0xffffffffU -#define CL_LONG_MAX ((cl_long) 0x7FFFFFFFFFFFFFFFLL) -#define CL_LONG_MIN ((cl_long) -0x7FFFFFFFFFFFFFFFLL - 1LL) -#define CL_ULONG_MAX ((cl_ulong) 0xFFFFFFFFFFFFFFFFULL) - -#define CL_FLT_DIG 6 -#define CL_FLT_MANT_DIG 24 -#define CL_FLT_MAX_10_EXP +38 -#define CL_FLT_MAX_EXP +128 -#define CL_FLT_MIN_10_EXP -37 -#define CL_FLT_MIN_EXP -125 -#define CL_FLT_RADIX 2 -#define CL_FLT_MAX 340282346638528859811704183484516925440.0f -#define CL_FLT_MIN 1.175494350822287507969e-38f -#define CL_FLT_EPSILON 1.1920928955078125e-7f - -#define CL_HALF_DIG 3 -#define CL_HALF_MANT_DIG 11 -#define CL_HALF_MAX_10_EXP +4 -#define CL_HALF_MAX_EXP +16 -#define CL_HALF_MIN_10_EXP -4 -#define CL_HALF_MIN_EXP -13 -#define CL_HALF_RADIX 2 -#define CL_HALF_MAX 65504.0f -#define CL_HALF_MIN 6.103515625e-05f -#define CL_HALF_EPSILON 9.765625e-04f - -#define CL_DBL_DIG 15 -#define CL_DBL_MANT_DIG 53 -#define CL_DBL_MAX_10_EXP +308 -#define CL_DBL_MAX_EXP +1024 -#define CL_DBL_MIN_10_EXP -307 -#define CL_DBL_MIN_EXP -1021 -#define CL_DBL_RADIX 2 -#define CL_DBL_MAX 1.7976931348623158e+308 -#define CL_DBL_MIN 2.225073858507201383090e-308 -#define CL_DBL_EPSILON 2.220446049250313080847e-16 - -#define CL_M_E 2.7182818284590452354 -#define CL_M_LOG2E 1.4426950408889634074 -#define CL_M_LOG10E 0.43429448190325182765 -#define CL_M_LN2 0.69314718055994530942 -#define CL_M_LN10 2.30258509299404568402 -#define CL_M_PI 3.14159265358979323846 -#define CL_M_PI_2 1.57079632679489661923 -#define CL_M_PI_4 0.78539816339744830962 -#define CL_M_1_PI 0.31830988618379067154 -#define CL_M_2_PI 0.63661977236758134308 -#define CL_M_2_SQRTPI 1.12837916709551257390 -#define CL_M_SQRT2 1.41421356237309504880 -#define CL_M_SQRT1_2 0.70710678118654752440 - -#define CL_M_E_F 2.718281828f -#define CL_M_LOG2E_F 1.442695041f -#define CL_M_LOG10E_F 0.434294482f -#define CL_M_LN2_F 0.693147181f -#define CL_M_LN10_F 2.302585093f -#define CL_M_PI_F 3.141592654f -#define CL_M_PI_2_F 1.570796327f -#define CL_M_PI_4_F 0.785398163f -#define CL_M_1_PI_F 0.318309886f -#define CL_M_2_PI_F 0.636619772f -#define CL_M_2_SQRTPI_F 1.128379167f -#define CL_M_SQRT2_F 1.414213562f -#define CL_M_SQRT1_2_F 0.707106781f - -#define CL_NAN (CL_INFINITY - CL_INFINITY) -#define CL_HUGE_VALF ((cl_float) 1e50) -#define CL_HUGE_VAL ((cl_double) 1e500) -#define CL_MAXFLOAT CL_FLT_MAX -#define CL_INFINITY CL_HUGE_VALF - -#else - -#include - -/* scalar types */ -typedef int8_t cl_char; -typedef uint8_t cl_uchar; -typedef int16_t cl_short __attribute__((aligned(2))); -typedef uint16_t cl_ushort __attribute__((aligned(2))); -typedef int32_t cl_int __attribute__((aligned(4))); -typedef uint32_t cl_uint __attribute__((aligned(4))); -typedef int64_t cl_long __attribute__((aligned(8))); -typedef uint64_t cl_ulong __attribute__((aligned(8))); - -typedef uint16_t cl_half __attribute__((aligned(2))); -typedef float cl_float __attribute__((aligned(4))); -typedef double cl_double __attribute__((aligned(8))); - -/* Macro names and corresponding values defined by OpenCL */ -#define CL_CHAR_BIT 8 -#define CL_SCHAR_MAX 127 -#define CL_SCHAR_MIN (-127-1) -#define CL_CHAR_MAX CL_SCHAR_MAX -#define CL_CHAR_MIN CL_SCHAR_MIN -#define CL_UCHAR_MAX 255 -#define CL_SHRT_MAX 32767 -#define CL_SHRT_MIN (-32767-1) -#define CL_USHRT_MAX 65535 -#define CL_INT_MAX 2147483647 -#define CL_INT_MIN (-2147483647-1) -#define CL_UINT_MAX 0xffffffffU -#define CL_LONG_MAX ((cl_long) 0x7FFFFFFFFFFFFFFFLL) -#define CL_LONG_MIN ((cl_long) -0x7FFFFFFFFFFFFFFFLL - 1LL) -#define CL_ULONG_MAX ((cl_ulong) 0xFFFFFFFFFFFFFFFFULL) - -#define CL_FLT_DIG 6 -#define CL_FLT_MANT_DIG 24 -#define CL_FLT_MAX_10_EXP +38 -#define CL_FLT_MAX_EXP +128 -#define CL_FLT_MIN_10_EXP -37 -#define CL_FLT_MIN_EXP -125 -#define CL_FLT_RADIX 2 -#define CL_FLT_MAX 340282346638528859811704183484516925440.0f -#define CL_FLT_MIN 1.175494350822287507969e-38f -#define CL_FLT_EPSILON 1.1920928955078125e-7f - -#define CL_HALF_DIG 3 -#define CL_HALF_MANT_DIG 11 -#define CL_HALF_MAX_10_EXP +4 -#define CL_HALF_MAX_EXP +16 -#define CL_HALF_MIN_10_EXP -4 -#define CL_HALF_MIN_EXP -13 -#define CL_HALF_RADIX 2 -#define CL_HALF_MAX 65504.0f -#define CL_HALF_MIN 6.103515625e-05f -#define CL_HALF_EPSILON 9.765625e-04f - -#define CL_DBL_DIG 15 -#define CL_DBL_MANT_DIG 53 -#define CL_DBL_MAX_10_EXP +308 -#define CL_DBL_MAX_EXP +1024 -#define CL_DBL_MIN_10_EXP -307 -#define CL_DBL_MIN_EXP -1021 -#define CL_DBL_RADIX 2 -#define CL_DBL_MAX 179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0 -#define CL_DBL_MIN 2.225073858507201383090e-308 -#define CL_DBL_EPSILON 2.220446049250313080847e-16 - -#define CL_M_E 2.7182818284590452354 -#define CL_M_LOG2E 1.4426950408889634074 -#define CL_M_LOG10E 0.43429448190325182765 -#define CL_M_LN2 0.69314718055994530942 -#define CL_M_LN10 2.30258509299404568402 -#define CL_M_PI 3.14159265358979323846 -#define CL_M_PI_2 1.57079632679489661923 -#define CL_M_PI_4 0.78539816339744830962 -#define CL_M_1_PI 0.31830988618379067154 -#define CL_M_2_PI 0.63661977236758134308 -#define CL_M_2_SQRTPI 1.12837916709551257390 -#define CL_M_SQRT2 1.41421356237309504880 -#define CL_M_SQRT1_2 0.70710678118654752440 - -#define CL_M_E_F 2.718281828f -#define CL_M_LOG2E_F 1.442695041f -#define CL_M_LOG10E_F 0.434294482f -#define CL_M_LN2_F 0.693147181f -#define CL_M_LN10_F 2.302585093f -#define CL_M_PI_F 3.141592654f -#define CL_M_PI_2_F 1.570796327f -#define CL_M_PI_4_F 0.785398163f -#define CL_M_1_PI_F 0.318309886f -#define CL_M_2_PI_F 0.636619772f -#define CL_M_2_SQRTPI_F 1.128379167f -#define CL_M_SQRT2_F 1.414213562f -#define CL_M_SQRT1_2_F 0.707106781f - -#if defined( __GNUC__ ) - #define CL_HUGE_VALF __builtin_huge_valf() - #define CL_HUGE_VAL __builtin_huge_val() - #define CL_NAN __builtin_nanf( "" ) -#else - #define CL_HUGE_VALF ((cl_float) 1e50) - #define CL_HUGE_VAL ((cl_double) 1e500) - float nanf( const char * ); - #define CL_NAN nanf( "" ) -#endif -#define CL_MAXFLOAT CL_FLT_MAX -#define CL_INFINITY CL_HUGE_VALF - -#endif - -#include - -/* Mirror types to GL types. Mirror types allow us to avoid deciding which 87s to load based on whether we are using GL or GLES here. */ -typedef unsigned int cl_GLuint; -typedef int cl_GLint; -typedef unsigned int cl_GLenum; - -/* - * Vector types - * - * Note: OpenCL requires that all types be naturally aligned. - * This means that vector types must be naturally aligned. - * For example, a vector of four floats must be aligned to - * a 16 byte boundary (calculated as 4 * the natural 4-byte - * alignment of the float). The alignment qualifiers here - * will only function properly if your compiler supports them - * and if you don't actively work to defeat them. For example, - * in order for a cl_float4 to be 16 byte aligned in a struct, - * the start of the struct must itself be 16-byte aligned. - * - * Maintaining proper alignment is the user's responsibility. - */ - -/* Define basic vector types */ -#if defined( __VEC__ ) - #include /* may be omitted depending on compiler. AltiVec spec provides no way to detect whether the header is required. */ - typedef vector unsigned char __cl_uchar16; - typedef vector signed char __cl_char16; - typedef vector unsigned short __cl_ushort8; - typedef vector signed short __cl_short8; - typedef vector unsigned int __cl_uint4; - typedef vector signed int __cl_int4; - typedef vector float __cl_float4; - #define __CL_UCHAR16__ 1 - #define __CL_CHAR16__ 1 - #define __CL_USHORT8__ 1 - #define __CL_SHORT8__ 1 - #define __CL_UINT4__ 1 - #define __CL_INT4__ 1 - #define __CL_FLOAT4__ 1 -#endif - -#if defined( __SSE__ ) - #if defined( __MINGW64__ ) - #include - #else - #include - #endif - #if defined( __GNUC__ ) - typedef float __cl_float4 __attribute__((vector_size(16))); - #else - typedef __m128 __cl_float4; - #endif - #define __CL_FLOAT4__ 1 -#endif - -#if defined( __SSE2__ ) - #if defined( __MINGW64__ ) - #include - #else - #include - #endif - #if defined( __GNUC__ ) - typedef cl_uchar __cl_uchar16 __attribute__((vector_size(16))); - typedef cl_char __cl_char16 __attribute__((vector_size(16))); - typedef cl_ushort __cl_ushort8 __attribute__((vector_size(16))); - typedef cl_short __cl_short8 __attribute__((vector_size(16))); - typedef cl_uint __cl_uint4 __attribute__((vector_size(16))); - typedef cl_int __cl_int4 __attribute__((vector_size(16))); - typedef cl_ulong __cl_ulong2 __attribute__((vector_size(16))); - typedef cl_long __cl_long2 __attribute__((vector_size(16))); - typedef cl_double __cl_double2 __attribute__((vector_size(16))); - #else - typedef __m128i __cl_uchar16; - typedef __m128i __cl_char16; - typedef __m128i __cl_ushort8; - typedef __m128i __cl_short8; - typedef __m128i __cl_uint4; - typedef __m128i __cl_int4; - typedef __m128i __cl_ulong2; - typedef __m128i __cl_long2; - typedef __m128d __cl_double2; - #endif - #define __CL_UCHAR16__ 1 - #define __CL_CHAR16__ 1 - #define __CL_USHORT8__ 1 - #define __CL_SHORT8__ 1 - #define __CL_INT4__ 1 - #define __CL_UINT4__ 1 - #define __CL_ULONG2__ 1 - #define __CL_LONG2__ 1 - #define __CL_DOUBLE2__ 1 -#endif - -#if defined( __MMX__ ) - #include - #if defined( __GNUC__ ) - typedef cl_uchar __cl_uchar8 __attribute__((vector_size(8))); - typedef cl_char __cl_char8 __attribute__((vector_size(8))); - typedef cl_ushort __cl_ushort4 __attribute__((vector_size(8))); - typedef cl_short __cl_short4 __attribute__((vector_size(8))); - typedef cl_uint __cl_uint2 __attribute__((vector_size(8))); - typedef cl_int __cl_int2 __attribute__((vector_size(8))); - typedef cl_ulong __cl_ulong1 __attribute__((vector_size(8))); - typedef cl_long __cl_long1 __attribute__((vector_size(8))); - typedef cl_float __cl_float2 __attribute__((vector_size(8))); - #else - typedef __m64 __cl_uchar8; - typedef __m64 __cl_char8; - typedef __m64 __cl_ushort4; - typedef __m64 __cl_short4; - typedef __m64 __cl_uint2; - typedef __m64 __cl_int2; - typedef __m64 __cl_ulong1; - typedef __m64 __cl_long1; - typedef __m64 __cl_float2; - #endif - #define __CL_UCHAR8__ 1 - #define __CL_CHAR8__ 1 - #define __CL_USHORT4__ 1 - #define __CL_SHORT4__ 1 - #define __CL_INT2__ 1 - #define __CL_UINT2__ 1 - #define __CL_ULONG1__ 1 - #define __CL_LONG1__ 1 - #define __CL_FLOAT2__ 1 -#endif - -#if defined( __AVX__ ) - #if defined( __MINGW64__ ) - #include - #else - #include - #endif - #if defined( __GNUC__ ) - typedef cl_float __cl_float8 __attribute__((vector_size(32))); - typedef cl_double __cl_double4 __attribute__((vector_size(32))); - #else - typedef __m256 __cl_float8; - typedef __m256d __cl_double4; - #endif - #define __CL_FLOAT8__ 1 - #define __CL_DOUBLE4__ 1 -#endif - -/* Define capabilities for anonymous struct members. */ -#if !defined(__cplusplus) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L -#define __CL_HAS_ANON_STRUCT__ 1 -#define __CL_ANON_STRUCT__ -#elif defined( __GNUC__) && ! defined( __STRICT_ANSI__ ) -#define __CL_HAS_ANON_STRUCT__ 1 -#define __CL_ANON_STRUCT__ __extension__ -#elif defined( _WIN32) && defined(_MSC_VER) - #if _MSC_VER >= 1500 - /* Microsoft Developer Studio 2008 supports anonymous structs, but - * complains by default. */ - #define __CL_HAS_ANON_STRUCT__ 1 - #define __CL_ANON_STRUCT__ - /* Disable warning C4201: nonstandard extension used : nameless - * struct/union */ - #pragma warning( push ) - #pragma warning( disable : 4201 ) - #endif -#else -#define __CL_HAS_ANON_STRUCT__ 0 -#define __CL_ANON_STRUCT__ -#endif - -/* Define alignment keys */ -#if defined( __GNUC__ ) - #define CL_ALIGNED(_x) __attribute__ ((aligned(_x))) -#elif defined( _WIN32) && (_MSC_VER) - /* Alignment keys neutered on windows because MSVC can't swallow function arguments with alignment requirements */ - /* http://msdn.microsoft.com/en-us/library/373ak2y1%28VS.71%29.aspx */ - /* #include */ - /* #define CL_ALIGNED(_x) _CRT_ALIGN(_x) */ - #define CL_ALIGNED(_x) -#else - #warning Need to implement some method to align data here - #define CL_ALIGNED(_x) -#endif - -/* Indicate whether .xyzw, .s0123 and .hi.lo are supported */ -#if __CL_HAS_ANON_STRUCT__ - /* .xyzw and .s0123...{f|F} are supported */ - #define CL_HAS_NAMED_VECTOR_FIELDS 1 - /* .hi and .lo are supported */ - #define CL_HAS_HI_LO_VECTOR_FIELDS 1 -#endif - -/* Define cl_vector types */ - -/* ---- cl_charn ---- */ -typedef union -{ - cl_char CL_ALIGNED(2) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_char x, y; }; - __CL_ANON_STRUCT__ struct{ cl_char s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_char lo, hi; }; -#endif -#if defined( __CL_CHAR2__) - __cl_char2 v2; -#endif -}cl_char2; - -typedef union -{ - cl_char CL_ALIGNED(4) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_char x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_char s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_char2 lo, hi; }; -#endif -#if defined( __CL_CHAR2__) - __cl_char2 v2[2]; -#endif -#if defined( __CL_CHAR4__) - __cl_char4 v4; -#endif -}cl_char4; - -/* cl_char3 is identical in size, alignment and behavior to cl_char4. See section 6.1.5. */ -typedef cl_char4 cl_char3; - -typedef union -{ - cl_char CL_ALIGNED(8) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_char x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_char s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_char4 lo, hi; }; -#endif -#if defined( __CL_CHAR2__) - __cl_char2 v2[4]; -#endif -#if defined( __CL_CHAR4__) - __cl_char4 v4[2]; -#endif -#if defined( __CL_CHAR8__ ) - __cl_char8 v8; -#endif -}cl_char8; - -typedef union -{ - cl_char CL_ALIGNED(16) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_char x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_char s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_char8 lo, hi; }; -#endif -#if defined( __CL_CHAR2__) - __cl_char2 v2[8]; -#endif -#if defined( __CL_CHAR4__) - __cl_char4 v4[4]; -#endif -#if defined( __CL_CHAR8__ ) - __cl_char8 v8[2]; -#endif -#if defined( __CL_CHAR16__ ) - __cl_char16 v16; -#endif -}cl_char16; - - -/* ---- cl_ucharn ---- */ -typedef union -{ - cl_uchar CL_ALIGNED(2) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uchar x, y; }; - __CL_ANON_STRUCT__ struct{ cl_uchar s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_uchar lo, hi; }; -#endif -#if defined( __cl_uchar2__) - __cl_uchar2 v2; -#endif -}cl_uchar2; - -typedef union -{ - cl_uchar CL_ALIGNED(4) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uchar x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_uchar s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_uchar2 lo, hi; }; -#endif -#if defined( __CL_UCHAR2__) - __cl_uchar2 v2[2]; -#endif -#if defined( __CL_UCHAR4__) - __cl_uchar4 v4; -#endif -}cl_uchar4; - -/* cl_uchar3 is identical in size, alignment and behavior to cl_uchar4. See section 6.1.5. */ -typedef cl_uchar4 cl_uchar3; - -typedef union -{ - cl_uchar CL_ALIGNED(8) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uchar x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_uchar s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_uchar4 lo, hi; }; -#endif -#if defined( __CL_UCHAR2__) - __cl_uchar2 v2[4]; -#endif -#if defined( __CL_UCHAR4__) - __cl_uchar4 v4[2]; -#endif -#if defined( __CL_UCHAR8__ ) - __cl_uchar8 v8; -#endif -}cl_uchar8; - -typedef union -{ - cl_uchar CL_ALIGNED(16) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uchar x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_uchar s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_uchar8 lo, hi; }; -#endif -#if defined( __CL_UCHAR2__) - __cl_uchar2 v2[8]; -#endif -#if defined( __CL_UCHAR4__) - __cl_uchar4 v4[4]; -#endif -#if defined( __CL_UCHAR8__ ) - __cl_uchar8 v8[2]; -#endif -#if defined( __CL_UCHAR16__ ) - __cl_uchar16 v16; -#endif -}cl_uchar16; - - -/* ---- cl_shortn ---- */ -typedef union -{ - cl_short CL_ALIGNED(4) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_short x, y; }; - __CL_ANON_STRUCT__ struct{ cl_short s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_short lo, hi; }; -#endif -#if defined( __CL_SHORT2__) - __cl_short2 v2; -#endif -}cl_short2; - -typedef union -{ - cl_short CL_ALIGNED(8) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_short x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_short s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_short2 lo, hi; }; -#endif -#if defined( __CL_SHORT2__) - __cl_short2 v2[2]; -#endif -#if defined( __CL_SHORT4__) - __cl_short4 v4; -#endif -}cl_short4; - -/* cl_short3 is identical in size, alignment and behavior to cl_short4. See section 6.1.5. */ -typedef cl_short4 cl_short3; - -typedef union -{ - cl_short CL_ALIGNED(16) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_short x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_short s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_short4 lo, hi; }; -#endif -#if defined( __CL_SHORT2__) - __cl_short2 v2[4]; -#endif -#if defined( __CL_SHORT4__) - __cl_short4 v4[2]; -#endif -#if defined( __CL_SHORT8__ ) - __cl_short8 v8; -#endif -}cl_short8; - -typedef union -{ - cl_short CL_ALIGNED(32) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_short x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_short s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_short8 lo, hi; }; -#endif -#if defined( __CL_SHORT2__) - __cl_short2 v2[8]; -#endif -#if defined( __CL_SHORT4__) - __cl_short4 v4[4]; -#endif -#if defined( __CL_SHORT8__ ) - __cl_short8 v8[2]; -#endif -#if defined( __CL_SHORT16__ ) - __cl_short16 v16; -#endif -}cl_short16; - - -/* ---- cl_ushortn ---- */ -typedef union -{ - cl_ushort CL_ALIGNED(4) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ushort x, y; }; - __CL_ANON_STRUCT__ struct{ cl_ushort s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_ushort lo, hi; }; -#endif -#if defined( __CL_USHORT2__) - __cl_ushort2 v2; -#endif -}cl_ushort2; - -typedef union -{ - cl_ushort CL_ALIGNED(8) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ushort x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_ushort s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_ushort2 lo, hi; }; -#endif -#if defined( __CL_USHORT2__) - __cl_ushort2 v2[2]; -#endif -#if defined( __CL_USHORT4__) - __cl_ushort4 v4; -#endif -}cl_ushort4; - -/* cl_ushort3 is identical in size, alignment and behavior to cl_ushort4. See section 6.1.5. */ -typedef cl_ushort4 cl_ushort3; - -typedef union -{ - cl_ushort CL_ALIGNED(16) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ushort x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_ushort s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_ushort4 lo, hi; }; -#endif -#if defined( __CL_USHORT2__) - __cl_ushort2 v2[4]; -#endif -#if defined( __CL_USHORT4__) - __cl_ushort4 v4[2]; -#endif -#if defined( __CL_USHORT8__ ) - __cl_ushort8 v8; -#endif -}cl_ushort8; - -typedef union -{ - cl_ushort CL_ALIGNED(32) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ushort x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_ushort s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_ushort8 lo, hi; }; -#endif -#if defined( __CL_USHORT2__) - __cl_ushort2 v2[8]; -#endif -#if defined( __CL_USHORT4__) - __cl_ushort4 v4[4]; -#endif -#if defined( __CL_USHORT8__ ) - __cl_ushort8 v8[2]; -#endif -#if defined( __CL_USHORT16__ ) - __cl_ushort16 v16; -#endif -}cl_ushort16; - - -/* ---- cl_halfn ---- */ -typedef union -{ - cl_half CL_ALIGNED(4) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_half x, y; }; - __CL_ANON_STRUCT__ struct{ cl_half s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_half lo, hi; }; -#endif -#if defined( __CL_HALF2__) - __cl_half2 v2; -#endif -}cl_half2; - -typedef union -{ - cl_half CL_ALIGNED(8) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_half x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_half s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_half2 lo, hi; }; -#endif -#if defined( __CL_HALF2__) - __cl_half2 v2[2]; -#endif -#if defined( __CL_HALF4__) - __cl_half4 v4; -#endif -}cl_half4; - -/* cl_half3 is identical in size, alignment and behavior to cl_half4. See section 6.1.5. */ -typedef cl_half4 cl_half3; - -typedef union -{ - cl_half CL_ALIGNED(16) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_half x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_half s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_half4 lo, hi; }; -#endif -#if defined( __CL_HALF2__) - __cl_half2 v2[4]; -#endif -#if defined( __CL_HALF4__) - __cl_half4 v4[2]; -#endif -#if defined( __CL_HALF8__ ) - __cl_half8 v8; -#endif -}cl_half8; - -typedef union -{ - cl_half CL_ALIGNED(32) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_half x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_half s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_half8 lo, hi; }; -#endif -#if defined( __CL_HALF2__) - __cl_half2 v2[8]; -#endif -#if defined( __CL_HALF4__) - __cl_half4 v4[4]; -#endif -#if defined( __CL_HALF8__ ) - __cl_half8 v8[2]; -#endif -#if defined( __CL_HALF16__ ) - __cl_half16 v16; -#endif -}cl_half16; - -/* ---- cl_intn ---- */ -typedef union -{ - cl_int CL_ALIGNED(8) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_int x, y; }; - __CL_ANON_STRUCT__ struct{ cl_int s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_int lo, hi; }; -#endif -#if defined( __CL_INT2__) - __cl_int2 v2; -#endif -}cl_int2; - -typedef union -{ - cl_int CL_ALIGNED(16) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_int x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_int s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_int2 lo, hi; }; -#endif -#if defined( __CL_INT2__) - __cl_int2 v2[2]; -#endif -#if defined( __CL_INT4__) - __cl_int4 v4; -#endif -}cl_int4; - -/* cl_int3 is identical in size, alignment and behavior to cl_int4. See section 6.1.5. */ -typedef cl_int4 cl_int3; - -typedef union -{ - cl_int CL_ALIGNED(32) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_int x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_int s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_int4 lo, hi; }; -#endif -#if defined( __CL_INT2__) - __cl_int2 v2[4]; -#endif -#if defined( __CL_INT4__) - __cl_int4 v4[2]; -#endif -#if defined( __CL_INT8__ ) - __cl_int8 v8; -#endif -}cl_int8; - -typedef union -{ - cl_int CL_ALIGNED(64) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_int x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_int s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_int8 lo, hi; }; -#endif -#if defined( __CL_INT2__) - __cl_int2 v2[8]; -#endif -#if defined( __CL_INT4__) - __cl_int4 v4[4]; -#endif -#if defined( __CL_INT8__ ) - __cl_int8 v8[2]; -#endif -#if defined( __CL_INT16__ ) - __cl_int16 v16; -#endif -}cl_int16; - - -/* ---- cl_uintn ---- */ -typedef union -{ - cl_uint CL_ALIGNED(8) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uint x, y; }; - __CL_ANON_STRUCT__ struct{ cl_uint s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_uint lo, hi; }; -#endif -#if defined( __CL_UINT2__) - __cl_uint2 v2; -#endif -}cl_uint2; - -typedef union -{ - cl_uint CL_ALIGNED(16) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uint x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_uint s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_uint2 lo, hi; }; -#endif -#if defined( __CL_UINT2__) - __cl_uint2 v2[2]; -#endif -#if defined( __CL_UINT4__) - __cl_uint4 v4; -#endif -}cl_uint4; - -/* cl_uint3 is identical in size, alignment and behavior to cl_uint4. See section 6.1.5. */ -typedef cl_uint4 cl_uint3; - -typedef union -{ - cl_uint CL_ALIGNED(32) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uint x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_uint s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_uint4 lo, hi; }; -#endif -#if defined( __CL_UINT2__) - __cl_uint2 v2[4]; -#endif -#if defined( __CL_UINT4__) - __cl_uint4 v4[2]; -#endif -#if defined( __CL_UINT8__ ) - __cl_uint8 v8; -#endif -}cl_uint8; - -typedef union -{ - cl_uint CL_ALIGNED(64) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_uint x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_uint s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_uint8 lo, hi; }; -#endif -#if defined( __CL_UINT2__) - __cl_uint2 v2[8]; -#endif -#if defined( __CL_UINT4__) - __cl_uint4 v4[4]; -#endif -#if defined( __CL_UINT8__ ) - __cl_uint8 v8[2]; -#endif -#if defined( __CL_UINT16__ ) - __cl_uint16 v16; -#endif -}cl_uint16; - -/* ---- cl_longn ---- */ -typedef union -{ - cl_long CL_ALIGNED(16) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_long x, y; }; - __CL_ANON_STRUCT__ struct{ cl_long s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_long lo, hi; }; -#endif -#if defined( __CL_LONG2__) - __cl_long2 v2; -#endif -}cl_long2; - -typedef union -{ - cl_long CL_ALIGNED(32) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_long x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_long s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_long2 lo, hi; }; -#endif -#if defined( __CL_LONG2__) - __cl_long2 v2[2]; -#endif -#if defined( __CL_LONG4__) - __cl_long4 v4; -#endif -}cl_long4; - -/* cl_long3 is identical in size, alignment and behavior to cl_long4. See section 6.1.5. */ -typedef cl_long4 cl_long3; - -typedef union -{ - cl_long CL_ALIGNED(64) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_long x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_long s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_long4 lo, hi; }; -#endif -#if defined( __CL_LONG2__) - __cl_long2 v2[4]; -#endif -#if defined( __CL_LONG4__) - __cl_long4 v4[2]; -#endif -#if defined( __CL_LONG8__ ) - __cl_long8 v8; -#endif -}cl_long8; - -typedef union -{ - cl_long CL_ALIGNED(128) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_long x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_long s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_long8 lo, hi; }; -#endif -#if defined( __CL_LONG2__) - __cl_long2 v2[8]; -#endif -#if defined( __CL_LONG4__) - __cl_long4 v4[4]; -#endif -#if defined( __CL_LONG8__ ) - __cl_long8 v8[2]; -#endif -#if defined( __CL_LONG16__ ) - __cl_long16 v16; -#endif -}cl_long16; - - -/* ---- cl_ulongn ---- */ -typedef union -{ - cl_ulong CL_ALIGNED(16) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ulong x, y; }; - __CL_ANON_STRUCT__ struct{ cl_ulong s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_ulong lo, hi; }; -#endif -#if defined( __CL_ULONG2__) - __cl_ulong2 v2; -#endif -}cl_ulong2; - -typedef union -{ - cl_ulong CL_ALIGNED(32) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ulong x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_ulong s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_ulong2 lo, hi; }; -#endif -#if defined( __CL_ULONG2__) - __cl_ulong2 v2[2]; -#endif -#if defined( __CL_ULONG4__) - __cl_ulong4 v4; -#endif -}cl_ulong4; - -/* cl_ulong3 is identical in size, alignment and behavior to cl_ulong4. See section 6.1.5. */ -typedef cl_ulong4 cl_ulong3; - -typedef union -{ - cl_ulong CL_ALIGNED(64) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ulong x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_ulong s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_ulong4 lo, hi; }; -#endif -#if defined( __CL_ULONG2__) - __cl_ulong2 v2[4]; -#endif -#if defined( __CL_ULONG4__) - __cl_ulong4 v4[2]; -#endif -#if defined( __CL_ULONG8__ ) - __cl_ulong8 v8; -#endif -}cl_ulong8; - -typedef union -{ - cl_ulong CL_ALIGNED(128) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_ulong x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_ulong s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_ulong8 lo, hi; }; -#endif -#if defined( __CL_ULONG2__) - __cl_ulong2 v2[8]; -#endif -#if defined( __CL_ULONG4__) - __cl_ulong4 v4[4]; -#endif -#if defined( __CL_ULONG8__ ) - __cl_ulong8 v8[2]; -#endif -#if defined( __CL_ULONG16__ ) - __cl_ulong16 v16; -#endif -}cl_ulong16; - - -/* --- cl_floatn ---- */ - -typedef union -{ - cl_float CL_ALIGNED(8) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_float x, y; }; - __CL_ANON_STRUCT__ struct{ cl_float s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_float lo, hi; }; -#endif -#if defined( __CL_FLOAT2__) - __cl_float2 v2; -#endif -}cl_float2; - -typedef union -{ - cl_float CL_ALIGNED(16) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_float x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_float s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_float2 lo, hi; }; -#endif -#if defined( __CL_FLOAT2__) - __cl_float2 v2[2]; -#endif -#if defined( __CL_FLOAT4__) - __cl_float4 v4; -#endif -}cl_float4; - -/* cl_float3 is identical in size, alignment and behavior to cl_float4. See section 6.1.5. */ -typedef cl_float4 cl_float3; - -typedef union -{ - cl_float CL_ALIGNED(32) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_float x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_float s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_float4 lo, hi; }; -#endif -#if defined( __CL_FLOAT2__) - __cl_float2 v2[4]; -#endif -#if defined( __CL_FLOAT4__) - __cl_float4 v4[2]; -#endif -#if defined( __CL_FLOAT8__ ) - __cl_float8 v8; -#endif -}cl_float8; - -typedef union -{ - cl_float CL_ALIGNED(64) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_float x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_float8 lo, hi; }; -#endif -#if defined( __CL_FLOAT2__) - __cl_float2 v2[8]; -#endif -#if defined( __CL_FLOAT4__) - __cl_float4 v4[4]; -#endif -#if defined( __CL_FLOAT8__ ) - __cl_float8 v8[2]; -#endif -#if defined( __CL_FLOAT16__ ) - __cl_float16 v16; -#endif -}cl_float16; - -/* --- cl_doublen ---- */ - -typedef union -{ - cl_double CL_ALIGNED(16) s[2]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_double x, y; }; - __CL_ANON_STRUCT__ struct{ cl_double s0, s1; }; - __CL_ANON_STRUCT__ struct{ cl_double lo, hi; }; -#endif -#if defined( __CL_DOUBLE2__) - __cl_double2 v2; -#endif -}cl_double2; - -typedef union -{ - cl_double CL_ALIGNED(32) s[4]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_double x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_double s0, s1, s2, s3; }; - __CL_ANON_STRUCT__ struct{ cl_double2 lo, hi; }; -#endif -#if defined( __CL_DOUBLE2__) - __cl_double2 v2[2]; -#endif -#if defined( __CL_DOUBLE4__) - __cl_double4 v4; -#endif -}cl_double4; - -/* cl_double3 is identical in size, alignment and behavior to cl_double4. See section 6.1.5. */ -typedef cl_double4 cl_double3; - -typedef union -{ - cl_double CL_ALIGNED(64) s[8]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_double x, y, z, w; }; - __CL_ANON_STRUCT__ struct{ cl_double s0, s1, s2, s3, s4, s5, s6, s7; }; - __CL_ANON_STRUCT__ struct{ cl_double4 lo, hi; }; -#endif -#if defined( __CL_DOUBLE2__) - __cl_double2 v2[4]; -#endif -#if defined( __CL_DOUBLE4__) - __cl_double4 v4[2]; -#endif -#if defined( __CL_DOUBLE8__ ) - __cl_double8 v8; -#endif -}cl_double8; - -typedef union -{ - cl_double CL_ALIGNED(128) s[16]; -#if __CL_HAS_ANON_STRUCT__ - __CL_ANON_STRUCT__ struct{ cl_double x, y, z, w, __spacer4, __spacer5, __spacer6, __spacer7, __spacer8, __spacer9, sa, sb, sc, sd, se, sf; }; - __CL_ANON_STRUCT__ struct{ cl_double s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, sA, sB, sC, sD, sE, sF; }; - __CL_ANON_STRUCT__ struct{ cl_double8 lo, hi; }; -#endif -#if defined( __CL_DOUBLE2__) - __cl_double2 v2[8]; -#endif -#if defined( __CL_DOUBLE4__) - __cl_double4 v4[4]; -#endif -#if defined( __CL_DOUBLE8__ ) - __cl_double8 v8[2]; -#endif -#if defined( __CL_DOUBLE16__ ) - __cl_double16 v16; -#endif -}cl_double16; - -/* Macro to facilitate debugging - * Usage: - * Place CL_PROGRAM_STRING_DEBUG_INFO on the line before the first line of your source. - * The first line ends with: CL_PROGRAM_STRING_DEBUG_INFO \" - * Each line thereafter of OpenCL C source must end with: \n\ - * The last line ends in "; - * - * Example: - * - * const char *my_program = CL_PROGRAM_STRING_DEBUG_INFO "\ - * kernel void foo( int a, float * b ) \n\ - * { \n\ - * // my comment \n\ - * *b[ get_global_id(0)] = a; \n\ - * } \n\ - * "; - * - * This should correctly set up the line, (column) and file information for your source - * string so you can do source level debugging. - */ -#define __CL_STRINGIFY( _x ) # _x -#define _CL_STRINGIFY( _x ) __CL_STRINGIFY( _x ) -#define CL_PROGRAM_STRING_DEBUG_INFO "#line " _CL_STRINGIFY(__LINE__) " \"" __FILE__ "\" \n\n" - -#ifdef __cplusplus -} -#endif - -#undef __CL_HAS_ANON_STRUCT__ -#undef __CL_ANON_STRUCT__ -#if defined( _WIN32) && defined(_MSC_VER) - #if _MSC_VER >=1500 - #pragma warning( pop ) - #endif -#endif - -#endif /* __CL_PLATFORM_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_va_api_media_sharing_intel.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_va_api_media_sharing_intel.h deleted file mode 100644 index 7cb777e84623c5767abd715fd51ec3f3a0504248..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_va_api_media_sharing_intel.h +++ /dev/null @@ -1,171 +0,0 @@ -/********************************************************************************** - * Copyright (c) 2008-2016 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - **********************************************************************************/ -/*****************************************************************************\ - -Copyright (c) 2013-2016 Intel Corporation All Rights Reserved. - -THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE -MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -File Name: cl_va_api_media_sharing_intel.h - -Abstract: - -Notes: - -\*****************************************************************************/ - - -#ifndef __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H -#define __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/****************************************** -* cl_intel_va_api_media_sharing extension * -*******************************************/ - -#define cl_intel_va_api_media_sharing 1 - -/* error codes */ -#define CL_INVALID_VA_API_MEDIA_ADAPTER_INTEL -1098 -#define CL_INVALID_VA_API_MEDIA_SURFACE_INTEL -1099 -#define CL_VA_API_MEDIA_SURFACE_ALREADY_ACQUIRED_INTEL -1100 -#define CL_VA_API_MEDIA_SURFACE_NOT_ACQUIRED_INTEL -1101 - -/* cl_va_api_device_source_intel */ -#define CL_VA_API_DISPLAY_INTEL 0x4094 - -/* cl_va_api_device_set_intel */ -#define CL_PREFERRED_DEVICES_FOR_VA_API_INTEL 0x4095 -#define CL_ALL_DEVICES_FOR_VA_API_INTEL 0x4096 - -/* cl_context_info */ -#define CL_CONTEXT_VA_API_DISPLAY_INTEL 0x4097 - -/* cl_mem_info */ -#define CL_MEM_VA_API_MEDIA_SURFACE_INTEL 0x4098 - -/* cl_image_info */ -#define CL_IMAGE_VA_API_PLANE_INTEL 0x4099 - -/* cl_command_type */ -#define CL_COMMAND_ACQUIRE_VA_API_MEDIA_SURFACES_INTEL 0x409A -#define CL_COMMAND_RELEASE_VA_API_MEDIA_SURFACES_INTEL 0x409B - -typedef cl_uint cl_va_api_device_source_intel; -typedef cl_uint cl_va_api_device_set_intel; - -extern CL_API_ENTRY cl_int CL_API_CALL -clGetDeviceIDsFromVA_APIMediaAdapterINTEL( - cl_platform_id /* platform */, - cl_va_api_device_source_intel /* media_adapter_type */, - void* /* media_adapter */, - cl_va_api_device_set_intel /* media_adapter_set */, - cl_uint /* num_entries */, - cl_device_id* /* devices */, - cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL * clGetDeviceIDsFromVA_APIMediaAdapterINTEL_fn)( - cl_platform_id /* platform */, - cl_va_api_device_source_intel /* media_adapter_type */, - void* /* media_adapter */, - cl_va_api_device_set_intel /* media_adapter_set */, - cl_uint /* num_entries */, - cl_device_id* /* devices */, - cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_mem CL_API_CALL -clCreateFromVA_APIMediaSurfaceINTEL( - cl_context /* context */, - cl_mem_flags /* flags */, - VASurfaceID* /* surface */, - cl_uint /* plane */, - cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_mem (CL_API_CALL * clCreateFromVA_APIMediaSurfaceINTEL_fn)( - cl_context /* context */, - cl_mem_flags /* flags */, - VASurfaceID* /* surface */, - cl_uint /* plane */, - cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueAcquireVA_APIMediaSurfacesINTEL( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireVA_APIMediaSurfacesINTEL_fn)( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -extern CL_API_ENTRY cl_int CL_API_CALL -clEnqueueReleaseVA_APIMediaSurfacesINTEL( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseVA_APIMediaSurfacesINTEL_fn)( - cl_command_queue /* command_queue */, - cl_uint /* num_objects */, - const cl_mem* /* mem_objects */, - cl_uint /* num_events_in_wait_list */, - const cl_event* /* event_wait_list */, - cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2; - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_version.h b/mobile/third_party/opencl/OpenCL-Headers/CL/cl_version.h deleted file mode 100644 index bb766cb9bbddca65a3cd599375a24cb827789d08..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/cl_version.h +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2018 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - ******************************************************************************/ - -#ifndef __CL_VERSION_H -#define __CL_VERSION_H - -/* Detect which version to target */ -#if !defined(CL_TARGET_OPENCL_VERSION) -#pragma message("cl_version.h: CL_TARGET_OPENCL_VERSION is not defined. Defaulting to 220 (OpenCL 2.2)") -#define CL_TARGET_OPENCL_VERSION 220 -#endif -#if CL_TARGET_OPENCL_VERSION != 100 && \ - CL_TARGET_OPENCL_VERSION != 110 && \ - CL_TARGET_OPENCL_VERSION != 120 && \ - CL_TARGET_OPENCL_VERSION != 200 && \ - CL_TARGET_OPENCL_VERSION != 210 && \ - CL_TARGET_OPENCL_VERSION != 220 -#pragma message("cl_version: CL_TARGET_OPENCL_VERSION is not a valid value (100, 110, 120, 200, 210, 220). Defaulting to 220 (OpenCL 2.2)") -#undef CL_TARGET_OPENCL_VERSION -#define CL_TARGET_OPENCL_VERSION 220 -#endif - - -/* OpenCL Version */ -#if CL_TARGET_OPENCL_VERSION >= 220 && !defined(CL_VERSION_2_2) -#define CL_VERSION_2_2 1 -#endif -#if CL_TARGET_OPENCL_VERSION >= 210 && !defined(CL_VERSION_2_1) -#define CL_VERSION_2_1 1 -#endif -#if CL_TARGET_OPENCL_VERSION >= 200 && !defined(CL_VERSION_2_0) -#define CL_VERSION_2_0 1 -#endif -#if CL_TARGET_OPENCL_VERSION >= 120 && !defined(CL_VERSION_1_2) -#define CL_VERSION_1_2 1 -#endif -#if CL_TARGET_OPENCL_VERSION >= 110 && !defined(CL_VERSION_1_1) -#define CL_VERSION_1_1 1 -#endif -#if CL_TARGET_OPENCL_VERSION >= 100 && !defined(CL_VERSION_1_0) -#define CL_VERSION_1_0 1 -#endif - -/* Allow deprecated APIs for older OpenCL versions. */ -#if CL_TARGET_OPENCL_VERSION <= 210 && !defined(CL_USE_DEPRECATED_OPENCL_2_1_APIS) -#define CL_USE_DEPRECATED_OPENCL_2_1_APIS -#endif -#if CL_TARGET_OPENCL_VERSION <= 200 && !defined(CL_USE_DEPRECATED_OPENCL_2_0_APIS) -#define CL_USE_DEPRECATED_OPENCL_2_0_APIS -#endif -#if CL_TARGET_OPENCL_VERSION <= 120 && !defined(CL_USE_DEPRECATED_OPENCL_1_2_APIS) -#define CL_USE_DEPRECATED_OPENCL_1_2_APIS -#endif -#if CL_TARGET_OPENCL_VERSION <= 110 && !defined(CL_USE_DEPRECATED_OPENCL_1_1_APIS) -#define CL_USE_DEPRECATED_OPENCL_1_1_APIS -#endif -#if CL_TARGET_OPENCL_VERSION <= 100 && !defined(CL_USE_DEPRECATED_OPENCL_1_0_APIS) -#define CL_USE_DEPRECATED_OPENCL_1_0_APIS -#endif - -#endif /* __CL_VERSION_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/CL/opencl.h b/mobile/third_party/opencl/OpenCL-Headers/CL/opencl.h deleted file mode 100644 index b5cd5a62a1e085b5bb99521bea965cd015ad5680..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/CL/opencl.h +++ /dev/null @@ -1,58 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2008-2015 The Khronos Group Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and/or associated documentation files (the - * "Materials"), to deal in the Materials without restriction, including - * without limitation the rights to use, copy, modify, merge, publish, - * distribute, sublicense, and/or sell copies of the Materials, and to - * permit persons to whom the Materials are furnished to do so, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Materials. - * - * MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS - * KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS - * SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - * https://www.khronos.org/registry/ - * - * THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. - ******************************************************************************/ - -/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */ - -#ifndef __OPENCL_H -#define __OPENCL_H - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __APPLE__ - -#include -#include -#include -#include - -#else - -#include -#include -#include -#include - -#endif - -#ifdef __cplusplus -} -#endif - -#endif /* __OPENCL_H */ diff --git a/mobile/third_party/opencl/OpenCL-Headers/LICENSE b/mobile/third_party/opencl/OpenCL-Headers/LICENSE deleted file mode 100644 index 020ce65fcac2a60e44dab1626fa4924dec17ea23..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/LICENSE +++ /dev/null @@ -1,25 +0,0 @@ -Copyright (c) 2008-2015 The Khronos Group Inc. - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and/or associated documentation files (the -"Materials"), to deal in the Materials without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Materials, and to -permit persons to whom the Materials are furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Materials. - -MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS -KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS -SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT - https://www.khronos.org/registry/ - -THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. diff --git a/mobile/third_party/opencl/OpenCL-Headers/README.md b/mobile/third_party/opencl/OpenCL-Headers/README.md deleted file mode 100644 index 757e56e152f8bc2fed68d2cdf38164c3171f929d..0000000000000000000000000000000000000000 --- a/mobile/third_party/opencl/OpenCL-Headers/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# OpenCLTM API Headers - -This repository contains C language headers for the OpenCL API. - -The authoritative public repository for these headers is located at: - -https://github.com/KhronosGroup/OpenCL-Headers - -Issues, proposed fixes for issues, and other suggested changes should be -created using Github. - -## Branch Structure - -The OpenCL API headers in this repository are Unified headers and are designed -to work with all released OpenCL versions. This differs from previous OpenCL -API headers, where version-specific API headers either existed in separate -branches, or in separate folders in a branch. - -## Compiling for a Specific OpenCL Version - -By default, the OpenCL API headers in this repository are for the latest -OpenCL version (currently OpenCL 2.2). To use these API headers to target -a different OpenCL version, an application may `#define` the preprocessor -value `CL_TARGET_OPENCL_VERSION` before including the OpenCL API headers. -The `CL_TARGET_OPENCL_VERSION` is a three digit decimal value representing -the OpenCL API version. - -For example, to enforce usage of no more than the OpenCL 1.2 APIs, you may -include the OpenCL API headers as follows: - -``` -#define CL_TARGET_OPENCL_VERSION 120 -#include -``` - -## Directory Structure - -``` -README.md This file -LICENSE Source license for the OpenCL API headers -CL/ Unified OpenCL API headers tree -``` - -## License - -See [LICENSE](LICENSE). - ---- - -OpenCL and the OpenCL logo are trademarks of Apple Inc. used by permission by Khronos. diff --git a/mobile/tools/build.sh b/mobile/tools/build.sh index 877791ff7bdb4fc64f2d62210ff974c0cd6bced0..741e6a590e685a0f723f364336ac1dc6061fe0ba 100755 --- a/mobile/tools/build.sh +++ b/mobile/tools/build.sh @@ -12,6 +12,23 @@ fi python gen_code.py "${merge_cl_to_so}" > "${opencl_kernels}" cd - +# get cl headers +opencl_header_dir="../third_party/opencl/OpenCL-Headers" +commit_id="320d7189b3e0e7b6a8fc5c10334c79ef364b5ef6" +if [[ -d "$opencl_header_dir" && -d "$opencl_header_dir/.git" ]]; then + echo "pulling opencl headers" + cd $opencl_header_dir + git stash + git pull + git checkout $commit_id + cd - +else + echo "cloning opencl headers" + rm -rf $opencl_header_dir + git clone https://github.com/KhronosGroup/OpenCL-Headers $opencl_header_dir + git checkout $commit_id +fi + build_for_mac() { if [ ! `which brew` ]; then echo "building failed! homebrew not found, please install homebrew." @@ -61,7 +78,7 @@ build_for_android() { elif [ "${PLATFORM}" = "arm-v8a" ]; then ABI="arm64-v8a" ARM_PLATFORM="V8" - CXX_FLAGS="-march=armv8-a -pie -fPIE -w -Wno-error=format-security -llog" + CXX_FLAGS="-march=armv8-a -pie -fPIE -w -Wno-error=format-security -llog -fuse-ld=gold" else echo "unknown platform!" exit -1 diff --git a/mobile/tools/op.cmake b/mobile/tools/op.cmake index c973b1b20f7448af5739a4522d2190f91ade11f4..923380940aa10147d65e374265c1073ec37cb11e 100755 --- a/mobile/tools/op.cmake +++ b/mobile/tools/op.cmake @@ -377,6 +377,8 @@ if(NOT FOUND_MATCH) set(FILL_CONSTANT_BATCH_SIZE_LIKE_OP ON) set(RANGE_OP ON) set(REDUCE_PROD_OP ON) + set(FUSION_INSTANCENORM_RELU_OP ON) + set(PIXEL_SHUFFLE_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -413,6 +415,9 @@ endif() if (INSTANCENORM_OP) add_definitions(-DINSTANCENORM_OP) endif() +if (FUSION_INSTANCENORM_RELU_OP) + add_definitions(-DFUSION_INSTANCENORM_RELU_OP) +endif() if (BOXCODER_OP) add_definitions(-DBOXCODER_OP) endif() @@ -747,3 +752,6 @@ endif() if (REDUCE_PROD_OP) add_definitions(-DREDUCE_PROD_OP) endif() +if (PIXEL_SHUFFLE_OP) + add_definitions(-DPIXEL_SHUFFLE_OP) +endif() diff --git a/mobile/tools/pre-commit.hooks/cpplint.hook b/mobile/tools/pre-commit.hooks/cpplint.hook index 78ca3cfcdda52a223be609801e6b12ec58b79323..3740e64c7331e63954fc85f8958b7613e48cce57 100644 --- a/mobile/tools/pre-commit.hooks/cpplint.hook +++ b/mobile/tools/pre-commit.hooks/cpplint.hook @@ -5,7 +5,7 @@ TOTAL_ERRORS=0 # The trick to remove deleted files: https://stackoverflow.com/a/2413151 for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \ grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \ - grep -v "protobuf-c.h" | grep -v "protobuf-c.c"); do + grep -v "protobuf-c.h" | grep -v "protobuf-c.c" | grep -v "^mobile/tools/quantification"); do cpplint $file; TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); done diff --git a/mobile/tools/python/fluidtools/run.py b/mobile/tools/python/fluidtools/run.py index a77943e2af40361876e950c316eda67cb6457191..6f82e426bd1ab1e376783c0d1015e625d7d47068 100644 --- a/mobile/tools/python/fluidtools/run.py +++ b/mobile/tools/python/fluidtools/run.py @@ -22,8 +22,11 @@ checked_encrypt_model_path = "checked_encrypt_model" output_var_filter = [] output_key_filter = {} check_shape = False +quantification = False +quantification_fold = 1000 architecture = "arm-v7a" # architecture = "arm-v8a" +correct_persistable = False np.set_printoptions(linewidth=150) @@ -67,6 +70,18 @@ exe.run(fluid.default_startup_program()) # 加载模型 def load_model(model_path): prog, feeds, fetches = fluid.io.load_inference_model(dirname=model_path, executor=exe, model_filename="model", params_filename="params") + global correct_persistable + if correct_persistable: + ops = prog.current_block().ops + vars = prog.current_block().vars + for op in ops: + for var_name in op.output_arg_names: + if var_name == "fetch": + continue + var = vars[var_name] + if var.persistable: + pp_red("has found non-persistable output var : {}".format(var_name)) + var.persistable = False return (prog, feeds, fetches) prog, feeds, fetches = load_model(model_path) @@ -107,7 +122,8 @@ def resave_model(feed_kv): for name in p_names: v = fluid.framework._get_var(name, prog) v.persistable = False - fluid.io.save_inference_model(dirname=checked_model_path, feeded_var_names=feeds, target_vars=fetches, executor=exe, main_program=prog, model_filename="model", params_filename="params") + if not quantification: + fluid.io.save_inference_model(dirname=checked_model_path, feeded_var_names=feeds, target_vars=fetches, executor=exe, main_program=prog, model_filename="model", params_filename="params") if has_found_wrong_shape: pp_red("has found wrong shape", 1) else: @@ -392,7 +408,7 @@ for op in ops: pp_tab("op types : {}".format(op_types), 1) def check_mobile_results(args, fuse, mem_opt): - args = "{} {} {}".format("1" if fuse else "0", "1" if mem_opt else "0", args) + args = "{} {} {} {} {}".format("1" if fuse else "0", "1" if mem_opt else "0", "1" if quantification else "0", quantification_fold, args) res = sh("adb shell \"cd {} && export LD_LIBRARY_PATH=. && ./test-net {}\"".format(mobile_exec_root, args)) lines = res.split("\n") # for line in lines: @@ -425,6 +441,26 @@ def check_mobile_results(args, fuse, mem_opt): fetch_names = [] for fetch in fetches: fetch_names.append(fetch.name) + fetch_diff = 0.0 + fetch_count = 0 + for index in op_cache: + op_output_var_name, op = op_cache[index] + if not op_output_var_name in output_var_cache: + continue + if not op_output_var_name in mobile_var_cache: + continue + if op_output_var_name not in fetch_names: + continue + values1 = output_var_cache[op_output_var_name] + values2 = mobile_var_cache[op_output_var_name] + shape = get_var_shape(op_output_var_name) if check_shape else [] + for i in range(len(values1)): + v1 = values1[i] + v2 = values2[len(shape) + i] + fetch_diff += abs(v1 - v2) + fetch_count += 1 + if fetch_count != 0: + pp_yellow("output avg diff : {}".format(fetch_diff / fetch_count), 1) for index in op_cache: op_output_var_name, op = op_cache[index] if mem_opt: @@ -523,7 +559,7 @@ def check_mobile_results(args, fuse, mem_opt): for i in range(len(values1)): v1 = values1[i] v2 = values2[len(shape) + i] - if abs(v1 - v2) > diff_threshold: + if ((not math.isnan(v1)) and math.isnan(v2)) or abs(v1 - v2) > diff_threshold: error_index = index break checked_names.append(op_output_var_name) diff --git a/mobile/tools/quantification/convert.cpp b/mobile/tools/quantification/convert.cpp index 3473f9a1181f141f7bfa240e678826704b503eb4..0d675de205296c8942e59575dffb7f7002bc7d7f 100644 --- a/mobile/tools/quantification/convert.cpp +++ b/mobile/tools/quantification/convert.cpp @@ -17,6 +17,22 @@ const size_t kSize64 = sizeof(uint64_t); const size_t kSize32 = sizeof(uint32_t); +const int minimal_fold_size = 2; +float max_entropy = 0.0; + +float entropy(std::vector &factors) { + int n = factors.size(); + std::vector counts(256); + for (uint8_t &factor : factors) { + counts[factor]++; + } + float res = 1.0; + float shift = 100000.0; + for (int i = 0; i < 256; i++) { + res *= (counts[i] + shift) / (n + shift); + } + return 1.0 / res; +} char *Get_binary_data(const std::string &filename) { @@ -68,7 +84,7 @@ std::shared_ptr loadParams(const std::string &model_path) { } -void LoadWithDumpForInt8(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) { +void LoadWithDumpForInt8(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file, int quantification_fold) { // 1. version uint32_t version = *reinterpret_cast(*dataP); @@ -162,27 +178,37 @@ void LoadWithDumpForInt8(const paddle_mobile::framework::VarDesc &var_desc, char } *dataP += tensorSize; - // for float 32 - float min_value = std::numeric_limits::max(); - float max_value = std::numeric_limits::min(); + quantification_fold = std::min(std::max(1, memory_size / minimal_fold_size), quantification_fold); + int step = std::max(memory_size / quantification_fold, 1); - for (int k = 0; k < memory_size; ++k) { - min_value = std::min(min_value, static_cast (memory)[k]); - max_value = std::max(max_value, static_cast (memory)[k]); - } + int visited_fold = 0; + while (visited_fold * step < memory_size) { + // for float 32 + float min_value = std::numeric_limits::max(); + float max_value = std::numeric_limits::min(); - fwrite(&min_value, sizeof(float), 1, out_file); - fwrite(&max_value, sizeof(float), 1, out_file); + for (int k = visited_fold * step; k < std::min((visited_fold + 1) * step, memory_size); ++k) { + min_value = std::min(min_value, static_cast (memory)[k]); + max_value = std::max(max_value, static_cast (memory)[k]); + } + + fwrite(&min_value, sizeof(float), 1, out_file); + fwrite(&max_value, sizeof(float), 1, out_file); - for (int g = 0; g < memory_size; ++g) { - float value = static_cast (memory)[g]; - auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255); - fwrite(&factor, sizeof(uint8_t), 1, out_file); + std::vector factors; + for (int g = visited_fold * step; g < std::min((visited_fold + 1) * step, memory_size); ++g) { + float value = static_cast (memory)[g]; + auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255); + factors.push_back(factor); + fwrite(&factor, sizeof(uint8_t), 1, out_file); + } + max_entropy = fmax(max_entropy, entropy(factors)); + visited_fold++; } } void -quantificate_combined_int8(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path) { +quantificate_combined_int8(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path, int quantification_fold) { auto program = loadParams(model_path); char *origin_data = Get_binary_data(param_path); char *data = origin_data; @@ -193,7 +219,7 @@ quantificate_combined_int8(const std::string &model_path, const std::string &par if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { continue; } - LoadWithDumpForInt8(*var_desc, &data, out_file); + LoadWithDumpForInt8(*var_desc, &data, out_file, quantification_fold); } } } @@ -201,7 +227,7 @@ quantificate_combined_int8(const std::string &model_path, const std::string &par delete origin_data; } -void quantificate_seperated_int8(const std::string model_dir, const std::string param_min_path) { +void quantificate_seperated_int8(const std::string model_dir, const std::string param_min_path, int quantification_fold) { auto program = loadParams(model_dir + "/__model__"); std::string shell_command = "mkdir " + param_min_path; @@ -217,7 +243,7 @@ void quantificate_seperated_int8(const std::string model_dir, const std::string FILE *out_file = fopen(file_name.c_str(), "wb"); char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name()); char *data = origin_data; - LoadWithDumpForInt8(*var_desc, &data, out_file); + LoadWithDumpForInt8(*var_desc, &data, out_file, quantification_fold); delete origin_data; fclose(out_file); } @@ -225,7 +251,7 @@ void quantificate_seperated_int8(const std::string model_dir, const std::string } } -void LoadWithDumpForFloat32(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) { +void LoadWithDumpForFloat32(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file, int quantification_fold) { // 1. version uint32_t version = *reinterpret_cast(*dataP); @@ -319,30 +345,40 @@ void LoadWithDumpForFloat32(const paddle_mobile::framework::VarDesc &var_desc, c } *dataP += tensorSize; - // for float 32 - float min_value = std::numeric_limits::max(); - float max_value = std::numeric_limits::min(); + quantification_fold = std::min(std::max(1, memory_size / minimal_fold_size), quantification_fold); + int step = std::max(memory_size / quantification_fold, 1); - for (int k = 0; k < memory_size; ++k) { - min_value = std::min(min_value, static_cast (memory)[k]); - max_value = std::max(max_value, static_cast (memory)[k]); - } + int visited_fold = 0; + while (visited_fold * step < memory_size) { + // for float 32 + float min_value = std::numeric_limits::max(); + float max_value = std::numeric_limits::min(); - float diff = 0.0; - for (int g = 0; g < memory_size; ++g) { - float value = static_cast (memory)[g]; - auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255); - float value_quantized = min_value + (factor / 255.0) * (max_value - min_value); - diff += fabs(value - value_quantized); - fwrite(&value_quantized, sizeof(float), 1, out_file); - } - if (memory_size > 0) { - std::cout << "avg diff caused by quantization for var " << var_desc.Name() << " is: " << (diff / memory_size) << std::endl; + for (int k = visited_fold * step; k < std::min((visited_fold + 1) * step, memory_size); ++k) { + min_value = std::min(min_value, static_cast (memory)[k]); + max_value = std::max(max_value, static_cast (memory)[k]); + } + + float diff = 0.0; + std::vector factors; + for (int g = visited_fold * step; g < std::min((visited_fold + 1) * step, memory_size); ++g) { + float value = static_cast (memory)[g]; + auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255); + factors.push_back(factor); + float value_quantized = min_value + (factor / 255.0) * (max_value - min_value); + diff += fabs(value - value_quantized); + fwrite(&value_quantized, sizeof(float), 1, out_file); + } + max_entropy = fmax(max_entropy, entropy(factors)); + if (memory_size > 0) { + std::cout << "avg diff caused by quantization for var " << var_desc.Name() << " is: " << (diff / memory_size) << std::endl; + } + visited_fold++; } } void -quantificate_combined_float32(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path) { +quantificate_combined_float32(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path, int quantification_fold) { auto program = loadParams(model_path); char *origin_data = Get_binary_data(param_path); char *data = origin_data; @@ -353,7 +389,7 @@ quantificate_combined_float32(const std::string &model_path, const std::string & if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { continue; } - LoadWithDumpForFloat32(*var_desc, &data, out_file); + LoadWithDumpForFloat32(*var_desc, &data, out_file, quantification_fold); } } } @@ -361,7 +397,7 @@ quantificate_combined_float32(const std::string &model_path, const std::string & delete origin_data; } -void quantificate_seperated_float32(const std::string model_dir, const std::string param_min_path) { +void quantificate_seperated_float32(const std::string model_dir, const std::string param_min_path, int quantification_fold) { auto program = loadParams(model_dir + "/__model__"); std::string shell_command = "mkdir " + param_min_path; @@ -377,7 +413,7 @@ void quantificate_seperated_float32(const std::string model_dir, const std::stri FILE *out_file = fopen(file_name.c_str(), "wb"); char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name()); char *data = origin_data; - LoadWithDumpForFloat32(*var_desc, &data, out_file); + LoadWithDumpForFloat32(*var_desc, &data, out_file, quantification_fold); delete origin_data; fclose(out_file); } @@ -402,10 +438,15 @@ int main(int argc, char **argv) { PADDLE_MOBILE_ENFORCE(argc > 3, "we need your output path. %s ", kNoteEg.c_str()); std::string output_path = argv[3]; + int quantification_fold = 1; + if (argc > 4) { + quantification_fold = std::stoi(argv[4]); + } + if (action_type == "0") { // for seperated const std::string &seperated_min_dir = output_path; - quantificate_seperated_int8(base_path, seperated_min_dir); + quantificate_seperated_int8(base_path, seperated_min_dir, quantification_fold); return 0; } @@ -414,14 +455,15 @@ int main(int argc, char **argv) { const std::string &combined_min_dir = output_path; std::string model_path = base_path + "/model"; std::string param_path = base_path + "/params"; - quantificate_combined_int8(model_path, param_path, combined_min_dir); + quantificate_combined_int8(model_path, param_path, combined_min_dir, quantification_fold); + std::cout << "max entropy : " << max_entropy << std::endl; return 0; } if (action_type == "2") { // for seperated const std::string &seperated_min_dir = output_path; - quantificate_seperated_float32(base_path, seperated_min_dir); + quantificate_seperated_float32(base_path, seperated_min_dir, quantification_fold); return 0; } @@ -430,7 +472,7 @@ int main(int argc, char **argv) { const std::string &combined_min_dir = output_path; std::string model_path = base_path + "/model"; std::string param_path = base_path + "/params"; - quantificate_combined_float32(model_path, param_path, combined_min_dir); + quantificate_combined_float32(model_path, param_path, combined_min_dir, quantification_fold); return 0; } diff --git a/mobile/tools/quantification/scripts/run.py b/mobile/tools/quantification/scripts/run.py new file mode 100644 index 0000000000000000000000000000000000000000..bf3444147092f0963e918369e933057dd5f28b38 --- /dev/null +++ b/mobile/tools/quantification/scripts/run.py @@ -0,0 +1,661 @@ +# -*- coding: utf-8 -* +import os +import sys +import math +import subprocess +import numpy as np +import paddle.fluid as fluid + +model_path = "model" +checked_model_path = "quantification_model" +feed_path = "feeds" +output_path = "outputs" +diff_threshold = 0.1 +is_lod = False +mobile_model_path = "" +fast_check = False +is_sample_step = False +sample_step = 1 +sample_num = 20 +need_encrypt = False +checked_encrypt_model_path = "checked_encrypt_model" +output_var_filter = [] +output_key_filter = {} +check_shape = False +quantification = True +quantification_fold = int(sys.argv[1]) +architecture = "arm-v7a" +# architecture = "arm-v8a" + +np.set_printoptions(linewidth=150) + +mobile_exec_root = "/data/local/tmp/bin" +mobile_src_root = os.path.abspath("../../../") +if mobile_src_root.endswith("/"): + mobile_src_root = mobile_src_root[:-1] + +dot = "•" +black = lambda x: "\033[30m" + str(x) + "\033[0m" +red = lambda x: "\033[31m" + str(x) + "\033[0m" +green = lambda x: "\033[32m" + str(x) + "\033[0m" +yellow = lambda x: "\033[33m" + str(x) + "\033[0m" +reset = lambda x: "\033[0m" + str(x) + +def pp_tab(x, level=0): + header = "" + for i in range(0, level): + header += "\t" + # print(header + str(x)) +def pp_black(x, level=0): + pp_tab(black(x) + reset(""), level) +def pp_red(x, level=0): + pp_tab(red(x) + reset(""), level) +def pp_green(x, level=0): + pp_tab(green(x) + reset(""), level) +def pp_yellow(x, level=0): + pp_tab(yellow(x) + reset(""), level) + +def sh(command): + pipe = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + return pipe.stdout.read().decode("utf-8") +def push(src, dest=""): + sh("adb push {} {}".format(src, mobile_exec_root + "/" + dest)) + +pp_yellow(dot + " start inspecting fluid model") + +exe = fluid.Executor(fluid.CPUPlace()) +exe.run(fluid.default_startup_program()) + +# 加载模型 +def load_model(model_path): + prog, feeds, fetches = fluid.io.load_inference_model(dirname=model_path, executor=exe, model_filename="model", params_filename="params") + return (prog, feeds, fetches) + +prog, feeds, fetches = load_model(model_path) + +# 强制要求所有张量的形状,在model和params中一致,并重新保存模型 +def resave_model(feed_kv): + if len(mobile_model_path) > 0: + pp_green("has set mobile_model_path, stop checking model & params", 1) + sh("cp {}/* {}".format(mobile_model_path, checked_model_path)) + return + ops = prog.current_block().ops + vars = prog.current_block().vars + # 强制所有var为可持久化 + p_names = [] + for name in vars: + name = str(name) + v = fluid.framework._get_var(name, prog) + if not v.persistable: + v.persistable = True + p_names.append(name) + outputs = run_model(feed_kv=feed_kv) + has_found_wrong_shape = False + # 修正每个var的形状 + for name in vars: + name = str(name) + v = vars[name] + if v.persistable: + v1 = fluid.global_scope().find_var(name) + try: + t1 = v1.get_tensor() + shape = t1.shape() + except: + continue + if v.desc.shape() != shape: + has_found_wrong_shape = True + v.desc.set_shape(shape) + # 恢复var的可持久化属性 + for name in p_names: + v = fluid.framework._get_var(name, prog) + v.persistable = False + if not quantification: + fluid.io.save_inference_model(dirname=checked_model_path, feeded_var_names=feeds, target_vars=fetches, executor=exe, main_program=prog, model_filename="model", params_filename="params") + if has_found_wrong_shape: + pp_red("has found wrong shape", 1) + else: + pp_green("has not found wrong shape", 1) + pp_green("new model is saved into directory 【{}】".format(checked_model_path), 1) + +# 分别加密model和params,加密key使用同一个 +def encrypt_model(): + if not need_encrypt: + return + pp_yellow(dot + dot + " encrypting model") + if not os.path.exists(checked_encrypt_model_path): + os.mkdir(checked_encrypt_model_path) + res = sh("model-encrypt-tool/enc_key_gen -l 20 -c 232") + lines = res.split("\n") + + for line in lines: + if line.startswith("key:"): + line = line.replace('key:','') + sh("model-encrypt-tool/enc_model_gen -k '{}' -c 2 -i {}/model -o {}/model.ml".format(line, checked_model_path, checked_model_path)) + sh("model-encrypt-tool/enc_model_gen -k '{}' -c 2 -i {}/params -o {}/params.ml".format(line, checked_model_path, checked_model_path)) + pp_green("model has been encrypted, key is : {}".format(line), 1) + sh("mv {} {}".format(checked_model_path + "/*.ml", checked_encrypt_model_path)) + return + pp_red("model encrypt error", 1) + +# 生成feed的key-value对 +def gen_feed_kv(): + feed_kv = {} + for feed_name in feeds: + feed_shape = get_feed_var_shape(feed_name) + data = np.random.random(feed_shape).astype("float32") + feed_kv[feed_name] = data + return feed_kv + +# 保存feed的key-value对 +def save_feed_kv(feed_kv): + for feed_name in feed_kv: + feed_data = feed_kv[feed_name] + feed_list = feed_data.flatten().tolist() + if not os.path.exists(feed_path): + os.mkdir(feed_path) + file_name = feed_name.replace("/", "_") + out_file = open(feed_path + "/" + file_name, "w") + for feed_item in feed_list: + out_file.write("{}\n".format(feed_item)) + out_file.close() + +last_feed_var_name = None +last_feed_file_name = None +last_feed_var_lod = None +# 加载feed的key-value对 +def load_feed_kv(): + if not os.path.exists(feed_path): + return None + global last_feed_var_name + global last_feed_file_name + global last_feed_var_lod + feed_kv = {} + pp_yellow(dot + dot + " checking feed info") + pp_green("feed data is saved into directory 【{}】".format(feed_path), 1) + for feed_name in feeds: + feed_shape = get_feed_var_shape(feed_name) + pp_tab("feed var name : {}; feed var shape : {}".format(feed_name, feed_shape), 1) + file_name = feed_name.replace("/", "_") + last_feed_var_name = feed_name + last_feed_file_name = file_name + feed_file_path = feed_path + "/" + file_name + if not os.path.exists(feed_file_path): + return None + data = np.loadtxt(feed_file_path) + expected_len = 1 + for dim in feed_shape: + expected_len *= dim + if len(np.atleast_1d(data)) != expected_len: + return None + data = data.reshape(feed_shape).astype("float32") + + if is_lod: + data_shape = [1] + for dim in feed_shape: + data_shape.append(dim) + data = data.reshape(data_shape).astype("float32") + tensor = fluid.LoDTensor() + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + data = data.reshape(feed_shape) + tensor.set(data, fluid.CPUPlace()) + tensor.set_lod([lod]) + last_feed_var_lod = lod + feed_kv[feed_name] = tensor + else: + feed_kv[feed_name] = data + return feed_kv + +# 运行模型 +def run_model(feed_kv=None): + if feed_kv is None: + feed_kv = gen_feed_kv() + outputs = exe.run(prog, feed=feed_kv, fetch_list=fetches, return_numpy=False) + results = [] + for output in outputs: + results.append(np.array(output)) + return results + +# 获取变量形状 +def get_var_shape(var_name): + vars = prog.current_block().vars + shape = vars[var_name].desc.shape() + for i in range(len(shape)): + dim = shape[i] + if dim == -1: + shape[i] = 1 + return shape + +# 获取输入变量形状 +def get_feed_var_shape(var_name): + # 如果想写死输入形状,放开以下语句 + # return [1, 3, 224, 224] + return get_var_shape(var_name) + +persistable_cache = [] +# 所有var,全部变成持久化 +def force_all_vars_to_persistable(): + global persistable_cache + for var_name in vars.keys(): + var_name = str(var_name) + v = fluid.framework._get_var(var_name, prog) + persistable = v.persistable + if not persistable: + persistable_cache.append(var_name) + v.persistable = True + +# 恢复持久化属性 +def restore_all_vars_persistable(): + global persistable_cache + for var_name in vars.keys(): + var_name = str(var_name) + v = fluid.framework._get_var(var_name, prog) + persistable = v.persistable + if var_name in persistable_cache: + v.persistable = False + persistable_cache = [] + +# 获取var的数据 +def get_var_data(var_name, feed_kv=None): + output = np.array(fluid.global_scope().var(var_name).get_tensor()) + return output + +output_var_cache = {} +def tensor_sample(tensor): + if is_sample_step: + step = sample_step + else: + step = math.floor(len(tensor) / sample_num) + step = max(step, 1) + step = int(step) + sample = [] + for i in range(0, len(tensor), step): + sample.append(tensor[i]) + return sample + +op_cache = {} +# 获取每层输出的数据 +def save_all_op_output(feed_kv=None): + force_all_vars_to_persistable() + outputs = run_model(feed_kv=feed_kv) + if not os.path.exists(output_path): + os.mkdir(output_path) + ops = prog.current_block().ops + fetch_names = [] + for fetch in fetches: + fetch_names.append(fetch.name) + feed_names = feeds + if len(output_var_filter) > 0: + for fetch_name in fetch_names: + output_var_filter.append(fetch_name) + for i in range(len(ops)): + op = ops[i] + var_name = None + var_name_index = -1 + for index in range(len(op.output_names)): + if op.output_names[index] in ["Y", "Out", "Output"]: + var_name_index = index + break + if var_name_index != -1: + var_name = op.output_arg_names[var_name_index] + else: + for name in op.output_arg_names: + var_name = name + if "tmp" in name: + break + if len(output_var_filter) > 0: + if var_name not in output_var_filter: + continue + # real_var_name = None + # if op.type == "fetch": + # for name in op.input_arg_names: + # real_var_name = name + # if "tmp" in name: + # break + # else: + # real_var_name = var_name + if fast_check: + if var_name not in fetch_names and var_name not in feed_names: + continue + try: + data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() + sample = tensor_sample(data) + output_var_cache[var_name] = (sample) + op_cache[i] = (var_name, op) + file_name = var_name.replace("/", "_") + out_file = open(output_path + "/" + file_name, "w") + if var_name in feed_names: + for item in data: + out_file.write("{}\n".format(item)) + else: + for item in sample: + out_file.write("{}\n".format(item)) + out_file.close() + except: + pass + for i in range(len(ops)): + op = ops[i] + if op.type not in output_key_filter: + continue + var_name = None + var_name_index = -1 + for index in range(len(op.output_names)): + if op.output_names[index] in output_key_filter[op.type]: + var_name_index = index + break + if var_name_index != -1: + var_name = op.output_arg_names[var_name_index] + else: + continue + if len(output_var_filter) > 0: + if var_name not in output_var_filter: + continue + # real_var_name = None + # if op.type == "fetch": + # for name in op.input_arg_names: + # real_var_name = name + # if "tmp" in name: + # break + # else: + # real_var_name = var_name + if fast_check: + if var_name not in fetch_names and var_name not in feed_names: + continue + try: + data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() + sample = tensor_sample(data) + output_var_cache[var_name] = (sample) + op_cache[i] = (var_name, op) + file_name = var_name.replace("/", "_") + out_file = open(output_path + "/" + file_name, "w") + if var_name in feed_names: + for item in data: + out_file.write("{}\n".format(item)) + else: + for item in sample: + out_file.write("{}\n".format(item)) + out_file.close() + except: + pass + pp_green("all the op outputs are saved into directory 【{}】".format(output_path), 1) + restore_all_vars_persistable() + +ops = prog.current_block().ops +vars = prog.current_block().vars + +pp_yellow(dot + dot + " checking op list") +op_types = set() +for op in ops: + op_types.add(op.type) +pp_tab("op types : {}".format(op_types), 1) + +def check_mobile_results(args, fuse, mem_opt): + args = "{} {} {} {} {}".format("1" if fuse else "0", "1" if mem_opt else "0", "1" if quantification else "0", quantification_fold, args) + res = sh("adb shell \"cd {} && export LD_LIBRARY_PATH=. && ./test-net {}\"".format(mobile_exec_root, args)) + lines = res.split("\n") + # for line in lines: + # print(line) + for line in lines: + if line.startswith("auto-test-debug"): + print(line) + pp_yellow(dot + dot + " checking paddle mobile results for {} -- {} ".format(green("【fusion】" if fuse else "【non fusion】"), green("【memory-optimization】" if mem_opt else "【non-memory-optimization】"))) + mobile_var_cache = {} + for line in lines: + parts = line.split(" ") + if len(parts) < 2: + continue + if "auto-test" != parts[0]: + continue + if parts[1] == "load-time-cost": + pp_green("load time cost : {}".format(parts[2]), 1) + elif parts[1] == "predict-time-cost": + pp_green("predict time cost : {}".format(parts[2]), 1) + elif parts[1] == "preprocess-time-cost": + pp_green("preprocess time cost : {}".format(parts[2]), 1) + elif parts[1] == "var": + var_name = parts[2] + values = list(map(lambda x: float(x), parts[3:])) + mobile_var_cache[var_name] = values + error_index = None + error_values1 = None + error_values2 = None + checked_names = [] + fetch_names = [] + for fetch in fetches: + fetch_names.append(fetch.name) + fetch_diff = 0.0 + fetch_count = 0 + for index in op_cache: + op_output_var_name, op = op_cache[index] + if not op_output_var_name in output_var_cache: + continue + if not op_output_var_name in mobile_var_cache: + continue + if op_output_var_name not in fetch_names: + continue + values1 = output_var_cache[op_output_var_name] + values2 = mobile_var_cache[op_output_var_name] + shape = get_var_shape(op_output_var_name) if check_shape else [] + for i in range(len(values1)): + v1 = values1[i] + v2 = values2[len(shape) + i] + fetch_diff += abs(v1 - v2) + fetch_count += 1 + if fetch_count != 0: + pp_yellow("output avg diff : {}".format(fetch_diff / fetch_count), 1) + print(fetch_diff / fetch_count) + for index in op_cache: + op_output_var_name, op = op_cache[index] + if mem_opt: + found_in_fetch = False + for fetch in fetches: + if op_output_var_name == fetch.name: + found_in_fetch = True + break + if not found_in_fetch: + continue + if not op_output_var_name in output_var_cache: + continue + if not op_output_var_name in mobile_var_cache: + continue + if op_output_var_name not in fetch_names: + continue + values1 = output_var_cache[op_output_var_name] + values2 = mobile_var_cache[op_output_var_name] + shape = get_var_shape(op_output_var_name) if check_shape else [] + if len(values1) + len(shape) != len(values2): + error_index = index + for i in range(len(shape)): + v1 = shape[i] + v2 = values2[i] + if v1 != v2: + error_index = index + break + if error_index == None: + for i in range(len(values1)): + v1 = values1[i] + v2 = values2[len(shape) + i] + if abs(v1 - v2) > diff_threshold: + error_index = index + break + checked_names.append(op_output_var_name) + if error_index != None: + error_values1 = values1 + error_values2 = values2 + break + if error_index == None: + for name in fetch_names: + if name not in checked_names: + error_index = -1 + break + if error_index == None: + pp_green("outputs are all correct", 1) + elif error_index == -1: + pp_red("outputs are missing") + else: + error_values1 = np.array(error_values1) + error_values2 = np.array(error_values2) + # pp_red("mobile op is not correct, error occurs at {}th op, op's type is {}") + pp_red("outputs are incorrect", 1) + pp_red("fluid results are : ", 1) + pp_red(str(error_values1).replace("\n", "\n" + "\t" * 1), 1) + pp_yellow("paddle mobile results are : ", 1) + pp_red(str(error_values2).replace("\n", "\n" + "\t" * 1), 1) + if not fuse and not mem_opt: + pp_yellow("checking individual ops : ", 1) + error_index = None + error_values1 = None + error_values2 = None + checked_names = [] + fetch_names = [] + for fetch in fetches: + fetch_names.append(fetch.name) + for index in op_cache: + op_output_var_name, op = op_cache[index] + if mem_opt: + found_in_fetch = False + for fetch in fetches: + if op_output_var_name == fetch.name: + found_in_fetch = True + break + if not found_in_fetch: + continue + if not op_output_var_name in output_var_cache: + continue + if not op_output_var_name in mobile_var_cache: + continue + if fuse or mem_opt: + if op_output_var_name not in fetch_names: + continue + values1 = output_var_cache[op_output_var_name] + values2 = mobile_var_cache[op_output_var_name] + shape = get_var_shape(op_output_var_name) if check_shape else [] + if len(values1) + len(shape) != len(values2): + error_index = index + for i in range(len(shape)): + v1 = shape[i] + v2 = values2[i] + if v1 != v2: + error_index = index + break + if error_index == None: + for i in range(len(values1)): + v1 = values1[i] + v2 = values2[len(shape) + i] + if abs(v1 - v2) > diff_threshold: + error_index = index + break + checked_names.append(op_output_var_name) + if error_index != None: + error_values1 = values1 + error_values2 = values2 + break + if error_index == None: + for name in fetch_names: + if name not in checked_names: + error_index = -1 + break + if error_index == None: + pp_green("outputs are all correct", 1) + elif error_index == -1: + pp_red("outputs are missing") + else: + error_values1 = np.array(error_values1) + error_values2 = np.array(error_values2) + # pp_red("mobile op is not correct, error occurs at {}th op, op's type is {}") + pp_red("corresponding fluid op is {}th op, op's type is {}, wrong var name is {}".format( + error_index,op_cache[error_index][1].type,op_output_var_name), 1) + pp_red("fluid results are : ", 1) + pp_red(str(error_values1).replace("\n", "\n" + "\t" * 1), 1) + pp_yellow("paddle mobile results are : ", 1) + pp_red(str(error_values2).replace("\n", "\n" + "\t" * 1), 1) + # print(output_var_cache) + # print(mobile_var_cache) + +def main(): + # 加载kv + feed_kv = load_feed_kv() + if feed_kv == None: + feed_kv = gen_feed_kv() + save_feed_kv(feed_kv) + feed_kv = load_feed_kv() + # 预测 + pp_yellow(dot + dot + " checking inference") + outputs = run_model(feed_kv=feed_kv) + pp_tab("fluid output : {}".format(outputs), 1) + # 重新保存模型 + pp_yellow(dot + dot + " checking model correctness") + resave_model(feed_kv=feed_kv) + # 输出加密模型 + encrypt_model() + # 输出所有中间结果 + pp_yellow(dot + dot + " checking output result of every op") + save_all_op_output(feed_kv=feed_kv) + pp_yellow(dot + dot + " checking fetch info") + for fetch in fetches: + fetch_name = fetch.name + fetch_shape = get_var_shape(fetch_name) + pp_tab("fetch var name : {}; fetch var shape : {}".format(fetch_name, fetch_shape), 1) + # 输出所有op、var信息 + info_file = open("info.txt", "w") + for i in range(len(ops)): + op = ops[i] + info_file.write("{}th op: type - {}\n".format(i, op.type)) + info_file.write("inputs:\n") + for var_name in op.input_arg_names: + try: + shape = get_var_shape(var_name) + shape_str = ", ".join(list(map(lambda x: str(x), shape))) + info_file.write("var {} : {}\n".format(var_name, shape_str)) + except: + pass + info_file.write("outputs:\n") + for var_name in op.output_arg_names: + try: + shape = get_var_shape(var_name) + shape_str = ", ".join(list(map(lambda x: str(x), shape))) + info_file.write("var {} : {}\n".format(var_name, shape_str)) + except: + pass + info_file.close() + # 开始检查mobile的正确性 + pp_yellow(dot + " start inspecting paddle mobile correctness & performance") + sh("rm -rf checked_model") + sh("cp -r {} checked_model".format(checked_model_path)) + push("checked_model") + push(feed_path + "/" + last_feed_file_name, "input.txt") + push(mobile_src_root + "/build/release/{}/build/libpaddle-mobile.so".format(architecture)) + push(mobile_src_root + "/build/release/{}/build/cl_kernel".format(architecture)) + push(mobile_src_root + "/test/build/test-net") + last_feed_var_shape = get_feed_var_shape(last_feed_var_name) + args = str(len(last_feed_var_shape)) + for dim in last_feed_var_shape: + args += " " + str(dim) + if is_lod: + args += " 1" + args += " " + str(len(last_feed_var_lod)) + for dim in last_feed_var_lod: + args += " " + str(dim) + else: + args += " 0" + args += " " + str(len(output_var_cache)) + args += " " + str(1 if is_sample_step else 0) + if is_sample_step: + args += " " + str(sample_step) + else: + args += " " + str(sample_num) + for var_name in output_var_cache.keys(): + args += " " + var_name + args += " " + str(1 if check_shape else 0) + # if not fast_check: + # check_mobile_results(args, False, False) + # check_mobile_results(args, False, True) + # check_mobile_results(args, True, False) + check_mobile_results(args, True, True) + +if __name__ == "__main__": + main() diff --git a/mobile/tools/quantification/tune_n_fold.py b/mobile/tools/quantification/tune_n_fold.py new file mode 100644 index 0000000000000000000000000000000000000000..6126a397b33f5c51cd2bcfad265e313e7fa84657 --- /dev/null +++ b/mobile/tools/quantification/tune_n_fold.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -* + +import os +import sys +import math +import subprocess +import numpy as np +import paddle.fluid as fluid + +def sh(command): + pipe = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + return pipe.stdout.read().decode("utf-8") + +for fold in range(100, 1001, 100): + print("checking fold : {}".format(fold)) + max_entropy = sh("./quantify 1 model params {}".format(fold)) + print("max entropy :", max_entropy, end="") + sh("rm -rf scripts/model") + sh("rm -rf scripts/quantification_model") + sh("cp -r model scripts/model") + sh("cp -r model scripts/quantification_model") + sh("mv params scripts/quantification_model") + diff = sh("cd scripts && python run.py {}".format(fold)) + print("output diff :", diff, end="")