diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4b7f4e375eb5147fad010af038677028a89f5f2..2a7235a9d653b0da544a006dda6f9a9c957364f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,8 @@ repos: sha: v1.0.1 hooks: - id: remove-crlf - files: (?!.*third_party)^.*$ | (?!.*book)^.*$ ^mobile/ ^metal/ ^web/ + files: (?!.*third_party)^.*$|(?!.*book)^.*$ + exclude: ^(mobile/|metal/|web/) #- repo: https://github.com/PaddlePaddle/mirrors-yapf.git #sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 #hooks: @@ -16,7 +17,7 @@ repos: - id: check-merge-conflict - id: check-symlinks - id: detect-private-key - files: (?!.*third_party)^.*$ | (?!.*book)^.*$ + files: (?!.*third_party)^.*$|(?!.*book)^.*$ - id: end-of-file-fixer - repo: local hooks: @@ -25,7 +26,8 @@ repos: description: Format files with ClangFormat. entry: bash ./tools/codestyle/clang_format.hook -i language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ ^mobile/ ^metal/ ^web/ + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ + exclude: ^(mobile/|metal/|web/) - repo: local hooks: - id: cpplint-cpp-source @@ -33,7 +35,8 @@ repos: description: Check C++ code style using cpplint.py. entry: bash ./tools/codestyle/cpplint_pre_commit.hook language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ ^mobile/ ^metal/ ^web/ + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ + exclude: ^(mobile/|metal/|web/) #- repo: local #hooks: #- id: pylint-doc-string @@ -48,5 +51,6 @@ repos: name: copyright_checker entry: python ./tools/codestyle/copyright.hook language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ ^mobile/ ^metal/ ^web/ - exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ + exclude: (?!.*third_party)^.*$|(?!.*book)^.*$ + exclude: ^(mobile/|metal/|web/) diff --git a/.travis.yml b/.travis.yml index 20fdddd5a172d63b6b3df3fb2a57265a08ed3732..c902afef91b816390170f1b7e1c8e4b07c7b0645 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,18 +9,17 @@ os: addons: apt: packages: - - git - - python - - python-pip - - python2.7-dev - - libc6-i386 - - curl - -compiler: - - clang +# - git +# - python +# - python-pip +# - python2.7-dev +# - libc6-i386 +# - curl + - clang-format-3.8 before_install: - - sudo pip install -U virtualenv pre-commit pip + - sudo pip install cpplint pre-commit + - sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format # Download and install recent cmake script: diff --git a/.travis/pre-commit-job.sh b/.travis/pre-commit-job.sh index a0ae98dddd27a7f24467ce2ce441aba9e4ffe156..cf4dd30659fe62eefa16b4365df4529ce78baa2a 100755 --- a/.travis/pre-commit-job.sh +++ b/.travis/pre-commit-job.sh @@ -11,6 +11,8 @@ cd `dirname $0` cd .. export PATH=/usr/bin:$PATH pre-commit install +which clang-format +clang-format --version if ! pre-commit run -a ; then ls -lh diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a23d869aa12ffc4534a3f6cb1fb32beed8aa283..a3336caa8463ceca536a81f53665a6809426514c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,8 @@ option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF) option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) option(LITE_SHUTDOWN_LOG "Shutdown log system or not." OFF) option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) +# publish options +option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kernels and operators" OFF) set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING "A path setting third party libraries download & build directories.") @@ -93,7 +95,7 @@ endif() # check options if (LITE_ON_TINY_PUBLISH) - if (NOT (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_JAVA AND NOT WITH_TESTING)) + if (NOT (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND NOT WITH_TESTING))#LITE_WITH_JAVA AND message(FATAL_ERROR "LITE_ON_TINY_PUBLISH=ON must be used with WITH_LITE=ON LITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON LITE_WITH_JAVA=ON WITH_TESTING=OFF") return() endif() diff --git a/cmake/cross_compiling/ios.cmake b/cmake/cross_compiling/ios.cmake index b8df182cd6dabc8b2ffc3dce5769b139329b18c1..76f62765aff791594123d689341b0876b3d0184d 100644 --- a/cmake/cross_compiling/ios.cmake +++ b/cmake/cross_compiling/ios.cmake @@ -127,6 +127,7 @@ elseif(ARM_TARGET_OS STREQUAL "ios64") else() return() endif() +add_definitions(-DTARGET_IOS) # if do not specify the ARM_TARGET_ARCH_ABI then use default all supported if(ARM_TARGET_ARCH_ABI STREQUAL "armv7" diff --git a/cmake/lite.cmake b/cmake/lite.cmake index 03d6cafcf98c9531ec77844a4723162b2657c251..89918b7cb94997d1a38cad81532e81f90756474f 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -57,6 +57,8 @@ function (lite_deps TARGET) endforeach(var) endif() + + if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) foreach(var ${lite_deps_HVY_DEPS}) set(deps ${deps} ${var}) @@ -182,9 +184,16 @@ function(lite_cc_test TARGET) set(oneValueArgs "") set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS - ARGS) + ARGS + COMPILE_LEVEL # (basic|extra) + ) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if (args_COMPILE_LEVEL STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) + MESSAGE(STATUS "Ignore test ${TARGET} due to compile level ${args_COMPILE_LEVEL}") + return() + endif() + set(deps "") lite_deps(deps DEPS ${args_DEPS} @@ -207,6 +216,117 @@ function(lite_cc_test TARGET) endif() endfunction() +set(arm_kernels CACHE INTERNAL "arm kernels") +set(x86_kernels CACHE INTERNAL "x86 kernels") +set(fpga_kernels CACHE INTERNAL "fpga kernels") +set(npu_kernels CACHE INTERNAL "npu kernels") +set(opencl_kernels CACHE INTERNAL "opencl kernels") +set(host_kernels CACHE INTERNAL "host kernels") + +set(kernels_src_list "${CMAKE_BINARY_DIR}/kernels_src_list.txt") +file(WRITE ${kernels_src_list} "") # clean +# add a kernel for some specific device +# device: one of (Host, ARM, X86, NPU, FPGA, OPENCL, CUDA) +# level: one of (basic, extra) +function(add_kernel TARGET device level) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS + ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) + return() + endif() + + if ("${device}" STREQUAL "Host") + set(host_kernels "${host_kernels};${TARGET}" CACHE INTERNAL "") + endif() + if ("${device}" STREQUAL "ARM") + if (NOT LITE_WITH_ARM) + return() + endif() + set(arm_kernels "${arm_kernels};${TARGET}" CACHE INTERNAL "") + endif() + if ("${device}" STREQUAL "X86") + if (NOT LITE_WITH_X86) + return() + endif() + set(x86_kernels "${x86_kernels};${TARGET}" CACHE INTERNAL "") + endif() + if ("${device}" STREQUAL "NPU") + if (NOT LITE_WITH_NPU) + return() + endif() + set(npu_kernels "${npu_kernels};${TARGET}" CACHE INTERNAL "") + endif() + if ("${device}" STREQUAL "FPGA") + if (NOT LITE_WITH_FPGA) + return() + endif() + set(fpga_kernels "${fpga_kernels};${TARGET}" CACHE INTERNAL "") + endif() + if ("${device}" STREQUAL "OPENCL") + if (NOT LITE_WITH_OPENCL) + return() + endif() + set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "") + endif() + + foreach(src ${args_SRCS}) + file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endforeach() + + lite_cc_library(${TARGET} SRCS ${args_SRCS} + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + CL_DEPS ${args_CL_DEPS} + ARM_DEPS ${args_ARM_DEPS} + FPGA_DEPS ${args_FPGA_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) +endfunction() + +set(ops CACHE INTERNAL "ops") +set(ops_src_list "${CMAKE_BINARY_DIR}/ops_src_list.txt") +file(WRITE ${ops_src_list} "") # clean +# add an operator +# level: one of (basic, extra) +function(add_operator TARGET level) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS + ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) + return() + endif() + + set(ops "${ops};${TARGET}" CACHE INTERNAL "source") + + foreach(src ${args_SRCS}) + file(APPEND ${ops_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") + endforeach() + + lite_cc_library(${TARGET} SRCS ${args_SRCS} + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + CL_DEPS ${args_CL_DEPS} + ARM_DEPS ${args_ARM_DEPS} + FPGA_DEPS ${args_FPGA_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) +endfunction() + # Bundle several static libraries into one. function(bundle_static_library tgt_name bundled_tgt_name fake_target) diff --git a/cmake/system.cmake b/cmake/system.cmake index 65db05bebe957d740e391847d980e211b0e9e750..ba00df928a0c52bfe05f4d3f6d7af2a50d2576f9 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -32,7 +32,11 @@ ELSE(WIN32) SET(CMAKE_OSX_DEPLOYMENT_TARGET ${MACOS_VERSION} CACHE STRING "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value.") ENDIF() - set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + IF(ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux" + OR ARM_TARGET_OS STREQUAL "ios" OR ARM_TARGET_OS STREQUAL "ios64") + ELSE() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + ENDIF() ELSE(APPLE) IF(EXISTS "/etc/issue") diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index be23ff25403efbf7590607fce5e1e4f11956d533..8c473904808454aa75dbd2fabfe5cc0bee75eff0 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -13,7 +13,6 @@ set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") set(LITE_ON_MOBILE ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}) - add_subdirectory(utils) add_subdirectory(operators) add_subdirectory(kernels) @@ -78,14 +77,16 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin" COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin" ) - add_dependencies(publish_inference_cxx_lib model_optimize_tool) - add_dependencies(publish_inference_cxx_lib paddle_code_generator) - add_dependencies(publish_inference_cxx_lib bundle_full_api) - add_dependencies(publish_inference_cxx_lib bundle_light_api) - add_dependencies(publish_inference_cxx_lib test_model_bin) - add_dependencies(publish_inference publish_inference_cxx_lib) - add_custom_command(TARGET publish_inference_cxx_lib POST_BUILD - COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.a) + if(NOT IOS) + add_dependencies(publish_inference_cxx_lib model_optimize_tool) + add_dependencies(publish_inference_cxx_lib paddle_code_generator) + add_dependencies(publish_inference_cxx_lib bundle_full_api) + add_dependencies(publish_inference_cxx_lib bundle_light_api) + add_dependencies(publish_inference_cxx_lib test_model_bin) + add_dependencies(publish_inference publish_inference_cxx_lib) + add_custom_command(TARGET publish_inference_cxx_lib POST_BUILD + COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.a) + endif() endif() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 5212d7a4ca763bd582e829e190fdf7ad56d78da5..85097a3e42c18ca3d154ef34783b68c90ced975b 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -17,6 +17,7 @@ if(LITE_WITH_FPGA) endif() message(STATUS "get ops ${ops}") +message(STATUS "get X86 kernels ${x86_kernels}") message(STATUS "get Host kernels ${host_kernels}") message(STATUS "get ARM kernels ${arm_kernels}") message(STATUS "get NPU kernels ${npu_kernels}") @@ -117,7 +118,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) add_dependencies(test_mobilenetv1 extern_lite_download_mobilenet_v1_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") set_target_properties(test_mobilenetv1 PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - + lite_cc_test(test_mobilenetv2 SRCS mobilenetv2_test.cc DEPS ${lite_model_test_DEPS} CL_DEPS ${opencl_kernels} @@ -125,7 +126,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) --model_dir=${LITE_MODEL_DIR}/mobilenet_v2_relu SERIAL) add_dependencies(test_mobilenetv2 extern_lite_download_mobilenet_v2_relu_tar_gz) set_target_properties(test_mobilenetv2 PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - + lite_cc_test(test_resnet50 SRCS resnet50_test.cc DEPS ${lite_model_test_DEPS} CL_DEPS ${opencl_kernels} @@ -145,8 +146,13 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL) add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz) -# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc -# DEPS ${lite_model_test_DEPS}) + # lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc + # DEPS ${lite_model_test_DEPS}) + + # lite_cc_test(model_run_test_image SRCS model_run_test_image.cc + # DEPS ${lite_model_test_DEPS} + # CL_DEPS ${opencl_kernels} + # FPGA_DEPS ${fpga_kernels}) endif() # These tests needs CLI arguments, and is not supported in ARM CI. @@ -169,7 +175,11 @@ lite_cc_library(paddle_api SRCS paddle_api.cc DEPS op_params tensor) #----------------------------------------------------------------------------------------------------- # The final inference library for both CxxConfig and MobileConfig. -lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api) +if (LITE_ON_TINY_PUBLISH) + lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api stream) +else() + lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api) +endif() if (NOT LITE_ON_TINY_PUBLISH) lite_cc_library(paddle_api_full SRCS cxx_api_impl.cc DEPS cxx_api paddle_api light_api ${ops} diff --git a/lite/api/paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h similarity index 98% rename from lite/api/paddle_use_kernels.h rename to lite/api/_paddle_use_kernels.h index 1f2d4229c64d1ec91643609cd55433b335eff03a..16924fbb0b0952411c6d73e675ecd57fc0236b92 100644 --- a/lite/api/paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -21,6 +21,8 @@ #ifndef LITE_WITH_FPGA USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_KERNEL(flatten, kHost, kAny, kAny, def); +USE_LITE_KERNEL(flatten2, kHost, kAny, kAny, def); #else USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def); USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); diff --git a/lite/api/paddle_use_ops.h b/lite/api/_paddle_use_ops.h similarity index 97% rename from lite/api/paddle_use_ops.h rename to lite/api/_paddle_use_ops.h index be50801ed72d75e869209aece850d5fbab69dc25..5e3c5f2e28037e71d4d6a7053f7e0d8559531807 100644 --- a/lite/api/paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -73,9 +73,12 @@ USE_LITE_OP(prior_box) USE_LITE_OP(density_prior_box) USE_LITE_OP(reshape) USE_LITE_OP(reshape2) +USE_LITE_OP(flatten) +USE_LITE_OP(flatten2) USE_LITE_OP(split) USE_LITE_OP(fake_quantize_moving_average_abs_max); USE_LITE_OP(fake_dequantize_max_abs); +USE_LITE_OP(fake_quantize_range_abs_max); USE_LITE_OP(calib); USE_LITE_OP(calib_once); USE_LITE_OP(norm); diff --git a/lite/api/android/jni/native/CMakeLists.txt b/lite/api/android/jni/native/CMakeLists.txt index 0d9f466fbd6dfe71d11dfc2f863a810dfb1e8014..afe051a437f4de83931bdaa3f2d03427b78d13ad 100644 --- a/lite/api/android/jni/native/CMakeLists.txt +++ b/lite/api/android/jni/native/CMakeLists.txt @@ -20,7 +20,7 @@ if (NOT LITE_ON_TINY_PUBLISH) else() add_library(paddle_lite_jni SHARED "") target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc) - #add_dependencies(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels}) + add_dependencies(paddle_lite_jni op_list_h kernel_list_h) endif() if (APPLE) diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index 42f89e7e66f36c1e19ec59d999f0b93d2c5ec08c..832194e6d41e3c70ba2c7a9a6b452885264cc085 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -30,6 +30,9 @@ DEFINE_string(input_shape, "1,3,224,224", "input shapes, separated by colon and comma"); DEFINE_string(result_filename, "", "save test result"); +DEFINE_bool(run_model_optimize, + false, + "apply model_optimize_tool to model, use optimized model to test"); namespace paddle { namespace lite_api { @@ -69,10 +72,10 @@ void Run(const std::vector>& input_shapes, #ifdef LITE_WITH_ARM lite::DeviceInfo::Init(); if (thread_num == 1) { - lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_HIGH, thread_num); + lite::DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, thread_num); LOG(INFO) << "LITE_POWER_HIGH"; } else { - lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_NO_BIND, thread_num); + lite::DeviceInfo::Global().SetRunMode(LITE_POWER_NO_BIND, thread_num); LOG(INFO) << "LITE_POWER_NO_BIND"; } #endif @@ -172,13 +175,17 @@ int main(int argc, char** argv) { } // Output optimized model - paddle::lite_api::OutputOptModel( - FLAGS_model_dir, save_optimized_model_dir, input_shapes); + if (FLAGS_run_model_optimize) { + paddle::lite_api::OutputOptModel( + FLAGS_model_dir, save_optimized_model_dir, input_shapes); + } #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK // Run inference using optimized model + std::string run_model_dir = + FLAGS_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir; paddle::lite_api::Run(input_shapes, - save_optimized_model_dir, + run_model_dir, FLAGS_repeats, FLAGS_threads, FLAGS_warmup, diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 36529ecf30003f5749eb2160ebe856d77f5539b4..622db412853cd780d6e2d2b00ec6c3c3fa788ae3 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -71,6 +71,13 @@ const lite::Tensor *Predictor::GetOutput(size_t offset) const { return &fetch_list.at(offset); } +const std::vector *Predictor::GetOutputs() const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + return &fetch_list; +} + const cpp::ProgramDesc &Predictor::program_desc() const { return program_desc_; } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 5d94a75bb1233ffc157f06096dfd32c9848951f6..d664565993c80d3853907906f53672d9b7df4a71 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -69,6 +69,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; + const std::vector* GetOutputs() const; const cpp::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; diff --git a/lite/api/efficientnet_b0_test.cc b/lite/api/efficientnet_b0_test.cc index 14e5e956511b70d37edb8cc3e017597454196b24..aab41fcf0df1f0060aa2c3411e34f604c6b29b12 100644 --- a/lite/api/efficientnet_b0_test.cc +++ b/lite/api/efficientnet_b0_test.cc @@ -28,7 +28,7 @@ namespace lite { void TestModel(const std::vector &valid_places, const Place &preferred_place) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/inceptionv4_test.cc b/lite/api/inceptionv4_test.cc index 9b23a3ba4ef8adb254d0d5c0de82523836d61d8a..c81933deea77776d91031439c9a2d2f30557e125 100644 --- a/lite/api/inceptionv4_test.cc +++ b/lite/api/inceptionv4_test.cc @@ -28,7 +28,7 @@ namespace lite { #ifdef LITE_WITH_ARM TEST(InceptionV4, test) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}}); diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 7020b9b0e82d6491658e088cae0558a40ded9862..545c7f4829eaa457813ee9db12f1d4f75507feab 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -40,6 +40,10 @@ class LightPredictorImpl : public PaddlePredictor { void LightPredictorImpl::Init(const MobileConfig& config) { // LightPredictor Only support NaiveBuffer backend in publish lib +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Init(); + lite::DeviceInfo::Global().SetRunMode(config.power_mode(), config.threads()); +#endif raw_predictor_.reset(new lite::LightPredictor(config.model_dir(), LiteModelType::kNaiveBuffer)); } diff --git a/lite/api/mobilenetv1_int8_test.cc b/lite/api/mobilenetv1_int8_test.cc index 7a87e11819a35975e789335b146539ae75eb228f..5bf40fe69835b36f0c980dcc5840d5b9dd4c4e91 100644 --- a/lite/api/mobilenetv1_int8_test.cc +++ b/lite/api/mobilenetv1_int8_test.cc @@ -29,7 +29,7 @@ void TestModel(const std::vector& valid_places, const Place& preferred_place, bool use_npu = false) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/mobilenetv1_ssd_test.cc b/lite/api/mobilenetv1_ssd_test.cc index 9f8ab4624104048d5f564b01e08be203f469a75f..921b17d67be4bb055c4ffadcf1b646e21201cd07 100644 --- a/lite/api/mobilenetv1_ssd_test.cc +++ b/lite/api/mobilenetv1_ssd_test.cc @@ -29,7 +29,7 @@ namespace lite { void TestModel(const std::vector& valid_places, const Place& preferred_place) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/mobilenetv1_test.cc b/lite/api/mobilenetv1_test.cc index fb40ccf7c6eaadecce3fc54a61786f096a75cff4..e97730b757a6df627b052c0785256df2e7804e4a 100644 --- a/lite/api/mobilenetv1_test.cc +++ b/lite/api/mobilenetv1_test.cc @@ -33,7 +33,7 @@ void TestModel(const std::vector& valid_places, bool gen_npu = false, bool save_model = false) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(model_dir, preferred_place, valid_places); diff --git a/lite/api/mobilenetv1_yolov3_test.cc b/lite/api/mobilenetv1_yolov3_test.cc index ec373fb115d0f8e6f855d435b0b568b709a6d485..cf37aefe556c691b3879c8524c402ec7f5e93758 100644 --- a/lite/api/mobilenetv1_yolov3_test.cc +++ b/lite/api/mobilenetv1_yolov3_test.cc @@ -29,7 +29,7 @@ namespace lite { void TestModel(const std::vector& valid_places, const Place& preferred_place) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/mobilenetv2_test.cc b/lite/api/mobilenetv2_test.cc index 380d6a1fb582bbc4add8cc3bba2e20167e5fbb1d..737caccc9c6296ca778a4f5760e79d9fc8216869 100644 --- a/lite/api/mobilenetv2_test.cc +++ b/lite/api/mobilenetv2_test.cc @@ -34,7 +34,7 @@ void TestModel(const std::vector& valid_places, bool gen_npu = false, bool save_model = false) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(model_dir, preferred_place, valid_places); diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc index cb29d1b8fd323e9268ed0035942fade4a9d575a4..1162cd1e1095a353cb2151f0889c3acb2f07d772 100644 --- a/lite/api/model_optimize_tool.cc +++ b/lite/api/model_optimize_tool.cc @@ -33,7 +33,7 @@ DEFINE_string(valid_targets, "arm", "The targets this model optimized for, should be one of (arm, " "opencl, x86), splitted by space"); -DEFINE_bool(int8_mode, false, "Support Int8 quantitative mode"); +DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); namespace paddle { namespace lite_api { @@ -62,7 +62,7 @@ void Main() { CHECK(!valid_places.empty()) << "At least one target should be set, should set the " "command argument 'valid_targets'"; - if (FLAGS_int8_mode) { + if (FLAGS_prefer_int8_kernel) { LOG(WARNING) << "Int8 mode is only support by ARM target"; valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)}); config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)}); diff --git a/lite/api/model_run_test_image.cc b/lite/api/model_run_test_image.cc new file mode 100644 index 0000000000000000000000000000000000000000..25184879906d0385bdf64083001b5bdbeb4ffae5 --- /dev/null +++ b/lite/api/model_run_test_image.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +TEST(model, test) { +#ifdef LITE_WITH_ARM + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt8)}}); + + auto precision = PRECISION(kFloat); + if (FLAGS_int8) { + precision = PRECISION(kInt8); + } + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), precision}, valid_places); + int im_width = FLAGS_im_width; + int im_height = FLAGS_im_height; + auto* input_tensor = predictor.GetInput(0); + auto in_dims = input_tensor->dims(); + input_tensor->Resize( + DDim(std::vector({1, 3, im_width, im_height}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + auto* output_tensors = predictor.GetOutputs(); + + LOG(INFO) << "======output:========"; + for (auto t : *output_tensors) { + LOG(INFO) << t; + } + LOG(INFO) + << "=====RUN_finished!!============= Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc index cf350ee0742f64daf00a421ada860c097235a3fd..271fe4a330a373a9007e78f890a68b005f38d15a 100644 --- a/lite/api/model_test.cc +++ b/lite/api/model_test.cc @@ -64,7 +64,7 @@ void Run(const std::vector>& input_shapes, const int warmup_times = 0) { #ifdef LITE_WITH_ARM lite::DeviceInfo::Init(); - lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_HIGH, thread_num); + lite::DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, thread_num); #endif lite_api::MobileConfig config; config.set_model_dir(model_dir); diff --git a/lite/api/ocr_attention_test.cc b/lite/api/ocr_attention_test.cc index 26cdde3ea7950abf5218439f119fd108aef8545f..336dad2791342723d973fb9bc8385dcb422a87e4 100644 --- a/lite/api/ocr_attention_test.cc +++ b/lite/api/ocr_attention_test.cc @@ -29,7 +29,7 @@ void TestModel(const std::vector& valid_places, const Place& preferred_place, bool use_npu = false) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 62df111e0aceafb5167b76d74e600926b37fd560..b728b7c482e8bae0290c6a189f71876bac957215 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -110,7 +110,18 @@ class LITE_API CxxConfig : public ConfigBase { /// MobileConfig is the config for the light weight predictor, it will skip /// IR optimization or other unnecessary stages. -class LITE_API MobileConfig : public ConfigBase {}; +class LITE_API MobileConfig : public ConfigBase { + PowerMode mode_{LITE_POWER_HIGH}; + int threads_{1}; +public: + MobileConfig(Place preferred_place=Place(TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW)), + PowerMode mode=LITE_POWER_HIGH, int threads=1) : mode_(mode), threads_(threads) {} + void set_power_mode(PowerMode mode) { mode_ = mode; } + void set_threads(int threads) { threads_ = threads; } + + PowerMode power_mode() const { return mode_; } + int threads() const { return threads_; } +}; template std::shared_ptr CreatePaddlePredictor(const ConfigT&); diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 4a75539d3a082401ab33588ef576c597e14743f1..f7fc29e7d6c5a902ab7d7a4f18e314885aaf2ac0 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -70,6 +70,14 @@ enum class DataLayoutType : int { kAny = 2, // any data layout NUM = 4, // number of fields. }; +typedef enum { + LITE_POWER_HIGH = 0, + LITE_POWER_LOW = 1, + LITE_POWER_FULL = 2, + LITE_POWER_NO_BIND = 3, + LITE_POWER_RAND_HIGH = 4, + LITE_POWER_RAND_LOW = 5 +} PowerMode; enum class ActivationType : int { kIndentity = 0, diff --git a/lite/api/resnet18_test.cc b/lite/api/resnet18_test.cc index ad8248160c8930dd116ce279ec203a39151e7ff9..5176ad8e4cb95f1173952a6593e41b1fb8450431 100644 --- a/lite/api/resnet18_test.cc +++ b/lite/api/resnet18_test.cc @@ -28,7 +28,7 @@ namespace lite { #ifdef LITE_WITH_ARM TEST(ResNet18, test) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}}); diff --git a/lite/api/resnet50_test.cc b/lite/api/resnet50_test.cc index 75404d173fff59615a7aefbe810268ee1eb3b571..098e5988ad3aa2f9d77d81c90ee298496b67c828 100644 --- a/lite/api/resnet50_test.cc +++ b/lite/api/resnet50_test.cc @@ -29,7 +29,7 @@ namespace lite { void TestModel(const std::vector& valid_places, const Place& preferred_place) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/shufflenetv2_test.cc b/lite/api/shufflenetv2_test.cc index e3b119ec7a3bd1c58b69c3d12113ef3a36c5139a..bba6b72d8f0c975c6334d5848c08702d9de50c20 100644 --- a/lite/api/shufflenetv2_test.cc +++ b/lite/api/shufflenetv2_test.cc @@ -28,7 +28,7 @@ namespace lite { void TestModel(const std::vector& valid_places, const Place& preferred_place) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; predictor.Build(FLAGS_model_dir, preferred_place, valid_places); diff --git a/lite/api/test_helper.h b/lite/api/test_helper.h index 1a5ab31abd3e97c5bfc484547af5d36d53e49b39..d835c030f03a3c95575217020cd298dabbf1a15a 100644 --- a/lite/api/test_helper.h +++ b/lite/api/test_helper.h @@ -23,6 +23,9 @@ DEFINE_string(model_dir, "", "model dir"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(im_width, 224, "image width"); +DEFINE_int32(im_height, 224, "image height"); +DEFINE_bool(int8, false, "is run int8"); namespace paddle { namespace lite { diff --git a/lite/api/unet_test.cc b/lite/api/unet_test.cc index e1d8c9ec1e2535ec016f2ce41e01d83f32d5a357..f330bf065d23d82d0fd4b2b16e16f69ca65f6b42 100644 --- a/lite/api/unet_test.cc +++ b/lite/api/unet_test.cc @@ -28,7 +28,7 @@ namespace lite { #ifdef LITE_WITH_ARM TEST(unet, test) { DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); lite::Predictor predictor; std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}}); diff --git a/lite/arm/CMakeLists.txt b/lite/arm/CMakeLists.txt index 8abd04b52338299f75399903aa68fe834ce81d04..2767b4e7ae22881fcf0b2025b3c715fd09d01731 100644 --- a/lite/arm/CMakeLists.txt +++ b/lite/arm/CMakeLists.txt @@ -1,2 +1 @@ - add_subdirectory(math) diff --git a/lite/arm/math/CMakeLists.txt b/lite/arm/math/CMakeLists.txt index 9924425609df49ab22fd73d763d58c95534590b7..981ca1b6fb65dfb210227713fd4e410402586640 100644 --- a/lite/arm/math/CMakeLists.txt +++ b/lite/arm/math/CMakeLists.txt @@ -65,7 +65,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR) conv_direct_3x3s1.cc conv_direct_3x3s2.cc conv_direct.cc - conv_depthwise_3x3_int7.cc conv_depthwise_3x3_int8.cc conv_depthwise_5x5s1_int8.cc conv_depthwise_3x3p0.cc diff --git a/lite/arm/math/conv_depthwise_3x3_int7.cc b/lite/arm/math/conv_depthwise_3x3_int7.cc deleted file mode 100644 index 18dd2225ae6a2cb9353e4f476f5f55236cd270ef..0000000000000000000000000000000000000000 --- a/lite/arm/math/conv_depthwise_3x3_int7.cc +++ /dev/null @@ -1,5322 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p1_bias_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 8 -void conv_depthwise_3x3s1p1_bias_s_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 8 -void conv_depthwise_3x3s2p1_bias_s_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s1p1_bias_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p1_bias_s_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p1_bias_s_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3_int7(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - ARMContext* ctx, - PrecisionType out_type, - const float* scale) { - int w_in = win; - int h_in = hin; - int ch_in = chin; - - int w_out = wout; - int h_out = hout; - int ch_out = chout; - int stride_h = param.strides[0]; - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active) { - // if (param.activation_param.active == Active_relu || - // fabs(param.activation_param.negative_slope) > 1e-6f) { - // flag_relu = true; - // } - // } - //! only support stride = 1 or 2 - if (stride_h == 1) { - if (flag_relu) { - if (w_in > 8) { - conv_depthwise_3x3s1p1_bias_relu_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_relu_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 8) { - conv_depthwise_3x3s1p1_bias_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } else { //! stride = 2 - if (flag_relu) { - if (w_in > 16) { - conv_depthwise_3x3s2p1_bias_relu_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s_relu_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 16) { - conv_depthwise_3x3s2p1_bias_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s_int7(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ - -// 4line w_in > 8 -void conv_depthwise_3x3s1p1_bias_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - const unsigned char right_pad_idx[16] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 7) >> 3; - int tile_h = (h_out + 1) >> 1; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); - - int size_pad_bottom = h_out % 2; - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), - // vld1_u8(right_pad_idx + 8)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; - -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "sub %[din_ptr0], %[din_ptr0], #1 \n" - "sub %[din_ptr1], %[din_ptr1], #1 \n" - - "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ - - "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "sub %[din_ptr2], %[din_ptr2], #1 \n" - "sub %[din_ptr3], %[din_ptr3], #1 \n" - - "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, - 7); 00123456 */ - "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 - */ - "smlal v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 - */ - "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 - */ - - // r2 - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 - */ - - // r3 - "smlal v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 to - q0*/ - - "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * - w11 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * - w11 */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ - "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ - "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ - - "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ - "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ - - "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - // r1 - "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ - "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - "smlal v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 - to q0*/ - - "smlal v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ - "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - - "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ - - // r3 - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "subs %[cnt], %[cnt], #1 \n" - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v21.8b, v14.8b \n" - "bif v1.8b, v21.8b, v15.8b \n" - "bif v2.8b, v21.8b, v14.8b \n" - "bif v3.8b, v21.8b, v15.8b \n" - - "ext v4.8b, v0.8b, v1.8b, #1 \n" - "ext v5.8b, v0.8b, v1.8b, #2 \n" - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v16.8b, v2.8b, v3.8b, #1 \n" - "ext v17.8b, v2.8b, v3.8b, #2 \n" - - "bif v6.8b, v21.8b, v14.8b \n" - "bif v7.8b, v21.8b, v15.8b \n" - - "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "bif v8.8b, v21.8b, v14.8b \n" - "bif v9.8b, v21.8b, v15.8b \n" - - "ext v20.8b, v6.8b, v7.8b, #1 \n" - "ext v22.8b, v6.8b, v7.8b, #2 \n" - - "smlal v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - // r1 - "ext v4.8b, v8.8b, v9.8b, #1 \n" - "ext v5.8b, v8.8b, v9.8b, #2 \n" - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v14.4s}, [%[rmask]], #16 \n" - "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" - "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" - - "smlal v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v1.4s}, [%[ptr_out0]] \n" - "ld1 {v3.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "sub %[ptr_out0], %[ptr_out0], #16 \n" - "sub %[ptr_out1], %[ptr_out1], #16 \n" - - "smlal v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - // r3 - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "bif v10.16b, v0.16b, v14.16b \n" - "bif v11.16b, v1.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "bif v12.16b, v2.16b, v14.16b \n" - "bif v13.16b, v3.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [ptr_out0] "+r"(doutr0), - [ptr_out1] "+r"(doutr1), - [vmask] "+r"(val_mask), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "add %[din_ptr0], #7 @add \n" - "add %[din_ptr1], #7 @add \n" - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #7 @add \n" - "add %[din_ptr3], #7 @add \n" - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "add %[din_ptr0], #8 @add \n" - "add %[din_ptr1], #8 @add \n" - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #8 @add \n" - "add %[din_ptr3], #8 @add \n" - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "subs %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - - "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 - - "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d12, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with " - "right pad\n" - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 - "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 - - "vbif.8 d14, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with " - "right pad\n" - - "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 - "sub %[dout_ptr1], #16 @ sub \n" - "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 - "vbif q8, q14, q1 @ bit select, deal with right " - "pad\n" - "vbif q9, q6, q2 @ bit select, deal with right " - "pad\n" - "sub %[dout_ptr2], #16 @ sub \n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vbif q10, q7, q1 @ bit select, deal with right pad\n" - "vbif q11, q12, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += 2 * w_out; - } - } - } -} - -// w_in <= 8 -void conv_depthwise_3x3s1p1_bias_s_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_h = (h_out + 1) >> 1; - - unsigned int size_pad_right = (unsigned int)(w_in); - - uint8x8_t vmask_rp = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), - // vld1_u8(right_pad_idx + 8)); - unsigned char vmask[8]; - vst1_u8(vmask, vmask_rp); - - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - - int out_buf1[8]; - int out_buf2[8]; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v4.8b}, [%[vmask]] \n" - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "bif v0.8b, v21.8b, v4.8b \n" - "bif v1.8b, v21.8b, v4.8b \n" - "bif v2.8b, v21.8b, v4.8b \n" - "bif v3.8b, v21.8b, v4.8b \n" - - "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v10.4s}, [%[vbias]] \n" - "ld1 {v11.4s}, [%[vbias]] \n" - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v12.4s}, [%[vbias]] \n" - "ld1 {v13.4s}, [%[vbias]] \n" - - "smlal v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v14.4s}, [%[rmask]], #16 \n" - // "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" - // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" - - "smlal v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, - // 1); 12345678 - - // "ld1 {v0.4s}, [%[ptr_out0]] \n" - // "ld1 {v1.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - // "sub %[ptr_out0], %[ptr_out0], #16 \n" - // "sub %[ptr_out1], %[ptr_out1], #16 \n" - - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - // r3 - "smlal v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - // "bif v10.16b, v16.16b, v14.16b \n" - // "bif v11.16b, v0.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out - */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // "bif v12.16b, v17.16b, v14.16b \n" - // "bif v13.16b, v1.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out */ - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [vbias] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [vmask] "r"(vmask), - [ptr_out0] "r"(out_buf1), - [ptr_out1] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" - "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 - - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - // "sub %[dout_ptr1], #16 @ sub \n" - "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - - // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 - // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 - // 2 3 4 5 6 7 8 9\n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - - // "vbif q8, q14, q1 @ bit select, deal with right - // pad\n" "vbif q9, q6, q2 @ bit select, deal - // with right pad\n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - // "sub %[dout_ptr2], #16 @ sub \n" - - "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // "vbif q10, q3, q1 @ bit select, deal with right - // pad\n" "vbif q11, q7, q2 @ bit select, deal - // with right pad\n" - - "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" - // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [dout_ptr1] "r"(out_buf1), - [dout_ptr2] "r"(out_buf2) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - dout_ptr += 2 * w_out; - } - } - } -} - -// 4line w_in > 16 -void conv_depthwise_3x3s2p1_bias_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 15) >> 4; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); - if (size_pad_right == 17) { - size_pad_right = 0; - cnt_col++; - } - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - // printf("cnt_col: %d, rst_remain: %d, size_pad_right: %d\n", cnt_col, - // rst_remain, size_pad_right); - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } -#ifdef __aarch64__ - int cnt = cnt_col; - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v10.4s, #0x0\n" - // left - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - "add %[din_ptr0], %[din_ptr0], #15 \n" - "add %[din_ptr1], %[din_ptr1], #15 \n" - "add %[din_ptr2], %[din_ptr2], #15 \n" - - // r1 - "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - // r2 - "smlal v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smlal v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smlal v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ - "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ - "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ - - "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - - // r0 - "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ - - // r2 - "smlal v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smlal v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smlal v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ - - "subs %[cnt], %[cnt], #1 \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v10.8b, v14.8b \n" - "bif v1.8b, v10.8b, v15.8b \n" - "bif v2.8b, v10.8b, v14.8b \n" - "bif v3.8b, v10.8b, v15.8b \n" - "bif v4.8b, v10.8b, v14.8b \n" - "bif v5.8b, v10.8b, v15.8b \n" - - "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468..*/ - "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - - // r0 - "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ - - // r2 - "smlal v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smlal v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smlal v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ - - "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ - "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "bif v12.16b, v0.16b, v9.16b \n" - "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [ptr_out0] "+r"(doutr0), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); -#else - unsigned int* rst_mask = rmask; - int cnt = cnt_col; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r2 - "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - "add %[din_ptr0], #15 @add \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "add %[din_ptr1], #15 @add \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - "add %[din_ptr2], #15 @add \n" - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - - "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 - // 4 6 - "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 - // 4 6 - "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 - // 4 6 - - "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 - // 4 6 8 - "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 - "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 - - // r2 - "vmlal.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 - - // "add %[din_ptr0], #16 @add \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - // "add %[din_ptr1], #16 @add \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - // "add %[din_ptr2], #16 @add \n" - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - - "subs %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "cmp %[size_pad_right], #1 \n" - "blt 3f \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 - // 1 3 5 - "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - - // r2 - "vmlal.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vbif q11, q6, q1 @ bit select, deal with right pad\n" - "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "3: \n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [dout_ptr1] "+r"(doutr0), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += w_out; - } - } - } -} -// w_in <= 16 -void conv_depthwise_3x3s2p1_bias_s_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - unsigned int size_pad_right = (unsigned int)(w_in); - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - - int out_buf1[8]; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr2 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } -#ifdef __aarch64__ - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v16.4s, #0x0\n" - // left - "ld1 {v10.8b}, [%[vmask]], #8 \n" - "ld1 {v11.8b}, [%[vmask]] \n" - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "bif v0.8b, v16.8b, v10.8b \n" - "bif v1.8b, v16.8b, v11.8b \n" - "bif v2.8b, v16.8b, v10.8b \n" - "bif v3.8b, v16.8b, v11.8b \n" - "bif v4.8b, v16.8b, v10.8b \n" - "bif v5.8b, v16.8b, v11.8b \n" - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, - // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* - // dup v10, bias */ - - // r1 - "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - // r2 - "smlal v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smlal v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smlal v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - // "bif v12.16b, v0.16b, v10.16b \n" - // "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out - */ - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask), - [ptr_out0] "r"(out_buf1) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - unsigned int* rst_mask = rmask; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // "pld [%[dout_ptr1]] @ preload data\n" - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r2 - "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - - // "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - // "vbif q11, q6, q1 @ bit select, deal with right pad\n" - // "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [size_pad_right] "r"(size_pad_right), - [dout_ptr1] "r"(out_buf1) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - } - dout_ptr += w_out; - } - } - } -} - -// relu -void conv_depthwise_3x3s1p1_bias_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 7) >> 3; - int tile_h = (h_out + 1) >> 1; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); - - int size_pad_bottom = h_out % 2; - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), - // vld1_u8(right_pad_idx + 8)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "sub %[din_ptr0], %[din_ptr0], #1 \n" - "sub %[din_ptr1], %[din_ptr1], #1 \n" - - "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ - - "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "sub %[din_ptr2], %[din_ptr2], #1 \n" - "sub %[din_ptr3], %[din_ptr3], #1 \n" - - "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, - 7); 00123456 */ - "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 - */ - "smlal v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 - */ - "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 - */ - - // r2 - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 to - q0*/ - - "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 - */ - - "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 - */ - "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 - */ - - // r3 - "smlal v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 - */ - - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 to - q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 to - q0*/ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * - w11 */ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ - "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * - w11 */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ - "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ - "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ - "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ - - "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ - "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ - - "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - // r1 - "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ - "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - - "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - "smlal v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load - a00-a015 - to q0*/ - - "smlal v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ - "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ - - "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ - - // r3 - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ - - "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load - a00-a015 - to q0*/ - "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load - a00-a015 - to q0*/ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ - "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "subs %[cnt], %[cnt], #1 \n" - - "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ - "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v21.8b, v14.8b \n" - "bif v1.8b, v21.8b, v15.8b \n" - "bif v2.8b, v21.8b, v14.8b \n" - "bif v3.8b, v21.8b, v15.8b \n" - - "ext v4.8b, v0.8b, v1.8b, #1 \n" - "ext v5.8b, v0.8b, v1.8b, #2 \n" - - // r0 - "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v16.8b, v2.8b, v3.8b, #1 \n" - "ext v17.8b, v2.8b, v3.8b, #2 \n" - - "bif v6.8b, v21.8b, v14.8b \n" - "bif v7.8b, v21.8b, v15.8b \n" - - "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "bif v8.8b, v21.8b, v14.8b \n" - "bif v9.8b, v21.8b, v15.8b \n" - - "ext v20.8b, v6.8b, v7.8b, #1 \n" - "ext v22.8b, v6.8b, v7.8b, #2 \n" - - "smlal v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - // r1 - "ext v4.8b, v8.8b, v9.8b, #1 \n" - "ext v5.8b, v8.8b, v9.8b, #2 \n" - - "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v14.4s}, [%[rmask]], #16 \n" - "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" - "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" - - "smlal v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v1.4s}, [%[ptr_out0]] \n" - "ld1 {v3.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "sub %[ptr_out0], %[ptr_out0], #16 \n" - "sub %[ptr_out1], %[ptr_out1], #16 \n" - - "smlal v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 - */ - - // r3 - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 - */ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ - "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ - - "bif v10.16b, v0.16b, v14.16b \n" - "bif v11.16b, v1.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ - "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ - - "bif v12.16b, v2.16b, v14.16b \n" - "bif v13.16b, v3.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [ptr_out0] "+r"(doutr0), - [ptr_out1] "+r"(doutr1), - [vmask] "+r"(val_mask), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "add %[din_ptr0], #7 @add \n" - "add %[din_ptr1], #7 @add \n" - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #7 @add \n" - "add %[din_ptr3], #7 @add \n" - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 - "vmov.u32 q0, #0 @ mov \n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "add %[din_ptr0], #8 @add \n" - "add %[din_ptr1], #8 @add \n" - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "add %[din_ptr2], #8 @add \n" - "add %[din_ptr3], #8 @add \n" - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "subs %[cnt], #1 \n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = vbias - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // r0 - "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 - - "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - - "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 - - "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vbif.8 d12, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with " - "right pad\n" - - "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 - "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 - - "vbif.8 d14, d11, d28 @ bit select, deal with " - "right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with " - "right pad\n" - - "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 - "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 - "vbif q8, q14, q1 @ bit select, deal with right " - "pad\n" - "vbif q9, q6, q2 @ bit select, deal with right " - "pad\n" - "sub %[dout_ptr1], #16 @ sub \n" - "sub %[dout_ptr2], #16 @ sub \n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - "vbif q10, q7, q1 @ bit select, deal with right pad\n" - "vbif q11, q12, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" - "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += 2 * w_out; - } - } - } -} -// w_in <= 8 -void conv_depthwise_3x3s1p1_bias_s_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s1 mult height \n"); - //! pad is done implicit - const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_in; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_h = (h_out + 3) >> 2; - - unsigned int size_pad_right = (unsigned int)(w_in); - - int size_pad_bottom = h_out % 4; - - uint8x8_t vmask_rp = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - unsigned char vmask[8]; - vst1_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - int* doutr1 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - const signed char* dr3 = dr2 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - const signed char* din_ptr3 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - - int out_buf1[8]; - int out_buf2[8]; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } -#ifdef __aarch64__ - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - // left - "ld1 {v4.8b}, [%[vmask]] \n" - "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load - a00-a015 - to - q0*/ - "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load - a00-a015 - to - q0*/ - - "bif v0.8b, v21.8b, v4.8b \n" - "bif v1.8b, v21.8b, v4.8b \n" - "bif v2.8b, v21.8b, v4.8b \n" - "bif v3.8b, v21.8b, v4.8b \n" - - "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "ld1 {v10.4s}, [%[vbias]] \n" - "ld1 {v11.4s}, [%[vbias]] \n" - - // r0 - "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 - */ - - "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "ld1 {v12.4s}, [%[vbias]] \n" - "ld1 {v13.4s}, [%[vbias]] \n" - - "smlal v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, - 1); 12345678 */ - - // r1 - "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v14.4s}, [%[rmask]], #16 \n" - // "ld1 {v15.4s}, [%[rmask]] \n" - - "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" - // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" - - "smlal v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 00123456 */ - "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, - // 1); 12345678 - - // "ld1 {v0.4s}, [%[ptr_out0]] \n" - // "ld1 {v1.4s}, [%[ptr_out1]] \n" - - // r2 - "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 - */ - - // "sub %[ptr_out0], %[ptr_out0], #16 \n" - // "sub %[ptr_out1], %[ptr_out1], #16 \n" - - "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 - */ - - "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 - */ - - // r3 - "smlal v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 - */ - - "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ - - "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 - */ - - "smax v10.4s, v10.4s, v21.4s \n" /* relu */ - "smax v11.4s, v11.4s, v21.4s \n" /* relu */ - - // "bif v10.16b, v16.16b, v14.16b \n" - // "bif v11.16b, v0.16b, v15.16b \n" - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 - */ - - "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out */ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v21.4s \n" /* relu */ - "smax v13.4s, v13.4s, v21.4s \n" /* relu */ - - // "bif v12.16b, v17.16b, v14.16b \n" - // "bif v13.16b, v1.16b, v15.16b \n" - - "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out - */ - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [rmask] "+r"(rst_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [vbias] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [vmask] "r"(vmask), - [ptr_out0] "r"(out_buf1), - [ptr_out1] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "pld [%[din_ptr3]] @ preload data\n" - "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - - "vmov.u32 d11, #0 @ zero\n" - // out0 - "vdup.32 q8, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q9, %[bias] @ and \n" // q9 = - // vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" - "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" - // out1 - "vdup.32 q10, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q11, %[bias] @ and \n" // q9 = - // vbias - - // r0 - "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 - - "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" - "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" - - "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 - - "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" - - "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 - - // r1 - "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 - "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - "vmull.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 - - "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" - - "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 - - "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" - // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 - - // r2 - "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 - "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 - - // "sub %[dout_ptr1], #16 @ sub \n" - "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - - "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 - "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 - - // r3 - "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 - "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vaddw.s16 q8, q8, d24 @addw \n" // out0 += - // vget_low_s16(out00) - "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += - // vget_high_s16(out00) - - "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 - - "vmov.u32 q0, #0 @ zero\n" - - // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 - // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 - // 2 3 4 5 6 7 8 9\n" - - "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 - - "vmax.s32 q8, q8, q0 @ max \n" - "vmax.s32 q9, q9, q0 @ max \n" - - "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 - - // "sub %[dout_ptr2], #16 @ sub \n" - // "vbif q8, q14, q1 @ bit select, deal with right - // pad\n" "vbif q9, q6, q2 @ bit select, deal - // with right pad\n" - - "vaddw.s16 q10, q10, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" - - "vmax.s32 q10, q10, q0 @ max \n" - "vmax.s32 q11, q11, q0 @ max \n" - - // "vbif q10, q3, q1 @ bit select, deal with right - // pad\n" "vbif q11, q7, q2 @ bit select, deal - // with right pad\n" - - "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" - // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [dout_ptr1] "r"(out_buf1), - [dout_ptr2] "r"(out_buf2) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - dout_ptr += 2 * w_out; - } - } - } -} - -// 1 line w_in > 16 -void conv_depthwise_3x3s2p1_bias_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 15) >> 4; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); - if (size_pad_right == 17) { - size_pad_right = 0; - cnt_col++; - } - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; - -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } - int cnt = cnt_col; -#ifdef __aarch64__ - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v10.4s, #0x0\n" - // left - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - "add %[din_ptr0], %[din_ptr0], #15 \n" - "add %[din_ptr1], %[din_ptr1], #15 \n" - "add %[din_ptr2], %[din_ptr2], #15 \n" - - // r1 - "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - // r2 - "smlal v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smlal v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smlal v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "cmp %[cnt], #1 \n" - "blt 3f \n" - // mid - "1: \n" - "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ - "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ - "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ - - "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 246810 */ - - // r0 - "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ - - // r2 - "smlal v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smlal v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smlal v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ - - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load - a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load - a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load - a00-a015 - to q0*/ - - "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ - - "subs %[cnt], %[cnt], #1 \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - "bne 1b \n" - // right - "3: \n" - "ld1 {v14.8b}, [%[vmask]], #8 \n" - "ld1 {v15.8b}, [%[vmask]] \n" - - "bif v0.8b, v10.8b, v14.8b \n" - "bif v1.8b, v10.8b, v15.8b \n" - "bif v2.8b, v10.8b, v14.8b \n" - "bif v3.8b, v10.8b, v15.8b \n" - "bif v4.8b, v10.8b, v14.8b \n" - "bif v5.8b, v10.8b, v15.8b \n" - - "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468..*/ - "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); - 2468.. */ - - // r0 - "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ - "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ - "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ - - // r1 - "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ - "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ - "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ - - // r2 - "smlal v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ - "smlal v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ - "smlal v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ - - "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ - "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ - - "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ - - "bif v12.16b, v0.16b, v9.16b \n" - "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> - ptr_out */ - - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [ptr_out0] "+r"(doutr0), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); -#else - unsigned int* rst_mask = rmask; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r2 - "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - "add %[din_ptr0], #15 @add \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vmov.u32 q8, #0 @ max \n" // max - "add %[din_ptr1], #15 @add \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - "add %[din_ptr2], #15 @add \n" - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "cmp %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "blt 1f \n" - - // mid - "2: \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - - "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 - // 4 6 - "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 - // 4 6 - "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 - // 4 6 - - "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 - // 4 6 8 - "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 - "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 - - // r2 - "vmlal.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 - - // "add %[din_ptr0], #16 @add \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - // "add %[din_ptr1], #16 @add \n" - "vmov.u32 q8, #0 @ mov \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - // "add %[din_ptr2], #16 @add \n" - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - - "subs %[cnt], #1 \n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "bne 2b \n" - // right - "1: \n" - "cmp %[size_pad_right], #1 \n" - "blt 3f \n" - "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = vbias - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 - // 1 3 5 - "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 - - // r0 - "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 - "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 - "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 - - // r1 - "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " - "7 8 9\n" - - // r2 - "vmlal.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 - "vmlal.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 - "vmlal.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 - - "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " - "9\n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "sub %[dout_ptr1], #16 @ sub \n" - "vmov.u32 q8, #0 @mov \n" - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - "vbif q11, q6, q1 @ bit select, deal with right pad\n" - "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" - "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - "3: \n" - - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [dout_ptr1] "+r"(doutr0), - [cnt] "+r"(cnt), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - dout_ptr += w_out; - } - } - } -} -// w_in <= 16 -void conv_depthwise_3x3s2p1_bias_s_relu_int7(int* dout, - const signed char* din, - const signed char* weights, - const int* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - // printf("3x3s2 mult height \n"); - //! pad is done implicit - // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - //! for 4x6 convolution window - const unsigned char right_pad_idx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - - // printf("conv3x3_dw start \n"); - signed char* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(signed char)); - int* write_ptr = - reinterpret_cast(ctx->workspace_data()) + w_out; - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - unsigned int size_pad_right = (unsigned int)(w_in); - - uint8x8_t vmask_rp1 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); - uint8x8_t vmask_rp2 = - vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); - unsigned int rst_remain = (unsigned int)w_out; - uint32x4_t vmask_result1 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); - uint32x4_t vmask_result2 = - vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); - - uint8x16_t vmask_rp = - vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); - unsigned char vmask[16]; - vst1q_u8(vmask, vmask_rp); - - unsigned int rmask[8]; - vst1q_u32(rmask, vmask_result1); - vst1q_u32(rmask + 4, vmask_result2); - int8x8_t vzero = vdup_n_s8(0); - int32x4_t vzero_32 = vdupq_n_s32(0); - - for (int n = 0; n < num; ++n) { - const signed char* din_batch = din + n * ch_in * size_in_channel; - int* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - int* dout_ptr = dout_batch + c * size_out_channel; - - const signed char* din_ch_ptr = din_batch + c * size_in_channel; - - int bias_val = flag_bias ? bias[c] : 0; - - const signed char* wei_ptr = weights + c * w_stride; - -#ifdef __aarch64__ - int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); - int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); - int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); - - int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); - - int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); -#endif - - int* doutr0 = nullptr; - - const signed char* dr0 = din_ch_ptr; - const signed char* dr1 = dr0 + w_in; - const signed char* dr2 = dr1 + w_in; - - const signed char* din_ptr0 = nullptr; - const signed char* din_ptr1 = nullptr; - const signed char* din_ptr2 = nullptr; - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - - doutr0 = dout_ptr; - - int out_buf1[8]; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr2 + w_in; - dr2 = dr1 + w_in; - } - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din_ptr1 = zero_ptr; - case 1: - din_ptr2 = zero_ptr; - default: - break; - } - } -#ifdef __aarch64__ - unsigned int* rst_mask = rmask; - unsigned char* val_mask = vmask; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "movi v16.4s, #0x0\n" - // left - "ld1 {v10.8b}, [%[vmask]], #8 \n" - "ld1 {v11.8b}, [%[vmask]] \n" - "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 - to q0*/ - "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 - to q0*/ - "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 - to q0*/ - - "bif v0.8b, v16.8b, v10.8b \n" - "bif v1.8b, v16.8b, v11.8b \n" - "bif v2.8b, v16.8b, v10.8b \n" - "bif v3.8b, v16.8b, v11.8b \n" - "bif v4.8b, v16.8b, v10.8b \n" - "bif v5.8b, v16.8b, v11.8b \n" - - "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ - - "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); - 013579 */ - - // r0 - "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ - "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ - "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ - - // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, - // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* - // dup v10, bias */ - - // r1 - "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ - "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ - "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ - - // r2 - "smlal v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ - "smlal v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ - "smlal v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ - - "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ - - "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ - "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ - - "smax v12.4s, v12.4s, v16.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v16.4s \n" /*relu*/ - - // "bif v12.16b, v0.16b, v10.16b \n" - // "bif v13.16b, v1.16b, v11.16b \n" - - "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out - */ - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [vmask] "+r"(val_mask) - : [v0] "w"(wr00), - [v1] "w"(wr01), - [v2] "w"(wr02), - [v3] "w"(wr10), - [bias_val] "r"(vbias), - [v4] "w"(wr11), - [v5] "w"(wr12), - [v6] "w"(wr20), - [v7] "w"(wr21), - [v8] "w"(wr22), - [rst_mask] "r"(rmask), - [ptr_out0] "r"(out_buf1) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - unsigned int* rst_mask = rmask; - // prefetch input - // store weights - asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" - : - : [wei_ptr] "r"(wei_ptr) - : "memory"); - asm volatile( - // left - "pld [%[din_ptr0]] @ preload data\n" - "pld [%[din_ptr1]] @ preload data\n" - "pld [%[din_ptr2]] @ preload data\n" - "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" - "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 - "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 - "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 - "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " - "8 9\n" - "vmov.u32 d11, #0 @ zero\n" - - "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" - - "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" - - "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" - "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" - - "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 - "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 - "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 - - // "pld [%[dout_ptr1]] @ preload data\n" - - // r0 - "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 - "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 - "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 - - "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" - "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" - "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" - - // r1 - "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 - - // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 - // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 - // 1 2 3 4 5 6 7 8 9\n" - - // out0 - "vdup.32 q11, %[bias] @ and \n" // q8 = - // vbias - "vdup.32 q12, %[bias] @ and \n" // q9 = - // vbias - - // r2 - "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 - "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 - "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 - - // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 - // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 - // 5 6 7 8 9\n" - - // "sub %[dout_ptr1], #16 @ sub \n" - - "vaddw.s16 q11, q11, d26 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += - // vget_high_s16(out10) - "vmov.u32 q8, #0 @ mov \n" - - "vaddw.s16 q11, q11, d28 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vaddw.s16 q11, q11, d30 @addw \n" // out1 += - // vget_low_s16(out10) - "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += - // vget_high_s16(out10) - - "vmax.s32 q11, q11, q8 @ max\n" - "vmax.s32 q12, q12, q8 @ max\n" - - // "vbif q11, q6, q1 @ bit select, deal with right pad\n" - // "vbif q12, q7, q2 @ bit select, deal with right pad\n" - - "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" - // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" - : [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [bias] "+r"(bias_val), - [rs_mask] "+r"(rst_mask) - : [mask] "r"(vmask), - [size_pad_right] "r"(size_pad_right), - [dout_ptr1] "r"(out_buf1) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - } - dout_ptr += w_out; - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/arm/math/prior_box.cc b/lite/arm/math/prior_box.cc index e6f455e72a231f36e10b2cde54140ca68fcd4a43..6ec312796578863dc9c7a046950aa4dcf79d38fa 100644 --- a/lite/arm/math/prior_box.cc +++ b/lite/arm/math/prior_box.cc @@ -51,7 +51,7 @@ void density_prior_box(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, @@ -82,14 +82,12 @@ void density_prior_box(const lite::Tensor* input, img_width = image->dims()[3]; img_height = image->dims()[2]; } - float step_w = step_w_; float step_h = step_h_; if (step_w == 0 || step_h == 0) { step_w = static_cast(img_width) / width; step_h = static_cast(img_height) / height; } - float offset = offset_; int step_average = static_cast((step_w + step_h) * 0.5); // add int channel_size = height * width * prior_num_ * 4; @@ -343,7 +341,7 @@ void prior_box(const lite::Tensor* input, min_size, std::vector(), std::vector(), - std::vector(), + std::vector(), max_size, aspect_ratio, variance, diff --git a/lite/arm/math/prior_box.h b/lite/arm/math/prior_box.h index 59efb2ab0027d3d5cab68118ea48fa70436d1c48..ffa821b75e54ee3e2329e4dcced8ddee2a003802 100644 --- a/lite/arm/math/prior_box.h +++ b/lite/arm/math/prior_box.h @@ -30,7 +30,7 @@ void density_prior_box(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index cc80637dd4460c3da898ef15c41ca46a13c06bfb..35b235221c94c2ee1d732505aa3344f706e39dba 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -37,9 +37,36 @@ lite_cc_library(context SRCS context.cc DEPS tensor any cpu_info CL_DEPS cl_cont else() lite_cc_library(context SRCS context.cc DEPS tensor any cpu_info eigen3 CL_DEPS cl_context gflags) endif() -lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor) + +#----------------------------------------------- NOT CHANGE ----------------------------------------------- +# A trick to generate the paddle_use_kernels.h +add_custom_command( + COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/parse_kernel_registry.py + ${kernels_src_list} + ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h + OUTPUT ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h + ) +# A trick to generate the paddle_use_ops.h +add_custom_command( + COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/parse_op_registry.py + ${ops_src_list} + ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h + OUTPUT ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h + ) +add_custom_target(op_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h) +add_custom_target(kernel_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h) + +#----------------------------------------------- NOT CHANGE ----------------------------------------------- +lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor + ) lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel - cpp_op_desc tensor) + cpp_op_desc tensor + ) + +add_dependencies(kernel kernel_list_h) +add_dependencies(op op_list_h) + + lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper) lite_cc_library(program SRCS program.cc @@ -73,3 +100,17 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils) lite_cc_test(test_types SRCS types_test.cc DEPS types) lite_cc_test(test_memory SRCS memory_test.cc DEPS memory) lite_cc_test(test_context SRCS context_test.cc DEPS context) + + +# # A trick to generate the paddle_use_kernels.h +# execute_process( +# COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/parse_kernel_registry.py +# ${kernels_src_list} +# ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h +# ) +# # A trick to generate the paddle_use_ops.h +# execute_process( +# COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/parse_op_registry.py +# ${ops_src_list} +# ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h +# ) diff --git a/lite/core/context.h b/lite/core/context.h index f36744dc00f8f88804987370aad05edd8eec0fa2..c8e84fb19e7d1e2f5544cc5b19c03900d40dd3d8 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -101,7 +101,7 @@ class Context { void CopySharedTo(ARMContext* ctx) {} - void SetRunMode(PowerMode mode, int threads) { + void SetRunMode(lite_api::PowerMode mode, int threads) { return DeviceInfo::Global().SetRunMode(mode, threads); } void SetCache(int l1size, int l2size, int l3size) { @@ -109,7 +109,7 @@ class Context { } void SetArch(ARMArch arch) { return DeviceInfo::Global().SetArch(arch); } - PowerMode mode() const { return DeviceInfo::Global().mode(); } + lite_api::PowerMode mode() const { return DeviceInfo::Global().mode(); } int threads() const { return DeviceInfo::Global().threads(); } ARMArch arch() const { return DeviceInfo::Global().arch(); } int l1_cache_size() const { return DeviceInfo::Global().l1_cache_size(); } diff --git a/lite/core/cpu_info.cc b/lite/core/cpu_info.cc index 4b352eee93289847bab42635b3731ab0db548021..e882ef59bdbd0a3ab55d01626ad9480cbf74a8a7 100644 --- a/lite/core/cpu_info.cc +++ b/lite/core/cpu_info.cc @@ -119,7 +119,8 @@ size_t get_mem_size() { return memsize; #elif defined(TARGET_IOS) // to be implemented - printf("not implemented\n"); + printf("not implemented, set to default 4GB\n"); + return 4096 * 1024; #endif return 0; } @@ -209,7 +210,7 @@ void get_cpu_arch(std::vector* archs, const int cpu_num) { } #elif defined(TARGET_IOS) for (int i = 0; i < cpu_num; ++i) { - archs->at(i) = APPLE; + archs->at(i) = kAPPLE; } #endif } @@ -818,7 +819,7 @@ void DeviceInfo::RequestPowerFullMode(int thread_num) { active_ids_.push_back(little_core_ids_[i - big_core_size]); } } - mode_ = LITE_POWER_FULL; + mode_ = lite_api::PowerMode::LITE_POWER_FULL; } void DeviceInfo::RequestPowerHighMode(int thread_num) { @@ -826,7 +827,7 @@ void DeviceInfo::RequestPowerHighMode(int thread_num) { int little_core_size = little_core_ids_.size(); active_ids_.clear(); if (big_core_size > 0) { - mode_ = LITE_POWER_HIGH; + mode_ =lite_api::PowerMode::LITE_POWER_HIGH; if (thread_num > big_core_size) { LOG(ERROR) << "Request thread num: " << thread_num << ", exceed the big cores size: " << big_core_size @@ -838,7 +839,7 @@ void DeviceInfo::RequestPowerHighMode(int thread_num) { } } } else { - mode_ = LITE_POWER_LOW; + mode_ = lite_api::PowerMode::LITE_POWER_LOW; LOG(ERROR) << "HIGH POWER MODE is not support, switch to little cores."; if (thread_num > little_core_size) { active_ids_ = little_core_ids_; @@ -855,7 +856,7 @@ void DeviceInfo::RequestPowerLowMode(int thread_num) { int little_core_size = little_core_ids_.size(); active_ids_.clear(); if (little_core_size > 0) { - mode_ = LITE_POWER_LOW; + mode_ = lite_api::PowerMode::LITE_POWER_LOW; if (thread_num > little_core_size) { LOG(WARNING) << "Request thread num: " << thread_num << ", exceed the little cores size: " << little_core_size @@ -867,7 +868,7 @@ void DeviceInfo::RequestPowerLowMode(int thread_num) { } } } else { - mode_ = LITE_POWER_HIGH; + mode_ = lite_api::PowerMode::LITE_POWER_HIGH; LOG(WARNING) << "LOW POWER MODE is not support, switch to big cores"; if (thread_num > big_core_size) { active_ids_ = big_core_ids_; @@ -893,7 +894,7 @@ void DeviceInfo::RequestPowerNoBindMode(int thread_num) { } } } - mode_ = LITE_POWER_NO_BIND; + mode_ = lite_api::PowerMode::LITE_POWER_NO_BIND; } void DeviceInfo::RequestPowerRandHighMode(int shift_num, int thread_num) { @@ -901,7 +902,7 @@ void DeviceInfo::RequestPowerRandHighMode(int shift_num, int thread_num) { int little_core_size = little_core_ids_.size(); active_ids_.clear(); if (big_core_size > 0) { - mode_ = LITE_POWER_RAND_HIGH; + mode_ = lite_api::PowerMode::LITE_POWER_RAND_HIGH; if (thread_num > big_core_size) { LOG(WARNING) << "Request thread num: " << thread_num << ", exceed the big cores size: " << big_core_size @@ -913,7 +914,7 @@ void DeviceInfo::RequestPowerRandHighMode(int shift_num, int thread_num) { } } } else { - mode_ = LITE_POWER_LOW; + mode_ = lite_api::PowerMode::LITE_POWER_LOW; LOG(WARNING) << "HIGH POWER MODE is not support, switch to little cores."; if (thread_num > little_core_size) { active_ids_ = little_core_ids_; @@ -930,7 +931,7 @@ void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) { int little_core_size = little_core_ids_.size(); active_ids_.clear(); if (little_core_size > 0) { - mode_ = LITE_POWER_RAND_LOW; + mode_ = lite_api::PowerMode::LITE_POWER_RAND_LOW; if (thread_num > little_core_size) { LOG(WARNING) << "Request thread num: " << thread_num << ", exceed the little cores size: " << little_core_size @@ -943,7 +944,7 @@ void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) { } } } else { - mode_ = LITE_POWER_HIGH; + mode_ = lite_api::PowerMode::LITE_POWER_HIGH; LOG(WARNING) << "LOW POWER MODE is not support, switch to big cores."; if (thread_num > big_core_size) { active_ids_ = big_core_ids_; @@ -957,6 +958,7 @@ void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) { int DeviceInfo::Setup() { core_num_ = get_cpu_num(); + printf("core number: %d\n", core_num_); mem_size_ = get_mem_size(); get_cpu_arch(&archs_, core_num_); // set defalut CPU info @@ -966,10 +968,10 @@ int DeviceInfo::Setup() { SetFP32Info(1, 1); SetFP16Info(1, 0); SetDotInfo(1, 0); -#ifdef LITE_WITH_LINUX - // get max&min freq max_freqs_.resize(core_num_); min_freqs_.resize(core_num_); +#ifdef LITE_WITH_LINUX + // get max&min freq for (int i = 0; i < core_num_; ++i) { int max_freq, min_freq; get_cpu_max_min_freq(i, &max_freq, &min_freq); @@ -981,6 +983,30 @@ int DeviceInfo::Setup() { if (!SetCPUInfoByName()) { SetCPUInfoByProb(); } + core_ids_.resize(core_num_); + cluster_ids_.resize(core_num_); + for (int i = 0; i < core_num_; ++i) { + max_freqs_[i] = 1000000; + min_freqs_[i] = 1000000; + cluster_ids_[i] = 0; + } +#else +#ifdef TARGET_IOS + dev_name_ = "Apple"; +#else + dev_name_ = "Unknown"; +#endif + core_ids_.resize(core_num_); + cluster_ids_.resize(core_num_); + big_core_ids_.resize(core_num_); + for (int i = 0; i < core_num_; ++i) { + max_freqs_[i] = 1000000; + min_freqs_[i] = 1000000; + cluster_ids_[i] = 0; + core_ids_[i] = i; + big_core_ids_[i] = i; + } +#endif // output info LOG(INFO) << "ARM multiprocessors name: " << dev_name_; LOG(INFO) << "ARM multiprocessors number: " << core_num_; @@ -1004,13 +1030,12 @@ int DeviceInfo::Setup() { LOG(INFO) << L3_cache_[i] / 1024 << " KB"; } LOG(INFO) << "Total memory: " << mem_size_ << "KB"; -#endif // set default run mode - SetRunMode(LITE_POWER_NO_BIND, 1); // use single thread by default + SetRunMode(lite_api::PowerMode::LITE_POWER_NO_BIND, 1); // use single thread by default return 0; } -void DeviceInfo::SetRunMode(PowerMode mode, int thread_num) { +void DeviceInfo::SetRunMode(lite_api::PowerMode mode, int thread_num) { #ifdef ARM_WITH_OMP thread_num = std::min(thread_num, core_num_); #else @@ -1024,22 +1049,22 @@ void DeviceInfo::SetRunMode(PowerMode mode, int thread_num) { count_++; int shift_num = (count_ / 10) % big_core_size; switch (mode) { - case LITE_POWER_FULL: + case lite_api::LITE_POWER_FULL: RequestPowerFullMode(thread_num); break; - case LITE_POWER_HIGH: + case lite_api::LITE_POWER_HIGH: RequestPowerHighMode(thread_num); break; - case LITE_POWER_LOW: + case lite_api::LITE_POWER_LOW: RequestPowerLowMode(thread_num); break; - case LITE_POWER_NO_BIND: + case lite_api::LITE_POWER_NO_BIND: RequestPowerNoBindMode(thread_num); break; - case LITE_POWER_RAND_HIGH: + case lite_api::LITE_POWER_RAND_HIGH: RequestPowerRandHighMode(shift_num, thread_num); break; - case LITE_POWER_RAND_LOW: + case lite_api::LITE_POWER_RAND_LOW: RequestPowerRandLowMode(shift_num, thread_num); break; default: @@ -1052,12 +1077,12 @@ void DeviceInfo::SetRunMode(PowerMode mode, int thread_num) { #ifdef ARM_WITH_OMP omp_set_num_threads(active_ids_.size()); #endif - if (mode_ != LITE_POWER_NO_BIND) { + if (mode_ != lite_api::LITE_POWER_NO_BIND) { if (check_cpu_online(active_ids_)) { bind_threads(active_ids_); } else { LOG(WARNING) << "Some cores are offline, switch to NO BIND MODE"; - mode_ = LITE_POWER_NO_BIND; + mode_ = lite_api::LITE_POWER_NO_BIND; } } #else // LITE_WITH_LINUX @@ -1080,7 +1105,7 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) { workspace_.Resize({2 * (l1size + l2size)}); } -bool DeviceInfo::ExtendWorkspace(size_t size) { +bool DeviceInfo::ExtendWorkspace(int size) { workspace_.Resize({size + llc_size()}); workspace_.mutable_data(); return true; diff --git a/lite/core/cpu_info.h b/lite/core/cpu_info.h index 495f95943e9112812fce952e5597196408c3e6a2..b05b8c07a68473d103384d599e657e6795f5402f 100644 --- a/lite/core/cpu_info.h +++ b/lite/core/cpu_info.h @@ -25,15 +25,6 @@ namespace lite { #ifdef LITE_WITH_ARM -typedef enum { - LITE_POWER_HIGH = 0, - LITE_POWER_LOW = 1, - LITE_POWER_FULL = 2, - LITE_POWER_NO_BIND = 3, - LITE_POWER_RAND_HIGH = 4, - LITE_POWER_RAND_LOW = 5 -} PowerMode; - typedef enum { kAPPLE = 0, kA53 = 53, @@ -60,11 +51,11 @@ class DeviceInfo { int Setup(); - void SetRunMode(PowerMode mode, int thread_num); + void SetRunMode(lite_api::PowerMode mode, int thread_num); void SetCache(int l1size, int l2size, int l3size); void SetArch(ARMArch arch) { arch_ = arch; } - PowerMode mode() const { return mode_; } + lite_api::PowerMode mode() const { return mode_; } int threads() const { return active_ids_.size(); } ARMArch arch() const { return arch_; } int l1_cache_size() const { return L1_cache_[active_ids_[0]]; } @@ -82,7 +73,7 @@ class DeviceInfo { T* workspace_data() { return reinterpret_cast(workspace_.mutable_data()); } - bool ExtendWorkspace(size_t size); + bool ExtendWorkspace(int size); private: int core_num_; @@ -107,7 +98,7 @@ class DeviceInfo { // LITE_POWER_HIGH stands for using big cores, // LITE_POWER_LOW stands for using small core, // LITE_POWER_FULL stands for using all cores - PowerMode mode_; + lite_api::PowerMode mode_; std::vector active_ids_; TensorLite workspace_; int64_t count_{0}; diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index d4370837c0be611038548a6e33be8f51653fdcec..6e54bd07859c8137f73823c3f5696c9c32928671 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -37,7 +37,7 @@ namespace lite { namespace mir { namespace subgraph { -void GenerateNPUProgramPass::NPUSortHelper( +void GenerateNPUProgramPass::SubgraphSortHelper( Node* node, const std::unordered_set& nodes_all, std::unordered_set* visited_nodes, @@ -46,7 +46,7 @@ void GenerateNPUProgramPass::NPUSortHelper( if (var_node->inlinks.empty()) continue; auto* op_node = var_node->inlinks.front(); if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) { - NPUSortHelper(op_node, nodes_all, visited_nodes, ret); + SubgraphSortHelper(op_node, nodes_all, visited_nodes, ret); } } ret->push_back(node); @@ -55,40 +55,68 @@ void GenerateNPUProgramPass::NPUSortHelper( void GenerateNPUProgramPass::CvtOpNodes( const std::vector& nodes2cvt, - std::vector* in_vars_name, - std::vector* out_vars_name, - lite::npu::bridge::node_map_type* cvted_vars, - std::unordered_set* nodes2rm) { + lite::npu::bridge::node_map_type* cvted_vars) { const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& cvtfunc_map = bridges.AllFunctions(); + // record all converted vars + // op node's inputs must be found in cvted_vars for (auto& node : nodes2cvt) { lite::npu::bridge::node_map_type node_inputs; auto& stmt = node->AsStmt(); for (auto& var_node : node->inlinks) { auto& arg = var_node->AsArg(); + if (arg.is_weight) continue; auto var_name = arg.name; if (!cvted_vars->count(var_name)) { - if (arg.is_weight) continue; cvted_vars->insert(std::make_pair( var_name, lite::npu::bridge::CvtNode(var_node, stmt.op()->scope()))); - in_vars_name->push_back(var_name); } node_inputs.insert(*cvted_vars->find(var_name)); } auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs); cvted_vars->insert(node_outputs.begin(), node_outputs.end()); - nodes2rm->insert(node); - for (auto& var_node : node->outlinks) { - for (auto& next_op_node : var_node->outlinks) { - if (std::find(nodes2cvt.begin(), nodes2cvt.end(), next_op_node) == - nodes2cvt.end()) { - out_vars_name->push_back(var_node->AsArg().name); - break; - } + } +} + +void GenerateNPUProgramPass::GetIOVars( + const std::vector& nodes2cvt, + const lite::npu::bridge::node_map_type& cvted_vars, + std::unordered_set* nodes2rm, + std::vector* in_vars, + std::vector* out_vars, + lite::npu::bridge::node_map_type* in_cvted_vars, + lite::npu::bridge::node_map_type* out_cvted_vars) { + std::unordered_set op_nodes_all(nodes2cvt.begin(), nodes2cvt.end()); + for (auto& op_node : nodes2cvt) { + for (auto& in_var : op_node->inlinks) { + if (in_var->AsArg().is_weight) continue; + auto* pre_op_node = in_var->inlinks.front(); + if (op_nodes_all.count(pre_op_node)) { + nodes2rm->insert(in_var); + continue; + } + in_vars->push_back(in_var); + auto arg_name = in_var->AsArg().name; + in_cvted_vars->insert(std::make_pair(arg_name, cvted_vars.at(arg_name))); + } + for (auto& out_var : op_node->outlinks) { + if (out_var->outlinks.empty()) { + nodes2rm->insert(out_var); + continue; + } + auto* next_op_node = out_var->outlinks.front(); + + if (op_nodes_all.count(next_op_node)) { + nodes2rm->insert(out_var); + continue; } + out_vars->push_back(out_var); + auto arg_name = out_var->AsArg().name; + out_cvted_vars->insert(std::make_pair(arg_name, cvted_vars.at(arg_name))); } } + nodes2rm->insert(nodes2cvt.begin(), nodes2cvt.end()); } void GenerateNPUProgramPass::GenNPUGraphOpNode( @@ -100,23 +128,38 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( for (auto& node : nodes_all) { if (!node->IsStmt()) continue; if (visited_nodes.count(node)) continue; - NPUSortHelper(node, nodes_all, &visited_nodes, &ret); + SubgraphSortHelper(node, nodes_all, &visited_nodes, &ret); } - std::vector in_vars_name; - std::vector out_vars_name; lite::npu::bridge::node_map_type cvted_vars; + CvtOpNodes(ret, &cvted_vars); + std::unordered_set nodes2rm; - CvtOpNodes(ret, &in_vars_name, &out_vars_name, &cvted_vars, &nodes2rm); - // insert new graph op node + std::vector in_vars; + std::vector out_vars; + lite::npu::bridge::node_map_type in_cvted_vars; + lite::npu::bridge::node_map_type out_cvted_vars; + GetIOVars(ret, + cvted_vars, + &nodes2rm, + &in_vars, + &out_vars, + &in_cvted_vars, + &out_cvted_vars); + + std::vector in_vars_name; + std::vector out_vars_name; std::vector inputs; std::vector outputs; - for (auto i : in_vars_name) { - inputs.push_back(*cvted_vars.at(i)); + for (auto i : in_cvted_vars) { + in_vars_name.push_back(i.first); + inputs.push_back(*i.second); } - for (auto i : out_vars_name) { - outputs.push_back(*cvted_vars.at(i)); + for (auto i : out_cvted_vars) { + out_vars_name.push_back(i.first); + outputs.push_back(*i.second); } + std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om"); if (!npu::BuildNPUClient(inputs, outputs, model_name)) { LOG(FATAL) << "Build NPU failed subgraph " << sub_id; @@ -125,27 +168,25 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( cpp::OpDesc op_desc; op_desc.SetType("graph_op"); + std::vector in_var_names; + op_desc.SetInput("Inputs", in_vars_name); op_desc.SetOutput("Outputs", out_vars_name); op_desc.SetAttr("model_name", model_name); auto graph_op = LiteOpRegistry::Global().Create("graph_op"); - // TODO(zpy): support multi inputs op - auto start_op = ret.front()->AsStmt().op(); - auto* scope = start_op->scope(); + + auto any_op = ret.front()->AsStmt().op(); + auto* scope = any_op->scope(); graph_op->Attach(op_desc, scope); - auto valid_places = start_op->valid_places(); + auto valid_places = any_op->valid_places(); auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); - for (auto& var_node : ret.front()->inlinks) { - auto& arg = var_node->AsArg(); - if (arg.is_weight) continue; - IR_NODE_LINK_TO(var_node, new_op_node); + for (auto& in_var : in_vars) { + IR_NODE_LINK_TO(in_var, new_op_node); } - for (auto& var_node : ret.back()->outlinks) { - auto& arg = var_node->AsArg(); - if (arg.is_weight) continue; - IR_NODE_LINK_TO(var_node, new_op_node); + for (auto& out_var : out_vars) { + IR_OP_VAR_LINK(new_op_node, out_var); } // assign context @@ -159,8 +200,10 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( void GenerateNPUProgramPass::ConvertSubgraph( const std::unique_ptr& graph, int sub_num) { std::unordered_map> nodes_all; + int ops_num = 0; for (auto& item : graph->StmtTopologicalOrder()) { if (!item->IsStmt()) continue; + ops_num++; auto& stmt = item->AsStmt(); int sub_id = stmt.subgraph_id(); if (sub_id < 1) continue; @@ -178,6 +221,7 @@ void GenerateNPUProgramPass::ConvertSubgraph( void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); + const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& op_map = bridges.AllFunctions(); std::vector supported_op_types; @@ -215,5 +259,3 @@ std::unique_ptr GenerateNPUProgramPass::GenProgram() { REGISTER_MIR_PASS(generate_npu_program_pass, paddle::lite::mir::subgraph::GenerateNPUProgramPass); - -// USE_LITE_OP(graph_op); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h index 0ce60fb22b7d8bde2605ce5d0b2f166920ece381..151138476e76d774301a50cfc0142adac53c2558 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.h +++ b/lite/core/mir/subgraph/generate_npu_program_pass.h @@ -38,21 +38,27 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { std::unique_ptr GenProgram(); protected: - void NPUSortHelper(Node* node, - const std::unordered_set& nodes_all, - std::unordered_set* visited_nodes, - std::vector* ret); + // sort nodes to operational sequence + void SubgraphSortHelper(Node* node, + const std::unordered_set& nodes_all, + std::unordered_set* visited_nodes, + std::vector* ret); // nodes2cvt: op nodes to convert - // in_vars_name: graph op's inputs var name - // out_vars_name: graph op's outputs var name - // vcted_vars: + // cvted_vars: converted var nodes // nodes2rm: op nodes and var nodes that need to be removed void CvtOpNodes(const std::vector& nodes2cvt, - std::vector* in_vars_name, - std::vector* out_vars_name, - lite::npu::bridge::node_map_type* cvted_vars, - std::unordered_set* nodes2rm); + lite::npu::bridge::node_map_type* cvted_vars); + + // achieve input and output vars/cvted_vars; + // achieve all nodes to remove + void GetIOVars(const std::vector& nodes2cvt, + const lite::npu::bridge::node_map_type& cvted_vars, + std::unordered_set* nodes2rm, + std::vector* in_vars, + std::vector* out_vars, + lite::npu::bridge::node_map_type* in_cvted_vars, + lite::npu::bridge::node_map_type* out_cvted_vars); void GenNPUGraphOpNode(const std::unique_ptr& graph, int sub_id, diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index 65cc1600773297f935149c040a264400e13f91cc..d9111e5c46c9217b181e5a3e5a8c7981f46250df 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -26,17 +26,49 @@ namespace paddle { namespace lite { namespace profile { +template +static void write_tensorfile(const Tensor* tensor, const std::string& locate) { + if (locate.find('/') != std::string::npos) { + return; + } + FILE* fp = fopen(locate.c_str(), "w"); + if (fp == nullptr) { + LOG(ERROR) << "file open field " << locate; + } else { + const dtype* data = tensor->data(); + for (int i = 0; i < tensor->numel(); ++i) { + fprintf(fp, "[%d] %f \n", i, static_cast(data[i])); + } + } + fclose(fp); +} + class PrecisionProfiler { public: explicit PrecisionProfiler(const Instruction* inst) : inst_(inst) {} ~PrecisionProfiler() { LOG(INFO) << ">> Running kernel: " << inst_->op()->op_info()->Repr() - << " on Target " << TargetToStr(inst_->kernel()->target()); - auto tensor_mean = [](const Tensor* in, PrecisionType ptype) -> double { + << " on Target " << TargetToStr(inst_->kernel()->target()) << " " + << PrecisionToStr(inst_->kernel()->precision()); + auto tensor_mean = [](const Tensor* in, + PrecisionType ptype, + std::string name = "inst") -> double { + if (!in->data()) { + return -99999; + } double sum = 0.; switch (ptype) { case PRECISION(kFloat): { auto ptr = in->data(); + // write_tensorfile(in, name); + for (int i = 0; i < in->numel(); ++i) { + sum += ptr[i]; + } + return sum / in->numel(); + } + case PRECISION(kAny): { + auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -44,6 +76,7 @@ class PrecisionProfiler { } case PRECISION(kInt8): { auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -51,6 +84,7 @@ class PrecisionProfiler { } case PRECISION(kInt32): { auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -70,17 +104,18 @@ class PrecisionProfiler { std::string out_arg_name; op->op_info()->GetOutputArgname(out_name, &out_arg_name); auto type = kernel->GetOutputDeclType(out_arg_name); + if (type->IsTensor()) { auto tout = op_scope->FindVar(out_name)->GetMutable(); - double mean = tensor_mean(tout, type->precision()); + double mean = tensor_mean(tout, type->precision(), out_name); LOG(INFO) << "output name: " << out_name << ", dims: " << tout->dims() << ", precision: " << PrecisionToStr(type->precision()) - << ", mean value: " << mean; + << ", mean value: " << mean << " shape:" << tout->dims(); } else if (type->IsTensorList()) { auto tout = op_scope->FindVar(out_name)->GetMutable>(); for (auto& t : *tout) { - double mean = tensor_mean(&t, type->precision()); + double mean = tensor_mean(&t, type->precision(), out_name); LOG(INFO) << "output name: " << out_name << ", dims: " << t.dims() << ", precision: " << PrecisionToStr(type->precision()) << ", mean value: " << mean; diff --git a/lite/demo/java/android/PaddlePredictor/gradlew.bat b/lite/demo/java/android/PaddlePredictor/gradlew.bat index e95643d6a2ca62258464e83c72f5156dc941c609..f9553162f122c71b34635112e717c3e733b5b212 100644 --- a/lite/demo/java/android/PaddlePredictor/gradlew.bat +++ b/lite/demo/java/android/PaddlePredictor/gradlew.bat @@ -1,84 +1,84 @@ -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto init - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/lite/kernels/CMakeLists.txt b/lite/kernels/CMakeLists.txt index d5a3f6d9f02428ce59a300040cfe0c1d44556cbe..1996f50133acc6f3bdf651e8c0daae5b68c96832 100644 --- a/lite/kernels/CMakeLists.txt +++ b/lite/kernels/CMakeLists.txt @@ -1,6 +1,6 @@ message(STATUS "add lite kernels") -set(lite_kernel_deps type_system kernel op op_registry context tensor CACHE INTERNAL "" FORCE) +set(lite_kernel_deps type_system kernel op op_registry context tensor any CACHE INTERNAL "" FORCE) add_subdirectory(host) add_subdirectory(arm) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 99098102dc0cc57c0845e5f46f5214150c7a9ecc..524a235ef4489977b6673864aaf8d3e9ed4a6e93 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -4,64 +4,66 @@ endif() message(STATUS "compile with lite ARM kernels") -lite_cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(activation_compute_arm SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(matmul_compute_arm SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(elementwise_compute_arm SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(lrn_compute_arm SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(decode_bboxes_compute_arm SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(multiclass_nms_compute_arm SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(concat_compute_arm SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(pad2d_compute_arm SRCS pad2d_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(prior_box_compute_arm SRCS prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(density_prior_box_compute_arm SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(negative_compute_arm SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(crop_compute_arm SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(dropout_compute_arm SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(calib_compute_arm SRCS calib_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(transpose_compute_arm SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(power_compute_arm SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(yolo_box_compute_arm SRCS yolo_box_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(shuffle_channel_compute_arm SRCS shuffle_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(argmax_compute_arm SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(axpy_compute_arm SRCS axpy_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(conv_transpose_compute_arm SRCS conv_transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(gru_unit_compute_arm SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(gru_compute_arm SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(beam_search_decode_compute_arm SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(lookup_table_compute_arm SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(im2sequence_compute_arm SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(sequence_softmax_compute_arm SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(norm_compute_arm SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(interpolate_compute_arm SRCS interpolate_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(logical_compute_arm SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(less_than_arm SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(while_compute_arm SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(compare_compute_arm SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(topk_compute_arm SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(increment_compute_arm SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(write_to_array_compute_arm SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(read_from_array_compute_arm SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(beam_search_compute_arm SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(fill_constant_compute_arm SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(lod_reset_compute_arm SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(box_coder_compute_arm SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(sequence_pool_compute_arm SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(sequence_expand_compute_arm SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(reduce_max_compute_arm SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(is_empty_compute_arm SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(shape_compute_arm SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(slice_compute_arm SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(cast_compute_arm SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(squeeze_compute_arm SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) -lite_cc_library(expand_compute_arm SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(fc_compute_arm ARM basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(activation_compute_arm ARM basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(mul_compute_arm ARM basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(matmul_compute_arm ARM basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(scale_compute_arm ARM basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(softmax_compute_arm ARM basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conv_compute_arm ARM basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(decode_bboxes_compute_arm ARM basic SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(multiclass_nms_compute_arm ARM basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(pool_compute_arm ARM basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(concat_compute_arm ARM basic SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(pad2d_compute_arm ARM basic SRCS pad2d_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(prior_box_compute_arm ARM basic SRCS prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(density_prior_box_compute_arm ARM basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(negative_compute_arm ARM basic SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(crop_compute_arm ARM basic SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(calib_compute_arm ARM basic SRCS calib_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(transpose_compute_arm ARM basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(power_compute_arm ARM basic SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(yolo_box_compute_arm ARM basic SRCS yolo_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(shuffle_channel_compute_arm ARM basic SRCS shuffle_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(axpy_compute_arm ARM basic SRCS axpy_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conv_transpose_compute_arm ARM basic SRCS conv_transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(norm_compute_arm ARM basic SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(shape_compute_arm ARM basic SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(expand_compute_arm ARM basic SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) + +# for OCR specific +add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sequence_pool_compute_arm ARM extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(logical_compute_arm ARM extra SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(less_than_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(compare_compute_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(fill_constant_compute_arm ARM extra SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) @@ -77,71 +79,7 @@ lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm) lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) -lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm) +lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra) lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm) - - -set(arm_kernels - fc_compute_arm - activation_compute_arm - mul_compute_arm - matmul_compute_arm - scale_compute_arm - softmax_compute_arm - conv_compute_arm - batch_norm_compute_arm - elementwise_compute_arm - lrn_compute_arm - decode_bboxes_compute_arm - multiclass_nms_compute_arm - pool_compute_arm - split_compute_arm - concat_compute_arm - pad2d_compute_arm - prior_box_compute_arm - density_prior_box_compute_arm - negative_compute_arm - crop_compute_arm - dropout_compute_arm - transpose_compute_arm - calib_compute_arm - argmax_compute_arm - axpy_compute_arm - conv_transpose_compute_arm - gru_unit_compute_arm - gru_compute_arm - beam_search_decode_compute_arm - lookup_table_compute_arm - im2sequence_compute_arm - sequence_softmax_compute_arm - norm_compute_arm - power_compute_arm - shuffle_channel_compute_arm - yolo_box_compute_arm - interpolate_compute_arm - logical_compute_arm - less_than_arm - while_compute_arm - compare_compute_arm - topk_compute_arm - increment_compute_arm - write_to_array_compute_arm - read_from_array_compute_arm - beam_search_compute_arm - fill_constant_compute_arm - lod_reset_compute_arm - box_coder_compute_arm - reduce_max_compute_arm - sequence_expand_compute_arm - sequence_pool_compute_arm - is_empty_compute_arm - shape_compute_arm - slice_compute_arm - cast_compute_arm - squeeze_compute_arm - expand_compute_arm - ) - -set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/lite/kernels/arm/density_prior_box_compute.cc b/lite/kernels/arm/density_prior_box_compute.cc index 47d14d6572281e212322be38cab67cdb5c1581b5..35616bc6e8ac1e8c142616cf633578a057bb967f 100644 --- a/lite/kernels/arm/density_prior_box_compute.cc +++ b/lite/kernels/arm/density_prior_box_compute.cc @@ -48,13 +48,12 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, void DensityPriorBoxCompute::Run() { auto& param = Param(); - bool is_flip = param.flip; bool is_clip = param.clip; std::vector min_size = param.min_sizes; std::vector fixed_size = param.fixed_sizes; std::vector fixed_ratio = param.fixed_ratios; - std::vector density_size = param.density_sizes; + auto density_size = param.density_sizes; std::vector max_size = param.max_sizes; std::vector aspect_ratio = param.aspect_ratios; std::vector variance = param.variances_; diff --git a/lite/kernels/fpga/CMakeLists.txt b/lite/kernels/fpga/CMakeLists.txt index 36d6ccc25548d89253d2022c7fcb910c3498a426..dc8860188043dde6538a303eef82617e46c2a6c9 100644 --- a/lite/kernels/fpga/CMakeLists.txt +++ b/lite/kernels/fpga/CMakeLists.txt @@ -1,50 +1,32 @@ if (NOT LITE_WITH_FPGA) return() endif() -message("fpga : ${lite_kernel_deps}") set(fpga_deps fpga_target_wrapper kernel_fpga) -lite_cc_library(activation_compute_fpga SRCS activation_compute.cc DEPS ${fpga_deps}) +add_kernel(activation_compute_fpga FPGA basic SRCS activation_compute.cc DEPS ${fpga_deps}) lite_cc_test(test_acivation_fpga SRCS activation_compute_test.cc DEPS ${lite_kernel_deps} activation_compute_fpga ${fpga_deps}) -lite_cc_library(conv_compute_fpga SRCS conv_compute.cc DEPS ${fpga_deps}) +add_kernel(conv_compute_fpga FPGA basic SRCS conv_compute.cc DEPS ${fpga_deps}) lite_cc_test(test_conv_fpga SRCS conv_compute_test.cc DEPS ${lite_kernel_deps} conv_compute_fpga ${fpga_deps}) -lite_cc_library(elementwise_compute_fpga SRCS elementwise_compute.cc DEPS ${fpga_deps}) +add_kernel(elementwise_compute_fpga FPGA basic SRCS elementwise_compute.cc DEPS ${fpga_deps}) lite_cc_test(test_elementwise_fpga SRCS elementwise_compute_test.cc DEPS ${lite_kernel_deps} elementwise_compute_fpga ${fpga_deps}) -lite_cc_library(pooling_compute_fpga SRCS pooling_compute.cc DEPS ${fpga_deps}) +add_kernel(pooling_compute_fpga FPGA basic SRCS pooling_compute.cc DEPS ${fpga_deps}) lite_cc_test(test_pooling_compute_fpga SRCS pooling_compute_test.cc DEPS ${lite_kernel_deps} pooling_compute_fpga ${fpga_deps}) -lite_cc_library(scale_compute_fpga SRCS scale_compute.cc DEPS ${fpga_deps}) +add_kernel(scale_compute_fpga FPGA basic SRCS scale_compute.cc DEPS ${fpga_deps}) -lite_cc_library(softmax_compute_fpga SRCS softmax_compute.cc DEPS ${fpga_deps}) +add_kernel(softmax_compute_fpga FPGA basic SRCS softmax_compute.cc DEPS ${fpga_deps}) lite_cc_test(test_softmax_compute_fpga SRCS softmax_compute_test.cc DEPS ${lite_kernel_deps} softmax_compute_fpga ${fpga_deps}) -lite_cc_library(fc_compute_fpga SRCS fc_compute.cc DEPS ${fpga_deps}) +add_kernel(fc_compute_fpga FPGA basic SRCS fc_compute.cc DEPS ${fpga_deps}) lite_cc_test(test_fc_compute_fpga SRCS fc_compute_test.cc DEPS ${lite_kernel_deps} fc_compute_fpga ${fpga_deps}) -lite_cc_library(io_copy_compute_fpga SRCS io_copy_compute.cc DEPS ${fpga_deps}) -lite_cc_library(calib_compute_fpga SRCS calib_compute.cc DEPS ${fpga_deps}) -lite_cc_library(layout_compute_fpga SRCS layout_compute.cc DEPS ${fpga_deps}) -lite_cc_library(feed_compute_fpga SRCS feed_compute.cc DEPS ${fpga_deps}) -lite_cc_library(fetch_compute_fpga SRCS fetch_compute.cc DEPS ${fpga_deps}) - -set (fpga_kernels - activation_compute_fpga - conv_compute_fpga - elementwise_compute_fpga - pooling_compute_fpga - scale_compute_fpga - softmax_compute_fpga - fc_compute_fpga - io_copy_compute_fpga - calib_compute_fpga - layout_compute_fpga - feed_compute_fpga - fetch_compute_fpga -) - -set(fpga_kernels "${fpga_kernels}" CACHE INTERNAL "fpga kernels") +add_kernel(io_copy_compute_fpga FPGA basic SRCS io_copy_compute.cc DEPS ${fpga_deps}) +add_kernel(calib_compute_fpga FPGA basic SRCS calib_compute.cc DEPS ${fpga_deps}) +add_kernel(layout_compute_fpga FPGA basic SRCS layout_compute.cc DEPS ${fpga_deps}) +add_kernel(feed_compute_fpga FPGA basic SRCS feed_compute.cc DEPS ${fpga_deps}) +add_kernel(fetch_compute_fpga FPGA basic SRCS fetch_compute.cc DEPS ${fpga_deps}) diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 5f93051f2be08e0ee02ee23b9196dcb39cd35a0a..abd96317cc2180ecf94e99835ab89216762b8f52 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -1,17 +1,7 @@ message(STATUS "compile with lite host kernels") -lite_cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) -lite_cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) -lite_cc_library(reshape_compute_host SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) +add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) -lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host) - -set(host_kernels - feed_compute_host - fetch_compute_host - reshape_compute_host - ) - -set(host_kernels "${host_kernels}" CACHE GLOBAL "host kernels") - - +lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) diff --git a/lite/kernels/host/reshape_compute.cc b/lite/kernels/host/reshape_compute.cc index cb3420fbbda06f34772ca672b2bc7a8444056185..a5934999cdd9c88037936bbf73f7d810aaffc3e7 100644 --- a/lite/kernels/host/reshape_compute.cc +++ b/lite/kernels/host/reshape_compute.cc @@ -93,3 +93,40 @@ REGISTER_LITE_KERNEL(reshape2, {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); + +REGISTER_LITE_KERNEL(flatten, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ReshapeCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Shape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .Finalize(); + +REGISTER_LITE_KERNEL(flatten2, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ReshapeCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Shape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("XShape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/lite/kernels/npu/CMakeLists.txt b/lite/kernels/npu/CMakeLists.txt index 2ef9bf03b2b989d599414dc137547e4af6191144..960dbff8dba4e391761d323cbfc24946853f9e3a 100644 --- a/lite/kernels/npu/CMakeLists.txt +++ b/lite/kernels/npu/CMakeLists.txt @@ -2,12 +2,8 @@ if(NOT LITE_WITH_NPU) return () endif() - + message(STATUS "compile with lite NPU kernels") -lite_cc_library(graph_compute_npu SRCS graph_compute.cc DEPS ${lite_kernel_deps} ${npu_ddk_libs}) +add_kernel(graph_compute_npu NPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} ${npu_ddk_libs}) # lite_cc_test(test_graph_compute_npu SRCS graph_compute_test.cc DEPS graph_compute_npu) - -set(npu_kernels graph_compute_npu) -set(npu_kernels "${npu_kernels}" CACHE INTERNAL "npu kernels") - diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 68662bf17e3d9cff1c09e1a38643477c02e80185..dc1ff6b97fad6e946452efe47a869998250c6ea2 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -4,17 +4,17 @@ endif() set(cl_kernel_deps op_params cl_runtime cl_context cl_wrapper cl_target_wrapper) -lite_cc_library(fc_opencl SRCS fc_compute.cc DEPS ${cl_kernel_deps}) -lite_cc_library(mul_opencl SRCS mul_compute.cc DEPS ${cl_kernel_deps}) -lite_cc_library(elementwise_add_opencl SRCS elementwise_add_compute.cc DEPS ${cl_kernel_deps}) -lite_cc_library(fusion_elementwise_add_activation_opencl - SRCS fusion_elementwise_add_activation_compute.cc +add_kernel(fc_opencl OPENCL basic SRCS fc_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(mul_opencl OPENCL basic SRCS mul_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(elementwise_add_opencl OPENCL basic SRCS elementwise_add_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(fusion_elementwise_add_activation_opencl + OPENCL basic SRCS fusion_elementwise_add_activation_compute.cc DEPS elementwise_add_opencl ${cl_kernel_deps}) -lite_cc_library(pool_opencl SRCS pool_compute.cc DEPS ${cl_kernel_deps}) -lite_cc_library(io_copy_compute_opencl SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps}) -lite_cc_library(relu_opencl SRCS relu_compute.cc DEPS ${cl_kernel_deps}) -lite_cc_library(depthwise_conv2d_opencl SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) -lite_cc_library(conv_opencl SRCS conv_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps}) +add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps}) lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context @@ -47,15 +47,3 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc DEPS conv_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl) - -set(opencl_kernels - io_copy_compute_opencl - elementwise_add_opencl - fusion_elementwise_add_activation_opencl - pool_opencl - relu_opencl - mul_opencl - fc_opencl - depthwise_conv2d_opencl - conv_opencl - CACHE INTERNAL "opencl_kernels") diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 3a6d72f5fbf7d2c9b627e91624e6fcdbbc251f86..7080cc8c554da5698f4462302f2fcf4f94db6649 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -10,7 +10,7 @@ endif() # lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps}) -lite_cc_library(scale_compute_x86 SRCS scale_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op) # lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) # lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) @@ -31,23 +31,3 @@ lite_cc_library(scale_compute_x86 SRCS scale_compute.cc DEPS ${lite_kernel_deps} # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) - - -set(x86_kernels -# activation_compute_x86 -# elementwise_compute_x86 -# mean_compute_x86 -# fill_constant_compute_x86 -# mul_compute_x86 -# relu_compute_x86 -# fc_compute_x86 - scale_compute_x86 -# softmax_compute_x86 -# dropout_compute_x86 -# concat_compute_x86 -# conv_compute_x86 -# pool_compute_x86 -# batch_norm_compute_x86 -# uniform_random_compute_x86 -# sgd_compute_x86 - CACHE INTERNAL "x86 kernels") diff --git a/lite/npu/bridge/batch_norm_op.cc b/lite/npu/bridge/batch_norm_op.cc index e07b94763aa0a980a47da8c4a0dec3f5270d2da8..4fffb85cf181f5bfe53b96b75d34f1a7a4ba1398 100644 --- a/lite/npu/bridge/batch_norm_op.cc +++ b/lite/npu/bridge/batch_norm_op.cc @@ -30,12 +30,14 @@ namespace bridge { node_map_type BatchNormConverter( const std::shared_ptr batch_norm_op, const node_map_type& inputs_map) { - LOG(INFO) << "converting batchnorm..."; - lite::Scope* scope = batch_norm_op->scope(); - const lite::OpInfo* op_info = batch_norm_op->op_info(); + auto scope = batch_norm_op->scope(); + auto op_info = batch_norm_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; - std::shared_ptr output_node = - std::make_shared(UniqueName("batch_norm")); + std::shared_ptr batch_norm_node = + std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); auto scale_var_name = op_info->Input("Scale").front(); @@ -68,21 +70,21 @@ node_map_type BatchNormConverter( int npu_mode = 1; // bnScale, bnBias tensor dims are 1xCx1x1 bool npu_use_global_stats = op_info->GetAttr("use_global_stats"); - output_node->set_input_x(*inputs_map.at(x_var_name)); - output_node->set_input_scale(*npu_scale); - output_node->set_input_b(*npu_bias); - output_node->set_input_mean(*npu_mean); - output_node->set_input_variance(*npu_variance); - output_node->set_attr_momentum(npu_momentum); - output_node->set_attr_epsilon(npu_epsilon); - output_node->set_attr_mode(npu_mode); - output_node->set_attr_use_global_stats(npu_use_global_stats); + batch_norm_node->set_input_x(*inputs_map.at(x_var_name)); + batch_norm_node->set_input_scale(*npu_scale); + batch_norm_node->set_input_b(*npu_bias); + batch_norm_node->set_input_mean(*npu_mean); + batch_norm_node->set_input_variance(*npu_variance); + batch_norm_node->set_attr_momentum(npu_momentum); + batch_norm_node->set_attr_epsilon(npu_epsilon); + batch_norm_node->set_attr_mode(npu_mode); + batch_norm_node->set_attr_use_global_stats(npu_use_global_stats); OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(output_node); + OpList::Global().add(batch_norm_node); node_map_type outputs_map; - outputs_map[op_info->Output("Y").front()] = output_node; + outputs_map[op_info->Output("Y").front()] = batch_norm_node; return outputs_map; } diff --git a/lite/npu/bridge/elementwise_ops.cc b/lite/npu/bridge/elementwise_ops.cc index 68e1120f57cd3686f81e1c4d19ce7031d1f940fe..784caf6a7acde1203741662cdb2ab58c1b5af6e8 100644 --- a/lite/npu/bridge/elementwise_ops.cc +++ b/lite/npu/bridge/elementwise_ops.cc @@ -30,11 +30,14 @@ namespace bridge { node_map_type ElementwiseConverter( const std::shared_ptr elementwise_op, const node_map_type& inputs_map) { + auto scope = elementwise_op->scope(); + auto op_info = elementwise_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); LOG(INFO) << "converting elementwise..."; - lite::Scope* scope = elementwise_op->scope(); - const lite::OpInfo* op_info = elementwise_op->op_info(); - std::shared_ptr output_node = - std::make_shared(UniqueName("elementwise")); + + std::shared_ptr elementwise_node = + std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); auto y_var_name = op_info->Input("Y").front(); @@ -43,27 +46,27 @@ node_map_type ElementwiseConverter( << "npu elementwise only support inputs with same size"; CHECK(inputs_map.find(x_var_name) != inputs_map.end()); - output_node->set_input_x1(*inputs_map.at(x_var_name)); + elementwise_node->set_input_x1(*inputs_map.at(x_var_name)); OpList::Global().add(inputs_map.at(x_var_name)); if (inputs_map.find(y_var_name) != inputs_map.end()) { - output_node->set_input_x2(*inputs_map.at(y_var_name)); + elementwise_node->set_input_x2(*inputs_map.at(y_var_name)); OpList::Global().add(inputs_map.at(y_var_name)); } else { auto consty = std::make_shared(y_var_name); auto* y = scope->FindVar(y_var_name)->GetMutable(); consty->set_attr_value(CvtFromLiteTensor(y)); - output_node->set_input_x2(*consty); + elementwise_node->set_input_x2(*consty); OpList::Global().add(consty); } - OpList::Global().add(output_node); + OpList::Global().add(elementwise_node); // paddlelite has sum only - output_node->set_attr_mode(1); + elementwise_node->set_attr_mode(1); node_map_type outputs_map; - outputs_map[op_info->Output("Out").front()] = output_node; + outputs_map[op_info->Output("Out").front()] = elementwise_node; return outputs_map; } diff --git a/lite/npu/bridge/pool_op.cc b/lite/npu/bridge/pool_op.cc index e7208ab8507d29fb3b85d28f5572295ac12ad796..7a701c62fb745d04c86097e5b49bef8c18c313e2 100644 --- a/lite/npu/bridge/pool_op.cc +++ b/lite/npu/bridge/pool_op.cc @@ -29,12 +29,14 @@ namespace bridge { node_map_type PoolConverter(const std::shared_ptr pool_op, const node_map_type& inputs_map) { - LOG(INFO) << "converting pool..."; - lite::Scope* scope = pool_op->scope(); - const lite::OpInfo* op_info = pool_op->op_info(); + auto scope = pool_op->scope(); + auto op_info = pool_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; - std::shared_ptr output_node = - std::make_shared(UniqueName("pool")); + std::shared_ptr pool_node = + std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); auto pooling_type = op_info->GetAttr("pooling_type"); int npu_mode = 0; @@ -61,21 +63,21 @@ node_map_type PoolConverter(const std::shared_ptr pool_op, npu_ceil_mode = op_info->GetAttr("ceil_mode") ? 1 : 0; } - output_node->set_input_x(*inputs_map.at(x_var_name)); - output_node->set_attr_mode(npu_mode); - output_node->set_attr_pad_mode(0); - output_node->set_attr_global_pooling(npu_global_pooling); - output_node->set_attr_window(npu_window); - output_node->set_attr_pad(npu_pad); - output_node->set_attr_stride(npu_stride); - output_node->set_attr_ceil_mode(npu_ceil_mode); + pool_node->set_input_x(*inputs_map.at(x_var_name)); + pool_node->set_attr_mode(npu_mode); + pool_node->set_attr_pad_mode(0); + pool_node->set_attr_global_pooling(npu_global_pooling); + pool_node->set_attr_window(npu_window); + pool_node->set_attr_pad(npu_pad); + pool_node->set_attr_stride(npu_stride); + pool_node->set_attr_ceil_mode(npu_ceil_mode); // output_node->set_attr_data_mode(npu_data_mode); OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(output_node); + OpList::Global().add(pool_node); node_map_type outputs_map; - outputs_map[op_info->Output("Out").front()] = output_node; + outputs_map[op_info->Output("Out").front()] = pool_node; return outputs_map; } diff --git a/lite/npu/bridge/shuffle_channel_op.cc b/lite/npu/bridge/shuffle_channel_op.cc index cb1bcdbec57823bfdd94c74cd1067d7728fc86b4..c87bcfe1a9d160a95f4f9edf179284d4921f6b18 100644 --- a/lite/npu/bridge/shuffle_channel_op.cc +++ b/lite/npu/bridge/shuffle_channel_op.cc @@ -30,22 +30,24 @@ namespace bridge { node_map_type ShuffleChannelConverter( const std::shared_ptr shuffle_channel_op, const node_map_type& inputs_map) { - LOG(INFO) << "converting shuffle_channel..."; - lite::Scope* scope = shuffle_channel_op->scope(); - const lite::OpInfo* op_info = shuffle_channel_op->op_info(); - - std::shared_ptr output_node = - std::make_shared(UniqueName("shuffle_channel")); + auto scope = shuffle_channel_op->scope(); + auto op_info = shuffle_channel_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; + + std::shared_ptr shuffle_channel_node = + std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); - output_node->set_input_x(*inputs_map.at(x_var_name)); - output_node->set_attr_group(op_info->GetAttr("group")); + shuffle_channel_node->set_input_x(*inputs_map.at(x_var_name)); + shuffle_channel_node->set_attr_group(op_info->GetAttr("group")); OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(output_node); + OpList::Global().add(shuffle_channel_node); node_map_type outputs_map; - outputs_map[op_info->Output("Out").front()] = output_node; + outputs_map[op_info->Output("Out").front()] = shuffle_channel_node; return outputs_map; } diff --git a/lite/npu/bridge/softmax_op.cc b/lite/npu/bridge/softmax_op.cc index 6532a283e439bae423bc18b7fe1065b20f23b486..3062e7e45479f55dd80f9e685e369f1e83a4ea17 100644 --- a/lite/npu/bridge/softmax_op.cc +++ b/lite/npu/bridge/softmax_op.cc @@ -29,12 +29,14 @@ namespace bridge { node_map_type SoftmaxConverter(const std::shared_ptr softmax_op, const node_map_type& inputs_map) { - LOG(INFO) << "converting softmax..."; - lite::Scope* scope = softmax_op->scope(); - const lite::OpInfo* op_info = softmax_op->op_info(); + auto scope = softmax_op->scope(); + auto op_info = softmax_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; - std::shared_ptr output_node = - std::make_shared(UniqueName("softmax")); + std::shared_ptr softmax_node = + std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); auto x_dims = scope->FindVar(x_var_name)->GetMutable()->dims(); @@ -46,14 +48,14 @@ node_map_type SoftmaxConverter(const std::shared_ptr softmax_op, } CHECK(inputs_map.count(x_var_name)); - output_node->set_input_x(*inputs_map.at(x_var_name)); - output_node->set_attr_axis(axis); + softmax_node->set_input_x(*inputs_map.at(x_var_name)); + softmax_node->set_attr_axis(axis); OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(output_node); + OpList::Global().add(softmax_node); node_map_type outputs_map; - outputs_map[op_info->Output("Out").front()] = output_node; + outputs_map[op_info->Output("Out").front()] = softmax_node; return outputs_map; } diff --git a/lite/npu/bridge/transpose_op.cc b/lite/npu/bridge/transpose_op.cc index 41c88cda0165925187b36657e3df60dc26b8973a..cc10a9b44a1960e26b471b45add8a22f6bd36674 100644 --- a/lite/npu/bridge/transpose_op.cc +++ b/lite/npu/bridge/transpose_op.cc @@ -30,19 +30,21 @@ namespace bridge { node_map_type TransposeConverter( const std::shared_ptr transpose_op, const node_map_type& inputs_map) { - LOG(INFO) << "converting transpose..."; - lite::Scope* scope = transpose_op->scope(); - const lite::OpInfo* op_info = transpose_op->op_info(); + auto scope = transpose_op->scope(); + auto op_info = transpose_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; - std::shared_ptr output_node = - std::make_shared(UniqueName("transpose")); + std::shared_ptr transpose_node = + std::make_shared(unique_op_type); auto x_var_name = op_info->Input("X").front(); // paddlelite doesn't have this input // w must be set, but it does nothing - auto w_var_name = "transpose_w"; + auto w_var_name = unique_op_type + "/w"; auto* w = scope->Var(w_var_name)->GetMutable(); - w->Resize(scope->FindVar(x_var_name)->GetMutable()->dims()); + w->Resize({1}); auto* w_data = w->mutable_data(); for (int i = 0; i < w->numel(); i++) { w_data[i] = 1.f; @@ -55,15 +57,15 @@ node_map_type TransposeConverter( auto npu_axis = ge::AttrValue::LIST_INT(axis.begin(), axis.end()); CHECK(inputs_map.count(x_var_name)); - output_node->set_input_x(*inputs_map.at(x_var_name)); - output_node->set_input_w(*npu_w); - output_node->set_attr_order(npu_axis); + transpose_node->set_input_x(*inputs_map.at(x_var_name)); + transpose_node->set_input_w(*npu_w); + transpose_node->set_attr_order(npu_axis); OpList::Global().add(inputs_map.at(x_var_name)); - OpList::Global().add(output_node); + OpList::Global().add(transpose_node); node_map_type outputs_map; - outputs_map[op_info->Output("Out").front()] = output_node; + outputs_map[op_info->Output("Out").front()] = transpose_node; return outputs_map; } diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 1362a86797f699ba37e328d8a4a2ffd166bb55b2..f46c0f02d6301ba80860ab83de0a66b8e3679d4a 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -1,176 +1,92 @@ -set(op_DEPS tensor op op_params) - -lite_cc_library(conv_op SRCS conv_op.cc DEPS ${op_DEPS}) -lite_cc_library(pool_op SRCS pool_op.cc DEPS ${op_DEPS}) -lite_cc_library(fc_op SRCS fc_op.cc DEPS ${op_DEPS}) -lite_cc_library(relu_op SRCS relu_op.cc DEPS ${op_DEPS}) -lite_cc_library(mul_op SRCS mul_op.cc DEPS ${op_DEPS}) -lite_cc_library(matmul_op SRCS matmul_op.cc DEPS ${op_DEPS}) -lite_cc_library(scale_op SRCS scale_op.cc DEPS ${op_DEPS}) -lite_cc_library(softmax_op SRCS softmax_op.cc DEPS ${op_DEPS}) -lite_cc_library(reshape_op SRCS reshape_op.cc DEPS ${op_DEPS} ) -lite_cc_library(batch_norm_op SRCS batch_norm_op.cc DEPS ${op_DEPS}) -lite_cc_library(feed_op SRCS feed_op.cc DEPS ${op_DEPS}) -lite_cc_library(fetch_op SRCS fetch_op.cc DEPS ${op_DEPS}) -lite_cc_library(io_copy_op SRCS io_copy_op.cc DEPS ${op_DEPS}) -lite_cc_library(io_copy_once_op SRCS io_copy_once_op.cc DEPS io_copy_op ${op_DEPS}) -lite_cc_library(activation_ops SRCS activation_ops.cc DEPS ${op_DEPS}) -lite_cc_library(elementwise_ops SRCS elementwise_ops.cc DEPS ${op_DEPS}) -lite_cc_library(lrn_op_lite SRCS lrn_op.cc DEPS ${op_DEPS}) -lite_cc_library(decode_bboxes_op_lite SRCS decode_bboxes_op.cc DEPS ${op_DEPS}) -lite_cc_library(box_coder_op_lite SRCS box_coder_op.cc DEPS ${op_DEPS}) -lite_cc_library(multiclass_nms_op_lite SRCS multiclass_nms_op.cc DEPS ${op_DEPS}) -lite_cc_library(fusion_elementwise_activation_ops SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS}) -lite_cc_library(mean_op SRCS mean_op.cc DEPS ${op_DEPS}) -lite_cc_library(fill_constant_op SRCS fill_constant_op.cc DEPS ${op_DEPS}) -lite_cc_library(sgd_op SRCS sgd_op.cc DEPS ${op_DEPS}) -lite_cc_library(uniform_random_op SRCS uniform_random_op.cc DEPS ${op_DEPS}) -lite_cc_library(power_op SRCS power_op.cc DEPS ${op_DEPS}) -lite_cc_library(shuffle_channel_op SRCS shuffle_channel_op.cc DEPS ${op_DEPS}) -lite_cc_library(yolo_box_op SRCS yolo_box_op.cc DEPS ${op_DEPS}) -lite_cc_library(interpolate_op SRCS interpolate_op.cc DEPS ${op_DEPS}) -lite_cc_library(argmax_op SRCS argmax_op.cc DEPS ${op_DEPS}) -lite_cc_library(axpy_op SRCS axpy_op.cc DEPS ${op_DEPS}) -lite_cc_library(gru_unit_op SRCS gru_unit_op.cc DEPS ${op_DEPS}) -lite_cc_library(gru_op SRCS gru_op.cc DEPS ${op_DEPS}) -lite_cc_library(layout_op SRCS layout_op.cc DEPS ${op_DEPS}) -lite_cc_library(layout_once_op SRCS layout_once_op.cc DEPS ${op_DEPS}) -lite_cc_library(while_op SRCS while_op.cc DEPS ${op_DEPS}) -lite_cc_library(lookup_table_op SRCS lookup_table_op.cc DEPS ${op_DEPS}) -lite_cc_library(beam_search_decode_op SRCS beam_search_decode_op.cc DEPS ${op_DEPS}) -lite_cc_library(prior_box_op SRCS prior_box_op.cc DEPS ${op_DEPS}) -lite_cc_library(density_prior_box_op SRCS density_prior_box_op.cc DEPS ${op_DEPS}) +set(op_DEPS tensor op op_params scope memory) lite_cc_library(op_params SRCS op_params.cc DEPS tensor any) -lite_cc_library(dropout_op SRCS dropout_op.cc DEPS ${op_DEPS}) -lite_cc_library(concat_op SRCS concat_op.cc DEPS ${op_DEPS}) -lite_cc_library(pad2d_op SRCS pad2d_op.cc DEPS ${op_DEPS}) -lite_cc_library(negative_op SRCS negative_op.cc DEPS ${op_DEPS}) -lite_cc_library(crop_op SRCS crop_op.cc DEPS ${op_DEPS}) -lite_cc_library(calib_op SRCS calib_op.cc DEPS ${op_DEPS}) -lite_cc_library(calib_once_op SRCS calib_once_op.cc DEPS ${op_DEPS}) -lite_cc_library(split_op SRCS split_op.cc DEPS ${op_DEPS}) -lite_cc_library(transpose_op SRCS transpose_op.cc DEPS ${op_DEPS}) -lite_cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) -lite_cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS}) -lite_cc_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS ${op_DEPS}) -lite_cc_library(im2sequence_op SRCS im2sequence_op.cc DEPS ${op_DEPS}) -lite_cc_library(sequence_softmax_op SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) -lite_cc_library(norm_op SRCS norm_op.cc DEPS ${op_DEPS}) -lite_cc_library(graph_op SRCS graph_op.cc DEPS ${op_DEPS}) -lite_cc_library(topk_op SRCS topk_op.cc DEPS ${op_DEPS}) -lite_cc_library(increment_op SRCS increment_op.cc DEPS ${op_DEPS}) -lite_cc_library(write_to_array_op SRCS write_to_array_op.cc DEPS ${op_DEPS}) -lite_cc_library(graph_op_lite SRCS graph_op.cc DEPS ${op_DEPS}) -lite_cc_library(logical_xor SRCS logical_op.cc DEPS ${op_DEPS}) -lite_cc_library(logical_and SRCS logical_op.cc DEPS ${op_DEPS}) -lite_cc_library(logical_or SRCS logical_op.cc DEPS ${op_DEPS}) -lite_cc_library(logical_not SRCS logical_op.cc DEPS ${op_DEPS}) -lite_cc_library(less_than SRCS compare_op.cc DEPS ${op_DEPS}) -lite_cc_library(equal SRCS compare_op.cc DEPS ${op_DEPS}) -lite_cc_library(not_equal SRCS compare_op.cc DEPS ${op_DEPS}) -lite_cc_library(less_equal SRCS compare_op.cc DEPS ${op_DEPS}) -lite_cc_library(greater_than SRCS compare_op.cc DEPS ${op_DEPS}) -lite_cc_library(greater_equal SRCS compare_op.cc DEPS ${op_DEPS}) -lite_cc_library(read_from_array_op SRCS read_from_array_op.cc DEPS ${op_DEPS}) -lite_cc_library(beam_search_op SRCS beam_search_op.cc DEPS ${op_DEPS}) -lite_cc_library(sequence_pool_op_lite SRCS sequence_pool_op.cc DEPS ${op_DEPS}) -lite_cc_library(sequence_expand_op_lite SRCS sequence_expand_op.cc DEPS ${op_DEPS}) -lite_cc_library(reduce_max_op_lite SRCS reduce_max_op.cc DEPS ${op_DEPS}) -lite_cc_library(lod_reset_op SRCS lod_reset_op.cc DEPS ${op_DEPS}) -lite_cc_library(is_empty SRCS is_empty_op.cc DEPS ${op_DEPS}) -lite_cc_library(shape_op_lite SRCS shape_op.cc DEPS ${op_DEPS}) -lite_cc_library(cast_op_lite SRCS cast_op.cc DEPS ${op_DEPS}) -lite_cc_library(slice_op_lite SRCS slice_op.cc DEPS ${op_DEPS}) -lite_cc_library(squeeze_op_lite SRCS squeeze_op.cc DEPS ${op_DEPS}) -lite_cc_library(expand_op_lite SRCS expand_op.cc DEPS ${op_DEPS}) +add_operator(conv_op basic SRCS conv_op.cc DEPS ${op_DEPS}) +add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS}) +add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS}) +add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS}) +add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS}) +add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS}) +add_operator(scale_op basic SRCS scale_op.cc DEPS ${op_DEPS}) +add_operator(softmax_op basic SRCS softmax_op.cc DEPS ${op_DEPS}) +add_operator(reshape_op basic SRCS reshape_op.cc DEPS ${op_DEPS} ) +add_operator(batch_norm_op basic SRCS batch_norm_op.cc DEPS ${op_DEPS}) +add_operator(feed_op basic SRCS feed_op.cc DEPS ${op_DEPS}) +add_operator(fetch_op basic SRCS fetch_op.cc DEPS ${op_DEPS}) +add_operator(io_copy_op basic SRCS io_copy_op.cc DEPS ${op_DEPS}) +add_operator(io_copy_once_op basic SRCS io_copy_once_op.cc DEPS io_copy_op ${op_DEPS}) +add_operator(activation_ops basic SRCS activation_ops.cc DEPS ${op_DEPS}) +add_operator(elementwise_ops basic SRCS elementwise_ops.cc DEPS ${op_DEPS}) +add_operator(lrn_op_lite basic SRCS lrn_op.cc DEPS ${op_DEPS}) +add_operator(decode_bboxes_op_lite basic SRCS decode_bboxes_op.cc DEPS ${op_DEPS}) +add_operator(box_coder_op_lite basic SRCS box_coder_op.cc DEPS ${op_DEPS}) +add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DEPS}) +add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS}) +add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS}) +add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS}) +#add_operator(sgd_op basic SRCS sgd_op.cc DEPS ${op_DEPS}) +add_operator(uniform_random_op basic SRCS uniform_random_op.cc DEPS ${op_DEPS}) +add_operator(power_op basic SRCS power_op.cc DEPS ${op_DEPS}) +add_operator(shuffle_channel_op basic SRCS shuffle_channel_op.cc DEPS ${op_DEPS}) +add_operator(yolo_box_op basic SRCS yolo_box_op.cc DEPS ${op_DEPS}) +add_operator(interpolate_op basic SRCS interpolate_op.cc DEPS ${op_DEPS}) +add_operator(argmax_op basic SRCS argmax_op.cc DEPS ${op_DEPS}) +add_operator(axpy_op basic SRCS axpy_op.cc DEPS ${op_DEPS}) +add_operator(gru_unit_op basic SRCS gru_unit_op.cc DEPS ${op_DEPS}) +add_operator(gru_op basic SRCS gru_op.cc DEPS ${op_DEPS}) +add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS}) +add_operator(layout_once_op basic SRCS layout_once_op.cc DEPS ${op_DEPS}) +add_operator(prior_box_op basic SRCS prior_box_op.cc DEPS ${op_DEPS}) +add_operator(density_prior_box_op basic SRCS density_prior_box_op.cc DEPS ${op_DEPS}) +add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS}) +add_operator(concat_op basic SRCS concat_op.cc DEPS ${op_DEPS}) +add_operator(pad2d_op basic SRCS pad2d_op.cc DEPS ${op_DEPS}) +add_operator(negative_op basic SRCS negative_op.cc DEPS ${op_DEPS}) +add_operator(crop_op basic SRCS crop_op.cc DEPS ${op_DEPS}) +add_operator(calib_op basic SRCS calib_op.cc DEPS ${op_DEPS}) +add_operator(calib_once_op basic SRCS calib_once_op.cc DEPS ${op_DEPS}) +add_operator(split_op basic SRCS split_op.cc DEPS ${op_DEPS}) +add_operator(transpose_op basic SRCS transpose_op.cc DEPS ${op_DEPS}) +add_operator(fake_quant basic SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +add_operator(fake_dequant basic SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS}) +add_operator(conv_transpose_op basic SRCS conv_transpose_op.cc DEPS ${op_DEPS}) +add_operator(graph_op basic SRCS graph_op.cc DEPS ${op_DEPS}) +add_operator(expand_op_lite basic SRCS expand_op.cc DEPS ${op_DEPS}) +add_operator(reduce_max_op_lite basic SRCS reduce_max_op.cc DEPS ${op_DEPS}) +add_operator(norm_op basic SRCS norm_op.cc DEPS ${op_DEPS}) +add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS}) +add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS}) +add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS}) + +# for OCR specific +add_operator(im2sequence_op extra SRCS im2sequence_op.cc DEPS ${op_DEPS}) +add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) +add_operator(lookup_table_op extra SRCS lookup_table_op.cc DEPS ${op_DEPS}) +add_operator(beam_search_decode_op extra SRCS beam_search_decode_op.cc DEPS ${op_DEPS}) +add_operator(graph_op_lite extra SRCS graph_op.cc DEPS ${op_DEPS}) +add_operator(logical_xor extra SRCS logical_op.cc DEPS ${op_DEPS}) +add_operator(logical_and extra SRCS logical_op.cc DEPS ${op_DEPS}) +add_operator(logical_or extra SRCS logical_op.cc DEPS ${op_DEPS}) +add_operator(logical_not extra SRCS logical_op.cc DEPS ${op_DEPS}) +add_operator(less_than extra SRCS compare_op.cc DEPS ${op_DEPS}) +add_operator(equal extra SRCS compare_op.cc DEPS ${op_DEPS}) +add_operator(not_equal extra SRCS compare_op.cc DEPS ${op_DEPS}) +add_operator(less_equal extra SRCS compare_op.cc DEPS ${op_DEPS}) +add_operator(greater_than extra SRCS compare_op.cc DEPS ${op_DEPS}) +add_operator(greater_equal extra SRCS compare_op.cc DEPS ${op_DEPS}) +add_operator(read_from_array_op extra SRCS read_from_array_op.cc DEPS ${op_DEPS}) +add_operator(beam_search_op extra SRCS beam_search_op.cc DEPS ${op_DEPS}) +add_operator(sequence_pool_op_lite extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) +add_operator(lod_reset_op extra SRCS lod_reset_op.cc DEPS ${op_DEPS}) +add_operator(is_empty extra SRCS is_empty_op.cc DEPS ${op_DEPS}) +add_operator(cast_op_lite extra SRCS cast_op.cc DEPS ${op_DEPS}) +add_operator(slice_op_lite extra SRCS slice_op.cc DEPS ${op_DEPS}) +add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS}) +add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS}) +add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) +add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) -set(ops - conv_op - pool_op - fc_op - relu_op - mul_op - matmul_op - scale_op - softmax_op - reshape_op - batch_norm_op - feed_op - fetch_op - gru_unit_op - gru_op - beam_search_decode_op - lookup_table_op - io_copy_op - io_copy_once_op - elementwise_ops - fusion_elementwise_activation_ops - lrn_op_lite - decode_bboxes_op_lite - multiclass_nms_op_lite - decode_bboxes_op_lite - box_coder_op_lite - multiclass_nms_op_lite - mean_op - fill_constant_op - activation_ops - dropout_op - concat_op - pad2d_op - crop_op - prior_box_op - density_prior_box_op - negative_op - calib_op - calib_once_op - split_op - transpose_op - fake_quant - fake_dequant - sgd_op - uniform_random_op - power_op - yolo_box_op - shuffle_channel_op - argmax_op - axpy_op - conv_transpose_op - im2sequence_op - sequence_softmax_op - norm_op - layout_op - layout_once_op - interpolate_op - logical_xor - logical_and - logical_or - logical_not - equal - not_equal - less_than - while_op - less_equal - greater_than - greater_equal - graph_op - topk_op - increment_op - write_to_array_op - read_from_array_op - beam_search_op - sequence_pool_op_lite - sequence_expand_op_lite - reduce_max_op_lite - lod_reset_op - is_empty - shape_op_lite - cast_op_lite - slice_op_lite - squeeze_op_lite - expand_op_lite - CACHE INTERNAL "ops lite") if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc @@ -184,7 +100,7 @@ if (NOT LITE_WITH_X86) lite_cc_test(test_softmax_op SRCS softmax_op_test.cc DEPS softmax_op memory) #lite_cc_test(test_reshape_op SRCS reshape_op_test.cc DEPS reshape_op memory) lite_cc_test(test_batch_norm_op SRCS batch_norm_op_test.cc DEPS batch_norm_op memory) - lite_cc_test(test_concat_op SRCS concat_op_test.cc DEPS concat_op memory) + lite_cc_test(test_concat_op SRCS concat_op_test.cc DEPS concat_op memory scope) lite_cc_test(test_calib_op SRCS calib_op_test.cc DEPS calib_op memory ARM_DEPS calib_compute_arm) lite_cc_test(test_fusion_elementwise_activation_ops SRCS fusion_elementwise_activation_ops_test.cc diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index b84b4ff16993b51410bf741db91c5ec46960d410..fb6b431fff8ab20dd1a6d1abc8aff7443771ee2f 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -85,7 +85,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, } } } - param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + if (op_desc.HasAttr("fuse_relu")) { + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + } return true; } diff --git a/lite/operators/density_prior_box_op.cc b/lite/operators/density_prior_box_op.cc index c6b646b33d64eaf8dc3ca34254d9a756e01fb1d6..86830df2f19b5615e8b9cfb4b3b57eb22000f588 100644 --- a/lite/operators/density_prior_box_op.cc +++ b/lite/operators/density_prior_box_op.cc @@ -41,15 +41,29 @@ bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, param_.boxes = scope->FindVar(boxes)->GetMutable(); param_.variances = scope->FindVar(variances)->GetMutable(); - param_.flip = opdesc.GetAttr("flip"); param_.clip = opdesc.GetAttr("clip"); - param_.min_sizes = opdesc.GetAttr>("min_sizes"); param_.fixed_sizes = opdesc.GetAttr>("fixed_sizes"); param_.fixed_ratios = opdesc.GetAttr>("fixed_ratios"); - param_.density_sizes = opdesc.GetAttr>("density_sizes"); - param_.max_sizes = opdesc.GetAttr>("max_sizes"); - param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); param_.variances_ = opdesc.GetAttr>("variances"); + + if (opdesc.HasAttr("aspect_ratios")) { + param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); + } + if (opdesc.HasAttr("max_sizes")) { + param_.max_sizes = opdesc.GetAttr>("max_sizes"); + } + if (opdesc.HasAttr("density_sizes")) { + param_.density_sizes = opdesc.GetAttr>("density_sizes"); + } + if (opdesc.HasAttr("densities")) { + param_.density_sizes = opdesc.GetAttr>("densities"); + } + if (opdesc.HasAttr("min_sizes")) { + param_.min_sizes = opdesc.GetAttr>("min_sizes"); + } + if (opdesc.HasAttr("flip")) { + param_.flip = opdesc.GetAttr("flip"); + } if (opdesc.HasAttr("img_w")) { param_.img_w = opdesc.GetAttr("img_w"); } diff --git a/lite/operators/fake_quantize_range_abs_max.cc b/lite/operators/fake_quantize_range_abs_max.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8ce3f75a59fec5b032c60f51177f428bd15fe0d --- /dev/null +++ b/lite/operators/fake_quantize_range_abs_max.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_quantize_range_abs_max.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_quantize_range_abs_max, + paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h new file mode 100644 index 0000000000000000000000000000000000000000..726731595a9c4b7cd2e30db911230cc2f00b5b92 --- /dev/null +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeRangeMaxAbsOpLite : public OpLite { + public: + FakeQuantizeRangeMaxAbsOpLite() {} + + explicit FakeQuantizeRangeMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShape() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_range_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/flatten_op.cc b/lite/operators/flatten_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6deab45023876b1a5707ef5cea6ec69af3875328 --- /dev/null +++ b/lite/operators/flatten_op.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/flatten_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool FlattenOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool FlattenOp::InferShape() const { + auto x_dims = param_.x->dims(); + + auto out_lod = param_.output->mutable_lod(); + *out_lod = param_.x->lod(); + + int64_t outer = 1, inner = 1; + for (int i = 0; i < x_dims.size(); ++i) { + if (i < axis_) { + outer *= x_dims[i]; + } else { + inner *= x_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + + param_.output->Resize(out_shape); + + return true; +} + +bool FlattenOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.x = const_cast(&(x_var->Get())); + param_.output = output_var->GetMutable(); + axis_ = opdesc.GetAttr("axis"); + + param_.inplace = false; + + CHECK(param_.x) << "Input(X) of FlattenOp should not be null."; + CHECK(param_.output) << "Output(Out) of FlattenOp should not be null."; + CHECK_GE(axis_, 0) << "Flatten op axis should >=0."; + return true; +} + +bool Flatten2Op::CheckShape() const { + FlattenOp::CheckShape(); + CHECK_OR_FALSE(param_.xshape); + return true; +} + +bool Flatten2Op::InferShape() const { + FlattenOp::InferShape(); + auto x_dims = param_.x->dims(); + std::vector xshape_dims(x_dims.size() + 1, 0); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.xshape->Resize(DDim(xshape_dims)); + return true; +} + +bool Flatten2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + FlattenOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.xshape = xshape_var->GetMutable(); + CHECK(param_.xshape) << "Output(XShape) of FlattenOp should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(flatten, paddle::lite::operators::FlattenOp); +REGISTER_LITE_OP(flatten2, paddle::lite::operators::Flatten2Op); diff --git a/lite/operators/flatten_op.h b/lite/operators/flatten_op.h new file mode 100644 index 0000000000000000000000000000000000000000..61680fd3903b77f8826cda6f6a242739720155d7 --- /dev/null +++ b/lite/operators/flatten_op.h @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FlattenOp : public OpLite { + public: + FlattenOp() {} + explicit FlattenOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "flatten"; } + + protected: + mutable ReshapeParam param_; + int axis_; +}; + +class Flatten2Op : public FlattenOp { + public: + Flatten2Op() : FlattenOp() {} + explicit Flatten2Op(const std::string &op_type) : FlattenOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "flatten2"; } +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 84693224dbb473eb202b774d0ee945da506a079d..deac6410b31da20b0456f419f6d53411f25d12c2 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -521,7 +521,7 @@ struct PriorBoxParam { struct DensityPriorBoxParam : public PriorBoxParam { std::vector fixed_sizes; std::vector fixed_ratios; - std::vector density_sizes; + std::vector density_sizes; }; /// ----------------------- GRU operators ----------------------f struct GRUParam { diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index 8053b24b623e38491876efc1ff486193a5a08cce..3cc8938f4eb3ffc5720a6e1cfc1746e1defd048e 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -40,12 +40,14 @@ bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { param_.boxes = scope->FindVar(boxes)->GetMutable(); param_.variances = scope->FindVar(variances)->GetMutable(); - param_.flip = opdesc.GetAttr("flip"); param_.clip = opdesc.GetAttr("clip"); param_.min_sizes = opdesc.GetAttr>("min_sizes"); param_.max_sizes = opdesc.GetAttr>("max_sizes"); param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); param_.variances_ = opdesc.GetAttr>("variances"); + if (opdesc.HasAttr("flip")) { + param_.flip = opdesc.GetAttr("flip"); + } if (opdesc.HasAttr("img_w")) { param_.img_w = opdesc.GetAttr("img_w"); } diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 777e4408c020d76142f556c2fbf2f1730d722759..8ef2532ac04d8bbaaa64fb4eaa101340ce7ffebe 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -21,7 +21,13 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH #lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + +if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) +endif() + lite_cc_test(test_sgemm SRCS test_sgemm.cc DEPS ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -31,9 +37,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_shape_compute SRCS shape_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index 1a6fefb8f18fa9da9108626e0a9b8ddc0c7593a6..95a8167701aa72dcc992f3ba829182bea6f3d143 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -171,9 +171,9 @@ void test_fc(Place place) { DDim bdim{{bflag ? n : 0}}; std::unique_ptr tester( new FcOPTest(place, "def", dim_in, wdim, bdim, 1)); -#ifdef WITH_ARM_LITE +#ifdef LITE_WITH_ARM auto& ctx = tester->context()->As(); - ctx.SetRunMode(LITE_POWER_HIGH, 1); + ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 1); #endif arena::Arena arena(std::move(tester), place, 6e-5); if (!arena.TestPrecision()) { diff --git a/lite/tests/kernels/gru_unit_test.cc b/lite/tests/kernels/gru_unit_test.cc index e218d6db2588145f22f7ea80d212c8e274112571..bf4b7dd5e285d30a3227ee463653186cd3b42953 100644 --- a/lite/tests/kernels/gru_unit_test.cc +++ b/lite/tests/kernels/gru_unit_test.cc @@ -344,7 +344,7 @@ void test_gru_unit(Place place) { place, "def", 1 /* sigomoid */, 2 /* tanh */, false, dims)); #ifdef LITE_WITH_ARM auto& ctx = tester->context()->template As(); - ctx.SetRunMode(LITE_POWER_HIGH, 1); + ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 1); #endif arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); diff --git a/lite/tests/kernels/prior_box_compute_test.cc b/lite/tests/kernels/prior_box_compute_test.cc index 57bea3e96dbf55014d55b0c7d34e8aa4db7b4b48..47f7bc9447b1b33b57c4bc4a495a106f49d6abbc 100644 --- a/lite/tests/kernels/prior_box_compute_test.cc +++ b/lite/tests/kernels/prior_box_compute_test.cc @@ -75,7 +75,7 @@ void prior_box_compute_ref(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, @@ -352,7 +352,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase { std::vector min_size_; std::vector fixed_size_; std::vector fixed_ratio_; - std::vector density_size_; + std::vector density_size_; std::vector max_size_; std::vector aspect_ratio_; std::vector variance_; @@ -375,7 +375,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase { const std::vector& min_size, const std::vector& fixed_size, const std::vector& fixed_ratio, - const std::vector& density_size, + const std::vector& density_size, const std::vector& max_size, const std::vector& aspect_ratio, const std::vector& variance, @@ -561,7 +561,7 @@ class PriorBoxComputeTester : public arena::TestCase { min_size_, std::vector(), std::vector(), - std::vector(), + std::vector(), max_size_, aspect_ratio_, variance_, @@ -621,7 +621,7 @@ void test_density_prior_box(Place place) { std::vector variance{0.1f, 0.1f, 0.2f, 0.2f}; std::vector fixed_size{60, 30}; std::vector fixed_ratio{1., 2.}; - std::vector density_size{1., 3.}; + std::vector density_size{1, 3}; bool flip = true; bool clip = false; float step_h = 0; diff --git a/lite/tools/benchmark.sh b/lite/tools/benchmark.sh index 66b4025f91d5fa9e21a337235e26131edcda2c22..8a48a16c732b0db65d90426a990dc5e2568127e5 100644 --- a/lite/tools/benchmark.sh +++ b/lite/tools/benchmark.sh @@ -5,18 +5,22 @@ if [ $# -lt 2 ]; then echo "Input error" echo "USAGE:" - echo " sh benchmark.sh benchmark_bin_path test_models_dir" - echo " sh benchmark.sh benchmark_bin_path test_models_dir arm_bi" + echo " sh benchmark.sh benchmark_bin_path benchmark_models_path" + echo " sh benchmark.sh benchmark_bin_path benchmark_models_path is_run_model_optimize" exit fi -BENCHMARK_BIN=$1 -MODELS_DIR=$2 -ARM_BI=$3 ANDROID_DIR=/data/local/tmp RESULT_FILENAME="result.txt" WARMUP=10 REPEATS=30 +BENCHMARK_BIN=$1 +MODELS_DIR=$2 +IS_RUN_MODEL_OPTIMIZE=false +if [ $# -gt 2 ]; +then + IS_RUN_MODEL_OPTIMIZE=$3 +fi adb push $BENCHMARK_BIN $ANDROID_DIR/benchmark_bin adb shell chmod 777 $ANDROID_DIR/benchmark_bin @@ -25,11 +29,11 @@ adb push $MODELS_DIR $ANDROID_DIR adb shell "echo PaddleLite Benchmark > $ANDROID_DIR/$RESULT_FILENAME" for threads in 1 2 4 do -adb shell "echo ABI=$ARM_BI Threads=$threads Warmup=$WARMUP Repeats=$REPEATS >> $ANDROID_DIR/$RESULT_FILENAME" +adb shell "echo Threads=$threads Warmup=$WARMUP Repeats=$REPEATS >> $ANDROID_DIR/$RESULT_FILENAME" for model_name in `ls $MODELS_DIR` do echo $model_name - adb shell "$ANDROID_DIR/benchmark_bin --model_dir=$ANDROID_DIR/${MODELS_DIR##*/}/$model_name --warmup=$WARMUP --repeats=$REPEATS --threads=$threads --result_filename=$ANDROID_DIR/$RESULT_FILENAME" + adb shell "$ANDROID_DIR/benchmark_bin --model_dir=$ANDROID_DIR/${MODELS_DIR##*/}/$model_name --warmup=$WARMUP --repeats=$REPEATS --threads=$threads --result_filename=$ANDROID_DIR/$RESULT_FILENAME --run_model_optimize=$IS_RUN_MODEL_OPTIMIZE" done adb shell "echo >> $ANDROID_DIR/$RESULT_FILENAME" done diff --git a/lite/tools/build.sh b/lite/tools/build.sh index 1fe8fd7dc4c5deeb731758648a17b6e4c038d5ba..7041e5b8f41c4bcf3641ad6574967cafdff4e00f 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -ex readonly CMAKE_COMMON_OPTIONS="-DWITH_GPU=OFF \ -DWITH_MKL=OFF \ @@ -31,6 +32,10 @@ function make_tiny_publish_so { cur_dir=$(pwd) build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} + if [ -d $build_dir ] + then + rm -rf $build_dir + fi mkdir -p $build_dir cd $build_dir @@ -55,6 +60,10 @@ function make_full_publish_so { cur_dir=$(pwd) build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} + if [ -d $build_dir ] + then + rm -rf $build_dir + fi mkdir -p $build_dir cd $build_dir @@ -78,6 +87,10 @@ function make_all_tests { cur_dir=$(pwd) build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} + if [ -d $build_dir ] + then + rm -rf $build_dir + fi mkdir -p $build_dir cd $build_dir diff --git a/lite/tools/build_armlinux.sh b/lite/tools/build_armlinux.sh new file mode 100755 index 0000000000000000000000000000000000000000..3c240ccea9ec9b95de90f1dc7211cdf5cd3cce1a --- /dev/null +++ b/lite/tools/build_armlinux.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +os=armlinux +abi=armv8 +lang=gcc + +if [ x$1 != x ]; then + abi=$1 +fi + +if [ x$2 != x ]; then + lang=$2 +fi + +cur_dir=$(pwd) +build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} +mkdir -p $build_dir +cd $build_dir + +GEN_CODE_PATH_PREFIX=lite/gen_code +mkdir -p ./${GEN_CODE_PATH_PREFIX} +touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc + +cmake .. \ + -DWITH_GPU=OFF \ + -DWITH_MKL=OFF \ + -DWITH_LITE=ON \ + -DLITE_WITH_CUDA=OFF \ + -DLITE_WITH_X86=OFF \ + -DLITE_WITH_ARM=ON \ + -DWITH_ARM_DOTPROD=ON \ + -DLITE_WITH_OPENMP=ON \ + -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ + -DWITH_TESTING=ON \ + -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} + +make -j4 publish_inference + +cd - diff --git a/lite/tools/build_ios_armv7_arm64.sh b/lite/tools/build_ios_armv7_arm64.sh index 04994e7adb765800eb2235e942cf6fdf1472e2f8..718c3e37e92dfe14269b92103bb0e19c15375301 100755 --- a/lite/tools/build_ios_armv7_arm64.sh +++ b/lite/tools/build_ios_armv7_arm64.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -e build_dir=build.ios.armv7.arm64 mkdir -p ${build_dir} @@ -15,11 +16,15 @@ cmake .. \ -DLITE_WITH_CUDA=OFF \ -DLITE_WITH_X86=OFF \ -DLITE_WITH_ARM=ON \ - -DLITE_WITH_OPENMP=ON \ + -DWITH_TESTING=OFF \ + -DLITE_WITH_JAVA=OFF \ + -DLITE_SHUTDOWN_LOG=ON \ + -DLITE_ON_TINY_PUBLISH=ON \ + -DLITE_WITH_OPENMP=OFF \ + -DWITH_ARM_DOTPROD=OFF \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ - -DWITH_TESTING=ON \ -DARM_TARGET_OS=ios -make -j2 +make -j4 cd - diff --git a/lite/tools/cmake_tools/parse_kernel_registry.py b/lite/tools/cmake_tools/parse_kernel_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..99804748f3780990194c429b050d364e3fa20b53 --- /dev/null +++ b/lite/tools/cmake_tools/parse_kernel_registry.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging + +ops_list_path = sys.argv[1] +dest_path = sys.argv[2] + +out_lines = [ + '#pragma once', + '#include "paddle_lite_factory_helper.h"', + '', +] + +with open(ops_list_path) as f: + for line in f: + path = line.strip() + + status = '' + with open(path) as g: + lines = [v for v in g] + for i in range(len(lines)): + line = lines[i].strip() + + if not status: + key = 'REGISTER_LITE_KERNEL' + if line.startswith(key): + forward = i + min(7, len(lines) - i) + remaining = line[len(key) + 1:] + ' '.join( + [v.strip() for v in lines[i + 1:forward]]) + + x = remaining.find('.') + if x > 0: + remaining = remaining[:x] + + fs = [v.strip() for v in remaining.split(',')] + assert (len(fs) >= 4) + op, target, precision, layout, __, alias = fs[:6] + alias = alias.replace(')', '') + + key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % ( + op, target, precision, layout, alias) + out_lines.append(key) + +with open(dest_path, 'w') as f: + logging.info("write kernel list to %s" % dest_path) + f.write('\n'.join(out_lines)) diff --git a/lite/tools/cmake_tools/parse_op_registry.py b/lite/tools/cmake_tools/parse_op_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..423036f6e84ffed39bb6d12589bbe354fcf8b883 --- /dev/null +++ b/lite/tools/cmake_tools/parse_op_registry.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' Collect op registry information. ''' + +import sys +import logging + +ops_list_path = sys.argv[1] +dest_path = sys.argv[2] + +out_lines = [ + '#pragma once', + '#include "paddle_lite_factory_helper.h"', + '', +] + +with open(ops_list_path) as f: + for line in f: + path = line.strip() + + with open(path) as g: + for line in g: + key = 'REGISTER_LITE_OP' + if line.startswith(key): + end = line.find(',') + op = line[len(key) + 1:end] + if not op: continue + if "_grad" in op: continue + out = "USE_LITE_OP(%s);" % op + out_lines.append(out) + +with open(dest_path, 'w') as f: + logging.info("write op list to %s" % dest_path) + f.write('\n'.join(out_lines)) diff --git a/lite/tools/debug/debug_utils.h b/lite/tools/debug/debug_utils.h index 644cdaba1d43c897d91ba9c3e7014cdc1e1f0a7b..7f77b90488657aab96c7942d703e86d64723f5fc 100644 --- a/lite/tools/debug/debug_utils.h +++ b/lite/tools/debug/debug_utils.h @@ -115,7 +115,7 @@ void FillTensorData(lite::Tensor* tensor, const DebugConfig& conf, int col) { data[i] = input_data[i]; } } else { - LOG(INFO) << "------------> Use all-ones input"; + LOG(INFO) << "-------------> Use all-ones input"; for (int i = 0; i < dim_size; i++) { data[i] = 1; } diff --git a/lite/tools/debug/model_debug_tool.cc b/lite/tools/debug/model_debug_tool.cc index 02ef376d90a320b326a9d6a0fa3a56ac1c6068ca..38afc969140bd9ac24f4a0f305c01b61895877f3 100644 --- a/lite/tools/debug/model_debug_tool.cc +++ b/lite/tools/debug/model_debug_tool.cc @@ -33,7 +33,7 @@ void Run(DebugConfig* conf) { CHECK(conf); #ifdef LITE_WITH_ARM DeviceInfo::Init(); - DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, conf->arm_thread_num); + DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, conf->arm_thread_num); #endif lite::Predictor predictor; std::vector valid_places({ diff --git a/lite/utils/io.h b/lite/utils/io.h index 72f00bd1ca355b173d026357dd87654061a108d0..ddd7e39b0d1d3e0b425cff1b31641a6a145f7bfa 100644 --- a/lite/utils/io.h +++ b/lite/utils/io.h @@ -35,7 +35,7 @@ static bool IsFileExists(const std::string& path) { // ARM mobile not support mkdir in C++ static void MkDirRecur(const std::string& path) { #ifndef LITE_WITH_ARM - if(system(string_format("mkdir -p %s", path.c_str()).c_str()) != 0) { + if (system(string_format("mkdir -p %s", path.c_str()).c_str()) != 0) { LOG(ERROR) << "Cann't mkdir " << path; } #else // On ARM diff --git a/lite/utils/logging.h b/lite/utils/logging.h index 095cdffc3eb768cbbb55717a258eee965d5f0d9e..c593bcf76b351bbcbe7ca03426f25d29f3998296 100644 --- a/lite/utils/logging.h +++ b/lite/utils/logging.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "lite/utils/replace_stl/stream.h" // NOLINTFILE() diff --git a/mobile/.clang-format b/mobile/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..d59e0885794e037ab02cd1e385cc8c16b93d3a76 --- /dev/null +++ b/mobile/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/mobile/.clang-tidy b/mobile/.clang-tidy new file mode 100644 index 0000000000000000000000000000000000000000..c788efe69d23e69ee6add3b0be9e09e567494662 --- /dev/null +++ b/mobile/.clang-tidy @@ -0,0 +1,67 @@ +Checks: > + * + -android-* + -bugprone-bool-pointer-implicit-conversion + -cert-env33-c + -cert-dcl50-cpp + -cert-dcl59-cpp + -cppcoreguidelines-* + -fuchsia-* + -google-* + google-default-arguments + google-explicit-constructor + google-runtime-member-string-references + google-runtime-operator + -hicpp-braces-around-statements + -hicpp-named-parameter + -hicpp-no-array-decay + -hicpp-no-assembler + -hicpp-no-malloc + -hicpp-function-size + -hicpp-special-member-functions + -hicpp-vararg + -llvm-* + -objc-* + -readability-else-after-return + -readability-implicit-bool-conversion + -readability-named-parameter + -readability-simplify-boolean-expr + -readability-braces-around-statements + -readability-identifier-naming + -readability-function-size + -readability-redundant-member-init + -misc-bool-pointer-implicit-conversion + -misc-definitions-in-headers + -misc-unused-alias-decls + -misc-unused-parameters + -misc-unused-using-decls + -modernize-use-using + -modernize-use-default-member-init + -clang-diagnostic-* + -clang-analyzer-* +WarningsAsErrors: '*' +HeaderFilterRegex: '' +AnalyzeTemporaryDtors: false +FormatStyle: none +User: allonli +CheckOptions: + - key: google-readability-braces-around-statements.ShortStatementLines + value: '1' + - key: google-readability-function-size.StatementThreshold + value: '800' + - key: google-readability-namespace-comments.ShortNamespaceLines + value: '10' + - key: google-readability-namespace-comments.SpacesBeforeComments + value: '2' + - key: modernize-loop-convert.MaxCopySize + value: '16' + - key: modernize-loop-convert.MinConfidence + value: reasonable + - key: modernize-loop-convert.NamingStyle + value: CamelCase + - key: modernize-pass-by-value.IncludeStyle + value: llvm + - key: modernize-replace-auto-ptr.IncludeStyle + value: llvm + - key: modernize-use-nullptr.NullMacros + value: 'NULL' diff --git a/mobile/.gitignore b/mobile/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..70d0b40927d434c6108a5845faf393b84aa40d34 --- /dev/null +++ b/mobile/.gitignore @@ -0,0 +1,103 @@ +opencl_kernels.cpp +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.lib +*.a + +# Executables +*.exe +*.out +*.app + +.DS_Store + +build/ + +.idea/ + +CMakeCache.txt + +CMakeFiles/ + +Makefile + +cmake_install.cmake + + +*.cbp + +paddle-mobile.cbp + +.idea + +compile_commands.json + +cmake-build-debug/ +cmake-build-release/ + +test/models/ + +test/images/ + +# Emacs intermediate files +*~ + +# CMake building directory +build + +# clion building directories +cmake-build-debug +cmake-build-release + +# ios +tools/libomp.a + +# ios demo +demo/ios/PaddleMobileDemo/PaddleMobileDemo/googlenet_combine/ +demo/ios/PaddleMobileDemo/PaddleMobileDemo/*.jpg +demo/ios/PaddleMobileDemo/PaddleMobileDemo/PaddleMobile/*.a +*.xcuserstate +/tools/quantification/quantify + +# metal +Podfile.lock +metal/Pods/ +SwiftProtobuf.framework +paddle-mobile.xcworkspace +metal/models/ +metal/images/ +*.a +metal/paddle-mobile/paddle-mobile/CPU/libpaddle-mobile.a +*.xcuserdatad/ +*/xcuserdata/ +/venv/ + +metal/paddle-mobile-demo/paddle-mobile-demo/images +metal/paddle-mobile-demo/paddle-mobile-demo/models +metal/paddle-mobile-demo/paddle-mobile-demo/Resources +metal/paddle-mobile-demo/paddle-mobile-demo/Resources/images +metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models +metal/MobileNetDemo/MobileNetDemo/Resources diff --git a/mobile/.pre-commit-config.yaml b/mobile/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9827afcd0ce2b7b8ce5aacd35f0d5a06fe9af3a --- /dev/null +++ b/mobile/.pre-commit-config.yaml @@ -0,0 +1,69 @@ +repos: +- repo: https://github.com/Lucas-C/pre-commit-hooks.git + sha: v1.0.1 + hooks: + - id: remove-crlf + files: ^(mobile/src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$ + exclude: ^(lite/) + - id: remove-tabs + files: ^(mobile/test/|mobile/src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$ + exclude: ^(lite/) + +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0 + hooks: + - id: check-added-large-files + exclude: ^(lite/) + - id: check-merge-conflict + exclude: ^(lite/) + - id: check-symlinks + exclude: ^(lite/) + - id: detect-private-key + files: (?!.*tar.gz)^.*$ + exclude: ^(lite/) + - id: end-of-file-fixer + files: ^(mobile/test/|mobile/src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|h|hpp|hxx)$ + exclude: ^(lite/) + - id: trailing-whitespace + files: ^(mobile/test/|mobile/src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|h|hpp|hxx)$ + exclude: ^(lite/) + +- repo: local + hooks: + - id: copyright + name: copyright + entry: python ./mobile/tools/pre-commit.hooks/copyright.hook + language: system + files: ^(mobile/test/|mobile/src/).*\.(c|cc|cxx|cpp|h|hpp|hxx|py)$ + exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ | ^(lite/) + +- repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat. + entry: bash ./mobile/tools/pre-commit.hooks/clang-format.hook -i + language: system + files: ^(mobile/test/|mobile/src/).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ + exclude: ^(lite/) + +- repo: local + hooks: + - id: cpplint + name: cpplint + description: Check C++ code style using cpplint. + entry: bash ./mobile/tools/pre-commit.hooks/cpplint.hook + language: system + files: ^(mobile/test/|mobile/src/).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ + exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$i | *\.pb\.cpp | ^(lite/) + + +# +#- repo: local +# hooks: +# - id: clang-tidy +# name: clang-tidy +# description: Check C++ code style using clang-tidy. +# entry: bash ./tools/pre-commit.hooks/.clang-tidy.hook -i +# language: system +# files: (src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ diff --git a/mobile/.travis.yml b/mobile/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..20fdddd5a172d63b6b3df3fb2a57265a08ed3732 --- /dev/null +++ b/mobile/.travis.yml @@ -0,0 +1,36 @@ +language: cpp +cache: ccache +sudo: required +dist: trusty + +os: + - linux + +addons: + apt: + packages: + - git + - python + - python-pip + - python2.7-dev + - libc6-i386 + - curl + +compiler: + - clang + +before_install: + - sudo pip install -U virtualenv pre-commit pip + # Download and install recent cmake + +script: + - | + function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } + - | + timeout 600 .travis/pre-commit-job.sh # 10min timeout + RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else exit 1; fi; + +notifications: + email: + on_success: change + on_failure: always diff --git a/mobile/.travis/pre-commit-job.sh b/mobile/.travis/pre-commit-job.sh new file mode 100755 index 0000000000000000000000000000000000000000..a0ae98dddd27a7f24467ce2ce441aba9e4ffe156 --- /dev/null +++ b/mobile/.travis/pre-commit-job.sh @@ -0,0 +1,21 @@ +#!/bin/bash +function abort(){ + echo "Your change doesn't follow Paddle-Moible's code style" 1>&2 + echo "Please use pre-commit to auto-format your code." 1>&2 + exit 1 +} + +trap 'abort' 0 +set -e +cd `dirname $0` +cd .. +export PATH=/usr/bin:$PATH +pre-commit install + +if ! pre-commit run -a ; then + ls -lh + git diff --exit-code + exit 1 +fi + +trap : 0 diff --git a/mobile/src/framework/cl/cl_engine.h b/mobile/src/framework/cl/cl_engine.h index 0c18cd5f1a4628bc6d157bc664af175f2bdd2be4..439209d0e8590429ce961b894af6735271099112 100644 --- a/mobile/src/framework/cl/cl_engine.h +++ b/mobile/src/framework/cl/cl_engine.h @@ -96,6 +96,21 @@ class CLEngine { return std::move(program_ptr); } + std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWithSource( + cl_context context, const char *source) { + size_t sourceSize[] = {strlen(source)}; + cl_program p = + clCreateProgramWithSource(context, 1, &source, sourceSize, &status_); + + DLOG << " cl kernel from source"; + DLOG << " source size: " << sourceSize[0]; + CL_CHECK_ERRORS(status_); + + std::unique_ptr<_cl_program, CLProgramDeleter> program_ptr(p); + + return std::move(program_ptr); + } + std::unique_ptr<_cl_event, CLEventDeleter> CreateEvent(cl_context context) { cl_event event = clCreateUserEvent(context, &status_); std::unique_ptr<_cl_event, CLEventDeleter> event_ptr(event); diff --git a/mobile/src/framework/cl/cl_scope.h b/mobile/src/framework/cl/cl_scope.h index e45baa525c5738d255a8296aaea19cd082734279..5f15a1f6d17f81dd18c5acdd923b8ad5c71e644c 100644 --- a/mobile/src/framework/cl/cl_scope.h +++ b/mobile/src/framework/cl/cl_scope.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include #include #include +#include #include "CL/cl.h" #include "framework/cl/cl_deleter.h" @@ -24,6 +26,10 @@ limitations under the License. */ #include "framework/cl/cl_tool.h" namespace paddle_mobile { + +extern const std::map> opencl_kernels; +extern const std::vector need_conv_header_kernels; + namespace framework { class CLScope { @@ -62,15 +68,35 @@ class CLScope { return it->second.get(); } - auto program = CLEngine::Instance()->CreateProgramWith( - context_, - CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name); - - DLOG << " --- begin build program -> " << program_key << " --- "; - CLEngine::Instance()->BuildProgram(program.get(), options); - DLOG << " --- end build program -> " << program_key << " --- "; - - programs_[program_key] = std::move(program); + if (opencl_kernels.find(file_name) != opencl_kernels.end()) { + auto it = opencl_kernels.find(file_name); + std::string source(it->second.begin(), it->second.end()); + if (std::find(need_conv_header_kernels.begin(), + need_conv_header_kernels.end(), + file_name) != need_conv_header_kernels.end()) { + auto it = opencl_kernels.find("conv_kernel.inc.cl"); + std::string header(it->second.begin(), it->second.end()); + source = header + source; + } + auto program = CLEngine::Instance()->CreateProgramWithSource( + context_, source.c_str()); + + DLOG << " --- begin build program -> " << program_key << " --- "; + CLEngine::Instance()->BuildProgram(program.get(), options); + DLOG << " --- end build program -> " << program_key << " --- "; + + programs_[program_key] = std::move(program); + } else { + auto program = CLEngine::Instance()->CreateProgramWith( + context_, + CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name); + + DLOG << " --- begin build program -> " << program_key << " --- "; + CLEngine::Instance()->BuildProgram(program.get(), options); + DLOG << " --- end build program -> " << program_key << " --- "; + + programs_[program_key] = std::move(program); + } return programs_[program_key].get(); } diff --git a/mobile/src/io/paddle_mobile_wrap.h b/mobile/src/io/paddle_mobile_wrap.h index 72d85b8a5727b54bdcea12ddb061e8b1675cec4d..5048b1234e318dfd7606114989587ef2ffbc4244 100644 --- a/mobile/src/io/paddle_mobile_wrap.h +++ b/mobile/src/io/paddle_mobile_wrap.h @@ -16,9 +16,9 @@ limitations under the License. */ #include #include +#include #include #include -#include #include #include diff --git a/mobile/src/operators/kernel/arm/conditional_block_kernel.cpp b/mobile/src/operators/kernel/arm/conditional_block_kernel.cpp index df98f74b8f5322c2d492b4c349e499a0ac82014c..a5530559d1a3a90996eb1a4ed94b31b85edad521 100644 --- a/mobile/src/operators/kernel/arm/conditional_block_kernel.cpp +++ b/mobile/src/operators/kernel/arm/conditional_block_kernel.cpp @@ -14,10 +14,10 @@ limitations under the License. */ #ifdef CONDITIONAL_BLOCK_OP -#include #include "operators/kernel/conditional_block_kernel.h" #include #include +#include #include "framework/data_type.h" namespace paddle_mobile { diff --git a/mobile/src/operators/kernel/cl/gen_code.py b/mobile/src/operators/kernel/cl/gen_code.py new file mode 100644 index 0000000000000000000000000000000000000000..14608c95fc0924ea95eea1b25493262f81c45505 --- /dev/null +++ b/mobile/src/operators/kernel/cl/gen_code.py @@ -0,0 +1,103 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +import sys + +source = """ +#pragma +#ifdef PADDLE_MOBILE_CL +#include +#include +#include +namespace paddle_mobile { + extern const std::map> opencl_kernels = { +%s + }; + extern const std::vector need_conv_header_kernels = { + %s + }; +} +#endif +""" + +def string_to_hex(str): + hex_list = [] + for i in range(len(code_str)): + hex_ = hex(ord(code_str[i])) + hex_list.append(hex_) + return hex_list + +infile = open("cl_kernel/cl_common.h", "r") +common_content = infile.read() +infile.close() +common_content = re.sub(r"/\*[^*]*\*/", "", common_content, flags=re.DOTALL) +lines = common_content.split("\n") +new_lines = [] +for i in range(len(lines)): + line = lines[i] + line = line.strip() + if line == "": + continue + if line.startswith("//"): + continue + line = re.sub(r"//.*$", "", line) + new_lines.append(line) +common_content = "\n".join(new_lines) + +need_conv_header_kernels = [] + +cores = "" +filenames = os.listdir("cl_kernel") +file_count = len(filenames) +for i in range(file_count): + filename = filenames[i] + infile = open("cl_kernel/" + filename, "r") + new_lines = [] + content = infile.read() + content = re.sub(r"/\*[^*]*\*/", "", content, flags=re.DOTALL) + infile.close() + lines = content.split("\n") + for i in range(len(lines)): + line = lines[i] + line = line.strip() + if line == "": + continue + if line.startswith("//"): + continue + line = re.sub(r"//.*$", "", line) + if "cl_common.h" in line: + line = common_content + elif "conv_kernel.inc.cl" in line: + need_conv_header_kernels.append("\"%s\"" % filename) + continue + new_lines.append(line) + content = "\n".join(new_lines) + if content == "": + content = " " + hexes = [] + for char in content: + hexes.append(hex(ord(char))) + core = " {\"%s\", {" % filename + for item in hexes: + core += str(item) + ", " + core = core[: -2] + core += "}}" + if i != file_count - 1: + core += ",\n" + cores += core + +source = source % (cores, ",".join(need_conv_header_kernels)) +print(source) diff --git a/mobile/test/CMakeLists.txt b/mobile/test/CMakeLists.txt index 36293ab8846741fd7e5c4de66fe6537eca277270..1b6675f43eed3d97213aef198e33db622676227a 100644 --- a/mobile/test/CMakeLists.txt +++ b/mobile/test/CMakeLists.txt @@ -1,6 +1,7 @@ set(dir ${CMAKE_CURRENT_SOURCE_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${dir}/build") set(FOUND_MATCH OFF) +set(ENABLE_ALL_TEST ON) set(CON -1) @@ -197,334 +198,340 @@ if (CON GREATER -1) set(FOUND_MATCH ON) endif () -if (NOT FOUND_MATCH) - # gen test - ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-resnet paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-squeezenet net/test_squeezenet.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-squeezenet paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-yolo net/test_yolo.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-yolo paddle-mobile) - - # gen test - ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test_yolo_combined paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-op-in-net net/test_op_in_net.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-op-in-net paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-googlenet paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-googlenet-quali net/test_googlenet_quali.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-googlenet-quali paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-conv-op operators/test_conv_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-conv-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-mul-op operators/test_mul_op.cpp test_helper.h test_include.h) - target_link_libraries(test-mul-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-elementwiseadd-op operators/test_elementwise_add_op.cpp test_helper.h test_include.h) - target_link_libraries(test-elementwiseadd-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-elementwisesub-op operators/test_elementwise_sub_op.cpp test_helper.h test_include.h) - target_link_libraries(test-elementwisesub-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-im2sequence-op operators/test_im2sequence_op.cpp test_helper.h test_include.h) - target_link_libraries(test-im2sequence-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h) - target_link_libraries(test-concat-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-lrn-op operators/test_lrn_op.cpp test_helper.h test_include.h) - target_link_libraries(test-lrn-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-batchnorm-op operators/test_batchnorm_op.cpp test_helper.h test_include.h) - target_link_libraries(test-batchnorm-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-priorbox-op operators/test_prior_box_op.cpp test_helper.h test_include.h) - target_link_libraries(test-priorbox-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-boxcoder-op operators/test_box_coder_op.cpp test_helper.h test_include.h) - target_link_libraries(test-boxcoder-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-transpose-op operators/test_transpose_op.cpp test_helper.h test_include.h) - target_link_libraries(test-transpose-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-transpose2-op operators/test_transpose2_op.cpp test_helper.h test_include.h) - target_link_libraries(test-transpose2-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h) - target_link_libraries(test-multiclassnms-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h) - target_link_libraries(test-polygon-box-transform-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-fill-constant-op operators/test_fill_constant_op.cpp test_helper.h test_include.h) - target_link_libraries(test-fill-constant-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) - target_link_libraries(test-reshape-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-reshape2-op operators/test_reshape2_op.cpp test_helper.h test_include.h) - target_link_libraries(test-reshape2-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h) - target_link_libraries(test-relu-op paddle-mobile) - - ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h) - target_link_libraries(test-relu6-op paddle-mobile) - - ADD_EXECUTABLE(test-tanh-op operators/test_tanh_op.cpp test_helper.h test_include.h) - target_link_libraries(test-tanh-op paddle-mobile) - - ADD_EXECUTABLE(test-log-op operators/test_log_op.cpp test_helper.h test_include.h) - target_link_libraries(test-log-op paddle-mobile) - - ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h) - target_link_libraries(test-topk-op paddle-mobile) - - ADD_EXECUTABLE(test-cast-op operators/test_cast_op.cpp test_helper.h test_include.h) - target_link_libraries(test-cast-op paddle-mobile) - - ADD_EXECUTABLE(test-less-than-op operators/test_less_than_op.cpp test_helper.h test_include.h) - target_link_libraries(test-less-than-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h) - target_link_libraries(test-fc-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-sum-op operators/test_sum_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sum-op paddle-mobile) - - # test quantize op - ADD_EXECUTABLE(test-quantize-op operators/test_quantize_op.cpp test_helper.h test_include.h) - target_link_libraries(test-quantize-op paddle-mobile) - - # test dequantize op - ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h) - target_link_libraries(test-dequantize-op paddle-mobile) - - # gen test log - ADD_EXECUTABLE(test-log common/test_log.cpp) - target_link_libraries(test-log paddle-mobile) - - # gen test log - ADD_EXECUTABLE(test-load framework/test_load.cpp) - target_link_libraries(test-load paddle-mobile) - - # gen test log - ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp) - target_link_libraries(test-loadmemory paddle-mobile) - - # gen test log - ADD_EXECUTABLE(test-loadmemory-inference framework/test_load_memory_inference_api.cpp) - target_link_libraries(test-loadmemory-inference paddle-mobile) - - ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp) - target_link_libraries(test-inference-api paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp) - target_link_libraries(test-optimize paddle-mobile) - - #gen test - ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-pool-op paddle-mobile) - - #gen test - ADD_EXECUTABLE(test-softmax-op operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-softmax-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp) - target_link_libraries(test-gemm-accuracy paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-gemm-int8-accuracy common/test_gemm_int8_accuracy.cpp) - target_link_libraries(test-gemm-int8-accuracy paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp) - target_link_libraries(test-gemm-perf paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-enforce common/test_enforce.cpp) - target_link_libraries(test-enforce paddle-mobile) - - # gen test - test if openmp works - ADD_EXECUTABLE(test-openmp common/test_openmp.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-openmp paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-mobilenetssd net/test_mobilenet+ssd.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-mobilenetssd paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-mobilenet-combine net/test_mobilenet_combine.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-mobilenet-combine paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-genet net/test_genet_combine.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-genet paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-sigmoid-op operators/test_sigmoid_op.cpp test_include.h) - target_link_libraries(test-sigmoid-op paddle-mobile) - - # gen test log - ADD_EXECUTABLE(test-leakyrelu operators/test_leaky_relu_op.cpp) - target_link_libraries(test-leakyrelu paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-depthwise-conv-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-mobilenet net/test_mobilenet.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-mobilenet paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-conv-add-relu-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-conv-add-bn-relu-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-nlp net/test_nlp.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-nlp paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h) - target_link_libraries(test-gru-op paddle-mobile) - - # gen test - - ADD_EXECUTABLE(test-inceptionv4 net/test_inceptionv4.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-inceptionv4 paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-alexnet net/test_alexnet.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-alexnet paddle-mobile) - - ADD_EXECUTABLE(test-googlenetv1 net/test_googlenetv1_combine.cpp test_helper.h test_include.h) - target_link_libraries(test-googlenetv1 paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-fssd net/test_mobilenet_025_fssd.cpp test_helper.h test_include.h) - target_link_libraries(test-fssd paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-mobilenetgpu net/test_mobilenet_GPU.cpp test_helper.h test_include.h) - target_link_libraries(test-mobilenetgpu paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-yologpu net/test_yologpu.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-yologpu paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h) - target_link_libraries(test-multi-process paddle-mobile) - - # gen test benchmark - ADD_EXECUTABLE(test-benchmark net/test_benchmark.cpp) - target_link_libraries(test-benchmark paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h) - target_link_libraries(test-eng paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-super net/test_super.cpp test_helper.h test_include.h) - target_link_libraries(test-super paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h) - target_link_libraries(test-ocr paddle-mobile) - - ADD_EXECUTABLE(test-gesture net/test_gesture.cpp test_helper.h test_include.h) - target_link_libraries(test-gesture paddle-mobile) - - - ADD_EXECUTABLE(test-sequence-expand-op operators/test_sequence_expand_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sequence-expand-op paddle-mobile) - - ADD_EXECUTABLE(test-sequence-pool-op operators/test_sequence_pool_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sequence-pool-op paddle-mobile) - - ADD_EXECUTABLE(test-sequence-softmax-op operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sequence-softmax-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-vgg16ssd net/test_vgg16ssd.cpp test_helper.h test_include.h) - target_link_libraries(test-vgg16ssd paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-logical-and-op operators/test_logical_and_op.cpp test_helper.h test_include.h) - target_link_libraries(test-logical-and-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-logical-or-op operators/test_logical_or_op.cpp test_helper.h test_include.h) - target_link_libraries(test-logical-or-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-logical-not-op operators/test_logical_not_op.cpp test_helper.h test_include.h) - target_link_libraries(test-logical-not-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h) - target_link_libraries(test-logical-xor-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-increment-op operators/test_increment_op.cpp test_helper.h test_include.h) - target_link_libraries(test-increment-op paddle-mobile) - - # gen test - ADD_EXECUTABLE(test-is-empty-op operators/test_is_empty_op.cpp test_helper.h test_include.h) - target_link_libraries(test-is-empty-op paddle-mobile) - - ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h) - target_link_libraries(test-conv-bn-relu-op paddle-mobile) - - ADD_EXECUTABLE(test-dwconv-bn-relu-op operators/test_dwconv_bn_relu_op.cpp test_helper.h test_include.h) - target_link_libraries(test-dwconv-bn-relu-op paddle-mobile) - - ADD_EXECUTABLE(test-conv-gpu operators/test_conv_gpu.cpp test_helper.h test_include.h) - target_link_libraries(test-conv-gpu paddle-mobile) - - ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h) - target_link_libraries(test-net-benchmark paddle-mobile) - +if (ENABLE_ALL_TEST) + if (NOT FOUND_MATCH) + # gen test + ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-resnet paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-squeezenet net/test_squeezenet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-squeezenet paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-yolo net/test_yolo.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-yolo paddle-mobile) + + # gen test + ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test_yolo_combined paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-op-in-net net/test_op_in_net.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-op-in-net paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-googlenet paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-googlenet-quali net/test_googlenet_quali.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-googlenet-quali paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-conv-op operators/test_conv_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-conv-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-mul-op operators/test_mul_op.cpp test_helper.h test_include.h) + target_link_libraries(test-mul-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-elementwiseadd-op operators/test_elementwise_add_op.cpp test_helper.h test_include.h) + target_link_libraries(test-elementwiseadd-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-elementwisesub-op operators/test_elementwise_sub_op.cpp test_helper.h test_include.h) + target_link_libraries(test-elementwisesub-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-im2sequence-op operators/test_im2sequence_op.cpp test_helper.h test_include.h) + target_link_libraries(test-im2sequence-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h) + target_link_libraries(test-concat-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-lrn-op operators/test_lrn_op.cpp test_helper.h test_include.h) + target_link_libraries(test-lrn-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-batchnorm-op operators/test_batchnorm_op.cpp test_helper.h test_include.h) + target_link_libraries(test-batchnorm-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-priorbox-op operators/test_prior_box_op.cpp test_helper.h test_include.h) + target_link_libraries(test-priorbox-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-boxcoder-op operators/test_box_coder_op.cpp test_helper.h test_include.h) + target_link_libraries(test-boxcoder-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-transpose-op operators/test_transpose_op.cpp test_helper.h test_include.h) + target_link_libraries(test-transpose-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-transpose2-op operators/test_transpose2_op.cpp test_helper.h test_include.h) + target_link_libraries(test-transpose2-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h) + target_link_libraries(test-multiclassnms-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h) + target_link_libraries(test-polygon-box-transform-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-fill-constant-op operators/test_fill_constant_op.cpp test_helper.h test_include.h) + target_link_libraries(test-fill-constant-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) + target_link_libraries(test-reshape-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-reshape2-op operators/test_reshape2_op.cpp test_helper.h test_include.h) + target_link_libraries(test-reshape2-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h) + target_link_libraries(test-relu-op paddle-mobile) + + ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h) + target_link_libraries(test-relu6-op paddle-mobile) + + ADD_EXECUTABLE(test-tanh-op operators/test_tanh_op.cpp test_helper.h test_include.h) + target_link_libraries(test-tanh-op paddle-mobile) + + ADD_EXECUTABLE(test-log-op operators/test_log_op.cpp test_helper.h test_include.h) + target_link_libraries(test-log-op paddle-mobile) + + ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h) + target_link_libraries(test-topk-op paddle-mobile) + + ADD_EXECUTABLE(test-cast-op operators/test_cast_op.cpp test_helper.h test_include.h) + target_link_libraries(test-cast-op paddle-mobile) + + ADD_EXECUTABLE(test-less-than-op operators/test_less_than_op.cpp test_helper.h test_include.h) + target_link_libraries(test-less-than-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h) + target_link_libraries(test-fc-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-sum-op operators/test_sum_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sum-op paddle-mobile) + + # test quantize op + ADD_EXECUTABLE(test-quantize-op operators/test_quantize_op.cpp test_helper.h test_include.h) + target_link_libraries(test-quantize-op paddle-mobile) + + # test dequantize op + ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h) + target_link_libraries(test-dequantize-op paddle-mobile) + + # gen test log + ADD_EXECUTABLE(test-log common/test_log.cpp) + target_link_libraries(test-log paddle-mobile) + + # gen test log + ADD_EXECUTABLE(test-load framework/test_load.cpp) + target_link_libraries(test-load paddle-mobile) + + # gen test log + ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp) + target_link_libraries(test-loadmemory paddle-mobile) + + # gen test log + ADD_EXECUTABLE(test-loadmemory-inference framework/test_load_memory_inference_api.cpp) + target_link_libraries(test-loadmemory-inference paddle-mobile) + + ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp) + target_link_libraries(test-inference-api paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp) + target_link_libraries(test-optimize paddle-mobile) + + #gen test + ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-pool-op paddle-mobile) + + #gen test + ADD_EXECUTABLE(test-softmax-op operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-softmax-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp) + target_link_libraries(test-gemm-accuracy paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-gemm-int8-accuracy common/test_gemm_int8_accuracy.cpp) + target_link_libraries(test-gemm-int8-accuracy paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp) + target_link_libraries(test-gemm-perf paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-enforce common/test_enforce.cpp) + target_link_libraries(test-enforce paddle-mobile) + + # gen test - test if openmp works + ADD_EXECUTABLE(test-openmp common/test_openmp.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-openmp paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-mobilenetssd net/test_mobilenet+ssd.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-mobilenetssd paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-mobilenet-combine net/test_mobilenet_combine.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-mobilenet-combine paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-genet net/test_genet_combine.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-genet paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-sigmoid-op operators/test_sigmoid_op.cpp test_include.h) + target_link_libraries(test-sigmoid-op paddle-mobile) + + # gen test log + ADD_EXECUTABLE(test-leakyrelu operators/test_leaky_relu_op.cpp) + target_link_libraries(test-leakyrelu paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-depthwise-conv-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-mobilenet net/test_mobilenet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-mobilenet paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-conv-add-relu-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-conv-add-bn-relu-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-nlp net/test_nlp.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-nlp paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h) + target_link_libraries(test-gru-op paddle-mobile) + + # gen test + + ADD_EXECUTABLE(test-inceptionv4 net/test_inceptionv4.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-inceptionv4 paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-alexnet net/test_alexnet.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-alexnet paddle-mobile) + + ADD_EXECUTABLE(test-googlenetv1 net/test_googlenetv1_combine.cpp test_helper.h test_include.h) + target_link_libraries(test-googlenetv1 paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-fssd net/test_mobilenet_025_fssd.cpp test_helper.h test_include.h) + target_link_libraries(test-fssd paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-mobilenetgpu net/test_mobilenet_GPU.cpp test_helper.h test_include.h) + target_link_libraries(test-mobilenetgpu paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-yologpu net/test_yologpu.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-yologpu paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h) + target_link_libraries(test-multi-process paddle-mobile) + + # gen test benchmark + ADD_EXECUTABLE(test-benchmark net/test_benchmark.cpp) + target_link_libraries(test-benchmark paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h) + target_link_libraries(test-eng paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-super net/test_super.cpp test_helper.h test_include.h) + target_link_libraries(test-super paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h) + target_link_libraries(test-ocr paddle-mobile) + + ADD_EXECUTABLE(test-gesture net/test_gesture.cpp test_helper.h test_include.h) + target_link_libraries(test-gesture paddle-mobile) + + + ADD_EXECUTABLE(test-sequence-expand-op operators/test_sequence_expand_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-expand-op paddle-mobile) + + ADD_EXECUTABLE(test-sequence-pool-op operators/test_sequence_pool_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-pool-op paddle-mobile) + + ADD_EXECUTABLE(test-sequence-softmax-op operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-softmax-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-vgg16ssd net/test_vgg16ssd.cpp test_helper.h test_include.h) + target_link_libraries(test-vgg16ssd paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-logical-and-op operators/test_logical_and_op.cpp test_helper.h test_include.h) + target_link_libraries(test-logical-and-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-logical-or-op operators/test_logical_or_op.cpp test_helper.h test_include.h) + target_link_libraries(test-logical-or-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-logical-not-op operators/test_logical_not_op.cpp test_helper.h test_include.h) + target_link_libraries(test-logical-not-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h) + target_link_libraries(test-logical-xor-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-increment-op operators/test_increment_op.cpp test_helper.h test_include.h) + target_link_libraries(test-increment-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-is-empty-op operators/test_is_empty_op.cpp test_helper.h test_include.h) + target_link_libraries(test-is-empty-op paddle-mobile) + + ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h) + target_link_libraries(test-conv-bn-relu-op paddle-mobile) + + ADD_EXECUTABLE(test-dwconv-bn-relu-op operators/test_dwconv_bn_relu_op.cpp test_helper.h test_include.h) + target_link_libraries(test-dwconv-bn-relu-op paddle-mobile) + + ADD_EXECUTABLE(test-conv-gpu operators/test_conv_gpu.cpp test_helper.h test_include.h) + target_link_libraries(test-conv-gpu paddle-mobile) + + ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h) + target_link_libraries(test-net-benchmark paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-net paddle-mobile) + endif () +else() # gen test ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-net paddle-mobile) -endif () +endif() diff --git a/mobile/tools/build.sh b/mobile/tools/build.sh index f7ae755a818fd605de70739f088a4b49530e22a8..f0e192805b1709f9ee721e79758f650ec18c16df 100755 --- a/mobile/tools/build.sh +++ b/mobile/tools/build.sh @@ -2,6 +2,15 @@ NETS="" declare -a supportedNets=("googlenet" "mobilenet" "yolo" "squeezenet" "resnet" "mobilenetssd" "nlp" "mobilenetfssd" "genet" "super" "op") +# merge cl to so +merge_cl_to_so=1 +rm ../src/operators/kernel/cl/opencl_kernels.cpp +if [ $merge_cl_to_so == 1 ]; then + cd ../src/operators/kernel/cl + python gen_code.py > opencl_kernels.cpp + cd - +fi + build_for_mac() { if [ ! `which brew` ]; then echo "building failed! homebrew not found, please install homebrew." diff --git a/mobile/tools/pre-commit.hooks/clang-format.hook b/mobile/tools/pre-commit.hooks/clang-format.hook index 92377d2dd6b53c69aaff41e4ea204b80fef31671..ffba8744f4b96c53907f7848592418e4356bf6bb 100644 --- a/mobile/tools/pre-commit.hooks/clang-format.hook +++ b/mobile/tools/pre-commit.hooks/clang-format.hook @@ -1,5 +1,5 @@ #!/bin/bash -set -e +# set -e readonly VERSION="5.0" diff --git a/mobile/tools/python/fluidtools/run.py b/mobile/tools/python/fluidtools/run.py index fc65f19a1dfc0e3fce2c55f487ba901cd9132242..2bf704fb8d119b657ba114fc90c01f92020bc7ce 100644 --- a/mobile/tools/python/fluidtools/run.py +++ b/mobile/tools/python/fluidtools/run.py @@ -535,6 +535,7 @@ def main(): push(checked_model_path) push(feed_path + "/" + last_feed_file_name, "input.txt") push(mobile_src_root + "/build/release/arm-v7a/build/libpaddle-mobile.so") + push(mobile_src_root + "/build/release/arm-v7a/build/cl_kernel") push(mobile_src_root + "/test/build/test-net") last_feed_var_shape = get_feed_var_shape(last_feed_var_name) args = str(len(last_feed_var_shape))