diff --git a/README.md b/README.md index 23974beee9a8af5ee7e2c454575efff2e3d96ee2..22b84888294b5ef60c3d91d7a7909aef8f601d81 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Framework compatibility: In addition to models trained on PaddlePaddle, those tr Paddle Lite is designed to support a wide range of hardwares and devices, and it enables mixed execution of a single model on multiple devices, optimization on various phases, and leight-weighted applications on devices. -![img](https://github.com/Superjomn/_tmp_images/raw/master/images/paddle-lite-architecture.png) +![img](https://user-images.githubusercontent.com/45189361/70908123-6ce4fd00-2045-11ea-97e1-ad08446c5c86.png) As is shown in the figure above, analysis phase includes Machine IR module, and it enables optimizations like Op fusion and redundant computation pruning. Besides, excecution phase only involves Kernal exevution, so it can be deployed on its own to ensure maximized light-weighted deployment. diff --git a/README_cn.md b/README_cn.md index 99d38c47ffbbaa3b8593801701e3528167899f97..11d3967fe8ce88826ca982b71d96268c1a7e5c3a 100644 --- a/README_cn.md +++ b/README_cn.md @@ -34,7 +34,7 @@ Paddle Lite为Paddle-Mobile的升级版,定位支持包括手机移动端在 PaddleLite 的架构设计着重考虑了对多硬件和平台的支持,并且强化了多个硬件在一个模型中混合执行的能力,多个层面的性能优化处理,以及对端侧应用的轻量化设计。 -![](https://github.com/Superjomn/_tmp_images/raw/master/images/paddle-lite-architecture.png) +![](https://user-images.githubusercontent.com/45189361/70908123-6ce4fd00-2045-11ea-97e1-ad08446c5c86.png) 其中,Analysis Phase 包括了 MIR(Machine IR) 相关模块,能够对原有的模型的计算图针对具体的硬件列表进行算子融合、计算裁剪 在内的多种优化。Execution Phase 只涉及到Kernel 的执行,且可以单独部署,以支持极致的轻量级部署。 diff --git a/cmake/cross_compiling/postproject.cmake b/cmake/cross_compiling/postproject.cmake index 88ac3e101a686cb49ef5a4c3b1879c15b8f7b57b..7466b3e6d438277ad31020f76665bf689df436f5 100644 --- a/cmake/cross_compiling/postproject.cmake +++ b/cmake/cross_compiling/postproject.cmake @@ -63,7 +63,7 @@ if (LITE_ON_TINY_PUBLISH) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto -fvisibility=hidden -fvisibility-inlines-hidden -fdata-sections -ffunction-sections") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections") check_linker_flag(-Wl,--gc-sections) endif() diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 574baa86a82963ffa76795e029a6ba14f537c80a..e5f1ec4cf21806992b22558f102c806e90e8858e 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -32,10 +32,9 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS $ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib /usr/lib - ${CUDA_TOOLKIT_ROOT_DIR} - ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 - ${CUDA_TOOLKIT_ROOT_DIR}/lib64 - ) + ${CUDA_TOOLKIT_ROOT_DIR} + ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0)) find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH) diff --git a/cmake/lite.cmake b/cmake/lite.cmake index a095eea6d1cce9ba09ee631a50b8029e769f6d37..d6b374529e27119f1c48c03c667aa694481e45e8 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -118,7 +118,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean function(lite_cc_library TARGET) set(options SHARED shared STATIC static MODULE module) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS XPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -128,10 +128,10 @@ function(lite_cc_library TARGET) X86_DEPS ${args_X86_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} - NPU_DEPS ${args_NPU_DEPS} - XPU_DEPS ${args_XPU_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} @@ -161,7 +161,7 @@ function(lite_cc_binary TARGET) set(options " -g ") endif() set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -173,6 +173,8 @@ function(lite_cc_binary TARGET) CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} @@ -205,7 +207,7 @@ function(lite_cc_test TARGET) endif() set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS COMPILE_LEVEL # (basic|extra) @@ -225,6 +227,8 @@ function(lite_cc_test TARGET) CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} @@ -267,7 +271,7 @@ endif() function(add_kernel TARGET device level) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -360,11 +364,12 @@ function(add_kernel TARGET device level) lite_cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS} X86_DEPS ${args_X86_DEPS} - XPU_DEPS ${args_XPU_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} @@ -383,7 +388,7 @@ endif() function(add_operator TARGET level) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -409,11 +414,12 @@ function(add_operator TARGET level) lite_cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS} X86_DEPS ${args_X86_DEPS} - XPU_DEPS ${args_XPU_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} diff --git a/cmake/xpu.cmake b/cmake/xpu.cmake index 8d99343c3041351102820cb20890031fa3f5807e..ab34f409b8fa08af4eb01ff1289107a599d8c27d 100644 --- a/cmake/xpu.cmake +++ b/cmake/xpu.cmake @@ -89,7 +89,7 @@ else() endif() find_library(XPU_SDK_LLVM_FILE NAMES LLVM-8 - PATHS ${XPU_SDK_ROOT}/XTDK/shlib) + PATHS ${XPU_SDK_ROOT}/XTDK/shlib/gcc482) if(NOT XPU_SDK_LLVM_FILE) message(FATAL_ERROR "Can not find LLVM Library in ${XPU_SDK_ROOT}") diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 21e53bde34af66cadeea84b831fda3eccf77c643..df6b7d3648409e13d88c049ec86173905f8b3cb6 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -172,13 +172,17 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/include" COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/libpaddle_light_api_shared.so" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" ) add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared) + add_dependencies(tiny_publish_cxx_lib bundle_light_api) add_dependencies(publish_inference tiny_publish_cxx_lib) - add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD - COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so) + if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD + COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so) + endif() endif() endif() endif() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 408a63e3f5bd911ec93575d7cd6b2e2ef3b2b2d8..70239e94e7a3064fb383246623d05a2079dda1fa 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -16,8 +16,11 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) target_link_libraries(paddle_full_api_shared framework_proto) if(LITE_WITH_X86) - add_dependencies(paddle_full_api_shared xxhash) - target_link_libraries(paddle_full_api_shared xxhash) + add_dependencies(paddle_full_api_shared xxhash) + target_link_libraries(paddle_full_api_shared xxhash) + if (NOT LITE_ON_MODEL_OPTIMIZE_TOOL) + add_dependencies(paddle_full_api_shared dynload_mklml) + endif() endif() if(LITE_WITH_CUDA) target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive") @@ -38,10 +41,11 @@ else() if ((ARM_TARGET_OS STREQUAL "android") OR (ARM_TARGET_OS STREQUAL "armlinux")) add_library(paddle_light_api_shared SHARED "") target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc) - add_dependencies(paddle_light_api_shared op_list_h kernel_list_h) + set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "-flto -fdata-sections") + add_dependencies(paddle_light_api_shared op_list_h kernel_list_h) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency - target_link_libraries(paddle_light_api_shared ${npu_runtime_libs}) + target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs}) endif() endif() endif() @@ -77,8 +81,8 @@ if (NOT LITE_ON_TINY_PUBLISH) DEPS ${cxx_api_deps} ${ops} ${host_kernels} program X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} - NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass - XPU_DEPS ${xpu_kernels} ${xpu_bridges} xpu_pass + NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels}) endif() diff --git a/lite/api/_paddle_use_ops.h b/lite/api/_paddle_use_ops.h index bdccfab5df67e485b9fef110dc6cc1e9d74b21c3..6da47e53789d651f4a36d0b8d6a7ca1ea5a0a3d3 100644 --- a/lite/api/_paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -108,7 +108,7 @@ USE_LITE_OP(while) USE_LITE_OP(lod_reset) USE_LITE_OP(lookup_table) USE_LITE_OP(multiclass_nms) -USE_LITE_OP(graph_op) +USE_LITE_OP(subgraph) USE_LITE_OP(sequence_expand) USE_LITE_OP(sequence_pool) USE_LITE_OP(reduce_max) diff --git a/lite/api/android/jni/native/CMakeLists.txt b/lite/api/android/jni/native/CMakeLists.txt index 3efa980332f25d786d5c880fab9b3ba5af0a1013..c1766772f8aaa417c3da1d72f2692c10c10194b4 100644 --- a/lite/api/android/jni/native/CMakeLists.txt +++ b/lite/api/android/jni/native/CMakeLists.txt @@ -25,11 +25,12 @@ if (NOT LITE_ON_TINY_PUBLISH) endif() else() add_library(paddle_lite_jni SHARED "") + set_target_properties(paddle_lite_jni PROPERTIES COMPILE_FLAGS "-flto -fdata-sections") target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc) add_dependencies(paddle_lite_jni op_list_h kernel_list_h) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency - target_link_libraries(paddle_lite_jni ${npu_runtime_libs}) + target_link_libraries(paddle_lite_jni ${npu_builder_libs} ${npu_runtime_libs}) endif() endif() diff --git a/lite/api/android/jni/native/tensor_jni.cc b/lite/api/android/jni/native/tensor_jni.cc index 59cafa19399c4d265915e2dac8653e9ed7d10851..5212fe9a6eba2b034883da93c9ea5d845a63c773 100644 --- a/lite/api/android/jni/native/tensor_jni.cc +++ b/lite/api/android/jni/native/tensor_jni.cc @@ -120,6 +120,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( return JNI_TRUE; } +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( + JNIEnv *env, jobject jtensor, jintArray buf) { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + if (tensor == nullptr || (*tensor == nullptr)) { + return JNI_FALSE; + } + int64_t buf_size = (int64_t)env->GetArrayLength(buf); + if (buf_size != product((*tensor)->shape())) { + return JNI_FALSE; + } + + int32_t *input = (*tensor)->mutable_data(); + env->GetIntArrayRegion(buf, 0, buf_size, input); + return JNI_TRUE; +} + JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) { if (is_const_tensor(env, jtensor)) { @@ -148,6 +164,20 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) { } } +JNIEXPORT jintArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) { + if (is_const_tensor(env, jtensor)) { + std::unique_ptr *tensor = + get_read_only_tensor_pointer(env, jtensor); + return cpp_array_to_jintarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } else { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + return cpp_array_to_jintarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } +} + JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor( JNIEnv *env, jobject jtensor, jlong java_pointer) { if (java_pointer == 0) { diff --git a/lite/api/android/jni/native/tensor_jni.h b/lite/api/android/jni/native/tensor_jni.h index 34c35b6a76f777895dbe88dc5eadf48c659ee544..9b029dfb4c7431354d5de20c6132236764c6cc66 100644 --- a/lite/api/android/jni/native/tensor_jni.h +++ b/lite/api/android/jni/native/tensor_jni.h @@ -16,8 +16,8 @@ #include /* Header for class com_baidu_paddle_lite_Tensor */ -#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ -#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#ifndef LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#define LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ #ifdef __cplusplus extern "C" { #endif @@ -49,6 +49,14 @@ Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject); JNIEXPORT jbyteArray JNICALL Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject); +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: getIntData + * Signature: ()[I + */ +JNIEXPORT jintArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject); + /* * Class: com_baidu_paddle_lite_Tensor * Method: nativeResize @@ -73,6 +81,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( JNIEnv *, jobject, jbyteArray); +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: nativeSetData + * Signature: ([I)Z + */ +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( + JNIEnv *, jobject, jintArray); + /* * Class: com_baidu_paddle_lite_Tensor * Method: deleteCppTensor @@ -87,4 +103,4 @@ Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong); #ifdef __cplusplus } #endif -#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#endif // LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java index ac78800bd2e4903b44332a0a0aefe9c69b75abab..f76841dd413ddda86678eecf8241068dd98b74a4 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java @@ -108,6 +108,19 @@ public class Tensor { return nativeSetData(buf); } + /** + * Set the tensor int data. + * + * @param buf the int array buffer which will be copied into tensor. + * @return true if set data successfully. + */ + public boolean setData(int[] buf) { + if (readOnly) { + return false; + } + return nativeSetData(buf); + } + /** * @return shape of the tensor as long array. */ @@ -123,12 +136,19 @@ public class Tensor { */ public native byte[] getByteData(); + /** + * @return the tensor data as int array. + */ + public native int[] getIntData(); + private native boolean nativeResize(long[] dims); private native boolean nativeSetData(float[] buf); private native boolean nativeSetData(byte[] buf); + private native boolean nativeSetData(int[] buf); + /** * Delete C++ Tenor object pointed by the input pointer, which is presented by a * long value. diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 4647f20bbe476d8763f94f707f3d88da7c7544df..990d08f18f541088d797510e9dbd4881d42b164f 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -139,22 +139,15 @@ std::vector Predictor::GetOutputNames() { return output_names_; } // append the names of inputs and outputs into input_names_ and output_names_ void Predictor::PrepareFeedFetch() { - std::vector feeds; - std::vector fetchs; -#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU) - // The shape of input tensors must be determined before generating NPU and XPU - // program. - auto current_block = program_desc_.GetBlock(0); - for (size_t i = 0; i < current_block->OpsSize(); i++) { - auto op = current_block->GetOp(i); -#else if (!program_) { GenRuntimeProgram(); } + + std::vector feeds; + std::vector fetchs; const auto &insts = program_->instructions(); for (size_t i = 0; i < program_->num_instructions(); i++) { const auto &op = insts[i].op()->op_info(); -#endif if (op->Type() == "feed") { feeds.push_back(op); } else if (op->Type() == "fetch") { diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 6fa400db6da9f029c38b496cd70d593a876628c9..3e6e10103e9f3af51923459a5921f9781431f352 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -20,6 +20,12 @@ #include "lite/core/device_info.h" #include "lite/core/version.h" +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) +#include +#include "lite/backends/x86/mklml.h" +#endif + namespace paddle { namespace lite { @@ -33,6 +39,17 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { mode_ = config.power_mode(); threads_ = config.threads(); + +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) + int num_threads = config.cpu_math_library_num_threads(); + int real_num_threads = num_threads > 1 ? num_threads : 1; + paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); + omp_set_num_threads(real_num_threads); + VLOG(3) << "set_cpu_math_library_math_threads() is set successfully and the " + "number of threads is:" + << num_threads; +#endif } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc index 1c426e8568cf71b6f48edbbeb8a93fec2e89c594..b678c7ecd24c5ffbf3e9e3531264ac195c6a7325 100644 --- a/lite/api/model_optimize_tool.cc +++ b/lite/api/model_optimize_tool.cc @@ -90,6 +90,10 @@ std::vector ParserValidPlaces() { TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel } else if (target_repr == "x86") { valid_places.emplace_back(TARGET(kX86)); + } else if (target_repr == "npu") { + valid_places.emplace_back(TARGET(kNPU)); + } else if (target_repr == "xpu") { + valid_places.emplace_back(TARGET(kXPU)); } else { LOG(FATAL) << lite::string_format( "Wrong target '%s' found, please check the command flag " diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc index a04e86b7d2a1e06a52c38b5f00e9c07966be1bfe..cf5fa4981a173ceb77e091ea9be0e510eb53980a 100644 --- a/lite/api/model_test.cc +++ b/lite/api/model_test.cc @@ -72,10 +72,6 @@ void Run(const std::vector>& input_shapes, const int thread_num, const int repeat, const int warmup_times = 0) { -#ifdef LITE_WITH_PROFILE - lite::profile::BasicProfiler::Global().SetWarmup( - warmup_times); -#endif lite_api::MobileConfig config; config.set_model_dir(model_dir); config.set_power_mode(power_mode); diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index c578769bd5159d27ad43e4e93de33f601223004b..339117cd503247a91694d1a9ca63b930af5658de 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -133,6 +133,7 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file_; std::string param_file_; bool model_from_memory_{false}; + int cpu_math_library_math_threads_ = 1; public: void set_valid_places(const std::vector& x) { valid_places_ = x; } @@ -151,6 +152,13 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file() const { return model_file_; } std::string param_file() const { return param_file_; } bool model_from_memory() const { return model_from_memory_; } + + void set_cpu_math_library_math_threads(int threads) { + cpu_math_library_math_threads_ = threads; + } + int cpu_math_library_num_threads() const { + return cpu_math_library_math_threads_; + } }; /// MobileConfig is the config for the light weight predictor, it will skip diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index 894d839185ea9e1b6b47b87c398f249f044c2b51..6d12df67ac70d5d922680fc76763123117045175 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -77,7 +77,8 @@ const std::string& PrecisionToStr(PrecisionType precision) { } const std::string& DataLayoutToStr(DataLayoutType layout) { - static const std::string datalayout2string[] = {"unk", "NCHW", "any", "NHWC"}; + static const std::string datalayout2string[] = { + "unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"}; auto x = static_cast(layout); CHECK_LT(x, static_cast(DATALAYOUT(NUM))); return datalayout2string[x]; @@ -115,8 +116,13 @@ const std::string& PrecisionRepr(PrecisionType precision) { } const std::string& DataLayoutRepr(DataLayoutType layout) { - static const std::string datalayout2string[] = { - "kUnk", "kNCHW", "kAny", "kNHWC"}; + static const std::string datalayout2string[] = {"kUnk", + "kNCHW", + "kAny", + "kNHWC", + "kImageDefault", + "kImageFolder", + "kImageNW"}; auto x = static_cast(layout); CHECK_LT(x, static_cast(DATALAYOUT(NUM))); return datalayout2string[x]; @@ -146,8 +152,12 @@ std::set ExpandValidPrecisions(PrecisionType precision) { } std::set ExpandValidLayouts(DataLayoutType layout) { - static const std::set valid_set( - {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + static const std::set valid_set({DATALAYOUT(kNCHW), + DATALAYOUT(kAny), + DATALAYOUT(kNHWC), + DATALAYOUT(kImageDefault), + DATALAYOUT(kImageFolder), + DATALAYOUT(kImageNW)}); if (layout == DATALAYOUT(kAny)) { return valid_set; } diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 07284be095c05e5dfa069b0973d5982cf1f07c8a..1aa41522352e9c2832e3c9919249887480e871a3 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -71,8 +71,11 @@ enum class DataLayoutType : int { kUnk = 0, kNCHW = 1, kNHWC = 3, - kAny = 2, // any data layout - NUM = 4, // number of fields. + kImageDefault = 4, // for opencl image2d + kImageFolder = 5, // for opencl image2d + kImageNW = 6, // for opencl image2d + kAny = 2, // any data layout + NUM = 7, // number of fields. }; typedef enum { diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 9d56d262abf549584819ab893144e41fc399439f..ac29cdda019c29ee208df391e0c637dc07329abe 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -20,12 +20,6 @@ USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(type_target_cast_pass); USE_MIR_PASS(generate_program_pass); -#ifdef LITE_WITH_NPU -USE_MIR_PASS(generate_npu_program_pass); -#endif -#ifdef LITE_WITH_XPU -USE_MIR_PASS(generate_xpu_program_pass); -#endif USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); @@ -40,8 +34,12 @@ USE_MIR_PASS(lite_interpolate_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); +USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) +USE_MIR_PASS(npu_subgraph_pass); +USE_MIR_PASS(xpu_subgraph_pass); diff --git a/lite/api/python/pybind/CMakeLists.txt b/lite/api/python/pybind/CMakeLists.txt index 178f167e6a1627d01df13b2e105e0af36b20601a..eabb6b150b93a722282118c3932676cd1aee5da8 100644 --- a/lite/api/python/pybind/CMakeLists.txt +++ b/lite/api/python/pybind/CMakeLists.txt @@ -4,3 +4,6 @@ if (NOT LITE_ON_TINY_PUBLISH) endif() lite_cc_library(lite_pybind SHARED SRCS pybind.cc DEPS ${PYBIND_DEPS}) +if (LITE_ON_TINY_PUBLISH) + set_target_properties(lite_pybind PROPERTIES COMPILE_FLAGS "-flto -fdata-sections") +endif() diff --git a/lite/api/python/pybind/pybind.cc b/lite/api/python/pybind/pybind.cc index 2df2e8f8f8aa56bb71b0e1cb293df2ecbbafd0bb..7d4ed4e98701a5328b0f05387dc73ad8b93dfe18 100644 --- a/lite/api/python/pybind/pybind.cc +++ b/lite/api/python/pybind/pybind.cc @@ -165,6 +165,9 @@ void BindLitePlace(py::module *m) { py::enum_(*m, "DataLayoutType") .value("NCHW", DataLayoutType::kNCHW) .value("NHWC", DataLayoutType::kNHWC) + .value("ImageDefault", DataLayoutType::kImageDefault) + .value("ImageFolder", DataLayoutType::kImageFolder) + .value("ImageNW", DataLayoutType::kImageNW) .value("Any", DataLayoutType::kAny); // Place diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc index 5314c5ed75d862635a1b87cdad33bf3c58dcd6cc..4d0aefbc06a9d0678d8b401629b7cc4355967f6c 100644 --- a/lite/api/test_step_rnn_lite_x86.cc +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -30,6 +30,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { std::string model_dir = FLAGS_model_dir; lite_api::CxxConfig config; config.set_model_dir(model_dir); + config.set_cpu_math_library_math_threads(10); config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index 076c791daab182c4eff477a621ecd2ec52a0c3e7..3bf1a00dd2701a2aaf79183eb6eb476e5cf67fff 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -120,5 +120,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) stack.cc affine_channel.cc anchor_generator.cc + split_merge_lod_tenosr.cc + reduce_prod.cc DEPS ${lite_kernel_deps} context tensor) endif() diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc index 5834461b8fe0b2d37f174d5f66269fb58f2504a1..67d60b18141f64fd4e0048e1a5d1e2c5373c7484 100644 --- a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -24,29 +24,48 @@ namespace paddle { namespace lite { namespace arm { namespace math { -void input_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride); -void output_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride); -void output_trans_c4_post(const float* src, - int src_stride, - float* dest, - int dest_stride, - float* bias_value, - bool has_relu); -void weight_trans_c4( +void input_trans_c4_8x8(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_post_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu); +void input_trans_c4_4x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride); +void output_trans_c4_post_2x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride, + float* bias_value, + bool has_relu); +void weight_trans_c4_8x8( + float* dest, const float* src, int ic, int oc, void* workspace); +void weight_trans_c4_4x4( float* dest, const float* src, int ic, int oc, void* workspace); /* -*The following function conv_compute_6x6_3x3 is base on +*The following function conv_compute_6x6_3x3 and conv_compute_2x2_3x3[_small] is +*base on *MNN[https://github.com/alibaba/MNN] * *Copyright © 2018, Alibaba Group Holding Limited */ + +// F(6,3) void conv_compute_6x6_3x3(const float* input, float* output, int num, @@ -75,11 +94,14 @@ void conv_compute_6x6_3x3(const float* input, int tile_w = (wout + 5) / 6; int tile_h = (hout + 5) / 6; int size_tile = tile_h * tile_w; - float zero_ptr[8]; - memset(zero_ptr, 0, 8 * sizeof(float)); int w_pad = win + pad_w * 2; int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + float* input_c4 = tmp_work_space; int new_h_stride = w_pad * 4; int new_c_stride = new_h_stride * h_pad; @@ -88,9 +110,6 @@ void conv_compute_6x6_3x3(const float* input, int oc_4_stride = wout * hout * 4; int tile_block = 8; -#ifdef __aarch64__ - tile_block = 16; -#endif int block_count = (size_tile + tile_block - 1) / tile_block; int threads = ctx->threads(); @@ -102,7 +121,8 @@ void conv_compute_6x6_3x3(const float* input, // begin compute for (int ni = 0; ni < num; ++ni) { - // trans input to c4 +// trans input to c4 +#pragma omp parallel for num_threads(threads) for (int i = 0; i < ic_4; ++i) { prepack_input_nxwc4_dw(input + ni * in_n_stride, input_c4 + i * new_c_stride, @@ -161,14 +181,14 @@ void conv_compute_6x6_3x3(const float* input, const float* src_ci = src_ptr + ci * ic_4_stride; for (int i = 0; i < 8; ++i) { const float* ci_ptr = src_ci + i * w_pad * 4; - input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32); + input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32); } float* dst_ci = dst_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - input_trans_c4(trans_tmp_data + i * 32, - 4, - dst_ci + i * b_gi_stride * 8, - b_gi_stride); + input_trans_c4_8x8(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); } } } else { @@ -189,14 +209,14 @@ void conv_compute_6x6_3x3(const float* input, // trans for (int i = 0; i < 8; ++i) { float* ci_ptr = trans_remain_tmp_data + i * 32; - input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32); + input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32); } float* dst_ci = dst_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - input_trans_c4(trans_tmp_data + i * 32, - 4, - dst_ci + i * b_gi_stride * 8, - b_gi_stride); + input_trans_c4_8x8(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); } } // for ci_4 } @@ -213,16 +233,8 @@ void conv_compute_6x6_3x3(const float* input, float* origin_C = dst_temp_data + gi * c_gi_stride; float* origin_B = b_ptr + gi * b_gi_stride; const float* origin_A = weight + gi * w_gi_stride; - sgemm_prepack_c4_small(oc_4 * 4, - tile_count, - ic_4 * 4, - origin_A, - origin_B, - origin_C, - nullptr, - false, - false, - ctx); + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); } //*/ //* @@ -258,18 +270,18 @@ void conv_compute_6x6_3x3(const float* input, float* dst_ci = dst_ptr + ci * oc_4_stride; float* src_ci = src_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - output_trans_c4(src_ci + i * c_gi_stride * 8, - c_gi_stride, - trans_tmp_data + i * 4, - 32); + output_trans_c4_6x8(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); } for (int i = 0; i < ey; ++i) { - output_trans_c4_post(trans_tmp_data + i * 32, - 4, - trans_remain_tmp_data + i * 24, - 4, - bias_value, - param.fuse_relu); + output_trans_c4_post_6x8(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); } write_to_output_c4_fp32(trans_remain_tmp_data, output_ptr, @@ -283,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } else { for (int ci = 0; ci < oc_4; ++ci) { @@ -297,18 +310,18 @@ void conv_compute_6x6_3x3(const float* input, float* dst_ci = dst_ptr + ci * oc_4_stride; float* src_ci = src_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - output_trans_c4(src_ci + i * c_gi_stride * 8, - c_gi_stride, - trans_tmp_data + i * 4, - 32); + output_trans_c4_6x8(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); } for (int i = 0; i < ey; ++i) { - output_trans_c4_post(trans_tmp_data + i * 32, - 4, - trans_remain_tmp_data + i * 24, - 4, - bias_value, - param.fuse_relu); + output_trans_c4_post_6x8(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); } // copy to dest memset(trans_tmp_data, 0, 144 * sizeof(float)); @@ -329,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } } @@ -338,10 +352,526 @@ void conv_compute_6x6_3x3(const float* input, } // for num } // conv_compute -void output_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride) { +// F(2,3) +void conv_compute_2x2_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64; + memset(g_tmp_data, 0, threads * tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 64; + + // begin compute + for (int ni = 0; ni < num; ++ni) { +// trans input to c4 +#pragma omp parallel for num_threads(threads) + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; +#pragma omp parallel for num_threads(threads) + for (int tbi = 0; tbi < block_count; ++tbi) { +#ifdef ARM_WITH_OMP + float* tmp_data = + g_tmp_data + omp_get_thread_num() * tmp_data_thread_stride; + float* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 64; + float* trans_remain_tmp_data = + g_trans_remain_tmp_data + omp_get_thread_num() * 64; +#else + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; +#endif + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4( + src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 64 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 16; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + // trans + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4(trans_remain_tmp_data, + 4, + 16, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 64; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; + for (int gi = 0; gi < 16; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 2) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + nullptr); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + // copy to dest + memset(trans_tmp_data, 0, 16 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 8, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + nullptr); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute +void conv_compute_2x2_3x3_small(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64; + memset(g_tmp_data, 0, tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + 64; + + // begin compute + for (int ni = 0; ni < num; ++ni) { +// trans input to c4 + +#pragma omp parallel for num_threads(threads) + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; + for (int tbi = 0; tbi < block_count; ++tbi) { + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4( + src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 64 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 16; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4(trans_remain_tmp_data, + 4, + 16, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 64; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; +#pragma omp parallel for num_threads(threads) + for (int gi = 0; gi < 16; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 2) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + nullptr); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + // copy to dest + memset(trans_tmp_data, 0, 16 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 8, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + nullptr); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute +void output_trans_c4_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride) { const float32x4_t src0 = vld1q_f32(src); const float32x4_t src1 = vld1q_f32(src + src_stride); const float32x4_t src2 = vld1q_f32(src + src_stride * 2); @@ -381,12 +911,13 @@ void output_trans_c4(const float* src, vst1q_f32(dest + dest_stride * 4, dest4); vst1q_f32(dest + dest_stride * 5, dest5); } -void output_trans_c4_post(const float* src, - int src_stride, - float* dest, - int dest_stride, - float* bias_value, - bool has_relu = false) { + +void output_trans_c4_post_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu = false) { const float32x4_t src0 = vld1q_f32(src); const float32x4_t src1 = vld1q_f32(src + src_stride); const float32x4_t src2 = vld1q_f32(src + src_stride * 2); @@ -447,10 +978,10 @@ void output_trans_c4_post(const float* src, vst1q_f32(dest + dest_stride * 5, dest5); } -void input_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride) { +void input_trans_c4_8x8(const float* src, + int src_stride, + float* dest, + int dest_stride) { float32x4_t src0 = vld1q_f32(src); float32x4_t src1 = vld1q_f32(src + src_stride); float32x4_t src2 = vld1q_f32(src + src_stride * 2); @@ -497,7 +1028,165 @@ void input_trans_c4(const float* src, vst1q_f32(dest + dest_stride * 6, dst6); vst1q_f32(dest + dest_stride * 7, dst7); } -void weight_trans_c4( + +// BT=[1, 0, -1, 0, +// 0, 1, 1, 0, +// 0, -1, 1, 0, +// 0, 1, 0, -1] +void input_trans_c4_4x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride) { + float32x4_t src00 = vld1q_f32(src); + float32x4_t src01 = vld1q_f32(src + src_stride); + float32x4_t src02 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src10 = vld1q_f32(src); + float32x4_t src11 = vld1q_f32(src + src_stride); + float32x4_t src12 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src20 = vld1q_f32(src); + float32x4_t src21 = vld1q_f32(src + src_stride); + float32x4_t src22 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src30 = vld1q_f32(src); + float32x4_t src31 = vld1q_f32(src + src_stride); + float32x4_t src32 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride); + + float32x4_t dst00 = vsubq_f32(src00, src02); + float32x4_t dst10 = vaddq_f32(src01, src02); + float32x4_t dst20 = vsubq_f32(src02, src01); + float32x4_t dst30 = vsubq_f32(src01, src03); + + float32x4_t dst01 = vsubq_f32(src10, src12); + float32x4_t dst11 = vaddq_f32(src11, src12); + float32x4_t dst21 = vsubq_f32(src12, src11); + float32x4_t dst31 = vsubq_f32(src11, src13); + + float32x4_t dst02 = vsubq_f32(src20, src22); + float32x4_t dst12 = vaddq_f32(src21, src22); + float32x4_t dst22 = vsubq_f32(src22, src21); + float32x4_t dst32 = vsubq_f32(src21, src23); + + float32x4_t dst03 = vsubq_f32(src30, src32); + float32x4_t dst13 = vaddq_f32(src31, src32); + float32x4_t dst23 = vsubq_f32(src32, src31); + float32x4_t dst33 = vsubq_f32(src31, src33); + + float32x4_t dest00 = vsubq_f32(dst00, dst02); + float32x4_t dest10 = vaddq_f32(dst01, dst02); + float32x4_t dest20 = vsubq_f32(dst02, dst01); + float32x4_t dest30 = vsubq_f32(dst01, dst03); + + float32x4_t dest01 = vsubq_f32(dst10, dst12); + float32x4_t dest11 = vaddq_f32(dst11, dst12); + float32x4_t dest21 = vsubq_f32(dst12, dst11); + float32x4_t dest31 = vsubq_f32(dst11, dst13); + + float32x4_t dest02 = vsubq_f32(dst20, dst22); + float32x4_t dest12 = vaddq_f32(dst21, dst22); + float32x4_t dest22 = vsubq_f32(dst22, dst21); + float32x4_t dest32 = vsubq_f32(dst21, dst23); + + float32x4_t dest03 = vsubq_f32(dst30, dst32); + float32x4_t dest13 = vaddq_f32(dst31, dst32); + float32x4_t dest23 = vsubq_f32(dst32, dst31); + float32x4_t dest33 = vsubq_f32(dst31, dst33); + + vst1q_f32(dest, dest00); + vst1q_f32(dest + dest_stride, dest10); + vst1q_f32(dest + dest_stride + dest_stride, dest20); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest30); + dest += dest_h_stride; + vst1q_f32(dest, dest01); + vst1q_f32(dest + dest_stride, dest11); + vst1q_f32(dest + dest_stride + dest_stride, dest21); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest31); + dest += dest_h_stride; + vst1q_f32(dest, dest02); + vst1q_f32(dest + dest_stride, dest12); + vst1q_f32(dest + dest_stride + dest_stride, dest22); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest32); + dest += dest_h_stride; + vst1q_f32(dest, dest03); + vst1q_f32(dest + dest_stride, dest13); + vst1q_f32(dest + dest_stride + dest_stride, dest23); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest33); +} + +// AT=[1, 1, 1, 0, +// 0, 1, -1, -1] +void output_trans_c4_post_2x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride, + float* bias_value, + bool has_relu) { + float32x4_t src00 = vld1q_f32(src); + float32x4_t src01 = vld1q_f32(src + src_stride); + float32x4_t src02 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src10 = vld1q_f32(src); + float32x4_t src11 = vld1q_f32(src + src_stride); + float32x4_t src12 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src20 = vld1q_f32(src); + float32x4_t src21 = vld1q_f32(src + src_stride); + float32x4_t src22 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src30 = vld1q_f32(src); + float32x4_t src31 = vld1q_f32(src + src_stride); + float32x4_t src32 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride); + + float32x4_t dst00 = vaddq_f32(vaddq_f32(src00, src01), src02); + float32x4_t dst10 = vsubq_f32(vsubq_f32(src01, src02), src03); + float32x4_t dst01 = vaddq_f32(vaddq_f32(src10, src11), src12); + float32x4_t dst11 = vsubq_f32(vsubq_f32(src11, src12), src13); + float32x4_t dst02 = vaddq_f32(vaddq_f32(src20, src21), src22); + float32x4_t dst12 = vsubq_f32(vsubq_f32(src21, src22), src23); + float32x4_t dst03 = vaddq_f32(vaddq_f32(src30, src31), src32); + float32x4_t dst13 = vsubq_f32(vsubq_f32(src31, src32), src33); + + float32x4_t dest00 = vaddq_f32(vaddq_f32(dst00, dst01), dst02); + float32x4_t dest10 = vsubq_f32(vsubq_f32(dst01, dst02), dst03); + float32x4_t dest01 = vaddq_f32(vaddq_f32(dst10, dst11), dst12); + float32x4_t dest11 = vsubq_f32(vsubq_f32(dst11, dst12), dst13); + + if (bias_value) { + float32x4_t bias = vld1q_f32(bias_value); + dest00 = vaddq_f32(dest00, bias); + dest10 = vaddq_f32(dest10, bias); + dest01 = vaddq_f32(dest01, bias); + dest11 = vaddq_f32(dest11, bias); + } + + if (has_relu) { + float32x4_t zeros = vdupq_n_f32(0); + dest00 = vmaxq_f32(dest00, zeros); + dest10 = vmaxq_f32(dest10, zeros); + dest01 = vmaxq_f32(dest01, zeros); + dest11 = vmaxq_f32(dest11, zeros); + } + + vst1q_f32(dest, dest00); + vst1q_f32(dest + dest_stride, dest10); + dest += dest_h_stride; + vst1q_f32(dest, dest01); + vst1q_f32(dest + dest_stride, dest11); +} +void weight_trans_c4_8x8( float* dest, const float* din, int ch_in, int ch_out, void* workspace) { const float coeff[8][3] = {{1.0f, 0.0f, 0.0f}, {-2.0f / 9, -2.0f / 9, -2.0f / 9}, @@ -558,6 +1247,63 @@ void weight_trans_c4( } } +void weight_trans_c4_4x4( + float* dest, const float* din, int ch_in, int ch_out, void* workspace) { + const float coeff[4][3] = {{1.0f, 0.0f, 0.0f}, + {0.5f, 0.5f, 0.5f}, + {0.5f, -0.5f, 0.5f}, + {0.0f, 0.0f, 1.0f}}; + + float* ptr_out = static_cast(workspace); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const float* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + float* ptr_channel = ptr_out + (i * ch_in + j) * 16; + + //! transform kernel, transposed + const float* k0 = kernel0; + const float* k1 = kernel0 + 3; + const float* k2 = kernel0 + 6; + + //! h + float tmp[4][3]; + for (int i = 0; i < 4; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 4; j++) { + float* tmpp = &tmp[j][0]; + for (int i = 0; i < 4; i++) { + ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + + int oc_pad = (ch_out + 3) / 4 * 4; + int ic_pad = (ch_in + 3) / 4 * 4; + int c_stride = ic_pad * oc_pad; + for (int i = 0; i < ch_out * ch_in * 16; ++i) { + int new_c = i % 16; + int new_oc = i / ch_in / 16 / 4; + int new_ic = i / 16 % (ch_in * 4) % ch_in; + int new_inner = i / ch_in / 16 % 4; + int dest_ind = + new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; + dest[dest_ind] = ptr_out[i]; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc index b4972a1ecab151947f8aaa7d6db0f6e82a08e5e4..5cee02b639af7e04a9184af765a5e96be4cb4cdb 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc @@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); auto paddings = *param.paddings; + auto act_param = param.activation_param; const int pad_h = paddings[0]; const int pad_w = paddings[2]; @@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } const float* weight_remain_ptr = weights + c_round_down * w_stride; #pragma omp parallel for num_threads(threads) @@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } } } diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index e4c9fb99ef9a6b5d3987a1efd5a644f322ea043c..6f056677378ad0499e0f2ce8b0dd56cee5d6a6ae 100644 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -32,6 +32,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1p0_bias_s(float *dout, @@ -46,6 +47,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1p1_bias(float *dout, @@ -60,6 +62,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1p1_bias_s(float *dout, @@ -74,6 +77,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1_fp32(const float *din, @@ -90,6 +94,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, int pad, bool flag_bias, bool flag_relu, + const operators::ActivationParam act_param, ARMContext *ctx) { if (pad == 0) { if (w_in > 5) { @@ -105,6 +110,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s1p0_bias_s(dout, @@ -119,6 +125,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } } @@ -136,6 +143,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s1p1_bias_s(dout, @@ -150,11 +158,12 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } } } - +// clang-format on #ifdef __aarch64__ #define INIT_S1 \ "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ @@ -255,14 +264,12 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ \ - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ + "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ /* r4 */ \ + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ #define LEFT_RESULT_S1 \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ @@ -345,16 +352,15 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ #define MID_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -411,30 +417,31 @@ void conv_depthwise_3x3s1_fp32(const float *din, #define RIGHT_COMPUTE_S1 \ "3: \n" \ + "movi v20.4s, #0 \n" \ "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ "ld1 {v22.4s}, [%[doutr0]] \n" \ "ld1 {v23.4s}, [%[doutr1]] \n" \ "ld1 {v24.4s}, [%[doutr2]] \n" \ "ld1 {v25.4s}, [%[doutr3]] \n" \ \ - "bif v0.16b, %[vzero].16b, v18.16b \n" \ - "bif v1.16b, %[vzero].16b, v19.16b \n" \ - "bif v2.16b, %[vzero].16b, v18.16b \n" \ - "bif v3.16b, %[vzero].16b, v19.16b \n" \ + "bif v0.16b, v20.16b, v18.16b \n" \ + "bif v1.16b, v20.16b, v19.16b \n" \ + "bif v2.16b, v20.16b, v18.16b \n" \ + "bif v3.16b, v20.16b, v19.16b \n" \ \ - "bif v4.16b, %[vzero].16b, v18.16b \n" \ - "bif v5.16b, %[vzero].16b, v19.16b \n" \ - "bif v6.16b, %[vzero].16b, v18.16b \n" \ - "bif v7.16b, %[vzero].16b, v19.16b \n" \ + "bif v4.16b, v20.16b, v18.16b \n" \ + "bif v5.16b, v20.16b, v19.16b \n" \ + "bif v6.16b, v20.16b, v18.16b \n" \ + "bif v7.16b, v20.16b, v19.16b \n" \ \ "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ - "bif v8.16b, %[vzero].16b, v18.16b \n" \ - "bif v9.16b, %[vzero].16b, v19.16b \n" \ - "bif v10.16b, %[vzero].16b, v18.16b \n" \ - "bif v11.16b, %[vzero].16b, v19.16b \n" \ + "bif v8.16b, v20.16b, v18.16b \n" \ + "bif v9.16b, v20.16b, v19.16b \n" \ + "bif v10.16b, v20.16b, v18.16b \n" \ + "bif v11.16b, v20.16b, v19.16b \n" \ \ "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ \ @@ -467,15 +474,13 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ #define RIGHT_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ "bif v12.16b, v22.16b, v18.16b \n" \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -520,10 +525,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v15.4s}, [%[doutr3]], #16 \n" #define LEFT_RESULT_S1_RELU \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ \ @@ -570,14 +571,113 @@ void conv_depthwise_3x3s1_fp32(const float *din, "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "blt 3f \n" +#define LEFT_RESULT_S1_RELU6 \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + +#define LEFT_RESULT_S1_LEAKY_RELU \ + "cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ + "bif v13.16b, v21.16b, v19.16b \n" /* choose*/ \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + "cmhs v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "cmhs v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "cmp %w[cnt], #1 \n" \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + #define MID_RESULT_S1_RELU \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "movi v20.4s, #0 \n" \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -598,7 +698,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -617,7 +717,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, /* r3 */ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ \ @@ -633,20 +733,157 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "subs %w[cnt], %w[cnt], #1 \n" \ \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ \ "st1 {v15.4s}, [%[doutr3]], #16 \n" \ "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ \ "bne 1b \n" -#define RIGHT_RESULT_S1_RELU \ +#define MID_RESULT_S1_RELU6 \ + "movi v20.4s, #0 \n" \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define MID_RESULT_S1_LEAKY_RELU \ + "movi v21.4s, #0 \n" \ + "cmhs v18.4s, v12.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "cmhs v18.4s, v13.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v13.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "cmhs v18.4s, v14.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "cmhs v18.4s, v15.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "bne 1b \n" + +#define RIGHT_RESULT_S1_RELU \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -664,7 +901,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -680,7 +917,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ \ @@ -690,72 +927,184 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \ \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ \ "bif v15.16b, v25.16b, v18.16b \n" \ \ "st1 {v15.4s}, [%[doutr3]], #16 \n" -#define COMPUTE_S_S1 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s}, [%[din0]], #16\n" \ - "ld1 {v1.4s}, [%[din1]], #16\n" \ - "ld1 {v2.4s}, [%[din2]], #16\n" \ - "ld1 {v3.4s}, [%[din3]], #16\n" \ - \ - "bif v0.16b, %[zero].16b, %[mask].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask].16b\n" \ - "bif v2.16b, %[zero].16b, %[mask].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask].16b\n" \ - \ - "ext v4.16b, %[zero].16b, v0.16b, #12\n" \ - "ext v5.16b, %[zero].16b, v1.16b, #12\n" \ - "ext v6.16b, %[zero].16b, v2.16b, #12\n" \ - "ext v7.16b, %[zero].16b, v3.16b, #12\n" \ - \ - "ext v8.16b, v0.16b, %[zero].16b, #4\n" \ - "ext v9.16b, v1.16b, %[zero].16b, #4\n" \ - "ext v10.16b, v2.16b, %[zero].16b, #4\n" \ - "ext v11.16b, v3.16b, %[zero].16b, #4\n" \ - \ - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ - \ - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ - \ - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ - \ - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ - \ - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ - \ - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ - \ - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ - \ - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v14.4s\n" \ - "fadd v12.4s, v12.4s, v16.4s\n" \ - \ - "fadd v13.4s, v13.4s, v15.4s\n" \ - "fadd v13.4s, v13.4s, v17.4s\n" \ - \ - "fadd v12.4s, v12.4s, %[bias].4s\n" \ +#define RIGHT_RESULT_S1_RELU6 \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "bif v12.16b, v22.16b, v18.16b \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v10.4s, v20.s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define RIGHT_RESULT_S1_LEAKY_RELU \ + "movi v1.4s, #0 \n" \ + "cmhs v20.4s, v12.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v21.16b, v20.16b \n" /* choose*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "bif v12.16b, v22.16b, v18.16b \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "cmhs v20.4s, v13.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v13.16b, v21.16b, v20.16b \n" \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "cmhs v20.4s, v14.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v21.16b, v20.16b \n" \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "cmhs v20.4s, v15.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + "bif v15.16b, v21.16b, v20.16b \n" \ + "bif v15.16b, v25.16b, v18.16b \n" \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define COMPUTE_S_S1 \ + "prfm pldl1keep, [%[din0]]\n" \ + "prfm pldl1keep, [%[din1]]\n" \ + "prfm pldl1keep, [%[din2]]\n" \ + "prfm pldl1keep, [%[din3]]\n" \ + \ + "ld1 {v0.4s}, [%[din0]], #16\n" \ + "ld1 {v1.4s}, [%[din1]], #16\n" \ + "ld1 {v2.4s}, [%[din2]], #16\n" \ + "ld1 {v3.4s}, [%[din3]], #16\n" \ + \ + "bif v0.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v1.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v2.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v3.16b, %[vzero].16b, %[mask].16b\n" \ + \ + "ext v4.16b, %[vzero].16b, v0.16b, #12\n" \ + "ext v5.16b, %[vzero].16b, v1.16b, #12\n" \ + "ext v6.16b, %[vzero].16b, v2.16b, #12\n" \ + "ext v7.16b, %[vzero].16b, v3.16b, #12\n" \ + \ + "ext v8.16b, v0.16b, %[vzero].16b, #4\n" \ + "ext v9.16b, v1.16b, %[vzero].16b, #4\n" \ + "ext v10.16b, v2.16b, %[vzero].16b, #4\n" \ + "ext v11.16b, v3.16b, %[vzero].16b, #4\n" \ + \ + "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ + "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ + \ + "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ + "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ + \ + "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ + "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ + \ + "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ + "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ + \ + "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ + "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ + \ + "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ + "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ + \ + "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ + \ + "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ + "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ + \ + "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ + "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ + \ + "fadd v12.4s, v12.4s, v14.4s\n" \ + "fadd v12.4s, v12.4s, v16.4s\n" \ + \ + "fadd v13.4s, v13.4s, v15.4s\n" \ + "fadd v13.4s, v13.4s, v17.4s\n" \ + \ + "fadd v12.4s, v12.4s, %[bias].4s\n" \ "fadd v13.4s, v13.4s, %[bias].4s\n" #define RESULT_S_S1 \ @@ -765,16 +1114,42 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v12.4s}, [%[out1]]\n" \ "st1 {v13.4s}, [%[out2]]\n" -#define RESULT_S_S1_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fmax v12.4s, v12.4s, %[zero].4s\n" \ - "fmax v13.4s, v13.4s, %[zero].4s\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ +#define RESULT_S_S1_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s\n" \ + "fmax v13.4s, v13.4s, %[vzero].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define RESULT_S_S1_RELU6 \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s\n" \ + "fmax v13.4s, v13.4s, %[vzero].4s\n" \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s\n" \ + "fmin v13.4s, v13.4s, %[vsix].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ "st1 {v13.4s}, [%[out2]]\n" +#define RESULT_S_S1_LEAKY_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" \ + "bif v13.16b, v21.16b, v19.16b \n" \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" #define COMPUTE_S_S1_P0 \ "prfm pldl1keep, [%[din0]]\n" \ "prfm pldl1keep, [%[din1]]\n" \ @@ -786,17 +1161,17 @@ void conv_depthwise_3x3s1_fp32(const float *din, "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ \ - "bif v0.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v0.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v1.16b, %[vzero].16b, %[mask2].16b\n" \ \ - "bif v2.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v2.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v3.16b, %[vzero].16b, %[mask2].16b\n" \ \ - "bif v4.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v5.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v4.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v5.16b, %[vzero].16b, %[mask2].16b\n" \ \ - "bif v6.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v7.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v6.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v7.16b, %[vzero].16b, %[mask2].16b\n" \ \ "ext v8.16b, v0.16b, v1.16b, #4\n" \ "ext v9.16b, v0.16b, v1.16b, #8\n" \ @@ -849,7 +1224,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, // "st1 {v12.4s}, [%[out1]]\n" \ // "st1 {v13.4s}, [%[out2]]\n" \ - #else #define INIT_S1 \ "pld [%[din0_ptr]] @ preload data\n" \ @@ -1129,6 +1503,66 @@ void conv_depthwise_3x3s1_fp32(const float *din, "vdup.32 q5, %[bias_val] @ and \n" \ "blt 3f @ jump to main loop start point\n" +#define LEFT_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.f32 {d28-d29}, [%[six_ptr]] @ load six \n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define LEFT_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + "vld1.f32 {d28-d29}, [%[scale_ptr]] @ load scale \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vbif q4, q6, q15 @ choose \n" \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vbif q5, q6, q7 @ choose \n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + #define MID_RESULT_S1_RELU \ /* r3 */ \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ @@ -1157,6 +1591,69 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "bne 1b @ jump to main loop start point\n" +#define MID_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[six_ptr]]! @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define MID_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vbif q4, q6, q15 @ choose \n" \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + \ + "vbif q5, q6, q7 @ choose \n" \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + #define RIGHT_RESULT_S1_RELU \ /* r3 */ \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ @@ -1178,6 +1675,58 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" +#define RIGHT_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define RIGHT_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vbif q4, q6, q15 @ choose \n" \ + \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + "vbif q5, q6, q7 @ choose \n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + #define COMPUTE_S_S1 \ "pld [%[din0]]\n" \ "pld [%[din1]]\n" \ @@ -1251,6 +1800,36 @@ void conv_depthwise_3x3s1_fp32(const float *din, "vst1.32 {d28-d29}, [%[out1]]\n" \ "vst1.32 {d30-d31}, [%[out2]]\n" +#define RESULT_S_S1_RELU6 \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vld1.32 {d20-d21}, [%[six_ptr]] \n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + \ + "vmin.f32 q14, q14, q10 \n" \ + "vmin.f32 q15, q15, q10 \n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define RESULT_S_S1_LEAKY_RELU \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vld1.32 {d18-d19}, [%[scale_ptr]] \n" \ + "vcge.f32 q10, q14, %q[vzero] @ q0 > 0 \n" \ + "vcge.f32 q11, q15, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q12, q14, q9 \n" \ + "vmul.f32 q13, q15, q9 \n" \ + \ + "vbif q14, q10, q12 \n" \ + "vbif q15, q11, q13 \n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + #define COMPUTE_S_S1_P0 \ "pld [%[din0]]\n" \ "pld [%[din1]]\n" \ @@ -1333,6 +1912,413 @@ void conv_depthwise_3x3s1_fp32(const float *din, "vadd.f32 q15, q5, q9 @ q4 += q10 \n" #endif + +#ifdef __aarch64__ +void act_switch_3x3s1p1(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + const float *din_ptr4, + const float *din_ptr5, + float *doutr0, + float *doutr1, + float *doutr2, + float *doutr3, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask, + unsigned int *rmask, + float32x4_t vzero, + float *vbias, + int cnt, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vsix] "w"(vsix), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vscale] "w"(vscale), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } +} +#else +void act_switch_3x3s1p1(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask_ptr, + unsigned int *rmask_ptr, + float32x4_t vzero, + float bias_val, + int cnt, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [six_ptr] "r"(vsix), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [scale_ptr] "r"(vscale), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +} +#endif +// clang-format on /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width > 4 @@ -1349,6 +2335,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! pad is done implicit const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; @@ -1486,106 +2473,25 @@ void conv_depthwise_3x3s1p1_bias(float *dout, } int cnt = cnt_col; - if (flag_relu) { - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } else { - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } + act_switch_3x3s1p1(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + doutr0, + doutr1, + doutr2, + doutr3, + wr0, + wr1, + wr2, + vmask, + rmask, + vzero, + vbias, + cnt, + act_param); dout_ptr = dout_ptr + 4 * w_out; } #else @@ -1598,7 +2504,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout, doutr0 = dout_ptr; doutr1 = dout_ptr + w_out; - // unsigned int* rst_mask = rmask; if (i == 0) { din_ptr0 = zero_ptr; @@ -1635,77 +2540,314 @@ void conv_depthwise_3x3s1p1_bias(float *dout, int cnt = cnt_col; unsigned int *rmask_ptr = rmask; unsigned int *vmask_ptr = vmask; - if (flag_relu) { - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } + act_switch_3x3s1p1(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + doutr0, + doutr1, + wr0, + wr1, + wr2, + vmask_ptr, + rmask_ptr, + vzero, + bias_val, + cnt, + act_param); dout_ptr += 2 * w_out; } //! end of processing mid rows #endif } } } - +void act_switch_3x3s1p1_s(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp, + float32x4_t vzero, + float32x4_t wbias, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); +#else + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [vsix] "w"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [six_ptr] "r"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [vscale] "w"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [scale_ptr] "r"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width <= 4 @@ -1722,6 +2864,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! 3x3s1 convolution, implemented by direct algorithm //! pad is done implicit @@ -1772,7 +2915,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, if (hs == -1) { dr0 = zero; } - switch (he - h_in) { case 2: dr2 = zero; @@ -1782,127 +2924,19 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, default: break; } -#ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); - } else { - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); - } -#else - if (flag_relu) { - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } -#endif + act_switch_3x3s1p1_s(dr0, + dr1, + dr2, + dr3, + out_buf1, + out_buf2, + wr0, + wr1, + wr2, + vmask_rp, + vzero, + wbias, + act_param); for (int w = 0; w < w_out; ++w) { *doutr0++ = out_buf1[w]; *doutr1++ = out_buf2[w]; @@ -1916,6 +2950,490 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, } // end of processing batchs } +#ifdef __aarch64__ +void act_switch_3x3s1p0(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + const float *din_ptr4, + const float *din_ptr5, + float *doutr0, + float *doutr1, + float *doutr2, + float *doutr3, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask, + unsigned int *rmask, + float32x4_t vzero, + float *vbias, + int cnt, + int remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU6 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU6 "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vsix] "w"(vsix), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_LEAKY_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vscale] "w"(vscale), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } +} +#else +void act_switch_3x3s1p0(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask_ptr, + unsigned int *rmask_ptr, + float32x4_t vzero, + float bias_val, + int cnt, + int remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU6 "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [six_ptr] "r"(vsix), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_LEAKY_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_LEAKY_RELU + "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [scale_ptr] "r"(vscale), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +} +#endif /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width > 4 @@ -1932,6 +3450,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! pad is done implicit const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; @@ -2060,15 +3579,16 @@ void conv_depthwise_3x3s1p0_bias(float *dout, } int cnt = tile_w; + /* if (flag_relu) { asm volatile( INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0) + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0) + "ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234 + "ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345 + "ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0) + "ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0) MID_COMPUTE_S1 MID_RESULT_S1_RELU "cmp %w[remain], #1 \n" "blt 0f \n" RIGHT_COMPUTE_S1 @@ -2123,12 +3643,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout, } else { asm volatile( INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0) + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0) + "ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234 + "ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345 + "ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0) + "ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0) MID_COMPUTE_S1 MID_RESULT_S1 "cmp %w[remain], #1 \n" "blt 0f \n" RIGHT_COMPUTE_S1 @@ -2181,6 +3701,27 @@ void conv_depthwise_3x3s1p0_bias(float *dout, "v24", "v25"); } + */ + act_switch_3x3s1p0(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + doutr0, + doutr1, + doutr2, + doutr3, + wr0, + wr1, + wr2, + vmask, + rmask, + vzero, + vbias, + cnt, + remain, + act_param); dout_ptr = dout_ptr + 4 * w_out; } #else @@ -2219,6 +3760,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, int cnt = tile_w; unsigned int *rmask_ptr = rmask; unsigned int *vmask_ptr = vmask; + /* if (flag_relu) { asm volatile(INIT_S1 "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" @@ -2301,13 +3843,328 @@ void conv_depthwise_3x3s1p0_bias(float *dout, "q13", "q14", "q15"); - } + }*/ + act_switch_3x3s1p0(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + doutr0, + doutr1, + wr0, + wr1, + wr2, + vmask_ptr, + rmask_ptr, + vzero, + bias_val, + cnt, + remain, + act_param); dout_ptr += 2 * w_out; } //! end of processing mid rows #endif } } } +void act_switch_3x3s1p0_s(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + float32x4_t vzero, + float32x4_t wbias, + unsigned int *vmask_ptr, + float bias_val, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); +#else + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [six_ptr] "r"(vsix), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [scale_ptr] "r"(vscale), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width <= 4 @@ -2324,6 +4181,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! 3x3s1 convolution, implemented by direct algorithm //! pad is done implicit @@ -2355,15 +4213,22 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, float32x4_t wr1 = vld1q_f32(weight_ptr + 3); float32x4_t wr2 = vld1q_f32(weight_ptr + 6); -#ifdef __aarch64__ + // #ifdef __aarch64__ + // float32x4_t wbias; + // if (flag_bias) { + // wbias = vdupq_n_f32(bias[i]); + // } else { + // wbias = vdupq_n_f32(0.f); + // } + // #endif // __aarch64__ float32x4_t wbias; + float bias_val = 0.f; if (flag_bias) { wbias = vdupq_n_f32(bias[i]); + bias_val = bias[i]; } else { wbias = vdupq_n_f32(0.f); } -#endif // __aarch64__ - float out_buf1[4]; float out_buf2[4]; float trash_buf[4]; @@ -2396,135 +4261,154 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, break; } } -#ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } -#else + /* + #ifdef __aarch64__ + if (flag_relu) { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } else { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } + #else + unsigned int *vmask_ptr = vmask; + float bias_val = flag_bias ? bias[i] : 0.f; + if (flag_relu) { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + #endif + */ unsigned int *vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } -#endif + act_switch_3x3s1p0_s(dr0, + dr1, + dr2, + dr3, + out_buf1, + out_buf2, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + vzero, + wbias, + vmask_ptr, + bias_val, + act_param); for (int w = 0; w < w_out; ++w) { *doutr0++ = out_buf1[w]; *doutr1++ = out_buf2[w]; diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index 08e5efecd751bcca534ba7a47035c5f70fa1f6bf..fd54e214cf27e001e21efcf255b09113bbe12d19 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -25,6 +25,785 @@ namespace paddle { namespace lite { namespace arm { namespace math { +// clang-format off +#ifdef __aarch64__ +#define COMPUTE \ + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \ + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \ + "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ \ + "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ \ + "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ \ + /* r0, r1, mul w0, get out r0, r1 */ \ + "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ \ + "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ \ + "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ \ + "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ \ + "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ \ + "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ \ + "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ \ + "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ \ + /* r0, r1, mul w1, get out r0, r1 */ \ + "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ \ + "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ \ + "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ \ + "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ \ + "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ \ + "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ \ + "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ \ + /* r0, r1, mul w2, get out r0, r1 */ \ + "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ \ + "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ \ + "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ \ + "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ \ + "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ \ + "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ \ + "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ \ + "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ \ + /* r1, r2, mul w3, get out r0, r1 */ \ + "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ \ + "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ \ + "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ \ + "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ \ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ \ + "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ \ + "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ \ + "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ \ + /* r1, r2, mul w4, get out r0, r1 */ \ + "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ \ + "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ \ + "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ \ + "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ \ + "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ \ + "ldp x0, x1, [%[outl]] \n" \ + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ \ + "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ \ + "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ \ + "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ \ + /* r1, r2, mul w5, get out r0, r1 */ \ + "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ \ + "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ \ + "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ \ + "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ \ + "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ \ + "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ \ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ \ + "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ \ + "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ \ + "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ \ + /* r2, r3, mul w6, get out r0, r1 */ \ + "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ \ + "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ \ + "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ \ + "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ \ + "ldp x2, x3, [%[outl], #16] \n" \ + "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ \ + "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ \ + "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ \ + "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ \ + /* r2, r3, mul w7, get out r0, r1 */ \ + "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ \ + "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ \ + "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ \ + "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ \ + "ldp x4, x5, [%[outl], #32] \n" \ + "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ \ + "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ \ + "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ \ + "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ \ + /* r2, r3, mul w8, get out r0, r1 */ \ + "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ \ + "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ \ + "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ \ + "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ \ + "ldp x6, x7, [%[outl], #48] \n" \ + "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ \ + "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ \ + "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ \ + "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ \ + \ + "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ \ + /* transpose */ \ + "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ \ + "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ \ + "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ \ + "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ \ + "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ \ + "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ \ + "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ \ + "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ \ + "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ + +#define RELU \ + "movi v0.4s, #0\n" /* for relu */ \ + "ldr x0, [%[outl], #80]\n" \ + "fmax v15.4s, v15.4s, v0.4s\n" \ + "fmax v16.4s, v16.4s, v0.4s\n" \ + "fmax v17.4s, v17.4s, v0.4s\n" \ + "fmax v18.4s, v18.4s, v0.4s\n" \ + "ld1 {v1.4s}, [x0]\n" \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" \ + "ldr x0, [%[outl]]\n" \ + +#define RELU6 \ + "fmin v15.4s, v15.4s, v1.4s\n" \ + "fmin v16.4s, v16.4s, v1.4s\n" \ + "fmin v17.4s, v17.4s, v1.4s\n" \ + "fmin v18.4s, v18.4s, v1.4s\n" \ + "fmin v19.4s, v19.4s, v1.4s\n" \ + "fmin v20.4s, v20.4s, v1.4s\n" \ + "fmin v21.4s, v21.4s, v1.4s\n" \ + "fmin v22.4s, v22.4s, v1.4s\n" + +#define LEAKY_RELU \ + "movi v0.4s, #0\n" /* for relu */ \ + "ldr x0, [%[outl], #88]\n" \ + "cmhs v1.4s, v15.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v2.4s, v16.4s, v0.4s \n" /* vcgeq_u32 */ \ + "ld1 {v9.4s}, [x0] \n" \ + "cmhs v3.4s, v17.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v4.4s, v18.4s, v0.4s \n" /* vcgeq_u32 */ \ + "ldr x0, [%[outl]] \n" \ + "fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \ + "fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \ + "fmul v7.4s, v17.4s, v9.4s \n" /* mul */ \ + "fmul v8.4s, v18.4s, v9.4s \n" /* mul */ \ + "bif v15.16b, v5.16b, v1.16b \n" /* choose*/ \ + "bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \ + "bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \ + "bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \ + "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v2.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v3.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v4.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \ + "fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \ + "fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \ + "fmul v8.4s, v22.4s, v9.4s \n" /* mul */ \ + "bif v19.16b, v5.16b, v1.16b \n" /* choose*/ \ + "bif v20.16b, v6.16b, v2.16b \n" /* choose*/ \ + "bif v21.16b, v7.16b, v3.16b \n" /* choose*/ \ + "bif v22.16b, v8.16b, v4.16b \n" /* choose*/ + +#define STORE \ + "cbnz %w[flag_mask], 1f\n" \ + "str q15, [x0]\n" /* save outc00 */ \ + "str q16, [x4]\n" /* save outc01 */ \ + "str q17, [x1]\n" /* save outc10 */ \ + "str q18, [x5]\n" /* save outc11 */ \ + "str q19, [x2]\n" /* save outc20 */ \ + "str q20, [x6]\n" /* save outc21 */ \ + "str q21, [x3]\n" /* save outc30 */ \ + "str q22, [x7]\n" /* save outc31 */ \ + "b 2f\n" \ + "1:\n" \ + "str q15, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q17, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q19, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q21, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q16, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q18, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q20, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q22, [%[out]], #16 \n" /* save remain to pre_out */ \ + "2:\n" +#else +#define COMPUTE \ + /* load weights */ \ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" \ + /* load r0, r1 */ \ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" \ + "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" \ + /* main loop */ \ + "0: @ main loop\n" \ + /* mul r0 with w0, w1, w2, get out r0 */ \ + "vmul.f32 q8, q5, q0 @ w0 * inr00\n" \ + "vmul.f32 q9, q5, q1 @ w0 * inr01\n" \ + "vmul.f32 q10, q5, q2 @ w0 * inr02\n" \ + "vmul.f32 q11, q5, q3 @ w0 * inr03\n" \ + "vmla.f32 q8, q6, q1 @ w1 * inr01\n" \ + "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" \ + "vmla.f32 q9, q6, q2 @ w1 * inr02\n" \ + "vmla.f32 q10, q6, q3 @ w1 * inr03\n" \ + "vmla.f32 q11, q6, q0 @ w1 * inr04\n" \ + "vmla.f32 q8, q7, q2 @ w2 * inr02\n" \ + "vmla.f32 q9, q7, q3 @ w2 * inr03\n" \ + "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" \ + "vmla.f32 q10, q7, q0 @ w2 * inr04\n" \ + "vmla.f32 q11, q7, q1 @ w2 * inr05\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" \ + "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" \ + /* mul r1 with w0-w5, get out r0, r1 */ \ + "vmul.f32 q12, q5, q2 @ w0 * inr10\n" \ + "vmul.f32 q13, q5, q3 @ w0 * inr11\n" \ + "vmul.f32 q14, q5, q0 @ w0 * inr12\n" \ + "vmul.f32 q15, q5, q1 @ w0 * inr13\n" \ + "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" \ + "vmla.f32 q8, q4, q2 @ w3 * inr10\n" \ + "vmla.f32 q9, q4, q3 @ w3 * inr11\n" \ + "vmla.f32 q10, q4, q0 @ w3 * inr12\n" \ + "vmla.f32 q11, q4, q1 @ w3 * inr13\n" \ + /* mul r1 with w1, w4, get out r1, r0 */ \ + "vmla.f32 q8, q5, q3 @ w4 * inr11\n" \ + "vmla.f32 q12, q6, q3 @ w1 * inr11\n" \ + "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" \ + "vmla.f32 q9, q5, q0 @ w4 * inr12\n" \ + "vmla.f32 q13, q6, q0 @ w1 * inr12\n" \ + "vmla.f32 q10, q5, q1 @ w4 * inr13\n" \ + "vmla.f32 q14, q6, q1 @ w1 * inr13\n" \ + "vmla.f32 q11, q5, q2 @ w4 * inr14\n" \ + "vmla.f32 q15, q6, q2 @ w1 * inr14\n" \ + "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" \ + /* mul r1 with w2, w5, get out r1, r0 */ \ + "vmla.f32 q12, q7, q0 @ w2 * inr12\n" \ + "vmla.f32 q13, q7, q1 @ w2 * inr13\n" \ + "vmla.f32 q8, q6, q0 @ w5 * inr12\n" \ + "vmla.f32 q9, q6, q1 @ w5 * inr13\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" \ + "vmla.f32 q14, q7, q2 @ w2 * inr14\n" \ + "vmla.f32 q15, q7, q3 @ w2 * inr15\n" \ + "vmla.f32 q10, q6, q2 @ w5 * inr14\n" \ + "vmla.f32 q11, q6, q3 @ w5 * inr15\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" \ + /* mul r2 with w3-w8, get out r0, r1 */ \ + "vmla.f32 q12, q4, q0 @ w3 * inr20\n" \ + "vmla.f32 q13, q4, q1 @ w3 * inr21\n" \ + "vmla.f32 q14, q4, q2 @ w3 * inr22\n" \ + "vmla.f32 q15, q4, q3 @ w3 * inr23\n" \ + "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" \ + "vmla.f32 q8, q7, q0 @ w6 * inr20\n" \ + "vmla.f32 q9, q7, q1 @ w6 * inr21\n" \ + "vmla.f32 q10, q7, q2 @ w6 * inr22\n" \ + "vmla.f32 q11, q7, q3 @ w6 * inr23\n" \ + /* mul r2 with w4, w7, get out r1, r0 */ \ + "vmla.f32 q8, q4, q1 @ w7 * inr21\n" \ + "vmla.f32 q12, q5, q1 @ w4 * inr21\n" \ + "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" \ + "vmla.f32 q9, q4, q2 @ w7 * inr22\n" \ + "vmla.f32 q13, q5, q2 @ w4 * inr22\n" \ + "vmla.f32 q10, q4, q3 @ w7 * inr23\n" \ + "vmla.f32 q14, q5, q3 @ w4 * inr23\n" \ + "vmla.f32 q11, q4, q0 @ w7 * inr24\n" \ + "vmla.f32 q15, q5, q0 @ w4 * inr24\n" \ + "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" \ + /* mul r1 with w5, w8, get out r1, r0 */ \ + "vmla.f32 q12, q6, q2 @ w5 * inr22\n" \ + "vmla.f32 q13, q6, q3 @ w5 * inr23\n" \ + "vmla.f32 q8, q5, q2 @ w8 * inr22\n" \ + "vmla.f32 q9, q5, q3 @ w8 * inr23\n" \ + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" \ + "ldr r4, [%[outl], #32] @ load bias addr to r4\n" \ + "vmla.f32 q14, q6, q0 @ w5 * inr24\n" \ + "vmla.f32 q15, q6, q1 @ w5 * inr25\n" \ + "vmla.f32 q10, q5, q0 @ w8 * inr24\n" \ + "vmla.f32 q11, q5, q1 @ w8 * inr25\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" \ + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \ + /* mul r3 with w6, w7, w8, get out r1 */ \ + "vmla.f32 q12, q7, q2 @ w6 * inr30\n" \ + "vmla.f32 q13, q7, q3 @ w6 * inr31\n" \ + "vmla.f32 q14, q7, q0 @ w6 * inr32\n" \ + "vmla.f32 q15, q7, q1 @ w6 * inr33\n" \ + "vmla.f32 q12, q4, q3 @ w7 * inr31\n" \ + "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" \ + "vld1.32 {d12-d13}, [r4] @ load bias\n" \ + "vmla.f32 q13, q4, q0 @ w7 * inr32\n" \ + "vmla.f32 q14, q4, q1 @ w7 * inr33\n" \ + "vmla.f32 q15, q4, q2 @ w7 * inr34\n" \ + "ldr r0, [%[outl]] @ load outc00 to r0\n" \ + "vmla.f32 q12, q5, q0 @ w8 * inr32\n" \ + "vmla.f32 q13, q5, q1 @ w8 * inr33\n" \ + "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \ + "vmla.f32 q14, q5, q2 @ w8 * inr34\n" \ + "vmla.f32 q15, q5, q3 @ w8 * inr35\n" \ + "ldr r1, [%[outl], #4] @ load outc10 to r1\n" \ + "vadd.f32 q8, q8, q6 @ r00 add bias\n" \ + "vadd.f32 q9, q9, q6 @ r01 add bias\n" \ + "vadd.f32 q10, q10, q6 @ r02 add bias\n" \ + "vadd.f32 q11, q11, q6 @ r03 add bias\n" \ + "ldr r2, [%[outl], #8] @ load outc20 to r2\n" \ + "vadd.f32 q12, q12, q6 @ r10 add bias\n" \ + "vadd.f32 q13, q13, q6 @ r11 add bias\n" \ + "vadd.f32 q14, q14, q6 @ r12 add bias\n" \ + "vadd.f32 q15, q15, q6 @ r13 add bias\n" \ + "ldr r3, [%[outl], #12] @ load outc30 to r3\n" \ + "vmov.u32 q7, #0 @ mov zero to q7\n" +#define RELU \ + "vmax.f32 q8, q8, q7 @ r00 relu\n" \ + "vmax.f32 q9, q9, q7 @ r01 relu\n" \ + "vmax.f32 q10, q10, q7 @ r02 relu\n" \ + "vmax.f32 q11, q11, q7 @ r03 relu\n" \ + "vmax.f32 q12, q12, q7 @ r10 relu\n" \ + "vmax.f32 q13, q13, q7 @ r11 relu\n" \ + "vmax.f32 q14, q14, q7 @ r12 relu\n" \ + "vmax.f32 q15, q15, q7 @ r13 relu\n" + +#define RELU6 \ + "ldr r4, [%[outl], #40] @ load six to r4\n" \ + "vld1.32 {d12-d13}, [r4] @load data \n" \ + "vmin.f32 q8, q8, q6 @ r00 relu\n" \ + "vmin.f32 q9, q9, q6 @ r01 relu\n" \ + "vmin.f32 q10, q10, q6 @ r02 relu\n" \ + "vmin.f32 q11, q11, q6 @ r03 relu\n" \ + "vmin.f32 q12, q12, q6 @ r10 relu\n" \ + "vmin.f32 q13, q13, q6 @ r11 relu\n" \ + "vmin.f32 q14, q14, q6 @ r12 relu\n" \ + "vmin.f32 q15, q15, q6 @ r13 relu\n" + +#define LEAKY_RELU \ + "ldr r4, [%[outl], #44] @ load scale to r4\n" \ + "vld1.32 {d12-d13}, [r4] @load data \n" \ + "vcge.f32 q0, q8, q7 @ q0 > 0 \n" \ + "vcge.f32 q1, q9, q7 @ q0 > 0 \n" \ + "vmul.f32 q4, q8, q6 \n" \ + "vmul.f32 q5, q9, q6 \n" \ + "vcge.f32 q2, q10, q7 @ q0 > 0 \n" \ + "vcge.f32 q3, q11, q7 @ q0 > 0 \n" \ + "vbif q8, q4, q0 @ choose \n" \ + "vbif q9, q5, q1 @ choose \n" \ + "vmul.f32 q4, q10, q6 \n" \ + "vmul.f32 q5, q11, q6 \n" \ + "vbif q10, q4, q2 @ choose \n" \ + "vbif q11, q5, q3 @ choose \n" \ + "vcge.f32 q0, q12, q7 @ q0 > 0 \n" \ + "vcge.f32 q1, q13, q7 @ q0 > 0 \n" \ + "vmul.f32 q4, q12, q6 \n" \ + "vmul.f32 q5, q13, q6 \n" \ + "vcge.f32 q2, q14, q7 @ q0 > 0 \n" \ + "vcge.f32 q3, q15, q7 @ q0 > 0 \n" \ + "vbif q12, q4, q0 @ choose \n" \ + "vbif q13, q5, q1 @ choose \n" \ + "vmul.f32 q4, q14, q6 \n" \ + "vmul.f32 q5, q15, q6 \n" \ + "vbif q14, q4, q2 @ choose \n" \ + "vbif q15, q5, q3 @ choose \n" + +#define STORE \ + "ldr r4, [%[outl], #16] @ load outc01 to r4\n" \ + "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" \ + "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \ + "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \ + "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \ + "ldr r5, [%[outl], #20] @ load outc11 to r5\n" \ + "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \ + "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \ + "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \ + "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" \ + "cmp %[flag_mask], #0 @ cmp flag mask\n" \ + "bne 2f\n" \ + "vst1.32 {d16-d17}, [r0] @ save outc00\n" \ + "vst1.32 {d18-d19}, [r1] @ save outc10\n" \ + "vst1.32 {d20-d21}, [r2] @ save outc20\n" \ + "vst1.32 {d22-d23}, [r3] @ save outc30\n" \ + "vst1.32 {d24-d25}, [r4] @ save outc01\n" \ + "vst1.32 {d26-d27}, [r5] @ save outc11\n" \ + "ldr r0, [%[outl], #24] @ load outc21 to r0\n" \ + "ldr r1, [%[outl], #28] @ load outc31 to r1\n" \ + "vst1.32 {d28-d29}, [r0] @ save outc21\n" \ + "vst1.32 {d30-d31}, [r1] @ save outc31\n" \ + "b 3f @ branch end\n" \ + "2: \n" \ + "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" \ + "3: \n" +#endif +// clang-format on +void act_switch_3x3s1(const float* inr0, + const float* inr1, + const float* inr2, + const float* inr3, + float* out0, + const float* weight_c, + float flag_mask, + void* outl_ptr, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + float32x4_t w7, + float32x4_t w8, + float32x4_t vbias, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + } +} void conv_3x3s1_depthwise_fp32(const float* i_data, float* o_data, int bs, @@ -37,6 +816,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx) { int threads = ctx->threads(); @@ -78,6 +858,31 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, remain = remain > 0 ? remain : 0; int row_len = win_round * out_c_block; + float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; + float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + break; + case lite_api::ActivationType::kRelu6: + six_ptr[0] = act_param.Relu_clipped_coef; + six_ptr[1] = act_param.Relu_clipped_coef; + six_ptr[2] = act_param.Relu_clipped_coef; + six_ptr[3] = act_param.Relu_clipped_coef; + break; + case lite_api::ActivationType::kLeakyRelu: + scale_ptr[0] = act_param.Leaky_relu_alpha; + scale_ptr[1] = act_param.Leaky_relu_alpha; + scale_ptr[2] = act_param.Leaky_relu_alpha; + scale_ptr[3] = act_param.Leaky_relu_alpha; + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; @@ -147,6 +952,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outc21 = ptr_write; outc31 = ptr_write; } + float* outl[] = {outc00, outc10, outc20, @@ -156,361 +962,54 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outc21, outc31, reinterpret_cast(bias_local), - reinterpret_cast(flag_relu)}; + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; void* outl_ptr = reinterpret_cast(outl); for (int w = 0; w < w_loop; ++w) { bool flag_mask = (w == w_loop - 1) && flag_remain; float* out0 = pre_out; -// clang-format off #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ - "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ - "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ - "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ - "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ - "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ - /* r0, r1, mul w0, get out r0, r1 */ - "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ - "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ - "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ - "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ - "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ - "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ - "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ - "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ - /* r0, r1, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ - "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ - "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ - "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ - "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ - "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ - "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ - "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ - "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ - /* r0, r1, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ - "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ - "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ - "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ - "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ - "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ - "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ - "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ - "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ - /* r1, r2, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ - "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ - "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ - "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ - "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ - "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ - "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ - /* r1, r2, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ - "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ - "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ - "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ - "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ - "ldp x0, x1, [%[outl]] \n" - "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ - "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ - "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ - "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ - /* r1, r2, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ - "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ - "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ - "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ - "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ - "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ - "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ - "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ - "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ - /* r2, r3, mul w6, get out r0, r1 */ - "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ - "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ - "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ - "ldp x2, x3, [%[outl], #16] \n" - "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ - "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ - "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ - "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ - /* r2, r3, mul w7, get out r0, r1 */ - "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ - "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ - "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ - "ldp x4, x5, [%[outl], #32] \n" - "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ - "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ - "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ - "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ - /* r2, r3, mul w8, get out r0, r1 */ - "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ - "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ - "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ - "ldp x6, x7, [%[outl], #48] \n" - "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ - "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ - "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ - "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ - - "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ - "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ - "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ - "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ - "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ - "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ - "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ - "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ - - /* transpose */ - "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ - "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ - "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ - "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ - "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ - "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ - "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ - "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ - "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ - "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ - "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ - "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ - "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ - "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ - "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ - "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ - - "cbz %w[flag_relu], 0f\n" /* skip relu*/ - "movi v0.4s, #0\n" /* for relu */ - "fmax v15.4s, v15.4s, v0.4s\n" - "fmax v16.4s, v16.4s, v0.4s\n" - "fmax v17.4s, v17.4s, v0.4s\n" - "fmax v18.4s, v18.4s, v0.4s\n" - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - "0:\n" - "cbnz %w[flag_mask], 1f\n" - "str q15, [x0]\n" /* save outc00 */ - "str q16, [x4]\n" /* save outc01 */ - "str q17, [x1]\n" /* save outc10 */ - "str q18, [x5]\n" /* save outc11 */ - "str q19, [x2]\n" /* save outc20 */ - "str q20, [x6]\n" /* save outc21 */ - "str q21, [x3]\n" /* save outc30 */ - "str q22, [x7]\n" /* save outc31 */ - "b 2f\n" - "1:\n" - "str q15, [%[out]], #16 \n" /* save remain to pre_out */ - "str q17, [%[out]], #16 \n" /* save remain to pre_out */ - "str q19, [%[out]], #16 \n" /* save remain to pre_out */ - "str q21, [%[out]], #16 \n" /* save remain to pre_out */ - "str q16, [%[out]], #16 \n" /* save remain to pre_out */ - "str q18, [%[out]], #16 \n" /* save remain to pre_out */ - "str q20, [%[out]], #16 \n" /* save remain to pre_out */ - "str q22, [%[out]], #16 \n" /* save remain to pre_out */ - "2:\n" - :[inr0] "+r"(inr0), [inr1] "+r"(inr1), - [inr2] "+r"(inr2), [inr3] "+r"(inr3), - [out]"+r"(out0) - :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), - [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), - [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), - [vbias]"w" (vbias), [outl] "r" (outl_ptr), - [flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu) - : "cc", "memory", - "v0","v1","v2","v3","v4","v5","v6","v7", - "v8", "v9", "v10", "v11", "v15", - "v16","v17","v18","v19","v20","v21","v22", - "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7" - ); + act_switch_3x3s1(inr0, + inr1, + inr2, + inr3, + out0, + weight_c, + flag_mask, + outl_ptr, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + w7, + w8, + vbias, + act_param); #else - asm volatile( - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" - /* load r0, r1 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" - "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" - /* main loop */ - "0: @ main loop\n" - /* mul r0 with w0, w1, w2, get out r0 */ - "vmul.f32 q8, q5, q0 @ w0 * inr00\n" - "vmul.f32 q9, q5, q1 @ w0 * inr01\n" - "vmul.f32 q10, q5, q2 @ w0 * inr02\n" - "vmul.f32 q11, q5, q3 @ w0 * inr03\n" - "vmla.f32 q8, q6, q1 @ w1 * inr01\n" - "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" - "vmla.f32 q9, q6, q2 @ w1 * inr02\n" - "vmla.f32 q10, q6, q3 @ w1 * inr03\n" - "vmla.f32 q11, q6, q0 @ w1 * inr04\n" - "vmla.f32 q8, q7, q2 @ w2 * inr02\n" - "vmla.f32 q9, q7, q3 @ w2 * inr03\n" - "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" - "vmla.f32 q10, q7, q0 @ w2 * inr04\n" - "vmla.f32 q11, q7, q1 @ w2 * inr05\n" - "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" - "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" - /* mul r1 with w0-w5, get out r0, r1 */ - "vmul.f32 q12, q5, q2 @ w0 * inr10\n" - "vmul.f32 q13, q5, q3 @ w0 * inr11\n" - "vmul.f32 q14, q5, q0 @ w0 * inr12\n" - "vmul.f32 q15, q5, q1 @ w0 * inr13\n" - "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" - "vmla.f32 q8, q4, q2 @ w3 * inr10\n" - "vmla.f32 q9, q4, q3 @ w3 * inr11\n" - "vmla.f32 q10, q4, q0 @ w3 * inr12\n" - "vmla.f32 q11, q4, q1 @ w3 * inr13\n" - /* mul r1 with w1, w4, get out r1, r0 */ - "vmla.f32 q8, q5, q3 @ w4 * inr11\n" - "vmla.f32 q12, q6, q3 @ w1 * inr11\n" - "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" - "vmla.f32 q9, q5, q0 @ w4 * inr12\n" - "vmla.f32 q13, q6, q0 @ w1 * inr12\n" - "vmla.f32 q10, q5, q1 @ w4 * inr13\n" - "vmla.f32 q14, q6, q1 @ w1 * inr13\n" - "vmla.f32 q11, q5, q2 @ w4 * inr14\n" - "vmla.f32 q15, q6, q2 @ w1 * inr14\n" - "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" - /* mul r1 with w2, w5, get out r1, r0 */ - "vmla.f32 q12, q7, q0 @ w2 * inr12\n" - "vmla.f32 q13, q7, q1 @ w2 * inr13\n" - "vmla.f32 q8, q6, q0 @ w5 * inr12\n" - "vmla.f32 q9, q6, q1 @ w5 * inr13\n" - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" - "vmla.f32 q14, q7, q2 @ w2 * inr14\n" - "vmla.f32 q15, q7, q3 @ w2 * inr15\n" - "vmla.f32 q10, q6, q2 @ w5 * inr14\n" - "vmla.f32 q11, q6, q3 @ w5 * inr15\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" - /* mul r2 with w3-w8, get out r0, r1 */ - "vmla.f32 q12, q4, q0 @ w3 * inr20\n" - "vmla.f32 q13, q4, q1 @ w3 * inr21\n" - "vmla.f32 q14, q4, q2 @ w3 * inr22\n" - "vmla.f32 q15, q4, q3 @ w3 * inr23\n" - "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" - "vmla.f32 q8, q7, q0 @ w6 * inr20\n" - "vmla.f32 q9, q7, q1 @ w6 * inr21\n" - "vmla.f32 q10, q7, q2 @ w6 * inr22\n" - "vmla.f32 q11, q7, q3 @ w6 * inr23\n" - /* mul r2 with w4, w7, get out r1, r0 */ - "vmla.f32 q8, q4, q1 @ w7 * inr21\n" - "vmla.f32 q12, q5, q1 @ w4 * inr21\n" - "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" - "vmla.f32 q9, q4, q2 @ w7 * inr22\n" - "vmla.f32 q13, q5, q2 @ w4 * inr22\n" - "vmla.f32 q10, q4, q3 @ w7 * inr23\n" - "vmla.f32 q14, q5, q3 @ w4 * inr23\n" - "vmla.f32 q11, q4, q0 @ w7 * inr24\n" - "vmla.f32 q15, q5, q0 @ w4 * inr24\n" - "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" - /* mul r1 with w5, w8, get out r1, r0 */ - "vmla.f32 q12, q6, q2 @ w5 * inr22\n" - "vmla.f32 q13, q6, q3 @ w5 * inr23\n" - "vmla.f32 q8, q5, q2 @ w8 * inr22\n" - "vmla.f32 q9, q5, q3 @ w8 * inr23\n" - "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" - "ldr r4, [%[outl], #32] @ load bias addr to r4\n" - "vmla.f32 q14, q6, q0 @ w5 * inr24\n" - "vmla.f32 q15, q6, q1 @ w5 * inr25\n" - "vmla.f32 q10, q5, q0 @ w8 * inr24\n" - "vmla.f32 q11, q5, q1 @ w8 * inr25\n" - "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" - "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" - /* mul r3 with w6, w7, w8, get out r1 */ - "vmla.f32 q12, q7, q2 @ w6 * inr30\n" - "vmla.f32 q13, q7, q3 @ w6 * inr31\n" - "vmla.f32 q14, q7, q0 @ w6 * inr32\n" - "vmla.f32 q15, q7, q1 @ w6 * inr33\n" - "vmla.f32 q12, q4, q3 @ w7 * inr31\n" - "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" - "vld1.32 {d12-d13}, [r4] @ load bias\n" - "vmla.f32 q13, q4, q0 @ w7 * inr32\n" - "vmla.f32 q14, q4, q1 @ w7 * inr33\n" - "vmla.f32 q15, q4, q2 @ w7 * inr34\n" - "ldr r0, [%[outl]] @ load outc00 to r0\n" - "vmla.f32 q12, q5, q0 @ w8 * inr32\n" - "vmla.f32 q13, q5, q1 @ w8 * inr33\n" - "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" - "vmla.f32 q14, q5, q2 @ w8 * inr34\n" - "vmla.f32 q15, q5, q3 @ w8 * inr35\n" - "ldr r1, [%[outl], #4] @ load outc10 to r1\n" - "vadd.f32 q8, q8, q6 @ r00 add bias\n" - "vadd.f32 q9, q9, q6 @ r01 add bias\n" - "vadd.f32 q10, q10, q6 @ r02 add bias\n" - "vadd.f32 q11, q11, q6 @ r03 add bias\n" - "ldr r2, [%[outl], #8] @ load outc20 to r2\n" - "vadd.f32 q12, q12, q6 @ r10 add bias\n" - "vadd.f32 q13, q13, q6 @ r11 add bias\n" - "vadd.f32 q14, q14, q6 @ r12 add bias\n" - "vadd.f32 q15, q15, q6 @ r13 add bias\n" - "ldr r3, [%[outl], #12] @ load outc30 to r3\n" - "vmov.u32 q7, #0 @ mov zero to q7\n" - "cmp r5, #0 @ cmp flag relu\n" - "beq 1f @ skip relu\n" - "vmax.f32 q8, q8, q7 @ r00 relu\n" - "vmax.f32 q9, q9, q7 @ r01 relu\n" - "vmax.f32 q10, q10, q7 @ r02 relu\n" - "vmax.f32 q11, q11, q7 @ r03 relu\n" - "vmax.f32 q12, q12, q7 @ r10 relu\n" - "vmax.f32 q13, q13, q7 @ r11 relu\n" - "vmax.f32 q14, q14, q7 @ r12 relu\n" - "vmax.f32 q15, q15, q7 @ r13 relu\n" - "1:\n" - "ldr r4, [%[outl], #16] @ load outc01 to r4\n" - "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" - "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" - "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" - "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" - "ldr r5, [%[outl], #20] @ load outc11 to r5\n" - "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" - "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" - "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" - "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" - "cmp %[flag_mask], #0 @ cmp flag mask\n" - "bne 2f\n" - "vst1.32 {d16-d17}, [r0] @ save outc00\n" - "vst1.32 {d18-d19}, [r1] @ save outc10\n" - "vst1.32 {d20-d21}, [r2] @ save outc20\n" - "vst1.32 {d22-d23}, [r3] @ save outc30\n" - "vst1.32 {d24-d25}, [r4] @ save outc01\n" - "vst1.32 {d26-d27}, [r5] @ save outc11\n" - "ldr r0, [%[outl], #24] @ load outc21 to r0\n" - "ldr r1, [%[outl], #28] @ load outc31 to r1\n" - "vst1.32 {d28-d29}, [r0] @ save outc21\n" - "vst1.32 {d30-d31}, [r1] @ save outc31\n" - "b 3f @ branch end\n" - "2: \n" - "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" - "3: \n" - : [r0] "+r"(inr0), [r1] "+r"(inr1), - [r2] "+r"(inr2), [r3] "+r"(inr3), - [out0] "+r"(out0), [wc0] "+r"(weight_c) - : [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr) - : "cc", "memory", - "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5" - ); -#endif // __arch64__ - // clang-format on + act_switch_3x3s1(inr0, + inr1, + inr2, + inr3, + out0, + weight_c, + flag_mask, + outl_ptr, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + act_param); +#endif outl[0] += 4; outl[1] += 4; outl[2] += 4; @@ -519,6 +1018,10 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outl[5] += 4; outl[6] += 4; outl[7] += 4; + inr0 += 16; + inr1 += 16; + inr2 += 16; + inr3 += 16; if (flag_mask) { memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); diff --git a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc index 807135f57dfadf690277ab57bd5597e9470ae549..f5b196efcca3f3f35367f2fea5e8f475b7147f48 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc @@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, //! prepack input to tmp buffer //! write output to tmp buffer auto paddings = *param.paddings; + auto act_param = param.activation_param; const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); const int pad_w = paddings[2]; @@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } #pragma omp parallel for num_threads(threads) @@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } } } diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index 455781e37e0747950e6740f6db45c1ce8c0e96c8..602239a1fe1675c6eecb5b45a8e526ada98a56bb 100644 --- a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ "fadd v16.4s, v16.4s, v11.4s \n" \ - "fadd v16.4s, v16.4s, v12.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" #define LEFT_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ - \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \ \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ @@ -244,53 +242,52 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "blt 1f \n" -#define MID_COMPUTE_S2 \ - "2: \n" /* r0 */ \ - "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ - "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ - "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ - \ - "ext v10.16b, v2.16b, v18.16b, #4 \n" \ - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ - "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ - "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ - "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ - \ - "ext v10.16b, v4.16b, v19.16b, #4 \n" \ - \ - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ - "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ - "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ - \ - "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ - "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ - \ - "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ - "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ext v10.16b, v6.16b, v20.16b, #4 \n" \ - \ - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ - "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ - "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ - \ - "ext v10.16b, v8.16b, v21.16b, #4 \n" \ - \ - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ - \ - "fadd v16.4s, v16.4s, v11.4s \n" \ - "fadd v16.4s, v16.4s, v12.4s \n" +#define MID_COMPUTE_S2 \ + "2: \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, v18.16b, #4 \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, v19.16b, #4 \n" \ + \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, v20.16b, #4 \n" \ + \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, v21.16b, #4 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" #define MID_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "fadd v16.4s, v16.4s, v11.4s \n" \ "fadd v16.4s, v16.4s, v12.4s \n" \ - "ld1 {v1.4s}, [%[outptr1]] \n" + "ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" #define RIGHT_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ "bif v16.16b, v0.16b, %[wmask].16b \n" \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "4: \n" #define LEFT_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ - \ "fmax v16.4s, v16.4s, %[vzero].4s \n" \ \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ @@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "blt 1f \n" #define MID_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "bne 2b \n" #define RIGHT_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index e4279d9a728bc7af0f14a00b781db449fc426582..c4fe965d0b17fa56d76812af14b40bddbc5b313a 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -20,6 +20,7 @@ #include "lite/backends/arm/math/sgemm.h" #include "lite/backends/arm/math/type_trans.h" #include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" #include "lite/utils/cp_logging.h" namespace paddle { @@ -28,6 +29,7 @@ namespace arm { namespace math { #define LITEMAX(a, b) ((a) > (b) ? (a) : (b)) +#define LITEMIN(a, b) ((a) < (b) ? (a) : (b)) #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) template @@ -589,7 +591,238 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din, } } } +// clang-format off +#ifdef __aarch64__ +#define NCHWC1_TRANS_FP32_COMPUTE \ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ + +#define NCHWC1_TRANS_FP32_RELU \ + "fmax v0.4s, v0.4s, v20.4s \n" /*relu*/ \ + "fmax v1.4s, v1.4s, v20.4s \n" /*relu*/ \ + "fmax v2.4s, v2.4s, v20.4s \n" /*relu*/ \ + "fmax v3.4s, v3.4s, v20.4s \n" /*relu*/ + +#define NCHWC1_TRANS_FP32_RELU6 \ + "fmin v0.4s, v0.4s, %[six].4s \n" /* relu6 */ \ + "fmin v1.4s, v1.4s, %[six].4s \n" /* relu6 */ \ + "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ + "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC1_TRANS_FP32_LEAKY_RELU \ + "cmhs v4.4s, v0.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v5.4s, v1.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \ + "fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \ + "fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v11.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \ + "bif v2.16b, v10.16b, v6.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v7.16b \n" /* choose*/ + +#define NCHWC1_TRANS_FP32_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + \ + "str q0, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q3, [%[doutc0r0]], #16 \n" /* store c2r0*/ \ + "ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC1_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0 \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data, c0r0 \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" +#define NCHWC1_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" + +#define NCHWC1_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" \ + "vmin.f32 q2, q2, %q[six] @ relu6 \n" \ + "vmin.f32 q3, q3, %q[six] @ relu6 \n" + +#define NCHWC1_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q7, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q8, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vmul.f32 q11, q2, %q[scale] \n" \ + "vmul.f32 q12, q3, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" \ + "vbif q2, q11, q7 @ choose \n" \ + "vbif q3, q12, q8 @ choose \n" + +#define NCHWC1_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ + "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \ + \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif +// clang-format on +inline void act_switch_c1_fp32(const float* din_ptr, + float* doutc0_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", "v1", "v2", "v3", "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c, h, w], output dout: [n, c, h, w] */ @@ -605,13 +838,14 @@ inline bool write_to_output_c1_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (cs > channel) { return true; } const int c1 = 1; - const int w4 = 4; + const int w4 = 16; int size_c_out = width * height; @@ -623,98 +857,53 @@ inline bool write_to_output_c1_fp32(const float* din, int w_round = we - ws; int cnt = (width - ws) / w4; - + int remain = (width - ws) % w4; for (int i = 0; i < size_h; i++) { int size_w = i * width; float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; const float* din_hei_ptr = ptr_din + i * w_round * c1; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile( - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "fmax v1.4s, v0.4s, v20.4s \n" /*relu*/ - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "bne 1b \n" /* jump to main loop*/ - : [doutc0r0] "+r"(doutc0_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v20"); -#else - asm volatile( - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - - "vmax.f32 q1, q0, q15 @ relu\n" - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" - - "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile( - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "1: \n" /* main loop*/ - "str q0, [%[doutc0r0]], #16 \n" /* store c2r0*/ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0"); -#else - asm volatile( - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " - "c0r1, c0r2, c0r3\n" - "1: @ main loop\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0"); -#endif - } + act_switch_c1_fp32(din_hei_ptr, doutc0_ptr, cnt_loop, act_param); } - if (we > width) { + if (remain > 0) { int offset = i * w_round * c1 + c1 * w4 * cnt; din_hei_ptr = ptr_din + offset; - int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - din_hei_ptr++; + doutc0_ptr += w4 * cnt; + int j = w4 * cnt; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + din_hei_ptr++; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < width; ++j) { + float tmp = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp, six); + din_hei_ptr++; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < width; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + din_hei_ptr++; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < width; ++j) { @@ -725,6 +914,7 @@ inline bool write_to_output_c1_fp32(const float* din, } return true; } +// clang-format off #ifdef __aarch64__ #define NCHWC2_TRANS_FP32_COMPUTE \ "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \ @@ -740,6 +930,18 @@ inline bool write_to_output_c1_fp32(const float* din, "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ \ "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ +#define NCHWC2_TRANS_FP32_RELU6 \ + "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ + "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \ + "bif v3.16b, v5.16b, v7.16b \n" /* choose*/ + #define NCHWC2_TRANS_FP32_STORE \ "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ \ @@ -749,8 +951,7 @@ inline bool write_to_output_c1_fp32(const float* din, "bne 1b \n" /* jump to main loop*/ #else #define NCHWC2_TRANS_FP32_COMPUTE \ - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " \ - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, c1r0 \n" \ "vmov.u32 q15, #0 @ dump zero\n" \ "1: @ main loop\n" \ "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " \ @@ -764,11 +965,21 @@ inline bool write_to_output_c1_fp32(const float* din, "vmax.f32 q0, q0, q15 @ relu\n" \ "vmax.f32 q1, q1, q15 @ relu\n" +#define NCHWC2_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" + +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" + #define NCHWC2_TRANS_FP32_STORE \ - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ - "pointer\n" \ - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " \ - "pointer\n" \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ \ "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ \ @@ -776,6 +987,151 @@ inline bool write_to_output_c1_fp32(const float* din, \ "bne 1b @ jump to main loop\n" #endif +// clang-format on +inline void act_switch_c2_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] */ @@ -791,11 +1147,11 @@ inline bool write_to_output_c2_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (cs > channel) { return true; } - const int c2 = 2; const int w4 = 4; @@ -828,55 +1184,56 @@ inline bool write_to_output_c2_fp32(const float* din, const float* din_hei_ptr = ptr_din + i * w_round * c2; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU - NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); -#else - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU - NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5"); -#else - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } + act_switch_c2_fp32( + din_hei_ptr, doutc0_ptr, doutc1_ptr, cnt_loop, act_param); } if (we > width) { int offset = i * w_round * c2 + c2 * w4 * cnt; din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - din_hei_ptr += 2; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + din_hei_ptr += 2; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < width; ++j) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + din_hei_ptr += 2; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < width; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + din_hei_ptr += 2; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < width; ++j) { @@ -888,7 +1245,7 @@ inline bool write_to_output_c2_fp32(const float* din, } return true; } - +// clang-format off #ifdef __aarch64__ #define NCHWC4_TRANS_FP32_COMPUTE \ "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ @@ -912,6 +1269,26 @@ inline bool write_to_output_c2_fp32(const float* din, "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \ "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ +#define NCHWC4_TRANS_FP32_RELU6 \ + "fmin v16.4s, v16.4s, %[six].4s \n" /* relu6 */ \ + "fmin v17.4s, v17.4s, %[six].4s \n" /* relu6 */ \ + "fmin v18.4s, v18.4s, %[six].4s \n" /* relu6 */ \ + "fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "cmhs v8.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v9.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v10.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v11.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \ + "bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \ + "bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \ + "bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \ + "bif v19.16b, v7.16b, v11.16b \n" /* choose*/ + #define NCHWC4_TRANS_FP32_STORE \ "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ @@ -940,6 +1317,26 @@ inline bool write_to_output_c2_fp32(const float* din, "vmax.f32 q2, q2, q15 @ relu\n" \ "vmax.f32 q3, q3, q15 @ relu\n" +#define NCHWC4_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" \ + "vmin.f32 q2, q2, %q[six] @ relu6 \n" \ + "vmin.f32 q3, q3, %q[six] @ relu6 \n" + +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q7, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q8, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vmul.f32 q11, q2, %q[scale] \n" \ + "vmul.f32 q12, q3, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" \ + "vbif q2, q11, q7 @ choose \n" \ + "vbif q3, q12, q8 @ choose \n" + #define NCHWC4_TRANS_FP32_STORE \ "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ @@ -953,68 +1350,19 @@ inline bool write_to_output_c2_fp32(const float* din, \ "bne 1b @ jump to main loop\n" #endif -/*wirte result in outputs -* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] -*/ -inline bool write_to_output_c4_fp32(const float* din, - float* dout, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - float* trash_ptr) { - const int c4 = 4; - const int w4 = 4; - const int w_round = we - ws; - const int ch_n = ce - cs; - if (ch_n != 4) { - LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is " - "more than zero"; - return false; - } - int size_c_out = width * height; - - float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - float* doutc1r0 = doutc0r0 + size_c_out; - float* doutc2r0 = doutc1r0 + size_c_out; - float* doutc3r0 = doutc2r0 + size_c_out; - - const float* ptr_din = din; - - int size_h = (he > height ? height : he) - hs; // size_h == hei_n - - int valid_we = we > width ? width : we; - int cnt = (valid_we - ws) / w4; - int remain = valid_we - ws - cnt * w4; - - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - const float* din_hei_ptr = ptr_din + i * w_round * ch_n; - if (cnt > 0) { - int cnt_loop = cnt; - if (flag_relu) { +// clang-format on +inline void act_switch_c4_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + float* doutc2_ptr, + float* doutc3_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: #ifdef __aarch64__ asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU NCHWC4_TRANS_FP32_STORE @@ -1023,7 +1371,7 @@ inline bool write_to_output_c4_fp32(const float* din, [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) + [ptr_din] "+r"(din_ptr) : : "v0", "v1", @@ -1052,57 +1400,290 @@ inline bool write_to_output_c4_fp32(const float* din, [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) : : "q0", "q1", "q2", "q3", "q15"); #endif - } else { + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ #ifdef __aarch64__ - asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) : "v0", "v1", "v2", "v3", + "v4", + "v5", + "v6", + "v7", "v8", "v9", "v10", "v11", + "v12", + "v13", + "v14", "v16", "v17", "v18", - "v19"); + "v19", + "v20"); #else - asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3"); + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); #endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v8", + "v9", + "v10", + "v11", + "v16", + "v17", + "v18", + "v19"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} +/*wirte result in outputs +* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c4_fp32(const float* din, + float* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr, + operators::ActivationParam* act_param) { + const int c4 = 4; + const int w4 = 4; + const int w_round = we - ws; + const int ch_n = ce - cs; + + if (ch_n != 4) { + LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is " + "more than zero"; + return false; + } + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_we = we > width ? width : we; + int cnt = (valid_we - ws) / w4; + int remain = valid_we - ws - cnt * w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; } } + const float* din_hei_ptr = ptr_din + i * w_round * ch_n; + if (cnt > 0) { + int cnt_loop = cnt; + act_switch_c4_fp32(din_hei_ptr, + doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + cnt_loop, + act_param); + } if (remain > 0) { int offset = i * w_round * c4 + c4 * w4 * cnt; din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; + doutc2_ptr += w4 * cnt; + doutc3_ptr += w4 * cnt; int j = 0; - if (flag_relu) { - for (; j < remain; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); - din_hei_ptr += w4; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < remain; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + din_hei_ptr += 4; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < remain; ++j) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + float tmp3 = LITEMAX(din_hei_ptr[2], 0.f); + float tmp4 = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + *(doutc2_ptr++) = LITEMIN(tmp3, six); + *(doutc3_ptr++) = LITEMIN(tmp4, six); + din_hei_ptr += 4; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < remain; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + if (din_hei_ptr[2] >= 0) { + *(doutc2_ptr++) = din_hei_ptr[2]; + } else { + *(doutc2_ptr++) = din_hei_ptr[2] * scale; + } + if (din_hei_ptr[3] >= 0) { + *(doutc3_ptr++) = din_hei_ptr[3]; + } else { + *(doutc3_ptr++) = din_hei_ptr[3] * scale; + } + din_hei_ptr += 4; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < remain; ++j) { @@ -1110,14 +1691,14 @@ inline bool write_to_output_c4_fp32(const float* din, *(doutc1_ptr++) = din_hei_ptr[1]; *(doutc2_ptr++) = din_hei_ptr[2]; *(doutc3_ptr++) = din_hei_ptr[3]; - din_hei_ptr += w4; + din_hei_ptr += 4; } } } } return true; } - +// clang-format off #ifdef __aarch64__ #define NCHWC8_TRANS_FP32_COMPUTE \ "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ @@ -1161,6 +1742,48 @@ inline bool write_to_output_c4_fp32(const float* din, "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ +#define NCHWC8_TRANS_FP32_RELU6 \ + "fmin v16.4s, v16.4s, %[six].4s \n" /*relu6*/ \ + "fmin v17.4s, v17.4s, %[six].4s \n" /*relu6*/ \ + "fmin v18.4s, v18.4s, %[six].4s \n" /*relu6*/ \ + "fmin v19.4s, v19.4s, %[six].4s \n" /*relu6*/ \ + \ + "fmin v8.4s, v8.4s, %[six].4s \n" /*relu6*/ \ + "fmin v9.4s, v9.4s, %[six].4s \n" /*relu6*/ \ + "fmin v12.4s, v12.4s, %[six].4s \n" /*relu6*/ \ + "fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/ + +#define NCHWC8_TRANS_FP32_LEAKY_RELU \ + "cmhs v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ + \ + "cmhs v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \ + \ + "fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v27.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v28.4s, v19.4s, %[scale].4s \n" /* mul */ \ + \ + "fmul v29.4s, v8.4s, %[scale].4s \n" /* mul */ \ + "fmul v30.4s, v9.4s, %[scale].4s \n" /* mul */ \ + "fmul v31.4s, v12.4s, %[scale].4s \n" /* mul */ \ + \ + "bif v16.16b, v25.16b, v10.16b \n" /* choose*/ \ + "bif v17.16b, v26.16b, v11.16b \n" /* choose*/ \ + "bif v18.16b, v27.16b, v14.16b \n" /* choose*/ \ + "bif v19.16b, v28.16b, v15.16b \n" /* choose*/ \ + "fmul v25.4s, v13.4s, %[scale].4s \n" /* mul */ \ + \ + "bif v8.16b, v29.16b, v21.16b \n" /* choose*/ \ + "bif v9.16b, v30.16b, v22.16b \n" /* choose*/ \ + "bif v12.16b, v31.16b, v23.16b \n" /* choose*/ \ + "bif v13.16b, v25.16b, v24.16b \n" /* choose*/ + #define NCHWC8_TRANS_FP32_STORE \ "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ @@ -1174,6 +1797,7 @@ inline bool write_to_output_c4_fp32(const float* din, "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ \ \ "bne 1b \n" /* jump to main loop*/ + #else #define NCHWC8_TRANS_FP32_COMPUTE \ "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ @@ -1203,6 +1827,48 @@ inline bool write_to_output_c4_fp32(const float* din, "vmax.f32 q6, q6, q15 @ relu\n" \ "vmax.f32 q7, q7, q15 @ relu\n" +#define NCHWC8_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6\n" \ + "vmin.f32 q1, q1, %q[six] @ relu6\n" \ + "vmin.f32 q2, q2, %q[six] @ relu6\n" \ + "vmin.f32 q3, q3, %q[six] @ relu6\n" \ + \ + "vmin.f32 q4, q4, %q[six] @ relu6\n" \ + "vmin.f32 q5, q5, %q[six] @ relu6\n" \ + "vmin.f32 q6, q6, %q[six] @ relu6\n" \ + "vmin.f32 q7, q7, %q[six] @ relu6\n" + +#define NCHWC8_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q9, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q10, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q11, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q12, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q13, q0, %q[scale] \n" \ + "vmul.f32 q14, q1, %q[scale] \n" \ + "vmul.f32 q15, q2, %q[scale] \n" \ + \ + "vbif q0, q13, q9 @ choose \n" \ + "vmul.f32 q9, q3, %q[scale] \n" \ + \ + "vbif q1, q14, q10 @ choose \n" \ + "vbif q2, q15, q11 @ choose \n" \ + "vbif q3, q9, q12 @ choose \n" \ + \ + "vcge.f32 q9, q4, q15 @ q0 > 0 \n" \ + "vcge.f32 q10, q5, q15 @ q0 > 0 \n" \ + "vcge.f32 q11, q6, q15 @ q0 > 0 \n" \ + "vcge.f32 q12, q7, q15 @ q0 > 0 \n" \ + "vmul.f32 q13, q4, %q[scale] \n" \ + "vmul.f32 q14, q5, %q[scale] \n" \ + "vmul.f32 q15, q6, %q[scale] \n" \ + \ + "vbif q4, q13, q9 @ choose \n" \ + "vmul.f32 q9, q7, %q[scale] \n" \ + \ + "vbif q5, q14, q10 @ choose \n" \ + "vbif q6, q15, q11 @ choose \n" \ + "vbif q7, q9, q12 @ choose \n" + #define NCHWC8_TRANS_FP32_STORE \ "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ @@ -1232,84 +1898,23 @@ inline bool write_to_output_c4_fp32(const float* din, "bne 1b @ jump to main loop\n" #endif -/*wirte result in outputs -* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] -*/ -inline bool write_to_output_c8_fp32(const float* din, - float* dout, - int ch_n, - int hei_n, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - float* trash_ptr) { - if (ch_n != 8 || hei_n <= 0) { - LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; - return false; - } - int size_c_out = width * height; - - float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - float* doutc1r0 = doutc0r0 + size_c_out; - float* doutc2r0 = doutc1r0 + size_c_out; - float* doutc3r0 = doutc2r0 + size_c_out; - float* doutc4r0 = doutc3r0 + size_c_out; - float* doutc5r0 = doutc4r0 + size_c_out; - float* doutc6r0 = doutc5r0 + size_c_out; - float* doutc7r0 = doutc6r0 + size_c_out; - - const float* ptr_din = din; - - int size_h = (he > height ? height : he) - hs; // size_h == hei_n - - int valid_w = we - ws; - int cnt = valid_w / 4; - - if (we > width) { - cnt--; - } - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - float* doutc4_ptr = doutc4r0 + size_w; - float* doutc5_ptr = doutc5r0 + size_w; - float* doutc6_ptr = doutc6r0 + size_w; - float* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const float* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; +// clang-format on +inline void act_switch_c8_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + float* doutc2_ptr, + float* doutc3_ptr, + float* doutc4_ptr, + float* doutc5_ptr, + float* doutc6_ptr, + float* doutc7_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: #ifdef __aarch64__ asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU NCHWC8_TRANS_FP32_STORE @@ -1322,9 +1927,10 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) + [ptr_din] "+r"(din_ptr) : - : "v1", + : "v0", + "v1", "v2", "v3", "v4", @@ -1338,7 +1944,6 @@ inline bool write_to_output_c8_fp32(const float* din, "v12", "v13", "v14", - "v15", "v16", "v17", "v18", @@ -1355,66 +1960,17 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc5r0] "+r"(doutc5_ptr), [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) : - : "q0", "q1", "q2", "q3", "q4", "q15"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); #endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); - *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); - *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); - *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); - *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); - din_hei_ptr += 8; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - float* doutc4_ptr = doutc4r0 + size_w; - float* doutc5_ptr = doutc5r0 + size_w; - float* doutc6_ptr = doutc6r0 + size_w; - float* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const float* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ #ifdef __aarch64__ - asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU6 + NCHWC8_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), @@ -1424,8 +1980,8 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) : "v0", "v1", "v2", @@ -1441,14 +1997,29 @@ inline bool write_to_output_c8_fp32(const float* din, "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20"); #else - asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU + NCHWC8_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), @@ -1457,16 +2028,323 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc5r0] "+r"(doutc5_ptr), [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4"); + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); #endif + } +} + +/*wirte result in outputs +* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c8_fp32(const float* din, + float* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr, + operators::ActivationParam* act_param) { + if (ch_n != 8 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + float* doutc4r0 = doutc3r0 + size_c_out; + float* doutc5r0 = doutc4r0 + size_c_out; + float* doutc6r0 = doutc5r0 + size_c_out; + float* doutc7r0 = doutc6r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int w4 = 4; + int cnt = valid_w / 4; + + if (we > width) { + cnt--; + } + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + float* doutc4_ptr = doutc4r0 + size_w; + float* doutc5_ptr = doutc5r0 + size_w; + float* doutc6_ptr = doutc6r0 + size_w; + float* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; + } + ptr_din = din + i * valid_w * ch_n; + const float* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; + act_switch_c8_fp32(din_hei_ptr, + doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + doutc4_ptr, + doutc5_ptr, + doutc6_ptr, + doutc7_ptr, + cnt_loop, + act_param); + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; + doutc2_ptr += w4 * cnt; + doutc3_ptr += w4 * cnt; + doutc4_ptr += w4 * cnt; + doutc5_ptr += w4 * cnt; + doutc6_ptr += w4 * cnt; + doutc7_ptr += w4 * cnt; + int i = we - 4; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); + *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); + *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); + *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); + din_hei_ptr += 8; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; i < width; ++i) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + float tmp3 = LITEMAX(din_hei_ptr[2], 0.f); + float tmp4 = LITEMAX(din_hei_ptr[3], 0.f); + float tmp5 = LITEMAX(din_hei_ptr[4], 0.f); + float tmp6 = LITEMAX(din_hei_ptr[5], 0.f); + float tmp7 = LITEMAX(din_hei_ptr[6], 0.f); + float tmp8 = LITEMAX(din_hei_ptr[7], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + *(doutc2_ptr++) = LITEMIN(tmp3, six); + *(doutc3_ptr++) = LITEMIN(tmp4, six); + *(doutc4_ptr++) = LITEMIN(tmp5, six); + *(doutc5_ptr++) = LITEMIN(tmp6, six); + *(doutc6_ptr++) = LITEMIN(tmp7, six); + *(doutc7_ptr++) = LITEMIN(tmp8, six); + din_hei_ptr += 8; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; i < width; ++i) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + if (din_hei_ptr[2] >= 0) { + *(doutc2_ptr++) = din_hei_ptr[2]; + } else { + *(doutc2_ptr++) = din_hei_ptr[2] * scale; + } + if (din_hei_ptr[3] >= 0) { + *(doutc3_ptr++) = din_hei_ptr[3]; + } else { + *(doutc3_ptr++) = din_hei_ptr[3] * scale; + } + if (din_hei_ptr[4] >= 0) { + *(doutc4_ptr++) = din_hei_ptr[4]; + } else { + *(doutc4_ptr++) = din_hei_ptr[4] * scale; + } + if (din_hei_ptr[4] >= 0) { + *(doutc5_ptr++) = din_hei_ptr[5]; + } else { + *(doutc5_ptr++) = din_hei_ptr[5] * scale; + } + if (din_hei_ptr[6] >= 0) { + *(doutc6_ptr++) = din_hei_ptr[6]; + } else { + *(doutc6_ptr++) = din_hei_ptr[6] * scale; + } + if (din_hei_ptr[7] >= 0) { + *(doutc7_ptr++) = din_hei_ptr[7]; + } else { + *(doutc7_ptr++) = din_hei_ptr[7] * scale; + } + din_hei_ptr += 8; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { for (; i < width; ++i) { *(doutc0_ptr++) = din_hei_ptr[0]; *(doutc1_ptr++) = din_hei_ptr[1]; diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index b6c3478880d5cb59999d23ff03e2e342708ca95b..503dab29b6c4f0b9d3ff30a89060e473194216a9 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_3x3s2_depthwise_fp32(const float* i_data, @@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din, int pad, bool flag_bias, bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2_fp32(const float* din, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index dc68e65f42a799d7fa7e8be75f5afcf3166b1df3..642d1c2c1b964b9553e522d70a086531f1706420 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din, ARMContext* ctx, const float* scale) { auto paddings = *param.paddings; + auto act_param = param.activation_param; const int pad_h = paddings[0]; const int pad_w = paddings[2]; int stride = param.strides[1]; @@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din, pad, flag_bias, flag_relu, + act_param, ctx); } else { conv_3x3s1_depthwise_fp32(reinterpret_cast(din), @@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din, reinterpret_cast(weights), bias, param, + act_param, ctx); } diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index f4d00039aaa635d0ffb31846fd9ff9077ac0c621..60f74b7feecc91a2fe8262a1fea4dce26430031d 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -316,7 +316,9 @@ void fill_bias_int8(int* tensor, int channel_size); // new winograd -void weight_trans_c4( +void weight_trans_c4_8x8( + float* dest, const float* src, int ic, int oc, void* workspace); +void weight_trans_c4_4x4( float* dest, const float* src, int ic, int oc, void* workspace); void conv_compute_6x6_3x3(const float* input, float* output, @@ -331,6 +333,32 @@ void conv_compute_6x6_3x3(const float* input, const float* bias, const operators::ConvParam& param, ARMContext* ctx); +void conv_compute_2x2_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); +void conv_compute_2x2_3x3_small(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index a4c61f9a9d181924c28cdd009f8412278d44f5bb..186ad19735799dcb91641354af4b4f09692bfce9 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -557,6 +557,52 @@ void elementwise_mul(const float* dinx, } } +template <> +void elementwise_mul(const int* dinx, + const int* diny, + int* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const int* dinx_ptr = dinx + (i << 4); + const int* diny_ptr = diny + (i << 4); + int* dout_ptr = dout + (i << 4); + + int32x4_t dinx0 = vld1q_s32(dinx_ptr); + int32x4_t dinx1 = vld1q_s32(dinx_ptr + 4); + int32x4_t dinx2 = vld1q_s32(dinx_ptr + 8); + int32x4_t dinx3 = vld1q_s32(dinx_ptr + 12); + + int32x4_t diny0 = vld1q_s32(diny_ptr); + int32x4_t diny1 = vld1q_s32(diny_ptr + 4); + int32x4_t diny2 = vld1q_s32(diny_ptr + 8); + int32x4_t diny3 = vld1q_s32(diny_ptr + 12); + + dinx0 = vmulq_s32(dinx0, diny0); + dinx1 = vmulq_s32(dinx1, diny1); + dinx2 = vmulq_s32(dinx2, diny2); + dinx3 = vmulq_s32(dinx3, diny3); + + vst1q_s32(dout_ptr, dinx0); + vst1q_s32(dout_ptr + 4, dinx1); + vst1q_s32(dout_ptr + 8, dinx2); + vst1q_s32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const int* dinx_ptr = dinx + (cnt << 4); + const int* diny_ptr = diny + (cnt << 4); + int* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr * *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + template <> void elementwise_mul_relu(const float* dinx, const float* diny, @@ -678,6 +724,73 @@ void elementwise_mul_broadcast(const float* dinx, } } +template <> +void elementwise_mul_broadcast(const int* dinx, + const int* diny, + int* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const int* din_ptr = dinx + offset; + const int diny_data = diny[j]; + int* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + int32x4_t rb = vdupq_n_s32(diny_data); + for (int k = 0; k < cnt; ++k) { + int32x4_t din0 = vld1q_s32(din_ptr); + int32x4_t din1 = vld1q_s32(din_ptr + 4); + int32x4_t din2 = vld1q_s32(din_ptr + 8); + int32x4_t din3 = vld1q_s32(din_ptr + 12); + + din0 = vmulq_s32(din0, rb); + din1 = vmulq_s32(din1, rb); + din2 = vmulq_s32(din2, rb); + din3 = vmulq_s32(din3, rb); + + vst1q_s32(dout_ptr, din0); + vst1q_s32(dout_ptr + 4, din1); + vst1q_s32(dout_ptr + 8, din2); + vst1q_s32(dout_ptr + 12, din3); + + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + int32x4_t din0 = vld1q_s32(din_ptr); + int32x4_t din1 = vld1q_s32(din_ptr + 4); + din0 = vmulq_s32(din0, rb); + din1 = vmulq_s32(din1, rb); + vst1q_s32(dout_ptr, din0); + vst1q_s32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + int32x4_t din0 = vld1q_s32(din_ptr); + din0 = vmulq_s32(din0, rb); + vst1q_s32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + *dout_ptr = *din_ptr * diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + template <> void elementwise_mul_relu_broadcast(const float* dinx, const float* diny, diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 8977b5712c13dec0088d83db4cbfef8494785301..6fb64138221ea4ca4d70ddf04f53b5bd4cdf4a92 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -51,6 +51,7 @@ #include "lite/backends/arm/math/prior_box.h" #include "lite/backends/arm/math/reduce_max.h" #include "lite/backends/arm/math/reduce_mean.h" +#include "lite/backends/arm/math/reduce_prod.h" #include "lite/backends/arm/math/scale.h" #include "lite/backends/arm/math/sequence_expand.h" #include "lite/backends/arm/math/sequence_pool.h" @@ -61,6 +62,7 @@ #include "lite/backends/arm/math/slice.h" #include "lite/backends/arm/math/softmax.h" #include "lite/backends/arm/math/split.h" +#include "lite/backends/arm/math/split_merge_lod_tenosr.h" #include "lite/backends/arm/math/stack.h" #include "lite/backends/arm/math/topk.h" #include "lite/backends/arm/math/yolo_box.h" diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index e9e18043dfc09001ebba23f952a59474630e54aa..1c53142fc53bc785efcbf28fa007d403ad99ab70 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -477,17 +477,23 @@ void nearest_interp(const float* src, float scale_h_new = (with_align) ? (static_cast(h_in - 1) / (h_out - 1)) : (static_cast(h_in) / (h_out)); - -#pragma omp parallel for collapse(2) schedule(static) - for (int h = 0; h < h_out; ++h) { - for (int w = 0; w < w_out; ++w) { - int near_x = (with_align) ? static_cast(scale_w_new * w + 0.5) - : static_cast(scale_w_new * w); - int near_y = (with_align) ? static_cast(scale_h_new * h + 0.5) - : static_cast(scale_h_new * h); - near_x = near_x < 0 ? 0 : near_x; - near_y = near_y < 0 ? 0 : near_y; - dst[h * w_out + w] = src[near_y * w_in + near_x]; + if (with_align) { + for (int h = 0; h < h_out; ++h) { + float* dst_p = dst + h * w_out; + int near_y = static_cast(scale_h_new * h + 0.5); + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w + 0.5); + *dst_p++ = src[near_y * w_in + near_x]; + } + } + } else { + for (int h = 0; h < h_out; ++h) { + float* dst_p = dst + h * w_out; + int near_y = static_cast(scale_h_new * h); + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w); + *dst_p++ = src[near_y * w_in + near_x]; + } } } } @@ -520,9 +526,9 @@ void interpolate(lite::Tensor* X, } auto out_size = OutSize; if (out_size != nullptr) { - auto out_size_data = get_new_data_from_tensor(out_size); - out_height = static_cast(out_size_data[0]); - out_width = static_cast(out_size_data[1]); + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; } } float height_scale = scale; @@ -544,8 +550,10 @@ void interpolate(lite::Tensor* X, int out_w = Out->dims()[3]; int spatial_in = in_h * in_w; int spatial_out = out_h * out_w; - for (int i = 0; i < count; ++i) { - if ("Bilinear" == interpolate_type) { + + if ("Bilinear" == interpolate_type) { +#pragma omp parallel for + for (int i = 0; i < count; ++i) { bilinear_interp(din + spatial_in * i, in_w, in_h, @@ -555,7 +563,10 @@ void interpolate(lite::Tensor* X, 1.f / width_scale, 1.f / height_scale, with_align); - } else if ("Nearest" == interpolate_type) { + } + } else if ("Nearest" == interpolate_type) { +#pragma omp parallel for + for (int i = 0; i < count; ++i) { nearest_interp(din + spatial_in * i, in_w, in_h, diff --git a/lite/backends/arm/math/packed_sgemm_c4.cc b/lite/backends/arm/math/packed_sgemm_c4.cc index 8087e0337bda0866f5d399a07ecb674f0fa55a3e..af4934e85756f03ec197520b2b5c130e27bdcad6 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.cc +++ b/lite/backends/arm/math/packed_sgemm_c4.cc @@ -695,7 +695,6 @@ void sgemm_prepack_c4_common(int M, } } } - void sgemm_prepack_c4_small(int M, int N, int K, @@ -1146,6 +1145,540 @@ void sgemm_prepack_c4_small(int M, } } +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + const int mloop = m_round >> 2; + const int lda = 4 * k_round; + const int ldb_byte = 4 * N * sizeof(float); + const int kcnt = k_round >> 2; +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); +#endif + for (int m = 0; m < mloop; ++m) { + const float* b = B; + int n = N; +#ifdef __aarch64__ + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v16.4s, v1.s[0] \n" + "fmul v10.4s, v16.4s, v2.s[0] \n" + "fmul v11.4s, v16.4s, v3.s[0] \n" + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmul v12.4s, v16.4s, v4.s[0] \n" + "fmul v13.4s, v16.4s, v5.s[0] \n" + "fmul v14.4s, v16.4s, v6.s[0] \n" + "fmul v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmla v12.4s, v16.4s, v4.s[0] \n" + "fmla v13.4s, v16.4s, v5.s[0] \n" + "fmla v14.4s, v16.4s, v6.s[0] \n" + "fmla v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v16.4s, v1.s[0] \n" + "fmul v10.4s, v16.4s, v2.s[0] \n" + "fmul v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v8", "v9", + "v10", "v11", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v17.4s, v0.s[1] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "beq 2f \n" + "1:\n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v17.4s, v0.s[1] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "bne 1b \n" + "2:\n" + "fadd v8.4s, v8.4s, v9.4s \n" + "st1 {v8.4s}, [%[c]], #16 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v8", "v9", "v16", "v17", + "v18", "v19", "cc", "memory" + ); + b += 4; + } +#else + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vld1.32 {d0-d3}, [%[b]]! \n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q8, q4, d0[0] \n" + "vmul.f32 q9, q4, d2[0] \n" + "vmul.f32 q10, q4, d4[0] \n" + "vmul.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q12, q4, d0[0] \n" + "vmul.f32 q13, q4, d2[0] \n" + "vmul.f32 q14, q4, d4[0] \n" + "vmul.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "beq 2f \n" + "1:\n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q12, q4, d0[0] \n" + "vmla.f32 q13, q4, d2[0] \n" + "vmla.f32 q14, q4, d4[0] \n" + "vmla.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "bne 1b \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q8, q4, d0[0] \n" + "vmul.f32 q9, q4, d2[0] \n" + "vmul.f32 q10, q4, d4[0] \n" + "vmul.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]!\n" + "vst1.32 {d20-d23}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + "vmul.f32 q5, q1, d0[0] \n" + "vmul.f32 q6, q2, d0[1] \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "beq 2f \n" + "1:\n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + "vmla.f32 q5, q1, d0[0] \n" + "vmla.f32 q6, q2, d0[1] \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 1b \n" + "2:\n" + "vadd.f32 q5, q5, q6 \n" + "vst1.32 {d10-d11}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "cc", "memory" + ); + // clang-format on + b += 4; + } +#endif + A_packed += lda; + } +} + void sgemm_prepack_c4(int M, int N, int K, diff --git a/lite/backends/arm/math/packed_sgemm_c4.h b/lite/backends/arm/math/packed_sgemm_c4.h index 21e5af634315a7da66914bb04775088fec55550c..3229ff3e0774ce8bff02b12d79d7ec50ed873cea 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.h +++ b/lite/backends/arm/math/packed_sgemm_c4.h @@ -47,6 +47,13 @@ void sgemm_prepack_c4_small(int M, bool has_bias, bool has_relu, ARMContext* ctx); +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index 8524d7376f2bb7e337dfc11b890c00e281d2e880..9d42fd98df3ccec33457dd6d20ecb3b11684e04c 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -167,7 +167,7 @@ void pooling_basic(const float* din, "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" \ "fmax v6.4s, v4.4s, v5.4s \n" \ "subs %w[cnt], %w[cnt], #1 \n" \ - "fmax %w[vmax].4s, %w[vmax].4s, v6.4s \n" \ + "fmax %[vmax].4s, %[vmax].4s, v6.4s \n" \ "bne 1b \n" #define GLOBAL_AVG \ "1: \n" \ @@ -176,7 +176,7 @@ void pooling_basic(const float* din, "ld1 {v0.4s-v1.4s}, [%[data_in_channel]], #32 \n" \ "fadd %[vsum].4s, %[vsum].4s, v3.4s \n" \ "subs %w[cnt], %w[cnt], #1 \n" \ - "fadd %w[vsum].4s, %w[vsum].4s, v4.4s \n" \ + "fadd %[vsum].4s, %[vsum].4s, v4.4s \n" \ "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" \ "bne 1b \n" diff --git a/lite/backends/arm/math/reduce_prod.cc b/lite/backends/arm/math/reduce_prod.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7b3f7095f2087af365d0765f49df7902df42bb9 --- /dev/null +++ b/lite/backends/arm/math/reduce_prod.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/arm/math/reduce_prod.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math {} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/reduce_prod.h b/lite/backends/arm/math/reduce_prod.h new file mode 100644 index 0000000000000000000000000000000000000000..6c8898288fa498a6f97709a27306e6975dffc975 --- /dev/null +++ b/lite/backends/arm/math/reduce_prod.h @@ -0,0 +1,185 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void reduce_prod_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = static_cast(1); + for (int n = 0; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = static_cast(1); + for (int c = 0; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = static_cast(1); + for (int h = 0; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = static_cast(1); + for (int w = 0; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + auto* tmp_out = tensor_tmp.mutable_data(); + reduce_prod_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_prod_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} + +template +void reduce_prod_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + auto* tmp_out = tensor_tmp.mutable_data(); + reduce_prod_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_prod_h(tmp_out, dst, num_in, 1, height_in, width_in); +} + +template +void reduce_prod_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + auto* tmp_out = tensor_tmp.mutable_data(); + reduce_prod_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_prod_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} + +template +void reduce_prod_all(const T* src, T* dst, int64_t total_num) { + dst[0] = static_cast(1); + for (int n = 0; n < total_num; ++n) { + dst[0] *= src[n]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/slice.cc b/lite/backends/arm/math/slice.cc index 8b9a7690509260ed4c6c0e14750d849f657d2fa8..67ca567fea988acfc9e20e2bfc929e9c3a0bbcb8 100644 --- a/lite/backends/arm/math/slice.cc +++ b/lite/backends/arm/math/slice.cc @@ -86,6 +86,13 @@ template void slice(const int* input, std::vector ends, int* out, Context* ctx); +template void slice(const float* input, + std::vector dims, + std::vector axes, + std::vector starts, + std::vector ends, + float* out, + Context* ctx); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/split_merge_lod_tenosr.cc b/lite/backends/arm/math/split_merge_lod_tenosr.cc new file mode 100644 index 0000000000000000000000000000000000000000..35dc4a455b7c51e0aab1a45c48460ccc513b9a08 --- /dev/null +++ b/lite/backends/arm/math/split_merge_lod_tenosr.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/split_merge_lod_tenosr.h" +#include +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +using LoDAndOffset = std::pair>; +LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, + size_t start_idx, + size_t end_idx, + size_t start_level) { + LoD sub_lod; + for (size_t level_idx = start_level; level_idx < lod.size(); ++level_idx) { + CHECK(start_idx <= end_idx); + CHECK(end_idx < lod[level_idx].size()); + std::vector level_lens; + for (size_t i = start_idx; i < end_idx; ++i) { + level_lens.push_back(lod[level_idx][i + 1] - lod[level_idx][i]); + } + sub_lod.emplace_back(level_lens); + start_idx = lod[level_idx][start_idx]; + end_idx = lod[level_idx][end_idx]; + } + return LoDAndOffset{sub_lod, {start_idx, end_idx}}; +} + +void AppendLoD(LoD *lod, const LoD &lod_length) { + CHECK(lod->empty() || lod->size() == lod_length.size()); + if (lod->empty()) { + for (size_t i = 0; i < lod_length.size(); ++i) { + lod->emplace_back(std::vector({0})); + } + } + for (size_t i = 0; i < lod->size(); ++i) { + auto &level = (*lod)[i]; + for (auto len : lod_length[i]) { + level.push_back(level.back() + len); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/split_merge_lod_tenosr.h b/lite/backends/arm/math/split_merge_lod_tenosr.h new file mode 100644 index 0000000000000000000000000000000000000000..47c484aa4a203ed1819a7e810f71858f4ef0b4dd --- /dev/null +++ b/lite/backends/arm/math/split_merge_lod_tenosr.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +std::pair> GetSubLoDAndAbsoluteOffset( + const LoD &lod, size_t start_idx, size_t end_idx, size_t start_level); + +void AppendLoD(LoD *lod, const LoD &lod_length); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc index a4f33f467feb8626696595e95a29fde7b636919d..5dd53084f4079ae68c6fda0530fb5de8cf1d3717 100644 --- a/lite/backends/cuda/math/cudnn_conv.cc +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -89,9 +89,15 @@ bool CudnnConv2D::create(const operators::ConvParam& param, this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0)); } +#if CUDNN_VERSION_MIN(7, 0, 0) + cudnnMathType_t math_type = + use_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; + CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc_, math_type)); +#endif + if (ic == param.groups && ic == oc && ic != 1) { this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - } else if (1) { + } else if (!param.var_length) { const auto* i_data = param.x->data(); const auto* w_data = param.filter->data(); auto* o_data = param.output->mutable_data(TARGET(kCUDA)); diff --git a/lite/backends/cuda/math/gemm.h b/lite/backends/cuda/math/gemm.h index 12194d54b08a533a3812e10b5d2f78134c19da24..85576e65018a0e1bdec6f2bd2fdc590bd35e9656 100644 --- a/lite/backends/cuda/math/gemm.h +++ b/lite/backends/cuda/math/gemm.h @@ -55,6 +55,8 @@ class Gemm { PtypeOut* c, Context* ctx); + cublasHandle_t get_handle() const { return cu_handle_; } + private: cudaStream_t exe_stream_; cublasHandle_t cu_handle_; diff --git a/lite/backends/cuda/math/transpose.cu b/lite/backends/cuda/math/transpose.cu index cebcece812dc584d0921edea2fef8f129e430b56..c50840fe269657965db8c58b171fce6819009775 100644 --- a/lite/backends/cuda/math/transpose.cu +++ b/lite/backends/cuda/math/transpose.cu @@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N, const int W, const T* input, T* out, - CUDAContext* ctx) { + cudaStream_t* stream) { const int dh = (H + kTileDim - 1) / kTileDim; const int dw = (W + kTileDim - 1) / kTileDim; BatchTranspose2DCUDAKernel< - T><<exec_stream()>>>( + T><<>>( N, H, W, dh, dw, input, out); cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \ - template <> \ - void NCHW2NHWC(const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float) -TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t) -#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC - -#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \ - template <> \ - void NHWC2NCHW(const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float) -TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t) -#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW - template __global__ void TransposeCUDAKernel(const int size, const int ndim, @@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector& X_dims, const std::vector& axes, const T* X, T* Y, - CUDAContext* ctx) { + lite::Tensor* Y_dims_, + lite::Tensor* strides_, + cudaStream_t* stream) { CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal"; int ndim = X_dims.size(); std::vector strides(ndim, 0); @@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector& X_dims, size *= X_dims[i]; } - lite::Tensor Y_dims_, strides_; - Y_dims_.Resize(std::vector({ndim})); - int* d_y_dims = Y_dims_.mutable_data(TARGET(kCUDA)); - CopySync( - d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD); + Y_dims_->Resize(std::vector({ndim})); + int* d_y_dims = Y_dims_->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_y_dims, + Y_dims.data(), + sizeof(int) * Y_dims.size(), + IoDirection::HtoD, + *stream); - strides_.Resize(std::vector({ndim})); - int* d_strides = strides_.mutable_data(TARGET(kCUDA)); - CopySync(d_strides, - strides.data(), - sizeof(int) * strides.size(), - IoDirection::HtoD); + strides_->Resize(std::vector({ndim})); + int* d_strides = strides_->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_strides, + strides.data(), + sizeof(int) * strides.size(), + IoDirection::HtoD, + *stream); const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; - TransposeCUDAKernel<<exec_stream()>>>( + TransposeCUDAKernel<<>>( size, ndim, d_strides, d_y_dims, X, Y); auto e = cudaGetLastError(); CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); } -#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \ - template <> \ - void Transpose(const std::vector& X_dims, \ - const std::vector& axes, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - TransposeCUDAImpl(X_dims, axes, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_TRANSPOSE(float) -#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF +template +void Transpose::NCHW2NHWC( + int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) { + BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, stream); +} + +template +void Transpose::NHWC2NCHW( + int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) { + BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, stream); +} + +template +void Transpose::transpose(T* dst, + const T* src, + const std::vector& src_dims, + const std::vector& axes, + cudaStream_t* stream) { + TransposeCUDAImpl(src_dims, axes, src, dst, &Y_dims_, &strides_, stream); +} + +// template +// void Transpose::transpose(T* dst, +// const T* src, +// const std::vector& src_dims, +// const std::vector& axes, +// cudaStream_t* stream) { +// std::vector _src_dims(src_dims.size(), 0); +// std::transform( +// src_dims.begin(), +// src_dims.end(), +// _src_dims.begin(), +// [](int data) -> int64_t { return static_cast(data); }); +// TransposeCUDAImpl(_src_dims, axes, src, dst, &Y_dims_, &strides_, +// stream); +//} + +template class Transpose; +template class Transpose; } // namespace math } // namespace cuda diff --git a/lite/backends/cuda/math/transpose.h b/lite/backends/cuda/math/transpose.h index ba2464547b587f44cd9b0ce287a0d40d37d46411..ed52ba3b5590ab631c3c57a0472e16cb0ed51a91 100644 --- a/lite/backends/cuda/math/transpose.h +++ b/lite/backends/cuda/math/transpose.h @@ -26,17 +26,27 @@ namespace cuda { namespace math { template -void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); +class Transpose { + public: + void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream); -template -void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); + void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream); -template -void Transpose(const std::vector& X_dims, - const std::vector& axes, - const T* X, - T* Y, - CUDAContext* ctx); + void transpose(T* dst, + const T* src, + const std::vector& src_dims, + const std::vector& axes, + cudaStream_t* stream); + + // void transpose(T* dst, + // const T* src, + // const std::vector& src_dims, + // const std::vector& axes, + // cudaStream_t* stream); + + private: + lite::Tensor Y_dims_, strides_; // for transpose. +}; } // namespace math } // namespace cuda diff --git a/lite/backends/fpga/CMakeLists.txt b/lite/backends/fpga/CMakeLists.txt index b12fd85caf7e0c79de830b45569e02ba916c34e6..a5207c01a4d5e7b8d05490bd7c9be0dcc01f365e 100644 --- a/lite/backends/fpga/CMakeLists.txt +++ b/lite/backends/fpga/CMakeLists.txt @@ -3,13 +3,35 @@ if (NOT LITE_WITH_FPGA) endif() set(LITE_FPGA_KD_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga/KD") +set(LITE_FPGA_KD_LLAPI_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga/KD/llapi") +set(LITE_FPGA_KD_PE_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga/KD/pes") set(LITE_FPGA_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga") message("fpga_kd_path ${LITE_FPGA_KD_PATH}") message("fpga_path ${LITE_FPGA_PATH}") -file(GLOB_RECURSE KD_CPP *.cpp *.cc) +file(GLOB KD_CPP "${LITE_FPGA_KD_PATH}/*.cpp") +file(GLOB PE_CPP "${LITE_FPGA_KD_PE_PATH}/*.cpp") +file(GLOB LLAPI_CPP "${LITE_FPGA_KD_LLAPI_PATH}/*.cpp") file(GLOB FPGA_CPP "${LITE_FPGA_PATH}/*.cc") - -cc_library(kernel_fpga SRCS ${KD_CPP} ${FPGA_CPP}) +set(FPGA_ALL_CPP "") +FOREACH(FILE_PATH ${KD_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list(APPEND FPGA_ALL_CPP KD/${FILE_NAME}) +ENDFOREACH(FILE_PATH) +FOREACH(FILE_PATH ${PE_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list(APPEND FPGA_ALL_CPP KD/pes/${FILE_NAME}) +ENDFOREACH(FILE_PATH) +FOREACH(FILE_PATH ${LLAPI_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list(APPEND FPGA_ALL_CPP KD/llapi/${FILE_NAME}) +ENDFOREACH(FILE_PATH) +FOREACH(FILE_PATH ${FPGA_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list( APPEND FPGA_ALL_CPP ${FILE_NAME}) +ENDFOREACH(FILE_PATH) +message("fpga kd: ${FPGA_ALL_CPP}") +cc_library(kernel_fpga SRCS ${FPGA_ALL_CPP}) +#cc_library(kernel_fpga SRCS ${KD_CPP} ${FPGA_CPP}) cc_library(lite_tensor_fpga SRCS lite_tensor.cc DEPS memory) -cc_library(fpga_target_wrapper SRCS ${LITE_FPGA_PATH}/target_wrapper.cc DEPS kernel_fpga) +cc_library(fpga_target_wrapper SRCS target_wrapper.cc DEPS kernel_fpga) diff --git a/lite/backends/fpga/KD/debugger.hpp b/lite/backends/fpga/KD/debugger.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b9b23070616baf18f347c6b2af2d87a300d428f --- /dev/null +++ b/lite/backends/fpga/KD/debugger.hpp @@ -0,0 +1,140 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +#define FPGA_PRINT_TENSOR + +class Debugger { + public: + static Debugger& get_instance() { + static Debugger s_instance; + return s_instance; + } + + void registerOutput(std::string op_type, zynqmp::Tensor* tensor) { + if (op_type != "conv") { // NOLINT + } + } + + private: + std::unordered_map op_config; + Debugger() { + op_config["concat"] = true; + op_config["conv"] = true; + op_config["crop"] = true; + } +}; + +inline void chw_to_hwc(Tensor* t, float* dst) { + int num = t->dims()[0]; + int channel = t->dims()[1]; + + int height = 1; + int width = 1; + if (t->dims().size() > 2) { + height = t->dims()[2]; + } + if (t->dims().size() > 3) { + width = t->dims()[3]; + } + const float* chw_data = t->data(); + float* hwc_data = dst; + + int chw = channel * height * width; + int wc = width * channel; + int index = 0; + for (int n = 0; n < num; n++) { + for (int c = 0; c < channel; c++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + hwc_data[n * chw + h * wc + w * channel + c] = chw_data[index]; + index++; + } + } + } + } +} + +inline void read_from_file(lite::Tensor* t, const std::string& path) { + std::ifstream file_stream; + file_stream.open(path); + if (!file_stream) { + return; + } + float* data = t->mutable_data(); + int num = t->numel(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + data[i] = value; + } +} + +inline void save_float(float* data, const std::string& name, int len) { + static int counter = 0; + std::string old_string = std::to_string(counter); + std::string new_string = + std::string(3 - old_string.length(), '0') + old_string; + + std::string file = "arm_" + new_string + name; + counter++; + + std::ofstream ofs; + ofs.open(file); + for (int i = 0; i < len; i++) { + float value = data[i]; + ofs << value << std::endl; + } + ofs.close(); +} + +inline void save_tensor(lite::Tensor* t, + const std::string& name, + bool convert = true) { + float* data = const_cast(t->data()); + float* dst = new float[t->numel()]; + if (convert) { + chw_to_hwc(t, dst); + data = dst; + } + + save_float(data, name, t->numel()); + delete[] dst; +} + +inline void save_tensor(const lite::Tensor* t, + const std::string& name, + bool convert = true) { + float* data = const_cast(t->data()); + float* dst = new float[t->numel()]; + if (convert) { + chw_to_hwc(const_cast(t), dst); + data = dst; + } + + save_float(data, name, t->numel()); + + delete[] dst; +} +} // namespace lite +} // namespace paddle diff --git a/lite/backends/fpga/KD/dl_engine.cpp b/lite/backends/fpga/KD/dl_engine.cpp old mode 100644 new mode 100755 index 9849e4275b5d0f59346b9684530610853f1a560c..ea503518a0f39671e77157f14788a1cadb4579f3 --- a/lite/backends/fpga/KD/dl_engine.cpp +++ b/lite/backends/fpga/KD/dl_engine.cpp @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "lite/backends/fpga/KD/dl_engine.hpp" + namespace paddle { namespace zynqmp { DLEngine::DLEngine() { open_device(); - struct DeviceInfo info; - int ret = get_device_info(info); - filter::set_filter_capacity(info.filter_cap); + int ret = get_device_info(info_); + filter::set_filter_capacity(info_.filter_cap); + filter::set_colunm(info_.colunm); } } // namespace zynqmp diff --git a/lite/backends/fpga/KD/dl_engine.hpp b/lite/backends/fpga/KD/dl_engine.hpp old mode 100644 new mode 100755 index 829f41dfebfabfe5642bd4cf107fc6c54f3ffd86..eddf5ca454cdc9e91f87d6e4f2c8dfc13f35fdc6 --- a/lite/backends/fpga/KD/dl_engine.hpp +++ b/lite/backends/fpga/KD/dl_engine.hpp @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include - #include "lite/backends/fpga/KD/llapi/filter.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -29,8 +28,15 @@ class DLEngine { return s_instance; } + DeviceInfo& deviceInfo(); + + bool isZU3() { return info_.device_type / 100 == 3; } + + float* out_data = nullptr; + private: DLEngine(); + DeviceInfo info_; }; } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/layout.hpp b/lite/backends/fpga/KD/layout.hpp index 74819cd2120630def0114422b04efe076e1d6cb2..c6b5c911872b6b22633a4319ea708ed23c7e7e36 100644 --- a/lite/backends/fpga/KD/layout.hpp +++ b/lite/backends/fpga/KD/layout.hpp @@ -22,6 +22,7 @@ namespace paddle { namespace zynqmp { enum LayoutType { + None, N, NC, NCHW, @@ -39,6 +40,15 @@ class Layout { virtual int elementCount(const std::vector& dims) = 0; }; +struct None : Layout { + int numIndex() { return -1; } + int channelIndex() { return -1; } + int heightIndex() { return -1; } + int widthIndex() { return -1; } + int alignedElementCount(const std::vector& dims) { return 16; } + virtual int elementCount(const std::vector& dims) { return 1; } +}; + struct NCHW : Layout { int numIndex() { return 0; } int channelIndex() { return 1; } diff --git a/lite/backends/fpga/KD/llapi/bias_scale.cpp b/lite/backends/fpga/KD/llapi/bias_scale.cpp index cd60f27f9896e857f8ad566d285a9b9aea1d4721..339a442207e811be31161ff25f60a080572efe8d 100644 --- a/lite/backends/fpga/KD/llapi/bias_scale.cpp +++ b/lite/backends/fpga/KD/llapi/bias_scale.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include +#include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/llapi/bias_scale.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -54,7 +55,7 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) { *data_in = ptr_aligned; } -void interleave(float **data_in, int num_after_alignment) { +size_t interleave(float **data_in, int num_after_alignment) { float *ptr_uninterleaved = *data_in; float *ptr_interleaved = (float *)fpga_malloc(2 * num_after_alignment * sizeof(float)); // NOLINT @@ -69,6 +70,7 @@ void interleave(float **data_in, int num_after_alignment) { fpga_free(ptr_uninterleaved); *data_in = ptr_interleaved; + return 2 * num_after_alignment * sizeof(float); } void format_bias_scale_array(float **bias_scale_array, @@ -78,8 +80,9 @@ void format_bias_scale_array(float **bias_scale_array, int div_num = (num + element_num_per_division - 1) / element_num_per_division; int element_num_after_division = align_to_x(element_num_per_division, BS_NUM_ALIGNMENT); - interleave(bias_scale_array, div_num * element_num_after_division); - fpga_flush(*bias_scale_array, 2 * element_num_after_division * sizeof(float)); + size_t mem = + interleave(bias_scale_array, div_num * element_num_after_division); + fpga_flush(*bias_scale_array, mem); } void format_bias_array(float **bias_array, int num) { float *ptr_unaligned = *bias_array; diff --git a/lite/backends/fpga/KD/llapi/bias_scale.h b/lite/backends/fpga/KD/llapi/bias_scale.h index 83f30df18fc7e5967d727ed8ce275d63e1cb29e0..d47d082ccdc6b41cf43860495e43076c17b13ac3 100644 --- a/lite/backends/fpga/KD/llapi/bias_scale.h +++ b/lite/backends/fpga/KD/llapi/bias_scale.h @@ -19,7 +19,7 @@ namespace zynqmp { namespace bias_scale { void align_element(float** data_in, int num_per_div_before_alignment, int num); -void interleave(float** data_in, int num_after_alignment); +size_t interleave(float** data_in, int num_after_alignment); void format_bias_scale_array(float** bias_scale_array, int element_num_per_division, int num); diff --git a/lite/backends/fpga/KD/llapi/filter.cpp b/lite/backends/fpga/KD/llapi/filter.cpp index 0e41a204a854b0b57e1a8c98fb3cc8d5224c807c..30250969b6fbe6e9e5ce7e9f96f963e8bee89224 100644 --- a/lite/backends/fpga/KD/llapi/filter.cpp +++ b/lite/backends/fpga/KD/llapi/filter.cpp @@ -15,6 +15,8 @@ limitations under the License. */ #include "lite/backends/fpga/KD/llapi/filter.h" #include #include +#include +#include #include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -23,11 +25,41 @@ namespace zynqmp { namespace filter { static int FILTER_SIZE = 2048; +static int COLUMN = 4; + +void saveToFile(std::string name, void* data_in, int size) { + std::ofstream ofs; + ofs.open(name); + + int8_t* data = static_cast data_in; + for (int i = 0; i < size; i++) { + float value = data[i]; + ofs << value << std::endl; + } + ofs.close(); +} + +void saveFloatToFile(std::string name, float* data_in, int size) { + std::ofstream ofs; + ofs.open(name); + + for (int i = 0; i < size; i++) { + float value = data_in[i]; + ofs << value << std::endl; + } + ofs.close(); +} void set_filter_capacity(uint32_t cap) { FILTER_SIZE = cap; } +void set_colunm(uint32_t column) { COLUMN = column; } + +// replace zynqmp_api.h #define FILTER_NUM_ALIGNMENT +int get_filter_num_alignment() { return COLUMN * 4; } + int calc_division_capacity(int chw) { - int n = FILTER_SIZE / ((chw + 15) / 16) * 32; + int filter_num_alignment = get_filter_num_alignment(); + int n = FILTER_SIZE / ((chw + 15) / 16) * filter_num_alignment; return n < FILTER_SIZE ? n : FILTER_SIZE; } @@ -52,28 +84,28 @@ int calc_num_per_div(int num, int group_num, int division_capacity) { } } -void convert_to_hwc( - char **data_in, int num, int channel, int height, int width) { - char *tmp = *data_in; +void convert_to_hwc(int8_t* chw_data, + int8_t* hwc_data, + int num, + int channel, + int height, + int width) { int chw = channel * height * width; - char *data_tmp = (char *)fpga_malloc(chw * num * sizeof(char)); // NOLINT + int wc = width * channel; + int index = 0; for (int n = 0; n < num; n++) { - int64_t amount_per_row = width * channel; for (int c = 0; c < channel; c++) { for (int h = 0; h < height; h++) { - int64_t offset_height = h * amount_per_row; for (int w = 0; w < width; w++) { - *(data_tmp + n * chw + offset_height + w * channel + c) = - *((*data_in)++); + hwc_data[n * chw + h * wc + w * channel + c] = chw_data[index]; + index++; } } } } - *data_in = data_tmp; - fpga_free(tmp); } -float find_max(float *data_in, int data_size) { +float find_max(float* data_in, int data_size) { float max = 0.0; for (int i = 0; i < data_size; ++i) { float value = data_in[i]; @@ -83,166 +115,178 @@ float find_max(float *data_in, int data_size) { return max; } -signed char float_to_int8(float fdata) { +int8_t float_to_int8(float fdata) { if (fdata < 0.0) { fdata -= 0.5; } else { fdata += 0.5; } - return (signed char)fdata; + return (int8_t)fdata; } -void quantize(float **data_in, int data_size, float max) { - float *tmp = *data_in; +void quantize(float* src, int8_t* dst, int len, float max) { float fix_range = 127; float scale = fix_range / max; - - signed char *tmp_data = (signed char *)fpga_malloc(data_size * sizeof(char)); - for (int i = 0; i < data_size; i++) { - tmp_data[i] = float_to_int8( - (*data_in)[i] * scale); // (signed char)((*data_in)[i] * scale); + for (size_t i = 0; i < len; i++) { + dst[i] = float_to_int8(src[i] * scale); } - *data_in = (float *)tmp_data; // NOLINT - fpga_free(tmp); } -void align_element(char **data_in, int num, int chw) { - int j = 0; +bool should_align_chw(int chw) { int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); - if (align_chw != chw) { - char *tmp = *data_in; - char *data_tmp = - (char *)fpga_malloc(num * align_chw * sizeof(char)); // NOLINT - - memset(data_tmp, 0, num * align_chw); - for (j = 0; j < num; j++) { - memcpy(data_tmp + j * align_chw, (*data_in) + j * chw, chw); - } - *data_in = data_tmp; - fpga_free(tmp); + return align_chw != chw; +} + +void align_chw(int8_t* src, int8_t* dst, int num, int chw) { + int aligned_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + memset(dst, 0, num * aligned_chw); + for (int j = 0; j < num; j++) { + memcpy((dst + j * aligned_chw), (src + j * chw), chw); } } -void align_num(char **data_in, +void align_num(int8_t* src, + int8_t* dst, int num_per_div_before_alignment, int num, - int chw) { - int i = 0; - int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + int align_chw) { + int filter_num_alignment = get_filter_num_alignment(); int num_per_div_after_alignment = - align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); + align_to_x(num_per_div_before_alignment, filter_num_alignment); - char *tmp = *data_in; int div_num = (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; int num_element = div_num * num_per_div_after_alignment * align_chw; - char *data_tmp = (char *)fpga_malloc(num_element * sizeof(char)); // NOLINT - - memset(data_tmp, 0, num_element * sizeof(char)); + memset(dst, 0, num_element * sizeof(int8_t)); + int i = 0; for (i = 0; i < div_num - 1; i++) { - memcpy(data_tmp + num_per_div_after_alignment * align_chw * i, - *data_in + num_per_div_before_alignment * align_chw * i, + memcpy(dst + num_per_div_after_alignment * align_chw * i, + src + num_per_div_before_alignment * align_chw * i, num_per_div_before_alignment * align_chw); } - memcpy(data_tmp + num_per_div_after_alignment * align_chw * i, - *data_in + num_per_div_before_alignment * align_chw * i, + memcpy(dst + num_per_div_after_alignment * align_chw * i, + src + num_per_div_before_alignment * align_chw * i, (num - (div_num - 1) * num_per_div_before_alignment) * align_chw); - - *data_in = data_tmp; - fpga_free(tmp); } -void reorder(char **data_in, int num_after_alignment, int chw) { +void reorder(int8_t* src, int8_t* dst, int num_after_alignment, int chw) { int index = 0; int new_index = 0; - + int filter_num_alignment = get_filter_num_alignment(); int chw_align = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); - - char *data_tmp = - (char *)fpga_malloc(chw_align * num_after_alignment * // NOLINT - sizeof(char)); - char *tmp = *data_in; for (index = 0; index < num_after_alignment; index++) { - new_index = index / 32 * 32 + (index % 16 / 4 * 8) + (index % 16 % 4) + - (index / 16 % 2 * 4); - memcpy(data_tmp + index * chw_align, - *data_in + new_index * chw_align, - chw_align); + new_index = index / filter_num_alignment * filter_num_alignment + + (index % (filter_num_alignment / 2) / 4 * 8) + + (index % (filter_num_alignment / 2) % 4) + + (index / (filter_num_alignment / 2) % 2 * 4); + memcpy((dst + index * chw_align), (src + new_index * chw_align), chw_align); } - *data_in = data_tmp; - fpga_free(tmp); } -size_t interleave(char **data_in, int num_after_alignment, int chw) { - int i = 0; - int j = 0; - int k = 0; +void interleave(int8_t* src, int8_t* dst, int num_after_alignment, int chw) { int interleave_per_num = 16; - int chw_align = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); - char *data_tmp = - (char *)fpga_malloc(chw_align * num_after_alignment * // NOLINT - sizeof(char)); - char *tmp = *data_in; int interleave_num = chw_align * 2 / interleave_per_num; - for (i = 0; i < num_after_alignment; i += 2) { - for (j = 0, k = 0; j < interleave_num; j += 2, k++) { - memcpy(data_tmp + i * chw_align + interleave_per_num * j, - *data_in + i * chw_align + interleave_per_num * k, + for (int i = 0; i < num_after_alignment; i += 2) { + for (int j = 0, k = 0; j < interleave_num; j += 2, k++) { + memcpy(dst + i * chw_align + interleave_per_num * j, + src + i * chw_align + interleave_per_num * k, interleave_per_num); - memcpy(data_tmp + i * chw_align + interleave_per_num * (j + 1), - *data_in + (i + 1) * chw_align + interleave_per_num * k, + memcpy(dst + i * chw_align + interleave_per_num * (j + 1), + src + (i + 1) * chw_align + interleave_per_num * k, interleave_per_num); } } - *data_in = data_tmp; - fpga_free(tmp); - return chw_align * num_after_alignment; } -size_t format_filter(float **data_in, - int num, - int channel, - int height, - int width, - int group_num, - float max) { +int8_t* format_filter(float* data_in, + int& mem_size_a, // NOLINT + int num, + int channel, + int height, + int width, + int group_num, + float max, + std::vector& filter_max) { // NOLINT int data_size = channel * height * width * num; int chw = channel * height * width; int division_capacity = calc_division_capacity(chw); + int filter_num_alignment = get_filter_num_alignment(); int num_per_div_before_alignment = calc_num_per_div(num, group_num, division_capacity); int num_per_div_after_alignment = - align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); + align_to_x(num_per_div_before_alignment, filter_num_alignment); int div_num = (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; int residual = num % num_per_div_before_alignment; int num_after_alignment = num_per_div_after_alignment * ((residual == 0) ? div_num : (div_num - 1)) + - align_to_x(residual, FILTER_NUM_ALIGNMENT); - quantize(data_in, data_size, max); - char **quantize_data = (char **)data_in; // NOLINT - convert_to_hwc(quantize_data, num, channel, height, width); - align_element(quantize_data, num, chw); - if (num_after_alignment != num) { - align_num(quantize_data, num_per_div_before_alignment, num, chw); + align_to_x(residual, filter_num_alignment); + + int8_t* quantized_data = + reinterpret_cast(fpga_malloc(data_size * sizeof(int8_t))); + + for (int n = 0; n < num; n++) { + float* filter_start = data_in + n * chw; + float f_max = find_max(filter_start, chw); + int8_t* quantized_start = quantized_data + n * chw; + quantize(filter_start, quantized_start, chw, max); + filter_max.push_back(max); } - reorder(quantize_data, num_after_alignment, chw); - size_t mem_size = interleave(quantize_data, num_after_alignment, chw); - fpga_flush(*quantize_data, + int8_t* hwc_data = + reinterpret_cast(fpga_malloc(data_size * sizeof(int8_t))); + convert_to_hwc(quantized_data, hwc_data, num, channel, height, width); + fpga_free(quantized_data); + + int8_t* temp_data = hwc_data; // NOLINT + int chw_aligned = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + if (should_align_chw(chw)) { + int8_t* hwc_aligned_data = reinterpret_cast( + fpga_malloc(num * chw_aligned * sizeof(int8_t))); + align_chw(hwc_data, hwc_aligned_data, num, chw); + + temp_data = hwc_aligned_data; + fpga_free(hwc_data); + } + if (num_after_alignment != num) { + int filter_num_alignment = get_filter_num_alignment(); + int num_per_div_after_alignment = + align_to_x(num_per_div_before_alignment, filter_num_alignment); + int num_element = div_num * num_per_div_after_alignment * chw_aligned; + int8_t* num_aligned_data = + reinterpret_cast(fpga_malloc(num_element * sizeof(int8_t))); + align_num(temp_data, + num_aligned_data, + num_per_div_before_alignment, + num, + chw_aligned); + + fpga_free(temp_data); + temp_data = num_aligned_data; + } + int8_t* aligned_data = + reinterpret_cast(fpga_malloc(num_after_alignment * chw_aligned)); + reorder(temp_data, aligned_data, num_after_alignment, chw); + fpga_free(temp_data); + int8_t* interleaved_data = + reinterpret_cast(fpga_malloc(num_after_alignment * chw_aligned)); + interleave(aligned_data, interleaved_data, num_after_alignment, chw); + fpga_free(aligned_data); + fpga_flush(interleaved_data, align_to_x(chw, FILTER_ELEMENT_ALIGNMENT) * num_after_alignment * sizeof(char)); - return mem_size; + mem_size_a = num_after_alignment * chw_aligned; + return interleaved_data; } -void convert_to_hwn(int16_t **data_in, int num, int height, int width) { - int16_t *tmp = *data_in; - int16_t *data_tmp = - (int16_t *)fpga_malloc(height * width * num * sizeof(int16_t)); // NOLINT +void convert_to_hwn(int16_t** data_in, int num, int height, int width) { + int16_t* tmp = *data_in; + int16_t* data_tmp = + (int16_t*)fpga_malloc(height * width * num * sizeof(int16_t)); // NOLINT for (int n = 0; n < num; n++) { for (int h = 0; h < height; h++) { for (int w = 0; w < width; w++) { @@ -254,16 +298,16 @@ void convert_to_hwn(int16_t **data_in, int num, int height, int width) { fpga_free(tmp); } -size_t align_element_n(int16_t **data_in, int num, int height, int width) { +size_t align_element_n(int16_t** data_in, int num, int height, int width) { int unalign_n = num; int align_n = align_to_x(num, FILTER_ELEMENT_ALIGNMENT); int num_element = height * width * align_n; if (unalign_n != align_n) { - int16_t *tmp = *data_in; + int16_t* tmp = *data_in; int num_element = height * width * align_n; - int16_t *data_tmp = - (int16_t *)fpga_malloc(num_element * sizeof(int16_t)); // NOLINT + int16_t* data_tmp = + (int16_t*)fpga_malloc(num_element * sizeof(int16_t)); // NOLINT memset(data_tmp, 0, num_element * sizeof(int16_t)); for (int h = 0; h < height; h++) { @@ -276,17 +320,37 @@ size_t align_element_n(int16_t **data_in, int num, int height, int width) { } } *data_in = data_tmp; - free(tmp); + fpga_free(tmp); } return num_element * sizeof(int16_t); } +void to_fp16(float* src, + float16* dst, + int num, + int height, + int width, + float* scale_ptr) { + int size = num * height * width; + for (int n = 0; n < num; n++) { + float scale_val = scale_ptr[n]; + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + int index = n * height * width + h * width + w; + float value = src[index] * scale_val; + dst[index] = float_to_half(value); + } + } + } + fpga_flush(dst, size * sizeof(int16_t)); +} + void quantize_to_fp16( - float **data_in, int num, int height, int width, float *scale_ptr) { - float *tmp = *data_in; + float** data_in, int num, int height, int width, float* scale_ptr) { + float* tmp = *data_in; int size = num * height * width; - float16 *tmp_data = (float16 *)fpga_malloc(size * sizeof(float16)); // NOLINT + float16* tmp_data = (float16*)fpga_malloc(size * sizeof(float16)); // NOLINT for (int n = 0; n < num; n++) { float scale_val = scale_ptr[n]; for (int h = 0; h < height; h++) { @@ -298,13 +362,14 @@ void quantize_to_fp16( } } fpga_flush(tmp_data, size * sizeof(int16_t)); - *data_in = (float *)tmp_data; // NOLINT + *data_in = (float*)tmp_data; // NOLINT fpga_free(tmp); } size_t format_dwconv_filter( - float **data_in, int num, int height, int width, float *scale_ptr) { + float** data_in, int num, int height, int width, float* scale_ptr) { quantize_to_fp16(data_in, num, height, width, scale_ptr); - int16_t **quantize_data = (int16_t **)data_in; // NOLINT + int16_t** quantize_data = reinterpret_cast(data_in); + convert_to_hwn(quantize_data, num, height, width); size_t size = align_element_n(quantize_data, num, height, width); fpga_flush(*quantize_data, diff --git a/lite/backends/fpga/KD/llapi/filter.h b/lite/backends/fpga/KD/llapi/filter.h index 7d9c6c2e015250cbcba2d1dba71b7c1f3554d9f0..6e056ce0da0d8e731abf7dc418800a8e3d94969a 100644 --- a/lite/backends/fpga/KD/llapi/filter.h +++ b/lite/backends/fpga/KD/llapi/filter.h @@ -18,38 +18,33 @@ limitations under the License. */ #include #include +#include + namespace paddle { namespace zynqmp { namespace filter { void set_filter_capacity(uint32_t cap); +void set_colunm(uint32_t column); +int get_filter_num_alignment(); int calc_division_capacity(int chw); int calc_split_num(int num, int division_capacity); int calc_division_number(int num, int group_num, int division_capacity); int calc_num_per_div(int num, int group_num, int division_capacity); -void convert_to_hwc( - char** data_in, int num, int channel, int height, int width); + float find_max(float* data_in, int data_size); -void quantize(float** data_in, int data_size, float max); -void align_element(char** data_in, int num, int chw); -void align_num(char** data_in, - int num_per_div_before_alignment, - int num, - int chw); -void reorder(char** data_in, int num_after_alignment, int chw); -size_t interleave(char** data_in, int num_after_alignment, int chw); -size_t format_filter(float** data_in, - int num, - int channel, - int height, - int width, - int group_num, - float max); +int8_t* format_filter(float* data_in, + int& mem_size, // NOLINT + int num, + int channel, + int height, + int width, + int group_num, + float max, // NOLINT + std::vector& filter_max); // NOLINT void convert_to_hwn(int16_t** data_in, int num, int height, int width); size_t align_element_n(int16_t** data_in, int num, int height, int width); -void quantize_to_fp16( - float** data_in, int num, int height, int width, float* scale_ptr); size_t format_dwconv_filter( float** data_in, int num, int height, int width, float* scale_ptr); diff --git a/lite/backends/fpga/KD/llapi/zynqmp_api.cpp b/lite/backends/fpga/KD/llapi/zynqmp_api.cpp old mode 100644 new mode 100755 index 1f1226ead3d4e9b50100f4de574104a5d6f777b2..06488469d97c077a34b3cfdb8a049c8cd61dfc93 --- a/lite/backends/fpga/KD/llapi/zynqmp_api.cpp +++ b/lite/backends/fpga/KD/llapi/zynqmp_api.cpp @@ -23,13 +23,12 @@ limitations under the License. */ #include #include -#include "lite/backends/fpga/KD/llapi/config.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" namespace paddle { namespace zynqmp { -#define PADDLE_LITE_OS_LINUX +#define PADDLE_OS_LINUX static int fd = -1; static const char *device_path = "/dev/fpgadrv0"; @@ -39,14 +38,10 @@ static size_t memory_size_max = 0; static size_t memory_size = 0; static inline int do_ioctl(uint64_t req, const void *arg) { - int ret = -1; -#ifdef PADDLE_LITE_OS_LINUX - ret = ioctl(fd, req, arg); - if (ret != 0) { - throw - 1; - } +#ifdef PADDLE_OS_LINUX + return ioctl(fd, req, arg); #else - return ret; + return -1; #endif } @@ -66,7 +61,9 @@ void reset_device() { // memory management; void *fpga_malloc(size_t size) { -#ifdef PADDLE_LITE_OS_LINUX +#ifdef ENABLE_DEBUG +#endif +#ifdef PADDLE_OS_LINUX void *ptr = reinterpret_cast( mmap64(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0)); if (ptr == NULL) { @@ -105,11 +102,8 @@ void fpga_free(void *ptr) { size = iter->second; memory_map.erase(iter); } - memory_size -= size; - -#ifdef PADDLE_LITE_OS_LINUX - +#ifdef PADDLE_OS_LINUX munmap(ptr, size); #else free(ptr); @@ -150,6 +144,11 @@ void fpga_copy(void *dest, const void *src, size_t num) { memcpy(dest, src, num); } +int fpga_reset() { + struct FpgaResetArgs args; + return do_ioctl(IOCTL_FPGA_RESET, &args); +} + int ioctl_conv(const struct ConvArgs &args) { return do_ioctl(IOCTL_CONFIG_CONV, &args); } @@ -166,7 +165,6 @@ int compute_fpga_conv(const struct SplitConvArgs &args) { } if (split_num > 1) { - std::cout << "Split num > 1 !!!!!!!!!!!!!!!!!!" << std::endl; exit(-1); } return ret; @@ -186,6 +184,7 @@ int get_device_info(const struct DeviceInfo &args) { } int perform_bypass(const struct BypassArgs &args) { + int ret = -1; int size = args.image.channels * args.image.width * args.image.height; int max_size = 1 << 21; @@ -213,7 +212,7 @@ int perform_bypass(const struct BypassArgs &args) { reinterpret_cast(input_address + i * max_size * type_size); bypassArgs.output.address = reinterpret_cast(output_address + i * max_size * out_type_size); - int ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); + ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); scale = std::max(scale, scales[0]); if (ret != 0) { @@ -222,13 +221,15 @@ int perform_bypass(const struct BypassArgs &args) { } int remainder = size - max_size * count; - bypassArgs.image.channels = remainder; - bypassArgs.image.address = - reinterpret_cast(input_address + count * max_size * type_size); - bypassArgs.output.address = reinterpret_cast( - output_address + count * max_size * out_type_size); - int ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); - scale = std::max(scale, scales[0]); + if (remainder > 0) { + bypassArgs.image.channels = remainder; + bypassArgs.image.address = + reinterpret_cast(input_address + count * max_size * type_size); + bypassArgs.output.address = reinterpret_cast( + output_address + count * max_size * out_type_size); + ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); + scale = std::max(scale, scales[0]); + } args.output.scale_address[0] = scale; args.output.scale_address[1] = 1.0f / scale; return ret; @@ -237,52 +238,17 @@ int perform_bypass(const struct BypassArgs &args) { int compute_fpga_concat(const struct ConcatArgs &args) { return -1; } int compute_fpga_scale(const struct ScaleArgs &args) { -#ifdef ENABLE_DEBUG - std::cout << "======Compute Scale======"; - std::cout << "scale_address:" << args.scale_address << std::endl; - std::cout << "bias_address:" << args.bias_address << std::endl; - - std::cout << "wc_alignment:" << args.wc_alignment << std::endl; - std::cout << "channel_alignment:" << args.channel_alignment << std::endl; - - std::cout << " image_address:" << args.image.address - << " image_scale_address:" << args.image.scale_address - << " image_channels:" << args.image.channels - << " image_height:" << args.image.height - << " image_width:" << args.image.width - << " pad_height:" << args.image.pad_height - << " pad_width:" << args.image.pad_width; - - std::cout << " out_address:" << args.output.address - << " out_scale_address:" << args.output.scale_address; - -#endif return do_ioctl(IOCTL_CONFIG_SCALE, &args); } int compute_fpga_dwconv(const struct DWconvArgs &args) { -#ifdef ENABLE_DEBUG - std::cout << "======Compute Basic Conv======"; - std::cout << " relu_enabled:" << args.relu_enabled - << " filter_address:" << args.filter_address; - std::cout << " image_address:" << args.image.address - << " image_scale_address:" << args.image.scale_address - << " image_channels:" << args.image.channels - << " image_height:" << args.image.height - << " image_width:" << args.image.width - << " pad_height:" << args.image.pad_height - << " pad_width:" << args.image.pad_width; - std::cout << " kernel_height:" << args.kernel.height - << " kernel_width:" << args.kernel.width - << " stride_h:" << args.kernel.stride_h - << " stride_w:" << args.kernel.stride_w; - std::cout << " out_address:" << args.output.address - << " out_scale_address:" << args.output.scale_address; - -#endif return do_ioctl(IOCTL_CONFIG_DWCONV, &args); } +int config_activation(const struct ActiveParamterArgs &args) { + return do_ioctl(IOCTL_CONFIG_ACTIVATION_PARAMETER, &args); +} + int config_inplace(const struct InplaceArgs &args) { return do_ioctl(IOCTL_CONFIG_INPLACE, &args); } diff --git a/lite/backends/fpga/KD/llapi/zynqmp_api.h b/lite/backends/fpga/KD/llapi/zynqmp_api.h old mode 100644 new mode 100755 index 7d22de95a2272862c6fe781295bdaab7177a92fe..9489c24730e52fb778ed341e0ce452b7ef86edf9 --- a/lite/backends/fpga/KD/llapi/zynqmp_api.h +++ b/lite/backends/fpga/KD/llapi/zynqmp_api.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#ifndef PADDLE_LITE_SRC_FPGA_KD_ZYNQMP_API_H +#define PADDLE_LITE_SRC_FPGA_KD_ZYNQMP_API_H + #include #include #include @@ -40,6 +43,13 @@ enum DLayoutType { LAYOUT_HWC = 0, }; +enum ActiveType { + TYPE_RELU = 0, + TYPE_RELU6 = 1, + TYPE_LEAK_RELU = 2, + TYPE_SIGMOID = 3, +}; + struct VersionArgs { void* buffer; }; @@ -48,7 +58,7 @@ struct DeviceInfo { uint32_t filter_cap; uint32_t version; uint16_t device_type; - uint32_t reserved0; + uint32_t colunm; uint32_t reserved1; uint32_t reserved2; uint32_t reserved3; @@ -108,6 +118,7 @@ struct ConvArgs { void* filter_scale_address; uint32_t filter_num; uint32_t group_num; + uint32_t dilation; struct KernelArgs kernel; struct ImageInputArgs image; // input image; @@ -199,9 +210,16 @@ struct NormalizeParameterArgs { uint32_t hight_width; }; +struct ActiveParamterArgs { + ActiveType type; + uint16_t leaky_relu_factor; +}; + struct InplaceArgs { bool leaky_relu_enable; bool relu_enable; + bool sigmoid_enable; + bool relu6_enable; bool power_enable; bool normalize_enable; }; @@ -216,7 +234,9 @@ struct FpgaRegReadArgs { uint64_t value; }; -struct FpgaResetArgs {}; +struct FpgaResetArgs { + uint32_t val; +}; #define IOCTL_FPGA_MAGIC (('F' + 'P' + 'G' + 'A') / 4) @@ -248,6 +268,8 @@ struct FpgaResetArgs {}; _IOW(IOCTL_FPGA_MAGIC, 41, struct PowerParameterArgs) #define IOCTL_CONFIG_NORMALIZE_PARAMETER \ _IOW(IOCTL_FPGA_MAGIC, 42, struct NormalizeParameterArgs) +#define IOCTL_CONFIG_ACTIVATION_PARAMETER \ + _IOW(IOCTL_FPGA_MAGIC, 43, struct ActiveParamterArgs) #define IOCTL_FPGA_REG_READ _IOW(IOCTL_FPGA_MAGIC, 50, struct FpgaRegReadArgs) #define IOCTL_FPGA_REG_WRITE _IOW(IOCTL_FPGA_MAGIC, 51, struct FpgaRegWriteArgs) #define IOCTL_FPGA_RESET _IOW(IOCTL_FPGA_MAGIC, 52, struct FpgaResetArgs) @@ -331,6 +353,7 @@ int compute_fpga_scale(const struct ScaleArgs& args); int compute_fpga_concat(const struct ConcatArgs& args); int compute_fpga_resize(const struct ResizeArgs& args); +int config_activation(const struct ActiveParamterArgs& args); int config_power(const struct PowerArgs& args); int compute_fpga_dwconv(const struct DWconvArgs& args); int config_norm_param(const struct NormalizeParameterArgs& args); @@ -341,7 +364,11 @@ int config_inplace(const struct InplaceArgs& args); int flush_cache(void* addr, int size); int invalidate_cache(void* addr, int size); +int fpga_reset(); + int16_t fp32_2_fp16(float fp32_num); float fp16_2_fp32(int16_t fp16_num); } // namespace zynqmp } // namespace paddle + +#endif // PADDLE_LITE_SRC_FPGA_KD_ZYNQMP_API_H diff --git a/lite/backends/fpga/KD/pe.hpp b/lite/backends/fpga/KD/pe.hpp index d1dc3c4caa18cbfeba74fac26cca9e19230e2c21..2796124341012574dc719ae9f30633d1d9524680 100644 --- a/lite/backends/fpga/KD/pe.hpp +++ b/lite/backends/fpga/KD/pe.hpp @@ -32,6 +32,5 @@ class PE { virtual ~PE() {} }; - } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/pe_params.hpp b/lite/backends/fpga/KD/pe_params.hpp index 709f04d399793c6f21c34fc1265f7ed8b5818314..9dc295a58d4bbfd50a0b9ecbdb06a22c8900cef7 100644 --- a/lite/backends/fpga/KD/pe_params.hpp +++ b/lite/backends/fpga/KD/pe_params.hpp @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -26,6 +27,7 @@ namespace zynqmp { struct ReLUParam { public: bool enabled = false; + float leaky_relu_factor = 0.0f; }; struct PEParam { @@ -98,6 +100,24 @@ struct DepthwiseConvParam : ConvParam { Tensor* quantizedFilter_ = new Tensor(); }; +struct GRUParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* h0 = nullptr; + Tensor* weight = nullptr; + Tensor* bias = nullptr; + + Tensor* batch_gate = nullptr; + Tensor* batch_reset_hidden_prev = nullptr; + Tensor* batch_hidden = nullptr; + Tensor* hidden = nullptr; + + std::string gate_activation = "sigmoid"; + std::string activation = "tanh"; + bool is_reverse = false; + bool origin_mode = false; +}; + enum PoolingType : int { MAX = 0, AVERAGE = 1, @@ -133,6 +153,12 @@ struct ElementwiseAddParam : PEParam { EWAddArgs ewargs; }; +struct ElementwiseMulParam : PEParam { + public: + std::vector inputs; + Tensor* output = nullptr; +}; + struct FullyConnectedParam : PEParam { public: Tensor* input = nullptr; diff --git a/lite/backends/fpga/KD/pes/conv_pe.hpp b/lite/backends/fpga/KD/pes/conv_pe.hpp index e897f82280fa57f904bd7c749e371d8ec9219b51..fb15eaf77822eed076ec2001bace6871e93587ff 100644 --- a/lite/backends/fpga/KD/pes/conv_pe.hpp +++ b/lite/backends/fpga/KD/pes/conv_pe.hpp @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "lite/backends/fpga/KD/pe.hpp" @@ -49,7 +50,108 @@ class ConvPE : public PE { concatPE_.init(); concatPE_.apply(); } + + if (DLEngine::get_instance().isZU3() && + param_.input->shape().dimSize() == 4 && + param_.input->shape().width() == 1 && + param_.input->shape().width() >= 2048) { + use_cpu_ = true; + } + + if (param_.filter->shape().width() == 1 && + param_.filter->shape().height() == 1) { // NOLINT + } + if (!use_cpu_) { // NOLINT + } } + + void cpu_conv_hwc() { + Tensor* input = param_.input; + Tensor* output = param_.output; + input->syncToCPU(); + + Tensor float_input; + Tensor float_output; + float* image_addr = float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); + float_input.syncToCPU(); + float* out = float_output.mutableData(FP32, output->shape()); + + int out_width = output->shape().width(); + int out_channel = output->shape().channel(); + int in_channel = input->shape().channel(); + + float* filter_data = param_.filter->data(); + + int image_height = input->shape().height(); + int image_width = input->shape().width(); + int image_channels = input->shape().channel(); + int image_pad_h = param_.paddings[0]; + int image_pad_w = param_.paddings[1]; + int kernel_height = param_.filter->shape().height(); + int kernel_width = param_.filter->shape().width(); + int kernel_step_h = param_.strides[0]; + int kernel_step_w = param_.strides[1]; + int pooled_height_ = output->shape().height(); + int pooled_width_ = out_width; + int filter_chw = image_channels * kernel_height * kernel_width; + + float max = 0; + + for (int ph = 0; ph < pooled_height_; ph++) { + for (int pw = 0; pw < pooled_width_; pw++) { + int hstart = ph * kernel_step_h - image_pad_h; + int wstart = pw * kernel_step_w - image_pad_w; + int hend = + std::min(hstart + kernel_height, static_cast(image_height)); + int wend = + std::min(wstart + kernel_width, static_cast(image_width)); + hstart = std::max(hstart, static_cast(0)); + wstart = std::max(wstart, static_cast(0)); + for (int oc = 0; oc < out_channel; oc++) { + float sum = 0.0f; + const int pool_index = (ph * pooled_width_ + pw) * out_channel + oc; + for (int c = 0; c < image_channels; c++) { + for (int h = hstart; h < hend; h++) { + int hi = 0; + if (ph == 0) { + hi = h - hstart + image_pad_h; + } else { + hi = h - hstart; + } + for (int w = wstart; w < wend; w++) { + int wi = 0; + if (pw == 0) { + wi = w - wstart + image_pad_w; + } else { + wi = w - wstart; + } + const int index = (h * image_width + w) * image_channels + c; + int weight_index = oc * filter_chw + + kernel_width * kernel_height * c + + kernel_width * hi + wi; + float value = image_addr[index] * filter_data[weight_index]; + sum += value; + } + } + } + + if (param_.relu.enabled && sum < 0) { + sum = -sum; + } + if (sum > max) { + max = sum; + } + out[pool_index] = sum; + } + } + } + float_output.flush(); + output->copyFrom(&float_output); + output->scale()[0] = max / 127; + output->scale()[1] = 127 / max; + } + void cpu_compute() { Tensor* input = param_.input; Tensor* output = param_.output; @@ -59,43 +161,78 @@ class ConvPE : public PE { Tensor float_output; float* image_addr = float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); + float_input.syncToCPU(); + float* out = float_output.mutableData(FP32, output->shape()); + float* bias_data = param_.bias()->data(); + + int out_width = output->shape().width(); int out_channel = output->shape().channel(); int in_channel = input->shape().channel(); float* filter_data = param_.filter->data(); float* mi = new float[in_channel]; + float max = 0; + int out_index = 0; for (int i = 0; i < out_channel; i++) { float* image = image_addr; float* filter_ptr = filter_data + i * in_channel; float* out_ptr = mi; -#pragma omp parallel for - for (int j = 0; j < in_channel; j++) { - float value = image_addr[j] * filter_ptr[j]; - mi[j] = value; - } - float sum = 0; - for (int j = 0; j < in_channel; j++) { - sum += mi[j]; + for (int h = 0; h < output->shape().height(); h++) { + for (int w = 0; w < output->shape().width(); w++) { + float sum = 0; + + // #pragma omp parallel for + for (int j = 0; j < in_channel; j++) { + int image_index = h * out_width * in_channel + w * in_channel + j; + float value = image_addr[image_index] * filter_ptr[j]; + sum += value; + } + + sum += bias_data[i]; + + if (param_.relu.enabled && sum < 0) { + sum = 0; + } + if (sum > max) { + max = sum; + } + out_index = h * out_width * out_channel + w * out_channel + i; + out[out_index] = sum; + } } - out[i] = sum; } delete[] mi; float_output.flush(); output->copyFrom(&float_output); + output->scale()[0] = max / 127; + output->scale()[1] = 127 / max; } bool dispatch() { - inplace_.relu_enable = param_.relu.enabled; + if (use_cpu_) { + cpu_compute(); + return true; + } + + inplace_.leaky_relu_enable = + (param_.relu.leaky_relu_factor != 0) ? true : false; + inplace_.relu_enable = + inplace_.leaky_relu_enable ? false : param_.relu.enabled; + inplace_.power_enable = false; inplace_.normalize_enable = false; - - if (param_.relu.enabled) { - inplace_.relu_enable = param_.relu.enabled; + if (inplace_.relu_enable || inplace_.leaky_relu_enable) { config_inplace(inplace_); + if (inplace_.leaky_relu_enable) { + activeParamterArgs.type = TYPE_LEAK_RELU; + activeParamterArgs.leaky_relu_factor = + fp32_2_fp16(param_.relu.leaky_relu_factor); + config_activation(activeParamterArgs); + } } std::vector& params = param_.splitParams(); @@ -104,9 +241,16 @@ class ConvPE : public PE { ret |= compute_fpga_conv_basic(conv_param->args); } - if (param_.relu.enabled) { + if (inplace_.relu_enable || inplace_.leaky_relu_enable) { inplace_.relu_enable = false; + inplace_.leaky_relu_enable = false; config_inplace(inplace_); + + if (inplace_.leaky_relu_enable) { + activeParamterArgs.type = TYPE_LEAK_RELU; + activeParamterArgs.leaky_relu_factor = fp32_2_fp16(0); + config_activation(activeParamterArgs); + } } size_t size = params.size(); @@ -127,11 +271,13 @@ class ConvPE : public PE { ConvParam& param() { return param_; } private: + bool use_cpu_ = false; ConvParam param_; ConcatPE concatPE_; ElementwiseAddPE addPE_; int split_axis = 0; InplaceArgs inplace_ = {0}; + ActiveParamterArgs activeParamterArgs; }; } // namespace zynqmp diff --git a/lite/backends/fpga/KD/pes/conv_process.hpp b/lite/backends/fpga/KD/pes/conv_process.hpp old mode 100644 new mode 100755 index 23332b422df65250f8cadf07f5e0d95e970d316a..ecee45569c8df3d3e3926b2ca78cb49da8415aa4 --- a/lite/backends/fpga/KD/pes/conv_process.hpp +++ b/lite/backends/fpga/KD/pes/conv_process.hpp @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#ifndef conv_process_hpp +#define conv_process_hpp + #include #include #include @@ -45,7 +48,9 @@ inline int get_split_num(Tensor* filter) { filter->shape().width(); auto num = filter->shape().num(); int div_capacity = filter::calc_division_capacity(chw); - return filter::calc_split_num(num, div_capacity); + int filter_num_alignment = filter::get_filter_num_alignment(); + int aligned_num = align_to_x(num, filter_num_alignment); + return filter::calc_split_num(aligned_num, div_capacity); } inline void fill_scale_bias_const(ConvParam* param_) { @@ -126,41 +131,85 @@ inline void format_scale_bias(Tensor* scale, bias_data = bias->data(); } int channel = filter->shape().num(); - Shape bias_scale_shape(N, {2 * channel}); + int scale_bias_len = align_to_x(channel / group, BS_NUM_ALIGNMENT) * group; + + int c_per_group = channel / group; + int aligned_c_per_group = align_to_x(channel / group, BS_NUM_ALIGNMENT); + + Shape bias_scale_shape(N, {2 * scale_bias_len}); float* bs_data = scale_bias->mutableData(FP32, bias_scale_shape); - for (int i = 0; i < channel; i++) { - float scale_value = scale_data == nullptr ? 1 : scale_data[i]; - float bias_value = bias_data == nullptr ? 0 : bias_data[i]; - bs_data[i + channel] = scale_value; - bs_data[i] = bias_value; + float* temp_data = + reinterpret_cast(fpga_malloc(2 * scale_bias_len * sizeof(float))); + memset(temp_data, 0, 2 * scale_bias_len * sizeof(float)); + + std::vector scales; + if (scale_data != nullptr) { + for (int i = 0; i < channel; ++i) { + scales.push_back(scale_data[i]); + } + for (int i = 0; i < scale_bias_len - channel; i++) { + scales.push_back(1); + } + } else { + for (int i = 0; i < scale_bias_len; i++) { + scales.push_back(1); + } } - int element_num_per_div = get_filter_num_per_div(filter, group); - bias_scale::format_bias_scale_array(&bs_data, element_num_per_div, channel); + for (int i = 0; i < scale_bias_len; ++i) { + temp_data[i + scale_bias_len] = 1; + temp_data[i] = 0; + } + + for (int g = 0; g < group; g++) { + for (int c = 0; c < c_per_group; c++) { + int src_index = g * c_per_group + c; + int dst_index = g * aligned_c_per_group + c; + float scale_value = scales[src_index]; + float bias_value = bias_data == nullptr ? 0 : bias_data[src_index]; + temp_data[dst_index + scale_bias_len] = scale_value; + temp_data[dst_index] = bias_value; + } + } + + bias_scale::format_bias_scale_array( + &temp_data, scale_bias_len / group, scale_bias_len); + memcpy(bs_data, temp_data, 2 * scale_bias_len * sizeof(float)); } -inline void format_filter(Tensor* filter, Tensor* quantized_filter, int group) { +inline void format_filter(Tensor* filter, + Tensor* quantized_filter, + int group, + std::vector& scales) { // NOLINT float max_value = find_max(*filter); Shape& filter_shape = filter->shape(); + + int mem_size; + std::vector max_values; + int8_t* quantized_data = filter::format_filter(filter->data(), + mem_size, + filter_shape.num(), + filter_shape.channel(), + filter_shape.height(), + filter_shape.width(), + group, + max_value, + max_values); + + float mem_factor = mem_size * 1.0f / filter->shape().numel(); + quantized_filter->setMemScale(mem_factor); + quantized_filter->setAligned(true); - quantized_filter->mutableData(INT8, filter->shape()); + int8_t* src = quantized_filter->mutableData(INT8, filter->shape()); quantized_filter->scale()[0] = max_value / 127.0f; quantized_filter->scale()[1] = 127.0f / max_value; - auto memory_size = filter->shape().memorySize(sizeof(float)); - auto new_data = reinterpret_cast(fpga_malloc(memory_size)); - memcpy(new_data, filter->data(), memory_size); - size_t mem_size = filter::format_filter(&new_data, - filter_shape.num(), - filter_shape.channel(), - filter_shape.height(), - filter_shape.width(), - group, - max_value); - int8_t* src = quantized_filter->mutableData(INT8, filter->shape()); - memcpy(src, new_data, mem_size); - fpga_free(new_data); + memcpy(src, quantized_data, mem_size); quantized_filter->flush(); + + for (size_t i = 0; i < max_values.size(); i++) { + scales.push_back(max_values[i] / max_value); + } } inline void format_dw_filter(Tensor* filter, @@ -207,10 +256,18 @@ inline void split_filter_num(const ConvParam& c_param) { Tensor* out = param.output; Tensor* filter = param.filter; auto channel = out->shape().channel(); - int split_num = param.groups == 1 ? get_split_num(param.filter) : 1; int filter_num_per_div = get_filter_num_per_div(filter, param.groups); + auto chw = filter->shape().channel() * filter->shape().height() * + filter->shape().width(); + auto num = filter->shape().num(); + int div_capacity = filter::calc_division_capacity(chw); + int filter_num_alignment = filter::get_filter_num_alignment(); + int aligned_num = + align_to_x(num / param.groups, filter_num_alignment) * param.groups; + split_num = filter::calc_split_num(aligned_num, div_capacity); + Shape& out_shape = out->shape(); for (int i = 0; i < split_num; i++) { BasicConvParam* conv_param = new BasicConvParam(); @@ -251,9 +308,17 @@ inline void split_filter_num(const ConvParam& c_param) { filter->data() + i * filter_num_per_div * filter_hwc, filter_num * filter_hwc * sizeof(float)); new_filter.flush(); - conv_param->filter.mutableData(FP32, f_shape); - format_filter(&new_filter, &(conv_param->filter), param.groups); + + if (param.groups != 1) { + int mem_factor = + 32 / filter_num_per_div; // TODO(chonwhite): change 32 to param; + conv_param->filter.setMemScale(mem_factor); + } + + std::vector v; // TODO(chonwhite): change local variable name + format_filter(&new_filter, &(conv_param->filter), param.groups, v); + conv_param->filter.setDataType(INT8); int sb_num = 2 * align_to_x(filter_num, BS_NUM_ALIGNMENT); Tensor scale; @@ -265,7 +330,7 @@ inline void split_filter_num(const ConvParam& c_param) { float* scale_data = scale.mutableData(FP32, s_shape); float* bias_data = bias.mutableData(FP32, s_shape); for (int n = 0; n < filter_num; n++) { - scale_data[n] = param.scale()->data()[n + chnnnel_start]; + scale_data[n] = param.scale()->data()[n + chnnnel_start] * v[n]; } for (int n = 0; n < filter_num; n++) { bias_data[n] = param.bias()->data()[n + chnnnel_start]; @@ -276,11 +341,14 @@ inline void split_filter_num(const ConvParam& c_param) { &conv_param->filter, &conv_param->scaleBias, param.groups); + conv_param->scaleBias.flush(); + float* bs_data = conv_param->scaleBias.data(); args.group_num = param.groups; args.relu_enabled = param.relu.enabled; args.sb_address = conv_param->scaleBias.data(); + args.sb_address = bs_data; args.kernel.stride_h = param.strides[1]; args.kernel.stride_w = param.strides[0]; args.kernel.height = new_filter.shape().height(); @@ -294,17 +362,12 @@ inline void split_filter_num(const ConvParam& c_param) { args.image.channels = input->shape().channel(); args.image.width = input->shape().width(); args.image.height = input->shape().height(); - auto paddings = *param.padding; - args.image.pad_width = param.paddings[2]; + args.image.pad_width = param.paddings[1]; args.image.pad_height = param.paddings[0]; + args.dilation = param.dilations[0]; + args.output.address = out_address; args.output.scale_address = out_scale_address; - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); - if (!pad_equal) { - LOG(FATA) << "This pad not support ! " << paddings[0] << ", " - << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; - } param.splitParams().push_back(conv_param); } } @@ -317,7 +380,7 @@ inline void split_channel(const ConvParam& c_param) { int num = ceil(input->shape().channel() * 1.0f / 2047); int channel = input->shape().channel() / num; - std::cout << "channel::" << channel << "num::" << num << std::endl; + Shape bs_shape(N, {channel}); for (int i = 0; i < num; i++) { @@ -331,6 +394,7 @@ inline void split_channel(const ConvParam& c_param) { // filter transformation; Shape f_shape(NCHW, {param.filter->shape().num(), channel, 1, 1}); + Tensor new_filter; float* dst = new_filter.mutableData(FP32, f_shape); @@ -341,7 +405,8 @@ inline void split_channel(const ConvParam& c_param) { src += param.filter->shape().channel(); } new_filter.flush(); - format_filter(&new_filter, &(conv_param->filter), param.groups); + std::vector scales; + format_filter(&new_filter, &(conv_param->filter), param.groups, scales); Tensor bias; Tensor scale; @@ -379,18 +444,11 @@ inline void split_channel(const ConvParam& c_param) { args.image.channels = conv_param->input.shape().channel(); args.image.width = conv_param->input.shape().width(); args.image.height = conv_param->input.shape().height(); - auto paddings = *param.paddings; - args.image.pad_width = paddings[2]; - args.image.pad_height = paddings[0]; - + args.image.pad_width = param.paddings[1]; + args.image.pad_height = param.paddings[0]; + args.dilation = param.dilations[0]; args.output.address = conv_param->output.mutableData(); args.output.scale_address = conv_param->output.scale(); - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); - if (!pad_equal) { - LOG(FATA) << "This pad not support ! " << paddings[0] << ", " - << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; - } param.splitParams().push_back(conv_param); } } @@ -422,7 +480,6 @@ inline bool compute_conv(const ConvParam& c_conv_params) { for (int i = 0; i < 1; i++) { for (int i = 0; i < img.shape().numel(); i++) { float value = half_to_float(img.data()[i]); - std::cout << "value:" << value << std::endl; } } } @@ -431,3 +488,5 @@ inline bool compute_conv(const ConvParam& c_conv_params) { } // namespace zynqmp } // namespace paddle + +#endif /* conv_process_hpp */ diff --git a/lite/backends/fpga/KD/pes/crop_pe.cpp b/lite/backends/fpga/KD/pes/crop_pe.cpp old mode 100644 new mode 100755 index c29df623aa610d329a46ee337cdcb1abd801881c..1438aaba6565cefa72f863d5fc3af0a389fc95e0 --- a/lite/backends/fpga/KD/pes/crop_pe.cpp +++ b/lite/backends/fpga/KD/pes/crop_pe.cpp @@ -14,8 +14,6 @@ limitations under the License. */ #include "lite/backends/fpga/KD/pes/crop_pe.hpp" -#include - namespace paddle { namespace zynqmp { diff --git a/lite/backends/fpga/KD/pes/crop_pe.hpp b/lite/backends/fpga/KD/pes/crop_pe.hpp index 6ebbcdb31f1afb7939c75a2ba9254c0b31f67d31..ccd1e0c98968375ebd840c7e8b15aedd6ad7ef77 100755 --- a/lite/backends/fpga/KD/pes/crop_pe.hpp +++ b/lite/backends/fpga/KD/pes/crop_pe.hpp @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include diff --git a/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp b/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp old mode 100644 new mode 100755 index f86806102d4a217ae4bb7355b36ca10d96ca4a05..0efca2ec2e60e8973d92f41463b0444722f2a73b --- a/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp +++ b/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp @@ -37,18 +37,36 @@ class DepthwiseConvPE : public PE { Tensor* output = param.output; int channel = output->shape().channel(); - float* new_scale_data = param_.scale()->data(); - float* new_bias_data = param_.bias()->data(); - float16* b_data = bias_.mutableData(FP16, param_.bias()->shape()); - for (int i = 0; i < channel; i++) { - b_data[i] = float_to_half(new_bias_data[i]); + if (param_.bias()->dataType() == FP32) { + float* new_bias_data = param_.bias()->data(); + // bias从float转换成float16 + for (int i = 0; i < channel; i++) { + b_data[i] = float_to_half(new_bias_data[i]); + } + bias_.flush(); + } else { + float16* new_bias_data = param_.bias()->data(); + memcpy(b_data, new_bias_data, channel * sizeof(float16)); + bias_.flush(); } - bias_.flush(); - Tensor* quantized_filter = param.quantizedFilter(); - quantized_filter->mutableData(FP16, param.filter->shape()); - format_dw_filter(param.filter, param.quantizedFilter(), new_scale_data); + if (param_.scale()->dataType() == FP32) { + float* new_scale_data = param_.scale()->data(); + Tensor* quantized_filter = param.quantizedFilter(); + quantized_filter->mutableData(FP16, param.filter->shape()); + format_dw_filter(param.filter, param.quantizedFilter(), new_scale_data); + + } else { + // filter 全为1时,且channal为对齐时 + float16* scale_data = param_.scale()->data(); + float16* filter_data = param.quantizedFilter()->mutableData( + FP16, param.filter->shape()); + memcpy(filter_data, + scale_data, + param.filter->shape().numel() * sizeof(float16)); + param.quantizedFilter()->flush(); + } DWconvArgs args = {0}; args.bias_address = b_data; @@ -61,21 +79,14 @@ class DepthwiseConvPE : public PE { args.image.channels = input->shape().channel(); args.image.height = input->shape().height(); args.image.width = input->shape().width(); - auto paddings = *param.paddings; - args.image.pad_width = param.paddings[2]; - args.image.pad_height = param.paddings[0]; + args.image.pad_width = param.paddings[0]; + args.image.pad_height = param.paddings[1]; args.image.scale_address = input->scale(); args.output.address = output->data(); args.output.scale_address = output->scale(); args.out_width = param.output->shape().width(); args.out_height = param.output->shape().height(); args.sub_conv_num = 1; - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); - if (!pad_equal) { - LOG(FATA) << "This pad not support ! " << paddings[0] << ", " - << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; - } param.args = args; inplace_.relu_enable = param_.relu.enabled; diff --git a/lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp b/lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0505e78b61e3b0130c876880894cec29c78406f2 --- /dev/null +++ b/lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp @@ -0,0 +1,77 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/backends/fpga/KD/pe.hpp" +#include "lite/backends/fpga/KD/pe_params.hpp" +namespace paddle { +namespace zynqmp { + +class ElementwiseMulPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* input = param_.inputs[0]; + Tensor* output = param_.output; + + int wc_aligned = align_to_x(param_.inputs[0]->shape().numel(), 32); + + Shape s(N, {wc_aligned}); + float16* bias_data = bias_tensor.mutableData(FP16, s); + memset(bias_data, 0, wc_aligned * sizeof(float16)); + + ScaleArgs& args = args_; + args.scale_address = param_.inputs[1]->data(); + args.bias_address = bias_tensor.data(); + args.wc_alignment = wc_aligned; + args.channel_alignment = wc_aligned; + args.image.address = input->data(); + args.image.scale_address = input->scale(); + args.image.channels = wc_aligned; + args.image.height = 1; + args.image.width = 1; + args.image.pad_width = 0; + args.image.pad_height = 0; + args.output.address = output->data(); + args.output.scale_address = output->scale(); + } + + void updateInput(Tensor* t, int index) { + if (index == 0) { + args_.scale_address = t->data(); // replace inputs? + } + } + + bool dispatch() { + compute_fpga_scale(args_) == 0; + return true; + } + + ElementwiseMulParam& param() { return param_; } + + private: + ElementwiseMulParam param_; + ScaleArgs args_ = {0}; + Tensor bias_tensor; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/backends/fpga/KD/pes/fully_connected_pe.hpp b/lite/backends/fpga/KD/pes/fully_connected_pe.hpp old mode 100644 new mode 100755 index 2179a142ad3b3a990512b3ea1cd202bc5ce502f1..db3e05276171607da4cea421dd554846a00314a6 --- a/lite/backends/fpga/KD/pes/fully_connected_pe.hpp +++ b/lite/backends/fpga/KD/pes/fully_connected_pe.hpp @@ -37,7 +37,10 @@ class FullyConnectedPE : public PE { ConvParam& convParam_ = convPE_.param(); Tensor* input = param_.input; convParam_.input = param_.input; + num_ = param_.input->shape().num(); + convParam_.output = param_.output; + convParam_.groups = 1; convParam_.strides = {1, 1}; convParam_.paddings = {0, 0}; @@ -63,7 +66,6 @@ class FullyConnectedPE : public PE { new_filter_data[i * chw + j] = scale; } } - conv_filter->flush(); convParam_.filter = conv_filter; @@ -89,6 +91,8 @@ class FullyConnectedPE : public PE { private: FullyConnectedParam param_; ConvPE convPE_; + Tensor tempOut_; + int num_ = 1; }; } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/pes/gru_pe.hpp b/lite/backends/fpga/KD/pes/gru_pe.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dcacab4eeef32b245d4126b72597b398a6627ba6 --- /dev/null +++ b/lite/backends/fpga/KD/pes/gru_pe.hpp @@ -0,0 +1,191 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/backends/arm/math/sgemm.h" +#include "lite/backends/fpga/KD/pe.hpp" +#include "lite/backends/fpga/KD/pe_params.hpp" +#include "lite/backends/fpga/KD/pes/elementwise_add_pe.hpp" +#include "lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp" +#include "lite/backends/fpga/KD/pes/fully_connected_pe.hpp" +#include "lite/backends/fpga/KD/pes/relu_pe.hpp" + +#include "lite/api/paddle_place.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace zynqmp { + +struct GRUTensors { + Tensor* gate; + Tensor* pre_output; + Tensor* output; + Tensor* reset_output; +}; + +class GRUPE : public PE { + public: + bool init() { return true; } + + void apply() { + auto hidden = param_.hidden; + int frame_size = hidden->shape().channel(); + + zynqmp::Shape hidden_shape{zynqmp::NCHW, {1, frame_size, 1, 1}}; + float16* prev_hidden_data = + prev_hidden_.mutableData(zynqmp::FP16, hidden_shape); + memset(prev_hidden_data, 0, hidden_shape.numel() * sizeof(float16)); + + zynqmp::Shape weight_shape{zynqmp::NC, {frame_size, frame_size * 2}}; + float* weight_data = weight_.mutableData(zynqmp::FP32, weight_shape); + memset(weight_data, 0, weight_shape.numel() * sizeof(float)); + weight_data = weight_.mutableData(zynqmp::FP32, weight_shape); + memcpy(weight_data, + param_.weight->data(), + weight_shape.numel() * sizeof(float)); + + Shape gate_shape(zynqmp::NC, {1, frame_size * 2}); + gate_ping_.mutableData(FP32, gate_shape); + gate_pong_.mutableData(FP16, gate_shape); + + zynqmp::FullyConnectedParam& pre_out_param = pre_out_pe_.param(); + pre_out_param.input = &prev_hidden_; + pre_out_param.output = &gate_pong_; + pre_out_param.filter = &weight_; + pre_out_param.bias = &gate_ping_; + pre_out_pe_.init(); + pre_out_pe_.apply(); + + reset_gate_.mutableData(FP16, hidden_shape); + prev_hidden_.mutableData(FP16, hidden_shape); + reset_hidden_.mutableData(FP16, hidden_shape); + + ElementwiseMulParam& mul_param = mul_pe_.param(); + mul_param.inputs = {&reset_gate_, &prev_hidden_}; + mul_param.output = &reset_hidden_; + mul_pe_.init(); + mul_pe_.apply(); + } + + bool dispatch() { return true; } + + void gru_unit_reset_act(const lite_api::ActivationType active_gate, + GRUTensors& value, // NOLINT + int frame_size, + int batch_size) { + int stride_update = 3 * frame_size; + int stride_cell_state = 3 * frame_size; + int stride_hidden_prev = frame_size; + int stride_hidden = frame_size; + + float* update_gate_data = gate_ping_.data(); + float* reset_gate_data = update_gate_data + frame_size; + + for (int b = 0; b < batch_size; b++) { + Tensor tmp; + Shape s(NC, {1, frame_size}); + float* tmp_data = tmp.mutableData(FP32, s); + + for (int i = 0; i < frame_size; i++) { + update_gate_data[i] = + lite::arm::math::active_f32( + update_gate_data[i]); + reset_gate_data[i] = + lite::arm::math::active_f32( + reset_gate_data[i]); + } + memcpy(tmp_data, reset_gate_data, frame_size * sizeof(float)); + tmp.flush(); + reset_gate_.copyFrom(&tmp); + + Tensor* hidden_prev = value.pre_output; + if (hidden_prev) { + // TODO(chonwhite): change to pre_out; + prev_hidden_.copyFrom(value.pre_output); + prev_hidden_.saveToFile("prev_.txt"); + } + + mul_pe_.dispatch(); + reset_hidden_.saveToFile("reset_hidden_.txt"); + update_gate_data += stride_update; + reset_gate_data += stride_update; + + // reset_hidden_prev += stride_hidden;// TODO + } + } + + void gru_unit_out_act(const lite_api::ActivationType active_node, + bool origin_mode, + GRUTensors& value, // NOLINT + int frame_size, + int batch_size) {} + + void copy_input(GRUTensors& value) { // NOLINT + float max = find_max(*(value.gate)); + gate_ping_.mutableData(FP32, value.gate->shape()); + gate_ping_.copyFrom(value.gate); + // update input pointer? + } + + void GRUCOmpute(GRUTensors& value, // NOLINT + int frame_size, + int batch_size, + const lite_api::ActivationType active_node, + const lite_api::ActivationType active_gate, + bool origin_mode) { + copy_input(value); + + if (value.pre_output) { + // copy by batch; + pre_out_pe_.dispatch(); + gate_ping_.copyFrom(&gate_pong_); + } + + gru_unit_reset_act(active_gate, value, frame_size, batch_size); + } + + GRUParam& param() { return param_; } + + Tensor* updateGate() { return &update_gate_; } + + Tensor* resetGate() { return &reset_gate_; } + + private: + GRUParam param_; + zynqmp::Tensor gate_ping_; + zynqmp::Tensor gate_pong_; + zynqmp::Tensor bias_; + zynqmp::Tensor weight_; + zynqmp::Tensor state_weight_; + zynqmp::Tensor update_gate_; + zynqmp::Tensor reset_gate_; + zynqmp::Tensor cell_state_; + zynqmp::Tensor prev_hidden_; + zynqmp::Tensor reset_hidden_; + + Tensor tempTensor; + + ReluPE update_relu_pe_; + ReluPE reset_relu_pe_; + zynqmp::ElementwiseMulPE mul_pe_; + zynqmp::FullyConnectedPE pre_out_pe_; + zynqmp::FullyConnectedPE reset_out_pe_; + + zynqmp::ElementwiseAddPE bias_ew_pe_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h b/lite/backends/fpga/KD/pes/gru_util.hpp similarity index 71% rename from lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h rename to lite/backends/fpga/KD/pes/gru_util.hpp index 3c76e0e8b5cf0842cb8d5a613cef7aee3cd13bdb..d49169846f4f18e4d8e30f3658c2173157678f81 100644 --- a/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h +++ b/lite/backends/fpga/KD/pes/gru_util.hpp @@ -14,13 +14,10 @@ #pragma once -#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/backends/arm/math/gru_utils.h" -USE_XPU_BRIDGE(relu); -USE_XPU_BRIDGE(conv2d); -USE_XPU_BRIDGE(depthwise_conv2d); -USE_XPU_BRIDGE(elementwise_add); -USE_XPU_BRIDGE(pool2d); -USE_XPU_BRIDGE(softmax); -USE_XPU_BRIDGE(mul); -USE_XPU_BRIDGE(batch_norm); +namespace paddle { +namespace lite { +namespace fpga {} +} +} diff --git a/lite/backends/fpga/KD/pes/output_pe.hpp b/lite/backends/fpga/KD/pes/output_pe.hpp old mode 100644 new mode 100755 index 1c99386ab19f485c07723c7fcc8501bdf5556f6c..2944691693b135a2d2df7b91ecbe0ef249b015d8 --- a/lite/backends/fpga/KD/pes/output_pe.hpp +++ b/lite/backends/fpga/KD/pes/output_pe.hpp @@ -25,6 +25,8 @@ class OutputPE : public PE { bool init() { Tensor* output = param_.output; output->setAligned(false); + DLEngine::get_instance().out_data = reinterpret_cast( + fpga_malloc(output->shape().numel() * sizeof(float))); return true; } @@ -41,6 +43,15 @@ class OutputPE : public PE { } else { output->copyFrom(input); } + // + output->syncToCPU(); + if (DLEngine::get_instance().out_data == nullptr) { + DLEngine::get_instance().out_data = reinterpret_cast( + fpga_malloc(output->shape().numel() * sizeof(float))); + } + memcpy(DLEngine::get_instance().out_data, + output->data(), + output->shape().numel() * sizeof(float)); return true; } diff --git a/lite/backends/fpga/KD/pes/pooling_pe.hpp b/lite/backends/fpga/KD/pes/pooling_pe.hpp old mode 100644 new mode 100755 index 5bb4f5285a48c7696b1f0f78a9b1c4fe6a9d76c5..a8725b51a690e0e134785fcfdb3dd70edeffd441 --- a/lite/backends/fpga/KD/pes/pooling_pe.hpp +++ b/lite/backends/fpga/KD/pes/pooling_pe.hpp @@ -35,24 +35,25 @@ class PoolingPE : public PE { Tensor* input = param_.input; Tensor* output = param_.output; - uint32_t k_width = param_.kernelSize[0]; - uint32_t k_height = param_.kernelSize[1]; + uint32_t k_height = param_.kernelSize[0]; + uint32_t k_width = param_.kernelSize[1]; if (param_.globalPooling) { k_width = input->shape().width(); k_height = input->shape().height(); + param_.kernelSize[0] = k_height; + param_.kernelSize[1] = k_width; } PoolingArgs args = {0}; args.mode = param_.type; - auto paddings = *param_.paddings; args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); args.image.address = input->data(); args.image.channels = input->shape().channel(); args.image.height = input->shape().height(); args.image.width = input->shape().width(); - args.image.pad_height = paddings[0]; - args.image.pad_width = paddings[2]; + args.image.pad_height = param_.paddings[0]; + args.image.pad_width = param_.paddings[1]; args.image.scale_address = input->scale(); args.output.address = output->mutableData(); args.output.scale_address = output->scale(); @@ -66,6 +67,7 @@ class PoolingPE : public PE { use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && (k_width > 7 || k_height > 7); + use_cpu_ = param_.type == AVERAGE; } void compute() { @@ -77,13 +79,12 @@ class PoolingPE : public PE { float* image_addr = float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); float16* data_out = output->data(); - auto paddings = *param_.paddings; int image_height = input->shape().height(); int image_width = input->shape().width(); int image_channels = input->shape().channel(); - int image_pad_h = paddings[0]; - int image_pad_w = paddings[2]; + int image_pad_h = param_.paddings[0]; + int image_pad_w = param_.paddings[1]; int kernel_height = param_.kernelSize[1]; int kernel_width = param_.kernelSize[0]; int kernel_step_h = param_.strides[0]; @@ -129,7 +130,7 @@ class PoolingPE : public PE { output->flush(); } - void cpu_compute() { + void cpu_compute1() { Tensor* input = param_.input; Tensor* output = param_.output; input->syncToCPU(); @@ -138,7 +139,6 @@ class PoolingPE : public PE { float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); float16* data_out = output->data(); - int kernel_hw = param_.kernelSize[0] * param_.kernelSize[1]; float scale_max = 0; @@ -154,7 +154,35 @@ class PoolingPE : public PE { } output->scale()[0] = scale_max / 127.0f; output->scale()[1] = 127.0f / scale_max; - std::cout << "pool scale:" << scale_max / 127.0f << std::endl; + output->flush(); + } + + void cpu_compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + input->syncToCPU(); + + Tensor float_input; + float* float_input_data = + float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); + + float16* data_out = output->data(); + + int kernel_hw = param_.kernelSize[0] * param_.kernelSize[1]; + + float scale_max = 0; + for (int i = 0; i < output->shape().channel(); i++) { + float sum = 0; + for (int j = 0; j < kernel_hw; j++) { + sum += float_input_data[i * kernel_hw + j]; + } + float value = sum / kernel_hw; + data_out[i] = float_to_half(value); + scale_max = std::max(scale_max, std::abs(value)); + } + output->scale()[0] = scale_max / 127.0f; + output->scale()[1] = 127.0f / scale_max; output->flush(); } diff --git a/lite/backends/fpga/KD/pes/prior_box_pe.cpp b/lite/backends/fpga/KD/pes/prior_box_pe.cpp index d6a503a31d4e0736724740ce1875c916969d93e0..00dfe1830f6f44cbf6a30708fa5783563470c686 100644 --- a/lite/backends/fpga/KD/pes/prior_box_pe.cpp +++ b/lite/backends/fpga/KD/pes/prior_box_pe.cpp @@ -253,9 +253,8 @@ bool PriorBoxPE::dispatch() { if (cachedBoxes_ == nullptr) { cachedBoxes_ = new Tensor(); cachedVariances_ = new Tensor(); - cachedBoxes_->mutableData(FP16, param_.outputBoxes->shape()); - cachedVariances_->mutableData(FP16, - param_.outputVariances->shape()); + cachedBoxes_->mutableData(FP32, param_.outputBoxes->shape()); + cachedVariances_->mutableData(FP32, param_.outputVariances->shape()); cachedBoxes_->setDataLocation(CPU); cachedVariances_->setDataLocation(CPU); compute_prior_box(); diff --git a/lite/backends/fpga/KD/pes/scale_pe.hpp b/lite/backends/fpga/KD/pes/scale_pe.hpp old mode 100755 new mode 100644 index d5e16615d9943a1771dfabe916433768ecf16319..cc89ac943f90cb20062a3d6ef9a46b705193ad04 --- a/lite/backends/fpga/KD/pes/scale_pe.hpp +++ b/lite/backends/fpga/KD/pes/scale_pe.hpp @@ -14,11 +14,16 @@ limitations under the License. */ #pragma once +#include + #include "lite/backends/fpga/KD/pe.hpp" #include "lite/backends/fpga/KD/pe_params.hpp" +#include "lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp" +#include "lite/backends/fpga/KD/tensor.hpp" namespace paddle { namespace zynqmp { + class ScalePE : public PE { public: inline int gcd(int a, int b) { @@ -42,6 +47,8 @@ class ScalePE : public PE { Tensor* input = param_.input; Tensor* output = param_.output; Shape& input_shape = input->shape(); + DepthwiseConvParam& dw_param = dw_pe_.param(); + int channel = input_shape.channel(); int repeat = 1; int alignment = 16; @@ -51,70 +58,142 @@ class ScalePE : public PE { int c_lcm = lcm(channel, alignment); repeat = c_lcm / (channel); } + + // FPGA限制 H >2047, W >1023 , WC> 65536 ,需要使用CPU实现 Shape shape(N, {channel * repeat}); - param_.alignedBias()->mutableData(FP16, shape); - param_.alignedScale()->mutableData(FP16, shape); - float16* bias_data = param_.alignedBias()->data(); - float16* scale_data = param_.alignedScale()->data(); + float* filter_data = filter.mutableData(FP32, shape); + std::fill_n(filter_data, input->shape().channel(), 1.0f); - if (param_.bias != nullptr) { - float* bias_data_float = param_.bias->data(); + Tensor* scale = dw_param.scale(); + float16* scale_data = scale->mutableData(FP16, shape); + + Tensor* bias = dw_param.bias(); + float16* bias_data = bias->mutableData(FP16, shape); + std::fill_n(bias_data, input->shape().channel(), 0); + + if (param_.scale->dataType() == FP32) { + if (param_.bias != nullptr) { + float* bias_data_float = param_.bias->data(); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = float_to_half(bias_data_float[j]); + bias_data[i * length + j] = value; + } + } + } else { + float16 zero = float_to_half(0.0f); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + bias_data[i * length + j] = zero; + } + } + } + + float* scale_data_float = param_.scale->data(); for (int i = 0; i < repeat; i++) { for (int j = 0; j < length; j++) { - float16 value = float_to_half(bias_data_float[j]); - bias_data[i * length + j] = value; + float16 value = float_to_half(scale_data_float[j]); + scale_data[i * length + j] = value; } } } else { - float16 zero = float_to_half(0.0f); + if (param_.bias != nullptr) { + float16* bias_data_float = param_.bias->data(); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = bias_data_float[j]; + bias_data[i * length + j] = value; + } + } + } else { + float16 zero = float_to_half(0.0f); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + bias_data[i * length + j] = zero; + } + } + } + + float16* scale_data_float = param_.scale->data(); for (int i = 0; i < repeat; i++) { for (int j = 0; j < length; j++) { - bias_data[i * length + j] = zero; + float16 value = scale_data_float[j]; + scale_data[i * length + j] = value; } } } - float* scale_data_float = param_.scale->data(); - for (int i = 0; i < repeat; i++) { - for (int j = 0; j < length; j++) { - float16 value = float_to_half(scale_data_float[j]); - scale_data[i * length + j] = value; + dw_param.input = param_.input; + dw_param.output = param_.output; + dw_param.filter = &filter; + + dw_param.strides = {1, 1}; + dw_param.paddings = {0, 0}; + dw_param.kernelSize = {1, 1}; + dw_param.dilations = {1, 1}; + + dw_pe_.init(); + dw_pe_.apply(); + } + + void cpu_compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + Tensor float_input; + float* image_addr = float_input.mutableData(FP32, input->shape()); + input->syncToCPU(); + float_input.copyFrom(input); + float16* data_out = output->data(); + + float* scale_data = param_.scale->data(); + + int wh = input->shape().width() * input->shape().height(); + + float16* in_data = input->data(); + + float max = 0; + + for (int i = 0; i < wh; i++) { + for (int c = 0; c < input->shape().channel(); c++) { + int index = i * input->shape().channel() + c; + float value = half_to_float(in_data[index]) * scale_data[c]; + data_out[index] = float_to_half(value); + + if (value < 0) { + value = -value; + } + if (value > max) { + max = value; + } } } - - param_.alignedScale()->flush(); - param_.alignedBias()->flush(); - - int wc = input_shape.width() * input_shape.channel(); - int wc_aligned = align_image(wc); - - ScaleArgs& args = param_.args; - args.scale_address = param_.alignedScale()->data(); - args.bias_address = param_.alignedBias()->data(); - args.wc_alignment = wc_aligned; - args.channel_alignment = channel * repeat; - - args.image.address = input->data(); - args.image.scale_address = input->scale(); - args.image.channels = channel; - args.image.height = input_shape.height(); - args.image.width = input_shape.width(); - args.image.pad_width = 0; - args.image.pad_height = 0; - args.output.address = output->data(); - args.output.scale_address = output->scale(); + output->flush(); + output->scale()[0] = max / 127.0f; + output->scale()[1] = 127.0f / max; } bool dispatch() { + if (param_.scale->dataType() == FP16) { + DepthwiseConvParam& dw_param = dw_pe_.param(); + memcpy(dw_param.quantizedFilter()->mutableData(), + param_.scale->data(), + param_.scale->shape().numel() * sizeof(float16)); + dw_param.quantizedFilter()->scale()[0] = param_.scale->scale()[0]; + dw_param.quantizedFilter()->scale()[1] = param_.scale->scale()[1]; + + dw_param.quantizedFilter()->flush(); + } param_.input->syncToDevice(); - return compute_fpga_scale(param_.args) == 0; + return dw_pe_.dispatch(); } ScaleParam& param() { return param_; } private: ScaleParam param_; + Tensor filter; + DepthwiseConvPE dw_pe_; }; } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/shape.hpp b/lite/backends/fpga/KD/shape.hpp index 566ad8e6ff2eff32301e83b6cdb5b1addd0117fe..c25c3315145137a147928a164fcabd2923b09e87 100755 --- a/lite/backends/fpga/KD/shape.hpp +++ b/lite/backends/fpga/KD/shape.hpp @@ -23,6 +23,7 @@ limitations under the License. */ namespace paddle { namespace zynqmp { +static struct None none_; static struct NCHW nchw_; static struct NHWC nhwc_; static struct NC nc_; @@ -82,6 +83,9 @@ class Shape { void setLayoutType(LayoutType layout) { this->layoutType_ = layout; switch (layout) { + case None: + layout_ = &none_; + break; case NCHW: layout_ = &nchw_; break; diff --git a/lite/backends/fpga/KD/tensor.hpp b/lite/backends/fpga/KD/tensor.hpp index f003ded33eb51136ae0ae0a2c21988460232f89a..f1b07d02622fad32e99205667424a4cb3c9fb46d 100644 --- a/lite/backends/fpga/KD/tensor.hpp +++ b/lite/backends/fpga/KD/tensor.hpp @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -24,13 +25,10 @@ limitations under the License. */ #include #include -// #include "lite/core/tensor.h" - #include "lite/backends/fpga/KD/dl_engine.hpp" #include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" #include "lite/backends/fpga/KD/shape.hpp" -// #include "lite/backends/fpga/KD/types.hpp" namespace paddle { namespace zynqmp { @@ -117,7 +115,8 @@ class Tensor { template Dtype* mutableData() { - size_t memorySize = shape_->memorySize(CellSize(dataType_)); + size_t memorySize = + shape_->memorySize(CellSize(dataType_)) * mem_scale_factor_; if (placeHolder_ != nullptr) { if (memorySize > placeHolder_->memorySize()) { placeHolder_.reset(new PlaceHolder(memorySize)); @@ -241,6 +240,10 @@ class Tensor { } } + void setMemScale(float scale_factor) { + this->mem_scale_factor_ = scale_factor; + } + void shareDataWith(Tensor* src) { shareDataWith(src, src->shape()); } void shareDataWith(Tensor* src, const Shape& shape, int offset = 0) { @@ -276,9 +279,11 @@ class Tensor { .height = 1, .pad_width = 0u, .pad_height = 0u}; - args.output = { + + ImageOutputArgs output = { .address = data(), .scale_address = scale(), }; + args.output = output; src->syncToDevice(); size_t aligned_remainder = src->shape().numel() % 16; if (aligned_remainder > 0) { @@ -294,10 +299,16 @@ class Tensor { this->invalidate(); } - void flush() { fpga_flush(placeHolder_->data(), placeHolder_->memorySize()); } + void flush() { + size_t memorySize = + shape_->memorySize(CellSize(dataType_)) * mem_scale_factor_; + fpga_flush(placeHolder_->data(), memorySize); + } void invalidate() { - fpga_invalidate(placeHolder_->data(), placeHolder_->memorySize()); + size_t memorySize = + shape_->memorySize(CellSize(dataType_)) * mem_scale_factor_; + fpga_invalidate(placeHolder_->data(), memorySize); } void sync() { @@ -339,6 +350,8 @@ class Tensor { } } + void printScale(std::string type) { printScale(); } + std::string dimsFileName() { return std::to_string(shape_->num()) + "_" + std::to_string(shape_->channel()) + "_" + @@ -358,29 +371,9 @@ class Tensor { saveToFile(path); } - friend std::ostream& operator<<(std::ostream& os, Tensor& tensor) { - os << "tensor:" - << "\n"; - os << "dims: {"; - for (int i = 0; i < tensor.shape().dimSize(); ++i) { - os << tensor.shape()[i] << " "; - } - os << "}\n"; - for (int i = 0; i < tensor.shape().numel(); i++) { - float value = 0; - if (tensor.dataType() == FP32) { - value = tensor.data()[i]; - } else { - value = half_to_float(tensor.data()[i]); - } - os << value << " "; - } - os << "\n"; - return os; - } - void saveToFile(std::string path) { syncToCPU(); + invalidate(); std::ofstream ofs; static int counter = 0; std::string npath = std::to_string(counter) + "_" + path; @@ -389,17 +382,18 @@ class Tensor { } void save_file_with_name(std::string path) { - // return; invalidate(); std::ofstream ofs; - ofs.open(path); + for (int i = 0; i < shape_->numel(); i++) { float value = 0; if (dataType_ == FP32) { value = data()[i]; - } else { + } else if (dataType_ == FP16) { value = half_to_float(data()[i]); + } else { + value = data()[i]; } ofs << value << std::endl; } @@ -415,18 +409,49 @@ class Tensor { int num = shape_->numel(); invalidate(); float max = 0.0f; - float16* data = mutableData(); - for (int i = 0; i < num; ++i) { - float value = 0; - file_stream >> value; - max = std::max(std::abs(value), max); - data[i] = float_to_half(value); + if (dataType_ == FP16) { + float16* data = mutableData(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + max = std::max(std::abs(value), max); + data[i] = float_to_half(value); + } + } else { + float* data = mutableData(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + max = std::max(std::abs(value), max); + data[i] = value; + } } flush(); placeHolder_->scale_[0] = max / 127.0f; placeHolder_->scale_[1] = 127.0f / max; } + friend std::ostream& operator<<(std::ostream& os, Tensor& tensor) { + os << "tensor:" + << "\n"; + os << "dims: {"; + for (int i = 0; i < tensor.shape().dimSize(); ++i) { + os << tensor.shape()[i] << " "; + } + os << "}\n"; + for (int i = 0; i < tensor.shape().numel(); i++) { + float value = 0; + if (tensor.dataType() == FP32) { + value = tensor.data()[i]; + } else { + value = half_to_float(tensor.data()[i]); + } + os << value << " "; + } + os << "\n"; + return os; + } + ~Tensor() { if (shape_ != nullptr) { delete shape_; @@ -436,6 +461,7 @@ class Tensor { private: int offset = 0; + float mem_scale_factor_ = 1.0f; std::shared_ptr placeHolder_; Shape* shape_ = nullptr; DataType dataType_ = FP32; diff --git a/lite/backends/fpga/lite_tensor.cc b/lite/backends/fpga/lite_tensor.cc old mode 100644 new mode 100755 index 43218173fd05626fb46495bb254b250c14e5417a..7f1e8d3e17f97315e77532b77bbcfcc8331edd4f --- a/lite/backends/fpga/lite_tensor.cc +++ b/lite/backends/fpga/lite_tensor.cc @@ -95,16 +95,14 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { dims_ = other.dims_; target_ = other.target_; lod_ = other.lod_; - // memory_size_ = other.memory_size_; - // buffer_->CopyDataFrom(*other.buffer_, memory_size_); - zynq_tensor_->mutableData(other.zynq_tensor_->dataType(), - other.zynq_tensor_->shape()); -} + auto dt = zynq_tensor_->dataType(); -// template -// void TensorLite::mutable_data_internal() { + auto shape = other.zynq_tensor_->shape(); -// } + Resize(other.dims()); + zynq_tensor_->mutableData(zynq_tensor_->dataType(), shape); + this->ZynqTensor()->copyFrom(other.ZynqTensor()); +} } // namespace lite } // namespace paddle diff --git a/lite/backends/fpga/lite_tensor.h b/lite/backends/fpga/lite_tensor.h index 2f9df3abb08dd15641323f4a3c59d6175f2e481b..311fc8a98400e5a6916ba1b9c8de1e6e0bcec4c0 100644 --- a/lite/backends/fpga/lite_tensor.h +++ b/lite/backends/fpga/lite_tensor.h @@ -106,7 +106,7 @@ class TensorLite { // For other devices, T and R may be the same type. template const R *data() const { - return zynq_tensor_->data(); + return zynq_tensor_->data() + offset_; } void Resize(const DDimLite &ddim) { dims_ = ddim; } @@ -125,6 +125,7 @@ class TensorLite { bool persistable() const { return persistable_; } void set_persistable(bool persistable) { persistable_ = persistable; } + // T is the data type and R is the return type // For OpenCL, the return type can be cl::Buffer // and the data type can be float/int8_t. @@ -147,6 +148,8 @@ class TensorLite { size_t memory_size() const { return zynq_tensor_->memorySize(); } + size_t offset() const { return offset_; } + bool IsInitialized() const { return buffer_->data(); } // Other share data to this. @@ -157,6 +160,9 @@ class TensorLite { template TensorLite Slice(int64_t begin, int64_t end) const; + template + void Slice(TensorLite &dst, int64_t begin, int64_t end) const; // NOLINT + TargetType target() const { return target_; } zynqmp::Tensor *ZynqTensor() const { return zynq_tensor_; } @@ -173,16 +179,21 @@ class TensorLite { private: TargetType target_{TargetType::kHost}; + + // precision_ and persistable_ are only used for persistable vars. + // If your tensor wants to be saved and loaded correctly, you must + // set values of precision_ and persistable_ after updating it. + // If your tensor is just a temp tensor, such as activations, + // you can ignore these two attributes. + PrecisionType precision_{PrecisionType::kUnk}; + bool persistable_{false}; + DDimLite dims_; std::shared_ptr buffer_; LoD lod_; size_t memory_size_{}; - size_t offset_{0}; - PrecisionType precision_{PrecisionType::kUnk}; - bool persistable_{false}; - zynqmp::Tensor *zynq_tensor_ = new zynqmp::Tensor(); template @@ -197,6 +208,9 @@ R *TensorLite::mutable_data() { } zynqmp::LayoutType layout_type = zynqmp::NCHW; switch (v.size()) { + case 0: + layout_type = zynqmp::None; + break; case 1: layout_type = zynqmp::N; break; @@ -228,24 +242,60 @@ R *TensorLite::mutable_data(TargetType target) { return mutable_data(); } -template -bool TensorCompareWith(const TensorT &a, const TensorT &b) { - if (a.dims() != b.dims()) return false; - if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; - return true; -} template TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { - int64_t base = numel() / dims_[0]; + throw - 1; + CHECK_GE(begin, 0); + CHECK_LE(end, dims_[0]); + CHECK_LT(begin, end); + if (dims_[0] == 1) { + return *this; + } else { + int64_t base = numel() / dims_[0]; + + TensorLite dst; + dst.target_ = target_; + auto dst_dims = dims_; + dst_dims[0] = end - begin; + dst.Resize(dst_dims); + void *dst_data = dst.mutable_data(); + + T *src_data = const_cast(data()); + memcpy(dst_data, + src_data + static_cast(begin * base) * sizeof(T), + dst_dims.production() * sizeof(T)); + dst.ZynqTensor()->saveToFile("_slice", true); + + return dst; + } +} + +template +void TensorLite::Slice(TensorLite &dst, int64_t begin, int64_t end) const { + CHECK_GE(begin, 0); + CHECK_LE(end, dims_[0]); + CHECK_LT(begin, end); - TensorLite dst; - dst.buffer_ = buffer_; dst.target_ = target_; auto dst_dims = dims_; dst_dims[0] = end - begin; dst.Resize(dst_dims); - dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); - return dst; + void *dst_data = dst.mutable_data(); + + int64_t base = numel() / dims_[0]; + + T *src_data = const_cast(data()); + memcpy(dst_data, + src_data + static_cast(begin * dst_dims.production()), + dst_dims.production() * sizeof(T)); } + +template +bool TensorCompareWith(const TensorT &a, const TensorT &b) { + if (a.dims() != b.dims()) return false; + if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; + return true; +} + } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/CMakeLists.txt b/lite/backends/npu/CMakeLists.txt index 426ff5698146c773c818b2bfd598d6bbbdf7867f..1540741d331097961dcf7cd791c9785a9c53ddd1 100644 --- a/lite/backends/npu/CMakeLists.txt +++ b/lite/backends/npu/CMakeLists.txt @@ -2,5 +2,4 @@ if(NOT LITE_WITH_NPU) return() endif() -lite_cc_library(npu_runtime SRCS runtime.cc DEPS ${npu_runtime_libs}) -lite_cc_library(npu_builder SRCS builder.cc DEPS ${npu_builder_libs} npu_runtime tensor op scope) +lite_cc_library(device_npu SRCS device.cc DEPS ${npu_builder_libs} ${npu_runtime_libs}) diff --git a/lite/backends/npu/device.cc b/lite/backends/npu/device.cc new file mode 100644 index 0000000000000000000000000000000000000000..e63939264214bc619814f06c7cf0de1b56f71ee6 --- /dev/null +++ b/lite/backends/npu/device.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/npu/device.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace npu { + +std::unique_ptr Device::Build( + std::string& model_name, // NOLINT + std::vector& input_nodes, // NOLINT + std::vector& output_nodes // NOLINT + ) { + VLOG(3) << "[NPU] Build model"; + // Build the HiAI IR graph to the HiAI om model + ge::Graph ir_graph("graph"); + ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes); + ge::Model om_model("model", "model"); + om_model.SetGraph(ir_graph); + domi::HiaiIrBuild ir_build; + domi::ModelBufferData om_model_buf; + if (!ir_build.CreateModelBuff(om_model, om_model_buf)) { + LOG(WARNING) << "[NPU] CreateModelBuff failed!"; + return nullptr; + } + if (!ir_build.BuildIRModel(om_model, om_model_buf)) { + LOG(WARNING) << "[NPU] BuildIRModel failed!"; + ir_build.ReleaseModelBuff(om_model_buf); + return nullptr; + } + // Create a HiAI model manager client to load the HiAI om model + std::unique_ptr model_client( + new hiai::AiModelMngerClient()); + if (model_client->Init(nullptr) != hiai::AI_SUCCESS) { + LOG(WARNING) << "[NPU] AiModelMngerClient init failed)!"; + ir_build.ReleaseModelBuff(om_model_buf); + return nullptr; + } + model_name = "model_" + std::to_string(model_count_++) + ".om"; + auto model_desc = std::make_shared( + model_name, freq_level(), framework_type(), model_type(), device_type()); + model_desc->SetModelBuffer(om_model_buf.data, om_model_buf.length); + std::vector> model_descs; + model_descs.push_back(model_desc); + if (model_client->Load(model_descs) != hiai::AI_SUCCESS) { + LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; + ir_build.ReleaseModelBuff(om_model_buf); + return nullptr; + } + ir_build.ReleaseModelBuff(om_model_buf); + return model_client; +} + +} // namespace npu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/npu/runtime.h b/lite/backends/npu/device.h similarity index 66% rename from lite/backends/npu/runtime.h rename to lite/backends/npu/device.h index 8b1ad51518d8626d9a6ecd6203a70b2637bb6004..3eba0b77e4bdeb26cdff869771645a5ce7637ae4 100644 --- a/lite/backends/npu/runtime.h +++ b/lite/backends/npu/device.h @@ -13,38 +13,47 @@ // limitations under the License. #pragma once + #include #include +#include +#include #include "ai_ddk_lib/include/HiAiModelManagerService.h" -#include "lite/core/tensor.h" +#include "ai_ddk_lib/include/hiai_ir_build.h" namespace paddle { namespace lite { namespace npu { -class DeviceInfo { +class Device { public: - static DeviceInfo &Global() { - static DeviceInfo x; + static Device& Global() { + static Device x; return x; } - DeviceInfo() {} + Device() {} int freq_level() { return freq_level_; } int framework_type() { return framework_type_; } int model_type() { return model_type_; } int device_type() { return device_type_; } + // Build the HiAI IR graph to om model, return HiAI model manager client to + // load om model and run inference. + std::unique_ptr Build( + std::string& model_name, // NOLINT + std::vector& input_nodes, // NOLINT + std::vector& output_nodes // NOLINT + ); // NOLINT + private: int freq_level_{3}; int framework_type_{0}; int model_type_{0}; int device_type_{0}; + int model_count_{0}; }; -bool LoadModel(const lite::Tensor &model_data, - std::shared_ptr *model_client, - std::string *model_name); } // namespace npu } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/runtime.cc b/lite/backends/npu/runtime.cc deleted file mode 100644 index 3485f63c7c8bb91081fd1969d0d41733417149d9..0000000000000000000000000000000000000000 --- a/lite/backends/npu/runtime.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/npu/runtime.h" -#include -#include -#include "lite/utils/cp_logging.h" - -namespace paddle { -namespace lite { -namespace npu { - -// Create hiai model manager to load om model from lite tensor, and return the -// manager and an unique model name -bool LoadModel(const lite::Tensor &model_data, - std::shared_ptr *model_client, - std::string *model_name) { - LOG(INFO) << "[NPU] Load model."; - auto model_data_ptr = model_data.data(); - auto model_data_size = model_data.numel() * sizeof(int8_t); - if (model_data_ptr == nullptr || model_data_size == 0) { - return false; - } - *model_client = std::make_shared(); - int ret = (*model_client)->Init(nullptr); - if (ret != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] AiModelMngerClient init failed(" << ret << ")!"; - return false; - } - *model_name = "model.om"; - auto model_desc = std::make_shared( - *model_name, - DeviceInfo::Global().freq_level(), - DeviceInfo::Global().framework_type(), - DeviceInfo::Global().model_type(), - DeviceInfo::Global().device_type()); - model_desc->SetModelBuffer(model_data_ptr, model_data_size); - std::vector> model_descs; - model_descs.push_back(model_desc); - if ((*model_client)->Load(model_descs) != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; - return false; - } - return true; -} - -} // namespace npu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/opencl/CMakeLists.txt b/lite/backends/opencl/CMakeLists.txt index 1acb98321844191832fd55b640a9b56d3d51b400..dd7f6b417e0d6416eec9bb3e60ef088432776112 100644 --- a/lite/backends/opencl/CMakeLists.txt +++ b/lite/backends/opencl/CMakeLists.txt @@ -11,8 +11,8 @@ lite_cc_library(cl_image SRCS cl_image.cc DEPS tensor cl_image_converter cl_runt lite_cc_library(cl_caller SRCS cl_caller.cc DEPS cl_context cl_image) lite_cc_library(cl_target_wrapper SRCS target_wrapper.cc DEPS cl_runtime) lite_cc_test(test_cl_functions SRCS cl_functions_test.cc DEPS cl_context cl_image cl_caller cl_wrapper cl_target_wrapper - ARGS --cl_path=${CMAKE_SOURCE_DIR}/paddle/fluid/lite/backends/opencl) + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_cl_im2col SRCS cl_im2col_test.cc DEPS tensor cl_context cl_wrapper cl_target_wrapper - ARGS --cl_path=${CMAKE_SOURCE_DIR}/paddle/fluid/lite/backends/opencl) + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) add_dependencies(cl_wrapper opencl_clhpp) diff --git a/lite/backends/opencl/cl_caller.cc b/lite/backends/opencl/cl_caller.cc index 4926a53c43d54b4e2b4d802a7d8ef289c7e87fc5..6b9cab1056beaa6f516a0d3a202a7816c911f1b2 100644 --- a/lite/backends/opencl/cl_caller.cc +++ b/lite/backends/opencl/cl_caller.cc @@ -23,6 +23,7 @@ limitations under the License. */ namespace paddle { namespace lite { + static void CopyImageData(CLContext* context, const CLImage& cl_image, float* out) { @@ -51,119 +52,5 @@ bool InitOpenCLRuntime(std::string cl_path) { return runtime->IsInitSuccess(); } -void elementwise_add(CLContext* context, - const float* in, - const DDim& in_dim, - const float* bias, - const DDim& bias_dim, - float* out, - const DDim& out_dim) { - if (!(bias_dim.size() == 1 || bias_dim.size() == 4)) { - LOG(FATAL) << "Error: bias dims is error"; - return; - } - auto kernel = bias_dim.size() == 1 ? context->GetKernel("channel_add") - : context->GetKernel("elementwise_add"); - CLImage in_image; - in_image.set_tensor_data(in, in_dim); - in_image.InitNormalCLImage(context->GetContext()); - VLOG(3) << " --- Inpu image: " << in_image << " --- "; - CLImage bias_image; - bias_image.set_tensor_data(bias, bias_dim); - bias_image.InitCLImage(context->GetContext()); - VLOG(3) << " --- Bias image: " << bias_image << " --- "; - CLImage out_image; - out_image.InitEmptyImage(context->GetContext(), out_dim); - cl_int status; - status = kernel.setArg(0, *in_image.cl_image()); - CL_CHECK_FATAL(status); - status = kernel.setArg(1, *bias_image.cl_image()); - CL_CHECK_FATAL(status); - status = kernel.setArg(2, *out_image.cl_image()); - CL_CHECK_FATAL(status); - - if (bias_dim.size() == 1) { - int tensor_w = in_dim[3]; - status = kernel.setArg(3, tensor_w); - CL_CHECK_FATAL(status); - } - size_t width = in_image.ImageWidth(); - size_t height = in_image.ImageHeight(); - auto global_work_size = cl::NDRange{width, height}; - status = context->GetCommandQueue().enqueueNDRangeKernel( - kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr); - CL_CHECK_FATAL(status); - - status = context->GetCommandQueue().finish(); - CL_CHECK_FATAL(status); - VLOG(3) << " --- Out image: " << out_image << " --- "; - CopyImageData(context, out_image, out); -} - -void pool(CLContext* context, - const std::string pooling_type, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int ksize_h, - const int ksize_w, - const float* in, - const DDim& in_dim, - float* out, - const DDim& out_dim) { - auto kernel = - context->GetKernel(string_format("pool_%s", pooling_type.c_str())); - CLImage in_image; - in_image.set_tensor_data(in, in_dim); - in_image.InitNormalCLImage(context->GetContext()); - VLOG(3) << " --- Inpu image: " << in_image << " --- "; - CLImage out_image; - out_image.InitEmptyImage(context->GetContext(), out_dim); - auto global_work_size = context->DefaultWorkSize(out_image); - auto* in_converter = - dynamic_cast(in_image.image_converter()); - auto* out_converter = - dynamic_cast(out_image.image_converter()); - const int in_height = in_converter->HeightOfOneBlock(); - const int in_width = in_converter->WidthOfOneBlock(); - const int out_height = out_converter->HeightOfOneBlock(); - const int out_width = out_converter->WidthOfOneBlock(); - cl_int status; - status = kernel.setArg(0, in_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(1, in_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(2, out_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(3, out_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(4, pad_h); - CL_CHECK_FATAL(status); - status = kernel.setArg(5, pad_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(6, stride_h); - CL_CHECK_FATAL(status); - status = kernel.setArg(7, stride_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(8, ksize_h); - CL_CHECK_FATAL(status); - status = kernel.setArg(9, ksize_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(10, *in_image.cl_image()); - CL_CHECK_FATAL(status); - status = kernel.setArg(11, *out_image.cl_image()); - CL_CHECK_FATAL(status); - - status = context->GetCommandQueue().enqueueNDRangeKernel( - kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr); - CL_CHECK_FATAL(status); - - status = context->GetCommandQueue().finish(); - CL_CHECK_FATAL(status); - VLOG(3) << " --- Out image: " << out_image << " --- "; - CopyImageData(context, out_image, out); -} - } // namespace lite } // namespace paddle diff --git a/lite/backends/opencl/cl_caller.h b/lite/backends/opencl/cl_caller.h index ed5c9153d3cedf140cbf0570b7f71393fb918bf9..1817db9f6bd6d9ecf21978b8293bd9534328de0f 100644 --- a/lite/backends/opencl/cl_caller.h +++ b/lite/backends/opencl/cl_caller.h @@ -23,30 +23,5 @@ namespace lite { bool InitOpenCLRuntime(std::string cl_path); -/// An elementwise_add method to embed OpenCL logic inside, it is used as a -/// black box so that the framework can remain simple. -/// NOTE Currently, these methods are quite expensive, we will optimize them -/// latter. -void elementwise_add(CLContext* context, - const float* in, - const DDim& in_dim, - const float* bias, - const DDim& bias_dim, - float* out, - const DDim& out_dim); - -void pool(CLContext* context, - const std::string pooling_type, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int ksize_h, - const int ksize_w, - const float* in, - const DDim& in_dim, - float* out, - const DDim& out_dim); - } // namespace lite } // namespace paddle diff --git a/lite/backends/opencl/cl_functions_test.cc b/lite/backends/opencl/cl_functions_test.cc index b9f6648c9956e1952b65f66abfa40d912a99ee67..70f47b47946641edf4d023437b48d46cae93ca6e 100644 --- a/lite/backends/opencl/cl_functions_test.cc +++ b/lite/backends/opencl/cl_functions_test.cc @@ -41,9 +41,10 @@ TEST(cl_test, runtime_test) { auto &context = runtime->context(); auto program = runtime->CreateProgram( context, - runtime->cl_path() + "/cl_kernel/" + "image/elementwise_add_kernel.cl"); + runtime->cl_path() + "/cl_kernel/" + "buffer/elementwise_add_kernel.cl"); auto event = runtime->CreateEvent(context); - CHECK(runtime->BuildProgram(program.get())); + const std::string build_option("-DCL_DTYPE_float"); + CHECK(runtime->BuildProgram(program.get(), build_option)); } TEST(cl_test, context_test) { @@ -51,9 +52,11 @@ TEST(cl_test, context_test) { CHECK(runtime->IsInitSuccess()); runtime->set_cl_path(FLAGS_cl_path); CLContext context; - context.AddKernel("pool_max", "image/pool_kernel.cl", ""); - context.AddKernel("elementwise_add", "image/elementwise_add_kernel.cl", ""); - context.AddKernel("elementwise_add", "image/elementwise_add_kernel.cl", ""); + context.AddKernel("pool_max", "image/pool_kernel.cl", "-DCL_DTYPE_float"); + context.AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); + context.AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); } TEST(cl_test, kernel_test) { @@ -61,9 +64,11 @@ TEST(cl_test, kernel_test) { CHECK(runtime->IsInitSuccess()); runtime->set_cl_path(FLAGS_cl_path); std::unique_ptr context(new CLContext); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); - context->AddKernel("pool_max", "image/pool_kernel.cl"); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); + context->AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); + context->AddKernel("pool_max", "image/pool_kernel.cl", "-DCL_DTYPE_float"); + context->AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); auto kernel = context->GetKernel(2); std::unique_ptr in_data(new float[4 * 3 * 256 * 512]); @@ -115,203 +120,12 @@ TEST(cl_test, kernel_test) { LOG(INFO) << out_image; } -TEST(cl_test, channel_add_test) { - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - - const DDim in_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr in_data(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - in_data[i] = dist(engine); - } - - const DDim bias_dim = DDim(std::vector{16}); - std::unique_ptr bias_data(new float[16]); - for (int i = 0; i < 16; i++) { - bias_data[i] = dist(engine); - } - - std::unique_ptr out_ref(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 16; j++) { - float b = bias_data[j]; - for (int k = 0; k < 256 * 512; k++) { - int index = (i * 16 + j) * 256 * 512 + k; - out_ref[index] = in_data[index] + b; - } - } - } - - const DDim out_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr out(new float[4 * 16 * 256 * 512]); - - bool status = InitOpenCLRuntime(FLAGS_cl_path); - CHECK(status) << "Fail to initialize OpenCL runtime."; - std::unique_ptr context(new CLContext); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); - context->AddKernel("channel_add", "image/channel_add_kernel.cl"); - elementwise_add(context.get(), - in_data.get(), - in_dim, - bias_data.get(), - bias_dim, - out.get(), - out_dim); - - int stride = 4 * 16 * 256 * 512 / 20; - for (int i = 0; i < 4 * 16 * 256 * 512; i += stride) { - std::cout << out[i] << " "; - } - std::cout << std::endl; - - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - EXPECT_NEAR(out[i], out_ref[i], 1e-6); - } -} - -TEST(cl_test, elementwise_add_test) { - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - - const DDim in_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr in_data(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - in_data[i] = dist(engine); - } - - const DDim bias_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr bias_data(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - bias_data[i] = dist(engine); - } - - std::unique_ptr out_ref(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - out_ref[i] = in_data[i] + bias_data[i]; - } - - const DDim out_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr out(new float[4 * 16 * 256 * 512]); - - bool status = InitOpenCLRuntime(FLAGS_cl_path); - CHECK(status) << "Fail to initialize OpenCL runtime."; - std::unique_ptr context(new CLContext); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); - context->AddKernel("channel_add", "image/channel_add_kernel.cl"); - elementwise_add(context.get(), - in_data.get(), - in_dim, - bias_data.get(), - bias_dim, - out.get(), - out_dim); - - int stride = 4 * 16 * 256 * 512 / 20; - for (int i = 0; i < 4 * 16 * 256 * 512; i += stride) { - std::cout << out[i] << " "; - } - std::cout << std::endl; - - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - EXPECT_NEAR(out[i], out_ref[i], 1e-6); - } -} - -void pool_avg(const int padding_height, - const int padding_width, - const int stride_height, - const int stride_width, - const int ksize_height, - const int ksize_width, - const float *input_data, - const DDim &in_dim, - float *output_data, - const DDim &out_dim) { - const int batch_size = in_dim[0]; - const int input_height = in_dim[2]; - const int input_width = in_dim[3]; - const int output_channels = out_dim[1]; - const int output_height = out_dim[2]; - const int output_width = out_dim[3]; - - const size_t input_spatial_size = input_height * input_width; - const size_t output_spatial_size = output_height * output_width; - - for (int i = 0; i < batch_size; i++) { - for (int c = 0; c < output_channels; ++c) { - int channel = i * output_channels + c; - const float *input_ptr = input_data + channel * input_spatial_size; - float *output_ptr = output_data + channel * output_spatial_size; - - for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); - for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); - - float val = 0.f; - int count = 0; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - val += input_ptr[h * input_width + w]; - ++count; - } - } - output_ptr[ph * output_width + pw] = - (count > 0) ? val * (1.f / count) : 0.f; - } - } - } - } -} - -TEST(cl_test, pool_test) { - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - - const DDim in_dim = DDim(std::vector{4, 1024, 7, 7}); - std::unique_ptr in_data(new float[4 * 1024 * 7 * 7]); - for (int i = 0; i < 4 * 1024 * 7 * 7; i++) { - in_data[i] = dist(engine); - } - - const DDim out_dim = DDim(std::vector{4, 1024, 1, 1}); - std::unique_ptr out(new float[4 * 1024 * 1 * 1]); - std::unique_ptr out_ref(new float[4 * 1024 * 1 * 1]); - - bool status = InitOpenCLRuntime(FLAGS_cl_path); - CHECK(status) << "Fail to initialize OpenCL runtime."; - std::unique_ptr context(new CLContext); - context->AddKernel("pool_max", "image/pool_kernel.cl"); - context->AddKernel("pool_avg", "image/pool_kernel.cl"); - pool(context.get(), - "avg", - 0, - 0, - 1, - 1, - 7, - 7, - in_data.get(), - in_dim, - out.get(), - out_dim); - pool_avg(0, 0, 1, 1, 7, 7, in_data.get(), in_dim, out_ref.get(), out_dim); - - for (int i = 0; i < 4 * 1024 * 1 * 1; i++) { - EXPECT_NEAR(out[i], out_ref[i], 1e-6); - } -} - TEST(cl_test, target_wrapper_buffer_test) { bool inited = InitOpenCLRuntime(FLAGS_cl_path); CHECK(inited) << "Fail to initialize OpenCL runtime."; std::unique_ptr context(new CLContext); std::string kernel_name = "elementwise_add"; - std::string build_options = "-DCL_DTYPE=float"; + std::string build_options = "-DCL_DTYPE_float"; context->AddKernel( kernel_name, "buffer/elementwise_add_kernel.cl", build_options); std::vector h_a; @@ -396,10 +210,13 @@ TEST(cl_test, target_wrapper_buffer_test) { TEST(cl_test, target_wrapper_image_test) { const size_t cl_image2d_width = 28; const size_t cl_image2d_height = 32; + const size_t cl_image2d_elem_size = + cl_image2d_width * cl_image2d_height * 4; // 4 for RGBA channels const size_t cl_image2d_row_pitch{0}; const size_t cl_image2d_slice_pitch{0}; auto *d_image = static_cast( TargetWrapperCL::MallocImage(cl_image2d_width, cl_image2d_height)); + // Map/Unmap test auto *h_image = static_cast(TargetWrapperCL::MapImage(d_image, @@ -407,15 +224,11 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_height, cl_image2d_row_pitch, cl_image2d_slice_pitch)); - CHECK_EQ( - cl_image2d_row_pitch, - cl_image2d_width * 4 * - 4); // row_pitch = 448 = 28 * 4 (RGBA: 4 floats) * 4 (float in bytes) - CHECK_EQ(cl_image2d_slice_pitch, 0); // slice_pitch = 0 + CHECK_EQ(cl_image2d_slice_pitch, 0); LOG(INFO) << "cl_image2d_row_pitch = " << cl_image2d_row_pitch << ", cl_image2d_slice_pitch " << cl_image2d_slice_pitch; - for (int i = 0; i < 10; i++) { + for (int i = 0; i < cl_image2d_elem_size; i++) { h_image[i] = 3.14f * i; } TargetWrapperCL::Unmap(d_image, h_image); @@ -426,15 +239,14 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_height, cl_image2d_row_pitch, cl_image2d_slice_pitch)); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < cl_image2d_elem_size; i++) { EXPECT_NEAR(h_ptr[i], 3.14f * i, 1e-6); } TargetWrapperCL::Unmap(d_image, h_ptr); // Imagecpy test - std::vector h_image_cpy(cl_image2d_width * 4 * - cl_image2d_height); // 4 for RGBA channels - for (int i = 0; i < cl_image2d_width * 4 * cl_image2d_height; i++) { + std::vector h_image_cpy(cl_image2d_elem_size); + for (int i = 0; i < cl_image2d_elem_size; i++) { h_image_cpy[i] = 3.14f; } TargetWrapperCL::ImgcpySync(d_image, @@ -446,6 +258,8 @@ TEST(cl_test, target_wrapper_image_test) { IoDirection::HtoD); auto *d_image_cpy = static_cast( TargetWrapperCL::MallocImage(cl_image2d_width, cl_image2d_height)); + + // device to device TargetWrapperCL::ImgcpySync(d_image_cpy, d_image, cl_image2d_width, @@ -454,6 +268,8 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_slice_pitch, IoDirection::DtoD); std::fill(h_image_cpy.begin(), h_image_cpy.end(), 0); + + // host to device TargetWrapperCL::ImgcpySync(h_image_cpy.data(), d_image_cpy, cl_image2d_width, @@ -461,7 +277,7 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_row_pitch, cl_image2d_slice_pitch, IoDirection::DtoH); - for (int i = 0; i < cl_image2d_width * 4 * cl_image2d_height; i++) { + for (int i = 0; i < cl_image2d_elem_size; i++) { EXPECT_NEAR(h_image_cpy[i], 3.14f, 1e-6); } diff --git a/lite/backends/opencl/cl_image_converter.h b/lite/backends/opencl/cl_image_converter.h index 6faa8045576f06d8c636372de644e6b5c164a5f4..962eb8d3ef35bdb603aa4a56181b1124885d5506 100644 --- a/lite/backends/opencl/cl_image_converter.h +++ b/lite/backends/opencl/cl_image_converter.h @@ -103,6 +103,7 @@ class CLImageConverterNormal : public CLImageConverterBase { }; class CLImageConverterNWBlock : public CLImageConverterBase { + public: DDim InitImageDimInfoWith(const DDim &tensor_dim) override; void NCHWToImage(float *tensor, float *image, @@ -113,6 +114,7 @@ class CLImageConverterNWBlock : public CLImageConverterBase { const DDim &tensor_dim) override; }; class CLImageConverterDWBlock : public CLImageConverterBase { + public: DDim InitImageDimInfoWith(const DDim &tensor_dim) override; void NCHWToImage(float *tensor, float *image, diff --git a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl index c9c16581d67db0c9143e91e13249edfd5901ddb8..532f947dd342b1ee4db69a084111a97ec014237f 100644 --- a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl +++ b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl @@ -61,6 +61,57 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in, write_imagef(output_image, output_pos, output); } +// buffer -> image2d_nw +__kernel void buffer_to_image2d_nw(__global CL_DTYPE* in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_N, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2) { + const int out_n = get_global_id(0); + const int out_w = get_global_id(1); + const int out_ch = get_global_id(2); + + const int out_c = out_ch / out_H; + const int out_h = out_ch % out_H; + + const int in_c = out_c; // index of c in h direction + + const int in_n0 = out_n * 4 + 0; + const int in_n1 = out_n * 4 + 1; + const int in_n2 = out_n * 4 + 2; + const int in_n3 = out_n * 4 + 3; + + const int in_h = out_h; + const int in_w = out_w; + + int input_pos0 = in_n0 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n1 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n2 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n3 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_n * out_W + out_w; + output_pos.y = out_ch; + + CL_DTYPE4 output = (CL_DTYPE4)0.0f; + output.x = convert_float(in[input_pos0]); + if (out_N - 4 * out_n >= 2) { + output.y = convert_float(in[input_pos1]); + } + if (out_N - 4 * out_n >= 3) { + output.z = convert_float(in[input_pos2]); + } + if (out_N - 4 * out_n >= 4) { + output.w = convert_float(in[input_pos3]); + } + write_imagef(output_image, output_pos, output); +} + + + // image2d -> buffer __kernel void image2d_to_buffer(__read_only image2d_t input, __private const int in_width, diff --git a/lite/backends/opencl/cl_kernel/cl_common.h b/lite/backends/opencl/cl_kernel/cl_common.h index 7f901fc994ffd82ccfe99f59614a3422260d0dc5..f193ab82d78fcd21165100658e9a0edefdbd5e0a 100644 --- a/lite/backends/opencl/cl_kernel/cl_common.h +++ b/lite/backends/opencl/cl_kernel/cl_common.h @@ -14,8 +14,17 @@ limitations under the License. */ #pragma once +///////////////////////////////// +// fp16 enabled, MAX_VALUE, MIN_VALUE +///////////////////////////////// #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define MAX_VALUE FLT_MAX +#define MIN_VALUE -FLT_MAX + +///////////////////////////////// +// CL_DTYPE_float / CL_DTYPE_half +///////////////////////////////// // Data type: pass one of macros on host: [CL_DTYPE_float, CL_DYPE_half] #ifdef CL_DTYPE_float #define CL_DTYPE float @@ -27,24 +36,36 @@ limitations under the License. */ #define CL_DTYPE_CHAR h #endif +///////////////////////////////// +// GET_VEC_TYPE +///////////////////////////////// // Note: macro name replacement need twice parser #define GET_VEC_TYPE(type__, size__) type__##size__ #define VECTORIZED_TYPE(type__, size__) GET_VEC_TYPE(type__, size__) #define CL_DTYPE4 VECTORIZED_TYPE(CL_DTYPE, 4) +///////////////////////////////// +// CONVERT_TYPE_TO +///////////////////////////////// #define _CONVERT_TYPE_TO(value, type) convert_##type(value) #define CONVERT_TYPE_TO(value, type) _CONVERT_TYPE_TO(value, type) +///////////////////////////////// +// WRITE_IMG_TYPE / READ_IMG_TYPE +///////////////////////////////// #define _WRITE_IMG_TYPE(type_char, img, pos, value) \ write_image##type_char(img, pos, value) #define WRITE_IMG_TYPE(type_char, img, pos, value) \ _WRITE_IMG_TYPE(type_char, img, pos, value) -#define _READ_IMG_TYPE(type_char, img, pos, sampler) \ +#define _READ_IMG_TYPE(type_char, img, sampler, pos) \ read_image##type_char(img, sampler, pos) -#define READ_IMG_TYPE(type_char, img, pos, sampler) \ - _READ_IMG_TYPE(type_char, img, pos, sampler) +#define READ_IMG_TYPE(type_char, img, sampler, pos) \ + _READ_IMG_TYPE(type_char, img, sampler, pos) +///////////////////////////////// +// activation / activation_type4 +///////////////////////////////// inline CL_DTYPE activation(CL_DTYPE in #ifdef PRELU , @@ -61,3 +82,20 @@ inline CL_DTYPE activation(CL_DTYPE in #endif return output; } + +inline CL_DTYPE4 activation_type4(CL_DTYPE4 in +#ifdef PRELU + , + CL_DTYPE4 prelu_alpha +#endif + ) { + CL_DTYPE4 output; +#ifdef PRELU + output = select(prelu_alpha * in, in, in >= (CL_DTYPE4)0.0); +#endif + +#ifdef RELU + output = fmax(in, (CL_DTYPE4)0); +#endif + return output; +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..6fe5596a4cf5cbce5b50c9a3d53be164aad8a0b5 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl @@ -0,0 +1,216 @@ +#include + +__kernel void conv2d_1x1(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int input_c_origin, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int old_w) { + CL_DTYPE zero = 0.0f; + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int out_w0 = out_w; + int out_w1 = out_w + global_size_dim1; + int out_w2 = out_w + global_size_dim1 * 2; + int out_w3 = out_w + global_size_dim1 * 3; + + int outpos_main = mul24(out_c, old_w); + int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); + int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); + int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); + int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 stride_xy = (int2)(stride, stride); + + int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); + int2 in_pos_in_one_block0 = + ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); + int2 in_pos_in_one_block1 = + ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); + int2 in_pos_in_one_block2 = + ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); + int2 in_pos_in_one_block3 = + ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; +#elif defined(BIASE_ELE) + CL_DTYPE4 output0 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos0); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; + +#else + CL_DTYPE4 output0 = 0.0f; + CL_DTYPE4 output1 = 0.0f; + CL_DTYPE4 output2 = 0.0f; + CL_DTYPE4 output3 = 0.0f; +#endif + + int max_w_bound = input_c * input_width; + int burndary_index = input_c * 4 - input_c_origin; + bool burndary_index_w = + burndary_index == 1 || burndary_index == 2 || burndary_index == 3; + bool burndary_index_z = burndary_index == 2 || burndary_index == 3; + bool burndary_index_y = burndary_index == 3; + + for (int i = 0; i < input_c; ++i) { + // ------------0--------------- + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, + in_pos_in_one_block0.y); + CL_DTYPE4 input0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + CL_DTYPE4 weight0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 0)); + CL_DTYPE4 weight1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 1)); + CL_DTYPE4 weight2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 2)); + CL_DTYPE4 weight3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 3)); + int bound_gap = max_w_bound - pos_in.x - 1; + + bool outof_bound = bound_gap < input_width && bound_gap >= 0; + input0.w = select(input0.w, zero, outof_bound && burndary_index_w); + input0.z = select(input0.z, zero, outof_bound && burndary_index_z); + input0.y = select(input0.y, zero, outof_bound && burndary_index_y); + + output0 = mad(input0.x, weight0, output0); + output0 = mad(input0.y, weight1, output0); + output0 = mad(input0.z, weight2, output0); + output0 = mad(input0.w, weight3, output0); + // -------------1-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, + in_pos_in_one_block1.y); + CL_DTYPE4 input1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input1.w = select(input1.w, zero, outof_bound && burndary_index_w); + input1.z = select(input1.z, zero, outof_bound && burndary_index_z); + input1.y = select(input1.y, zero, outof_bound && burndary_index_y); + + output1 = mad(input1.x, weight0, output1); + output1 = mad(input1.y, weight1, output1); + output1 = mad(input1.z, weight2, output1); + output1 = mad(input1.w, weight3, output1); + + // -------------2-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, + in_pos_in_one_block2.y); + CL_DTYPE4 input2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input2.w = select(input2.w, zero, outof_bound && burndary_index_w); + input2.z = select(input2.z, zero, outof_bound && burndary_index_z); + input2.y = select(input2.y, zero, outof_bound && burndary_index_y); + + output2 = mad(input2.x, weight0, output2); + output2 = mad(input2.y, weight1, output2); + output2 = mad(input2.z, weight2, output2); + output2 = mad(input2.w, weight3, output2); + + // -------------3-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, + in_pos_in_one_block3.y); + CL_DTYPE4 input3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input3.w = + select(input3.w, + zero, + outof_bound && (burndary_index == 1 || burndary_index == 2 || + burndary_index == 3)); + input3.z = + select(input3.z, + zero, + outof_bound && (burndary_index == 2 || burndary_index == 3)); + input3.y = select(input3.y, zero, outof_bound && burndary_index == 3); + + output3 = mad(input3.x, weight0, output3); + output3 = mad(input3.y, weight1, output3); + output3 = mad(input3.z, weight2, output3); + output3 = mad(input3.w, weight3, output3); + } + +#ifdef BATCH_NORM + output0 = output0 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output1 = output1 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output2 = output2 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output3 = output3 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output0 = activation_type4(output0); + output1 = activation_type4(output1); + output2 = activation_type4(output2); + output3 = activation_type4(output3); +#endif + + if (out_w0 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos0, output0); + } + + if (out_w1 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos1, output1); + } + + if (out_w2 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos2, output2); + } + + if (out_w3 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos3, output3); + } +} diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl new file mode 100755 index 0000000000000000000000000000000000000000..1e3586b7fde8d79fe49327185c623ac613cd080d --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl @@ -0,0 +1,322 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include + +__kernel void depth_conv2d_3x3(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int dilation, + __private const int input_c, + __private const int input_width,/* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + const int batch_index = out_nh / output_height; + + const int out_nh_in_one_batch = out_nh % output_height; + + + int2 stride_xy = (int2)(stride, stride); + int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch); + + int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + const int filter_width = 3; + const int filter_height = 3; + + int2 pos_in_input_block = (int2)(out_c * input_width, batch_index * input_height); + + int2 pos_in_filter_block = (int2)(out_c * filter_width, batch_index * filter_height); + + int filter_x = pos_in_filter_block.x ; + int filter_y = pos_in_filter_block.y ; + + CL_DTYPE4 inputs[9]; + + inputs[0] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height) << 15)); + + inputs[1] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y - 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - 1 >= input_height) << 15)); + + inputs[2] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height) << 15)); + + inputs[3] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + /* + if (output_pos.x == 112 && output_pos.y == 0) { + CL_DTYPE4 input1 = inputs[3]; + float4 in = (float4)(input1.x, input1.y, input1.z, input1.w); + printf(" input4 3 - %v4hlf \n", in); + printf(" --- %d ---\n", in_pos_in_one_block.x - 1); + } + */ + + + inputs[4] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + inputs[5] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + inputs[6] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height) << 15)); + + inputs[7] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y + 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + 1 >= input_height) << 15)); + + inputs[8] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height) << 15)); + + CL_DTYPE4 filters[9]; + filters[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y)); + filters[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y)); + filters[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y)); + filters[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 1)); + filters[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 1)); + filters[5] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 1)); + filters[6] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 2)); + filters[7] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 2)); + filters[8] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 2)); + + for(int i = 0 ;i < 9 ; i++){ + output += inputs[i] * filters[i]; + } +#ifdef BATCH_NORM + output = output * READ_IMG_TYPE(CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output = activation(output); +#endif + + + /* + + if (output_pos.x == 112 && output_pos.y == 0) { + + for (int i = 0; i < 9; ++i) { + CL_DTYPE4 input1 = inputs[i]; + float4 in = (float4)(input1.x, input1.y, input1.z, input1.w); + printf(" input4 %d - %v4hlf \n", i, in); + } + + float4 out = (float4)(output.x, output.y, output.z, output.w); + printf(" depth wise output output4 = %v4hlf \n", out); + printf(" pos_in_input_block -x %d \n ", pos_in_input_block.x); + printf(" pos_in_input_block -y %d \n ", pos_in_input_block.y); + printf(" in_pos_in_one_block - x %d \n", in_pos_in_one_block.x); + printf(" in_pos_in_one_block - y %d \n", in_pos_in_one_block.y); + } + + */ + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); + +} + + + +__kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk, + __private const int ou_w_blk, + __private const int ou_nh, + __read_only image2d_t input, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int in_ch, + __private const int in_w,/* of one block */ + __private const int in_h, /* of one block */ + __private const int ou_w, + __private const int ou_h) { + + const int ou_ch_blk_id = get_global_id(0); + const int ou_w_blk_id = get_global_id(1); + const int ou_nh_id = get_global_id(2); + const int w_blk_size = 2; + + const int batch_id = ou_nh_id / ou_h; + int ou_col_id = ou_w_blk_id * w_blk_size; + int ou_row_id = ou_nh_id % ou_h; + int ou_x = mad24(ou_ch_blk_id, ou_w, ou_col_id); + + // input pos in one block and on batch + int col_id = ou_col_id - pad; + int row_id = ou_row_id - pad; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + +#ifdef BIASE_CH + CL_DTYPE4 output[2]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(ou_ch_blk_id, 0)); + output[1] = output[0]; +#elif defined(BIASE_ELE) + CL_DTYPE4 output[2]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(ou_x, ou_nh_id)); + if (ou_col_id + 1 < ou_w) { + output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(ou_x + 1, ou_nh_id)); + } +#else + CL_DTYPE4 output[2] = {0.0f}; +#endif + + CL_DTYPE4 inputs[12]; + + int filter_x = ou_ch_blk_id * 3; + int filter_y = 0; + CL_DTYPE4 filters[9]; + filters[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y)); + filters[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y)); + filters[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y)); + + int in_x = mad24(ou_ch_blk_id, in_w, col_id); + int in_y = mad24(batch_id, in_h, row_id); + + int y0 = select(in_y, -1, row_id < 0 || row_id >= in_h); + int x0 = select(in_x, -1, col_id < 0 || col_id >= in_w); + inputs[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x0, y0)); + int x1 = select(in_x + 1, -1, col_id + 1 < 0 || col_id + 1 >= in_w); + inputs[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x1, y0)); + int x2 = select(in_x + 2, -1, col_id + 2 < 0 || col_id + 2 >= in_w); + inputs[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x2, y0)); + int x3 = select(in_x + 3, -1, col_id + 3 < 0 || col_id + 3 >= in_w); + inputs[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x3, y0)); + + output[0] = mad(inputs[0], filters[0], output[0]); + output[1] = mad(inputs[1], filters[0], output[1]); + + output[0] = mad(inputs[1], filters[1], output[0]); + output[1] = mad(inputs[2], filters[1], output[1]); + + output[0] = mad(inputs[2], filters[2], output[0]); + output[1] = mad(inputs[3], filters[2], output[1]); + + + filters[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 1)); + filters[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 1)); + filters[5] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 1)); + + + int y1 = select(in_y + 1, -1, row_id + 1 < 0 || row_id + 1 >= in_h); + inputs[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x0, y1)); + inputs[5] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x1, y1)); + inputs[6] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x2, y1)); + inputs[7] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x3, y1)); + + + output[0] = mad(inputs[4], filters[3], output[0]); + output[1] = mad(inputs[5], filters[3], output[1]); + + output[0] = mad(inputs[5], filters[4], output[0]); + output[1] = mad(inputs[6], filters[4], output[1]); + + output[0] = mad(inputs[6], filters[5], output[0]); + output[1] = mad(inputs[7], filters[5], output[1]); + + + filters[6] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 2)); + filters[7] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 2)); + filters[8] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 2)); + + int y2 = select(in_y + 2, -1, row_id + 2 < 0 || row_id + 2 >= in_h); + inputs[8] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x0, y2)); + inputs[9] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x1, y2)); + inputs[10] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x2, y2)); + inputs[11] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x3, y2)); + + + output[0] = mad(inputs[8], filters[6], output[0]); + output[1] = mad(inputs[9], filters[6], output[1]); + + output[0] = mad(inputs[9], filters[7], output[0]); + output[1] = mad(inputs[10], filters[7], output[1]); + + output[0] = mad(inputs[10], filters[8], output[0]); + output[1] = mad(inputs[11], filters[8], output[1]); +#ifdef BATCH_NORM + CL_DTYPE4 scale = READ_IMG_TYPE(CL_DTYPE_CHAR, new_scale, sampler, (int2)(ou_ch_blk_id, 0)); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(ou_ch_blk_id, 0)); + output[0] = mad(scale, output[0], biase); + if (ou_col_id + 1 < ou_w) { + output[1] = mad(scale, output[1], biase); + } +#endif + +#ifdef RELU + output[0] = activation(output[0]); + output[1] = activation(output[1]); +#endif + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x, ou_nh_id), output[0]); + if (ou_col_id + 1 < ou_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x + 1, ou_nh_id), output[1]); + } + +} + diff --git a/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl b/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl index ecf719ae9316ed14743e872a1c2cde4b254b35ff..a95c6c6897944c9c943f65b72e51a2ced94befa6 100644 --- a/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + __kernel void elementwise_add(__read_only image2d_t input, __read_only image2d_t bias, __write_only image2d_t outputImage) { int x = get_global_id(0); int y = get_global_id(1); diff --git a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl index 0ca3b9141daf671737af8d24cd03e59587e33350..775166261d01dc639cd5af8cee49f7e7fb30cb19 100644 --- a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl @@ -12,15 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define MIN_VALUE -FLT_MAX - -__kernel void pool_max( - __private const int in_height, __private const int in_width, - __private const int out_height, __private const int out_width, - __private const int pad_top, __private const int pad_left, - __private const int stride_h, __private const int stride_w, - __private const int ksize_h, __private const int ksize_w, - __read_only image2d_t input, __write_only image2d_t output) { +#include + +__kernel void pool_max(__read_only image2d_t input, + __write_only image2d_t output, + __private const int in_height, + __private const int in_width, + __private const int out_height, + __private const int out_width, + __private const int ksize_h, + __private const int ksize_w, + __private const int stride_h, + __private const int stride_w, + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); @@ -40,25 +45,30 @@ __kernel void pool_max( const int pos_in_x = out_c * in_width; const int pos_in_y = out_n * in_height; - float4 max_value = (float4)(MIN_VALUE); + CL_DTYPE4 max_value = (CL_DTYPE4)(MIN_VALUE); for (int y = start_h; y < end_h; ++y) { for (int x = start_w; x < end_w; ++x) { - float4 tmp = read_imagef(input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); + CL_DTYPE4 tmp = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); max_value = max(max_value, tmp); } } const int pos_out_x = mad24(out_c, out_width, out_w); - write_imagef(output, (int2)(pos_out_x, out_nh), max_value); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), max_value); } -__kernel void pool_avg( - __private const int in_height, __private const int in_width, - __private const int out_height, __private const int out_width, - __private const int pad_top, __private const int pad_left, - __private const int stride_h, __private const int stride_w, - __private const int ksize_h, __private const int ksize_w, - __read_only image2d_t input, __write_only image2d_t output) { +__kernel void pool_avg(__read_only image2d_t input, + __write_only image2d_t output, + __private const int in_height, + __private const int in_width, + __private const int out_height, + __private const int out_width, + __private const int ksize_h, + __private const int ksize_w, + __private const int stride_h, + __private const int stride_w, + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); @@ -76,15 +86,14 @@ __kernel void pool_avg( const int pos_in_x = out_c * in_width; const int pos_in_y = out_n * in_height; - float4 sum = (float4)(0.0f); - int num = 0; + CL_DTYPE4 sum = (CL_DTYPE4)(0.0f); + for (int y = start_h; y < end_h; ++y) { for (int x = start_w; x < end_w; ++x) { - sum += read_imagef(input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); - num++; + sum += READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); } } - float4 avg = sum / num; + CL_DTYPE4 avg = sum / (ksize_h * ksize_w); const int pos_out_x = mad24(out_c, out_width, out_w); - write_imagef(output, (int2)(pos_out_x, out_nh), avg); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), avg); } diff --git a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl index a99ac79d32bcedb48354d2e179ef6c8c1ff7f997..43a27067c2f2c418d314f9bce95bccbbb51a9be0 100644 --- a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl @@ -24,7 +24,7 @@ __kernel void relu(__read_only image2d_t input, CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - CL_DTYPE4 in = read_imagef(input, sampler, (int2)(x, y)); + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); in = max((CL_DTYPE4)(0.0f), in); - write_imagef(output, (int2)(x, y), in); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); } diff --git a/lite/backends/opencl/cl_kernel/image/reshape_kernel.cl b/lite/backends/opencl/cl_kernel/image/reshape_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..314be875d29d2125f9573d33010ee9d33317ea71 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/reshape_kernel.cl @@ -0,0 +1,162 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void reshape(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int in_W, + __private const int in_H, + __private const int in_Stride0, + __private const int in_Stride1, + __private const int in_Stride2, + __private const int out_Stride0, + __private const int out_Stride1, + __private const int out_Stride2) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_H; + const int out_h = out_nh % out_H; + const int out_c0 = out_c * 4; + const int out_c1 = out_c * 4 + 1; + const int out_c2 = out_c * 4 + 2; + const int out_c3 = out_c * 4 + 3; + + int count0 = + out_n * out_Stride2 + out_c0 * out_Stride1 + out_h * out_Stride0 + out_w; + int count1 = + out_n * out_Stride2 + out_c1 * out_Stride1 + out_h * out_Stride0 + out_w; + int count2 = + out_n * out_Stride2 + out_c2 * out_Stride1 + out_h * out_Stride0 + out_w; + int count3 = + out_n * out_Stride2 + out_c3 * out_Stride1 + out_h * out_Stride0 + out_w; + + int in_n0 = count0 / in_Stride2; + int in_n1 = count1 / in_Stride2; + int in_n2 = count1 / in_Stride2; + int in_n3 = count2 / in_Stride2; + + count0 = count0 % in_Stride2; + count1 = count1 % in_Stride2; + count2 = count2 % in_Stride2; + count3 = count3 % in_Stride2; + + int in_c0 = count0 / in_Stride1; + int in_c1 = count1 / in_Stride1; + int in_c2 = count2 / in_Stride1; + int in_c3 = count3 / in_Stride1; + + int in_h0 = (count0 % in_Stride1) / in_Stride0; + int in_h1 = (count1 % in_Stride1) / in_Stride0; + int in_h2 = (count2 % in_Stride1) / in_Stride0; + int in_h3 = (count3 % in_Stride1) / in_Stride0; + + int in_w0 = (count0 % in_Stride1) % in_Stride0; + int in_w1 = (count1 % in_Stride1) % in_Stride0; + int in_w2 = (count2 % in_Stride1) % in_Stride0; + int in_w3 = (count3 % in_Stride1) % in_Stride0; + + int2 input_pos0; + int2 input_pos1; + int2 input_pos2; + int2 input_pos3; + + input_pos0.x = (in_c0 / 4) * in_W + in_w0; + input_pos0.y = in_n0 * in_H + in_h0; + + input_pos1.x = (in_c1 / 4) * in_W + in_w1; + input_pos1.y = in_n1 * in_H + in_h1; + + input_pos2.x = (in_c2 / 4) * in_W + in_w2; + input_pos2.y = in_n2 * in_H + in_h2; + + input_pos3.x = (in_c3 / 4) * in_W + in_w3; + input_pos3.y = in_n3 * in_H + in_h3; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 input0; + CL_DTYPE4 input1; + CL_DTYPE4 input2; + CL_DTYPE4 input3; + CL_DTYPE4 output; + + input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos0); + if (in_c0 % 4 == 0) { + output.x = input0.x; + } else if (in_c0 % 4 == 1) { + output.x = input0.y; + } else if (in_c0 % 4 == 2) { + output.x = input0.z; + } else { + output.x = input0.w; + } + if (out_C - out_c * 4 >= 2) { + input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos1); + if (in_c1 % 4 == 0) { + output.y = input1.x; + } else if (in_c1 % 4 == 1) { + output.y = input1.y; + } else if (in_c1 % 4 == 2) { + output.y = input1.z; + } else { + output.y = input1.w; + } + + } else { + output.y = 0.0f; + } + + if (out_C - out_c * 4 >= 3) { + input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos2); + + if (in_c2 % 4 == 0) { + output.z = input2.x; + } else if (in_c2 % 4 == 1) { + output.z = input1.y; + } else if (in_c2 % 4 == 2) { + output.z = input2.z; + } else { + output.z = input2.w; + } + } else { + output.z = 0.0f; + } + + if (out_C - out_c * 4 >= 4) { + input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos3); + if (in_c3 % 4 == 0) { + output.w = input3.x; + } else if (in_c3 % 4 == 1) { + output.w = input3.y; + } else if (in_c3 % 4 == 2) { + output.w = input3.z; + } else { + output.w = input3.w; + } + } else { + output.w = 0.0f; + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/backends/opencl/cl_runtime.cc b/lite/backends/opencl/cl_runtime.cc index c2504ab611e93399c70169f3f123d4a0514c07ad..0c7b2f8575a88082f6d79a5392c4468715a701b9 100644 --- a/lite/backends/opencl/cl_runtime.cc +++ b/lite/backends/opencl/cl_runtime.cc @@ -103,6 +103,7 @@ std::unique_ptr CLRuntime::CreateEvent( bool CLRuntime::BuildProgram(cl::Program* program, const std::string& options) { std::string build_option = options + " -cl-fast-relaxed-math -I " + CLRuntime::Global()->cl_path() + "/cl_kernel"; + VLOG(4) << "OpenCL build_option: " << build_option; status_ = program->build({*device_}, build_option.c_str()); CL_CHECK_ERROR(status_); diff --git a/lite/backends/opencl/target_wrapper.cc b/lite/backends/opencl/target_wrapper.cc index 575f87d0f8d0192345c6ab111db46715a809a976..310567baa539697f6a67b59f6c0e5f29ce46a80e 100644 --- a/lite/backends/opencl/target_wrapper.cc +++ b/lite/backends/opencl/target_wrapper.cc @@ -24,6 +24,8 @@ static cl_channel_type GetCLChannelType(const PrecisionType type) { switch (type) { case PRECISION(kFloat): return CL_FLOAT; + case PRECISION(kFP16): + return CL_HALF_FLOAT; case PRECISION(kInt32): return CL_SIGNED_INT32; case PRECISION(kInt8): @@ -58,17 +60,18 @@ void TargetWrapperCL::Free(void *ptr) { template <> void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height) { + const size_t cl_image2d_height, + void *host_ptr) { cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFloat))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), img_format, cl_image2d_width, cl_image2d_height, 0, - nullptr, + host_ptr, &status); if (status != CL_SUCCESS) { delete cl_image; @@ -78,19 +81,20 @@ void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, return cl_image; } -template <> -void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height) { - cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt8))); +template <> // use int16_t represents half float +void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height, + void *host_ptr) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFP16))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), img_format, cl_image2d_width, cl_image2d_height, 0, - nullptr, + host_ptr, &status); if (status != CL_SUCCESS) { delete cl_image; @@ -102,17 +106,18 @@ void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, template <> void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height) { + const size_t cl_image2d_height, + void *host_ptr) { cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt32))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), img_format, cl_image2d_width, cl_image2d_height, 0, - nullptr, + host_ptr, &status); if (status != CL_SUCCESS) { delete cl_image; diff --git a/lite/backends/opencl/target_wrapper.h b/lite/backends/opencl/target_wrapper.h index 7753448052e17ac739f730c9fabcaf9533e0045e..c5ff9e900a70fd96ccb461c74fb61e33815a5e81 100644 --- a/lite/backends/opencl/target_wrapper.h +++ b/lite/backends/opencl/target_wrapper.h @@ -48,7 +48,8 @@ class TargetWrapper { template static void* MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height); + const size_t cl_image2d_height, + void* host_ptr = nullptr); static void FreeImage(void* image); static void* Map(void* buffer, size_t offset, size_t size); diff --git a/lite/backends/xpu/CMakeLists.txt b/lite/backends/xpu/CMakeLists.txt index f911f8e0e7c61481e1d4e309bc0635718be11206..4491fdeaefe9f16265bdee2c07ebb02b86a2b038 100644 --- a/lite/backends/xpu/CMakeLists.txt +++ b/lite/backends/xpu/CMakeLists.txt @@ -2,5 +2,4 @@ if(NOT LITE_WITH_XPU) return() endif() -lite_cc_library(xpu_runtime SRCS runtime.cc DEPS ${xpu_runtime_libs}) -lite_cc_library(xpu_builder SRCS builder.cc DEPS ${xpu_builder_libs} xpu_runtime tensor op scope) +lite_cc_library(device_xpu SRCS device.cc DEPS ${xpu_builder_libs} ${xpu_runtime_libs}) diff --git a/lite/backends/xpu/device.cc b/lite/backends/xpu/device.cc new file mode 100644 index 0000000000000000000000000000000000000000..74a5681aa98f2c2d3d4025d91207f24f0733a19e --- /dev/null +++ b/lite/backends/xpu/device.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/device.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace xpu { + +std::unique_ptr Device::Build( + xtcl::network::xNetworkBuilder* builder, + xtcl::network::xTensorCompiler::ParamNDArrayMap* params, + std::vector* outputs) { + VLOG(3) << "[XPU] Build model"; + CHECK(builder != nullptr); + CHECK(outputs != nullptr); + CHECK_GT(outputs->size(), 0); + + // The XPU compiler build the graph and fill all of the constant params, only + // one output is supported now. + xtcl::Array all_outs; + for (size_t i = 0; i < outputs->size(); i++) { + all_outs.push_back(*outputs->at(i)); + } + xtcl::xNetwork network = + builder->FinalizeNetwork(xtcl::relay::TupleNode::make(all_outs)); + auto target = xtcl::Target::Create(device_name_); + auto compiler = xtcl::network::xTensorCompiler(network, target); + compiler.SetParams(*params); // Set the data of constant tensors + compiler.Build(); + return std::unique_ptr( + new xtcl::network::xRuntimeInstance(compiler.CreateRuntimeInstance())); +} + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/device.h b/lite/backends/xpu/device.h new file mode 100644 index 0000000000000000000000000000000000000000..bf9a8bf76af168a8a73f8f497b793df88f48f96b --- /dev/null +++ b/lite/backends/xpu/device.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace xpu { + +class Device { + public: + static Device& Global() { + static Device x; + return x; + } + Device() {} + + // Build the XPU graph to the XPU runtime, return the XPU runtime which can be + // used to run inference. + std::unique_ptr Build( + xtcl::network::xNetworkBuilder* builder, + xtcl::network::xTensorCompiler::ParamNDArrayMap* params, + std::vector* outputs); + + private: + // Keep reserved fields + int device_id_{0}; + std::string device_name_{"llvm"}; +}; + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/runtime.cc b/lite/backends/xpu/runtime.cc deleted file mode 100644 index a2c34b95758e8abf81c8294507d0ca60aad7c021..0000000000000000000000000000000000000000 --- a/lite/backends/xpu/runtime.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/xpu/runtime.h" -#include -#include "lite/utils/cp_logging.h" - -namespace paddle { -namespace lite { -namespace xpu { - -// Extract the model data and recover the XPU model for inference, the function -// is called by the graph computing kernel when the graph op is executed. -// Due to the lack of XPU APIs for loading and recovering the XPU model from -// memory, the key name is obtained from the weight tensor of graph op, to get -// the runtime object for inference from the global variable 'DeviceInfo'. -// TODO(hong19860320) Recover the XPU model from the weight tensor of graph op. -bool LoadModel(const lite::Tensor &model, - std::shared_ptr *runtime) { - LOG(INFO) << "[XPU] Load Model."; - CHECK_GT(model.dims().production(), 0); - std::string name(reinterpret_cast(model.data())); - LOG(INFO) << "[XPU] Model Name: " << name; - CHECK(runtime != nullptr); - *runtime = DeviceInfo::Global().Find(name); - if (*runtime == nullptr) { - LOG(WARNING) << "[XPU] Load Model failed!"; - return false; - } - return true; -} - -} // namespace xpu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/xpu/runtime.h b/lite/backends/xpu/runtime.h deleted file mode 100644 index 4ff8d75bce6156d51a4988d427058da34460443f..0000000000000000000000000000000000000000 --- a/lite/backends/xpu/runtime.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include "lite/core/tensor.h" - -namespace paddle { -namespace lite { -namespace xpu { - -class DeviceInfo { - public: - static DeviceInfo& Global() { - static DeviceInfo x; - return x; - } - DeviceInfo() {} - - void Insert(const std::string& name, - std::shared_ptr runtime) { - if (runtimes_.find(name) != runtimes_.end()) { - LOG(WARNING) << "[XPU] Model " << name << " already exists."; - return; - } - runtimes_.emplace(std::make_pair(name, runtime)); - } - - void Clear() { runtimes_.clear(); } - - std::shared_ptr Find( - const std::string& name) const { - if (runtimes_.find(name) != runtimes_.end()) { - return runtimes_.at(name); - } else { - return nullptr; - } - } - - private: - int device_id_{0}; - std::string device_name_{"default"}; - std::unordered_map> - runtimes_; -}; - -bool LoadModel(const lite::Tensor& model, - std::shared_ptr* runtime); - -} // namespace xpu -} // namespace lite -} // namespace paddle diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index a93b962a4723b2677defc16fdaf1d0922f1b48fa..34d9deff6a5262c16c2f74301771b73479f3ae30 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -33,9 +33,9 @@ lite_cc_library(scope SRCS scope.cc DEPS tensor) lite_cc_library(device_info SRCS device_info.cc DEPS tensor) if (LITE_WITH_ARM) -lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS npu_runtime) +lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags) else() -lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags XPU_DEPS xpu_runtime) +lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags) endif() #-------------------------------------------- GET CODE META INFO ------------------------------------------ diff --git a/lite/core/arena/CMakeLists.txt b/lite/core/arena/CMakeLists.txt index bc77afd81e0859b9492b2068ce681098a9393923..6c0c917a3e6b18f926a5fa768131e36296301432 100644 --- a/lite/core/arena/CMakeLists.txt +++ b/lite/core/arena/CMakeLists.txt @@ -5,6 +5,6 @@ endif() lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest) -if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) - lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) +if((NOT LITE_WITH_OPENCL) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) + lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc index 561a508d20f1db9283a410b8ee35dd851149429c..fe36f1e1ba16ad85c44136b09a0d2e5d3fadf688 100644 --- a/lite/core/arena/framework.cc +++ b/lite/core/arena/framework.cc @@ -14,13 +14,38 @@ #include "lite/core/arena/framework.h" #include "lite/core/context.h" +#include "lite/operators/subgraph_op.h" namespace paddle { namespace lite { namespace arena { void TestCase::CreateInstruction() { - auto op = LiteOpRegistry::Global().Create(op_desc().Type()); + std::shared_ptr op = nullptr; + if (place_.target == TARGET(kNPU) || place_.target == TARGET(kXPU)) { + // Create a new block desc to wrap the original op desc + int sub_block_idx = 0; + auto sub_block_desc = new cpp::BlockDesc(); + sub_block_desc->ClearOps(); + sub_block_desc->ClearVars(); + auto sub_block_op_desc = sub_block_desc->AddOp(); + *sub_block_op_desc = *op_desc_; + // Add the block desc into the subgraph op which used to replace the + // original op + op_desc_.reset(new cpp::OpDesc()); + op_desc_->SetType("subgraph"); + op_desc_->SetAttr("sub_block", sub_block_idx); + auto in_names = sub_block_op_desc->input_vars(); + auto out_names = sub_block_op_desc->output_vars(); + op_desc_->SetInput("Inputs", in_names); + op_desc_->SetOutput("Outputs", out_names); + op_desc_->SetAttr>("input_data_names", in_names); + op_desc_->SetAttr>("output_data_names", out_names); + op = LiteOpRegistry::Global().Create(op_desc().Type()); + static_cast(op.get())->SetSubBlock(sub_block_desc); + } else { + op = LiteOpRegistry::Global().Create(op_desc().Type()); + } CHECK(op) << "no op for " << op_desc().Type(); op->Attach(*op_desc_, inst_scope_); auto kernels = op->CreateKernels({place_}); @@ -68,6 +93,19 @@ void TestCase::PrepareInputsForInstruction() { } } +TestCase::~TestCase() { + if (op_desc_->Type() == "subgraph") { + // Release the subblock desc of Subgraph op + auto subgraph_op = const_cast( + static_cast(instruction_->op())); + CHECK(subgraph_op); + auto sub_block_desc = subgraph_op->GetSubBlock(); + if (sub_block_desc) { + delete sub_block_desc; + } + } +} + } // namespace arena } // namespace lite } // namespace paddle diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 412ac0c167b8abe6d196dc25d1bc5b193d02965d..05af21bbdbfd6d00aa0eb3992fa732cf8f2e0fab 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include "lite/core/op_registry.h" @@ -42,7 +43,7 @@ class TestCase { : place_(place), scope_(new Scope), alias_(alias) { ctx_ = ContextScheduler::Global().NewContext(place_.target); } - virtual ~TestCase() {} + virtual ~TestCase(); void Prepare() { PrepareScopes(); @@ -77,6 +78,20 @@ class TestCase { // kernel registry. void CheckKernelConsistWithDefinition() {} + // Get the real precision of the output for check precision. When the declare + // precision obtained from the kernel is any, we should set the precision of + // the output in test case. + bool GetPrecisonType(const std::string& var_name, + PrecisionType* precision_type) { + auto res = precision_type_map_.find(var_name); + if (res == precision_type_map_.end()) { + return false; + } else { + *precision_type = precision_type_map_.at(var_name); + return true; + } + } + Scope& scope() { return *scope_; } Scope* baseline_scope() { return base_scope_; } @@ -105,6 +120,19 @@ class TestCase { // Prepare for the operator. virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; + // Set the real precision of the output for check precision. When the declare + // precision obtained from the kernel is any, we should set the precision of + // the output in test case. + void SetPrecisionType(const std::string& var_name, + const PrecisionType& precision_type) { + auto res = precision_type_map_.find(var_name); + if (res == precision_type_map_.end()) { + precision_type_map_.insert({var_name, precision_type}); + } else { + precision_type_map_.at(var_name) = precision_type; + } + } + public: const Instruction& instruction() { return *instruction_; } @@ -148,6 +176,7 @@ class TestCase { Scope* base_scope_{}; std::unique_ptr op_desc_; std::unique_ptr instruction_; + std::unordered_map precision_type_map_; }; class Arena { @@ -159,13 +188,17 @@ class Arena { tester_->Prepare(); } - bool TestPrecision() { + bool TestPrecision(const std::vector& exclude_outs = {}) { tester_->RunBaseline(tester_->baseline_scope()); tester_->RunInstruction(); bool success = true; for (auto& out : tester_->op_desc().OutputArgumentNames()) { for (auto& var : tester_->op_desc().Output(out)) { + if (std::find(exclude_outs.begin(), exclude_outs.end(), var) != + exclude_outs.end()) { + continue; + } success = success && CompareTensor(out, var); } } @@ -189,8 +222,11 @@ class Arena { // get tensor type. const Type* type = tester_->instruction().kernel()->GetOutputDeclType(arg_name); - - switch (type->precision()) { + auto precision_type = type->precision(); + if (precision_type == PRECISION(kAny)) { + CHECK(tester_->GetPrecisonType(var_name, &precision_type)); + } + switch (precision_type) { case PRECISION(kFloat): return tester_->CheckPrecision(var_name, abs_error_); case PRECISION(kInt8): @@ -199,7 +235,6 @@ class Arena { return tester_->CheckPrecision(var_name, abs_error_); case PRECISION(kBool): return tester_->CheckPrecision(var_name, abs_error_); - default: LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); return false; diff --git a/lite/core/context.h b/lite/core/context.h index eb25e7e1d980de9e8f633591fc1320f2a7cd476d..2830bca5c1b1e3dce151e498dd502e6636e54950 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -25,12 +25,6 @@ #include "lite/backends/opencl/cl_context.h" #include "lite/backends/opencl/cl_runtime.h" #endif -#ifdef LITE_WITH_NPU -#include "lite/backends/npu/runtime.h" -#endif -#ifdef LITE_WITH_XPU -#include "lite/backends/xpu/runtime.h" -#endif #include #include @@ -93,7 +87,7 @@ template <> class Context { public: Context() {} - explicit Context(const NPUContext& ctx); + explicit Context(const XPUContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} void CopySharedTo(XPUContext* ctx) {} diff --git a/lite/core/memory.h b/lite/core/memory.h index cb4ac044e7af6994e5e404f379eeb12290e34778..18b9958911a6173c088b415369555235d63d184d 100644 --- a/lite/core/memory.h +++ b/lite/core/memory.h @@ -100,13 +100,14 @@ class Buffer { template void ResetLazyImage2D(TargetType target, const size_t img_w, - const size_t img_h) { + const size_t img_h, + void* host_ptr = nullptr) { size_t size = sizeof(T) * img_w * img_h * 4; // 4 for RGBA, un-used for opencl Image2D if (target != target_ || cl_image2d_width_ < img_w || cl_image2d_height_ < img_h) { Free(); - data_ = TargetWrapperCL::MallocImage(img_w, img_h); + data_ = TargetWrapperCL::MallocImage(img_w, img_h, host_ptr); target_ = target; space_ = size; // un-used for opencl Image2D cl_image2d_width_ = img_w; diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index a44b8348716449519486d37f6784e31ecc39f554..810ff0f875168da1c4411471b7ea3ea6617a9b4f 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -16,10 +16,12 @@ lite_cc_library(mir_passes fusion/interpolate_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc fusion/conv_activation_fuse_pass.cc + fusion/var_conv_2d_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc + elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc @@ -32,7 +34,7 @@ lite_cc_library(mir_passes demo_pass.cc runtime_context_assign_pass.cc memory_optimize_pass.cc - DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) + DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs}) # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #mir_ssa_graph scope op diff --git a/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..863c01ef0646794b5cbe54d7a81a8f26dbf164ae --- /dev/null +++ b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class ElementwiseMulConstantEliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + auto* post_op = OpNode("postop"); // the post op's output need update + // TODO(Superjomn) check has only one output + auto* x = + VarNode("x")->assert_is_op_input("elementwise_mul", "X")->AsOutput(); + auto* y = VarNode("Y")->assert_is_op_input("elementwise_mul", "Y"); + + // create op nodes + auto* mul = OpNode("mul", "elementwise_mul") + ->assert_is_op("elementwise_mul") + ->AsIntermediate(); + + auto* fill_constant = OpNode("fill_constant", "fill_constant") + ->assert_is_op("fill_constant") + ->assert_op_attr("value", 1.) + ->AsIntermediate(); + // create output node + auto* mul_out = + VarNode("output")->assert_is_op_output("elementwise_mul", "Out"); + // create topology. + std::vector add_inputs{x, y}; + *pre_op >> *x; + *fill_constant >> *y; + add_inputs >> *mul >> *mul_out; + *mul_out >> *post_op; + + // The pre_op will be eliminated, and a new output-updated op will insert. + mul_out->AsIntermediate(); // mul_out is pre_op's output, need to update + y->AsIntermediate(); // need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& post_op = matched.at("postop")->AsStmt(); + auto op_info = *post_op.op_info(); + + op_info.UpdateAllInputs(matched.at("output")->AsArg().name, + matched.at("x")->AsArg().name); + post_op.ResetOp(op_info, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("x"), matched.at("postop")); + } +}; + +} // namespace + +class ElementwiseMulConstantEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + ElementwiseMulConstantEliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(elementwise_mul_constant_eliminate_pass, + paddle::lite::mir::ElementwiseMulConstantEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc index acea48c742522d5b6b5f1f3b570fcbfe0c4be08d..345361047bbbad68cdd0b298a163214cbfe114fc 100644 --- a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -25,7 +25,8 @@ namespace { class Eliminator : public FuseBase { public: void BuildPattern() override { - auto* pre_op = OpNode("preop"); // the previous op's output need update + // the previous op's output need updat + auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); // TODO(Superjomn) check has only one output auto* x = VarNode("x")->assert_is_op_input("scale", "X"); auto* scale_op = OpNode("scale", "scale") diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 5ac52837551f0b78d67dfe1733fe354ee2cf7f01..8699470955b663fc2562074e99529def72836794 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -10,6 +10,9 @@ lite_cc_library(fuse_conv_elementwise lite_cc_library(fuse_conv_activation SRCS conv_activation_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_var_conv_activation + SRCS var_conv_2d_activation_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) @@ -31,6 +34,7 @@ set(mir_fusers fuse_shuffle_channel fuse_conv_elementwise fuse_conv_activation + fuse_var_conv_activation fuse_conv_bn fuse_quant_dequant fuse_elementwise_add_activation diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index 0d11b47db6a7f767f8cd032877d8647b0872b8d4..c5ce74e30e34b5878a534010b6cf8b86f91a1118 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -30,7 +30,7 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { break; } } - for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { + for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { for (auto act_type : act_types) { for (auto has_bias : {true, false}) { fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 5ab5f8c0a4797e51cce656de43883a68d4931e9b..4725ca74855d72674b922478acd1f6f3a3b59798 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -26,7 +26,8 @@ namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; - std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; // start fuse using params for (auto conv_has_bias : conv_has_bias_cases) { diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce2248cbc23d8887a22f94c14b2507fb0cacbed --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void VarConv2dActivationFusePass::Apply( + const std::unique_ptr& graph) { + std::vector act_types{"relu"}; + for (auto act_type : act_types) { + fusion::VarConvActivationFuser fuser(act_type, "var_conv_2d"); + fuser(graph.get()); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_var_conv_2d_activation_fuse_pass, + paddle::lite::mir::VarConv2dActivationFusePass) + .BindTargets({TARGET(kCUDA)}); diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7616aadef340d3e4d6bc11534dd839c91fe9ed1d --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class VarConv2dActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..eabd97ae4513b84c9c002aa1587d45cce6b22e21 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void VarConvActivationFuser::BuildPattern() { + // create nodes. + auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput(); + auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput(); + + auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate(); + + auto* act = OpNode("act", act_type_)->AsIntermediate(); + + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Out") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + auto* conv2d_out_1 = VarNode("conv2d_out_1") + ->assert_is_op_output(conv_type_, "Col") + ->AsIntermediate(); + + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input}; + conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; + *conv2d >> *conv2d_out_1; +} + +void VarConvActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("var_conv_2d")->stmt()->op(); + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("X"), new_op_node); + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc VarConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("var_conv_2d")->stmt()->op_info(); + op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); + cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info(); + + if (act_type_ == "relu") { + op_desc.SetAttr("fuse_relu", true); + } + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.h b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..68bc89f7d13d38dc07814f3296a25bfd7dea0248 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class VarConvActivationFuser : public FuseBase { + public: + explicit VarConvActivationFuser(const std::string& act_type, + const std::string& conv_type) + : act_type_(act_type), conv_type_(conv_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string act_type_; + std::string conv_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index 76ea9555c29a245aa9f20b158f0706557940bef8..3a27360f94d7d828e1c19214d621f1dfe4e048ca 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -36,15 +36,6 @@ std::string Visualize(mir::SSAGraph* graph) { int id = 0; std::set exists_args; - std::map graph_col; // Different colors of subgraphs - graph_col.insert({{1, "red"}, - {2, "green"}, - {3, "cyan"}, - {4, "bisque3"}, - {5, "coral"}, - {6, "darkseagreen1"}, - {7, "goldenrod1"}, - {8, "darkorchid"}}); for (auto& node : graph->mutable_nodes()) { std::string key; if (node.IsArg()) { @@ -52,24 +43,12 @@ std::string Visualize(mir::SSAGraph* graph) { } else { key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++); } - if (node.IsStmt()) { - auto& stmt = node.AsStmt(); - auto sub_id = stmt.subgraph_id(); - auto it = graph_col.find(sub_id); - if (sub_id > 0 && it != graph_col.end()) { - dot.AddNode(key, - {Dot::Attr("shape", "box"), - Dot::Attr("style", "filled"), - Dot::Attr("color", "black"), - Dot::Attr("fillcolor", it->second)}); - } else { - dot.AddNode(key, - {Dot::Attr("shape", "box"), - Dot::Attr("style", "filled"), - Dot::Attr("color", "black"), - Dot::Attr("fillcolor", "yellow")}); - } + dot.AddNode(key, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", "yellow")}); for (auto& x : node.inlinks) { auto name = x->AsArg().name; if (!exists_args.count(name)) { diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 4f41ba4a601ae763e6fa48c0a98de238252ea7c2..dbf32da2348c6aa39eb4f9d9c65b404e31fb3145 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -50,7 +50,7 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( "lod_reset", "concat", "yolo_box", - "graph_op", + "subgraph", "feed", "fetch"}; for (auto* tmp : node->inlinks) { diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h index 60fa1fb1ebe49e1be38a7d84cb82545389ea4aac..e2c8a68bde6ee18506de73a7531716695b3d54f1 100644 --- a/lite/core/mir/node.h +++ b/lite/core/mir/node.h @@ -64,9 +64,6 @@ class Node { return valid_kernels_; } - void ClearSubgraphID() { subgraph_id_ = -1 /* note: not 0 */; } - void SetSubgraphID(int id) { subgraph_id_ = id; } - int subgraph_id() const { return subgraph_id_; } void SetOp(const std::shared_ptr& op) { op_ = op; } const std::shared_ptr op() const { return op_; } @@ -82,11 +79,6 @@ class Node { // Description. std::string desc; - - protected: - // -1 means not in subgraph, 0 means supported but not one id, id started - // from 1 - int subgraph_id_{-1}; }; struct Arg { diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index 8e0fc55be2389244ae065b4c2809bbdd74be370c..b625919cbfb6d26ecbbd1bad36772aff86bee087 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -377,6 +377,19 @@ PMNode *PMNode::assert_is_op(const std::string &op_type) { return this; } +PMNode *PMNode::assert_is_not_op_type(const std::string &op_type) { + asserts_.emplace_back([op_type](const Node *x) { + if (x && x->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + if (op_info->Type() == op_type) { + return false; + } + } + return true; + }); + return this; +} + PMNode *PMNode::assert_is_var() { asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); return this; diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 47a0a30b5667ddc97b3783ab9edbab04281528a4..90c4359c6d3ade98cf60b5c23411e2026cdeccc9 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -123,6 +123,7 @@ struct PMNode { // Assertions, helper functions to simplify the pattern definition. PMNode* assert_is_op(); PMNode* assert_is_op(const std::string& op_type); + PMNode* assert_is_not_op_type(const std::string& op_type); PMNode* assert_is_var(); PMNode* assert_var_not_persistable(); PMNode* assert_is_persistable_var(); diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc index 8f22022789046900c3c09cfb122c914968d8d87f..2b5b65ce5903ede41137311c585c0e87eaaa0e9d 100644 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -123,6 +123,9 @@ void SSAGraph::Build(const Program &program, return true; }; + std::unordered_map var_types = + program.var_data_type(); + std::unordered_map arg_update_node_map_; for (auto &op : program.ops()) { VLOG(3) << op->op_info()->Type(); @@ -137,6 +140,10 @@ void SSAGraph::Build(const Program &program, arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; } + if (var_types.count(name) && !arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); DirectedLink(arg_node, op_node); @@ -146,6 +153,10 @@ void SSAGraph::Build(const Program &program, auto *arg_node = &node_storage_.back(); arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; + if (var_types.count(name) && !arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index c49e4497099c5f04a39bf91e70ca8f48900e7ba7..1cc8942d611db389a44cbf6a244775a5b666b587 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -14,7 +14,10 @@ #include "lite/core/mir/static_kernel_pick_pass.h" #include +#include #include +#include +#include #include #include #include "lite/core/mir/graph_visualize_pass.h" @@ -43,13 +46,33 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); + std::unordered_map in_types; + std::unordered_map out_types; + for (std::list::iterator i = node.inlinks.begin(); + i != node.inlinks.end(); + ++i) { + if ((*i)->arg()->type) + in_types[(*i)->arg()->name] = (*i)->arg()->type->precision(); + } + for (std::list::iterator i = node.outlinks.begin(); + i != node.outlinks.end(); + ++i) { + if ((*i)->arg()->type) + out_types[(*i)->arg()->name] = (*i)->arg()->type->precision(); + } // Get candidate kernels std::vector>> scored; CHECK(!instruct.kernels().empty()) << "No kernels found for " << instruct.op_type(); VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(instruct, *kernel, graph->valid_places()); + float score = KernelGrade(instruct, + *kernel, + graph->valid_places(), + in_types, + out_types, + instruct.op_info()->input_names(), + instruct.op_info()->output_names()); VLOG(4) << "kernel->summary():" << kernel->summary() << " score:" << score; scored.emplace_back(score, std::move(kernel)); @@ -99,7 +122,13 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { instruct.ResetOp(update_desc, graph->valid_places()); scored.clear(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(instruct, *kernel, graph->valid_places()); + float score = KernelGrade(instruct, + *kernel, + graph->valid_places(), + in_types, + out_types, + instruct.op_info()->input_names(), + instruct.op_info()->output_names()); scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h index cd54e2654c22b98cbacc9a73bef7770a029c0b30..f655b298bf2d800f4adf142ad14b8ac05ca00482 100644 --- a/lite/core/mir/static_kernel_pick_pass.h +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "lite/core/mir/pass.h" #include "lite/core/types.h" @@ -48,9 +50,14 @@ class StaticKernelPickPass : public mir::StmtPass { private: // Score the kernel. - size_t KernelGrade(const lite::mir::Node::Stmt& instruct, - const lite::KernelBase& kernel, - const std::vector& places) { + size_t KernelGrade( + const lite::mir::Node::Stmt& instruct, + const lite::KernelBase& kernel, + const std::vector& places, + const std::unordered_map& in_types, + const std::unordered_map& out_types, + const std::vector& in_names, + const std::vector& out_names) { CHECK_GT(places.size(), 0) << "valid_places is empty."; float final_score{-1.}; Place winner_place{places[0]}; @@ -100,6 +107,37 @@ class StaticKernelPickPass : public mir::StmtPass { core::KernelPickFactor::Factor::DataLayoutFirst); } VLOG(4) << "[score s3]:" << score; + + // add new rules for precision: When the input types are consistent with + // kernel's input types and the output types are consistent with kernel's + // output types. Select the kernel of the precision. Note that this + // strategy is not compatible with quantization, so skip quantization op. + if (!instruct.op_info()->HasAttr("enable_int8")) { + bool type_match = true; + for (size_t i = 0; i < in_names.size(); ++i) { + std::string tmp; + CHECK(instruct.op_info()->GetInputArgname(in_names[i], &tmp)); + if (in_types.count(in_names[i]) && + in_types.at(in_names[i]) != + kernel.GetInputDeclType(tmp)->precision()) { + type_match = false; + } + } + for (size_t i = 0; i < out_names.size(); ++i) { + std::string tmp; + CHECK(instruct.op_info()->GetOutputArgname(out_names[i], &tmp)); + if (out_types.count(out_names[i]) && + out_types.at(out_names[i]) != + kernel.GetOutputDeclType(tmp)->precision()) { + type_match = false; + } + } + if (type_match) { + score *= 2; + } + VLOG(4) << "[score s4]:" << score; + } + if (weight * score > final_score) { final_score = weight * score; winner_place = place; diff --git a/lite/core/mir/subgraph/CMakeLists.txt b/lite/core/mir/subgraph/CMakeLists.txt index 95b5fe5ae13e03940bda8d83fcfc252b4ca490ab..1ac4ab346f15edf9e039d3143c0a301d49a1c0b4 100644 --- a/lite/core/mir/subgraph/CMakeLists.txt +++ b/lite/core/mir/subgraph/CMakeLists.txt @@ -1,50 +1,30 @@ - +lite_cc_library(subgraph_detector + SRCS subgraph_detector.cc + DEPS mir_pass types subgraph_op) lite_cc_library(subgraph_pass - SRCS subgraph_program_pass.cc - DEPS mir_pass types ${mir_fusers}) -lite_cc_test(test_subgraph_pass SRCS subgraph_program_pass_test.cc - DEPS subgraph_pass mir_passes gflags model_parser cxx_api - ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) + SRCS subgraph_pass.cc + DEPS mir_pass types context ${mir_fusers} subgraph_detector) if (WITH_TESTING) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v1_tar_gz) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) - set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") -endif() - -set(subgraph_passes subgraph_pass) - -if(LITE_WITH_NPU) - lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc - DEPS mir_pass types context ${mir_fusers} ${npu_bridges} graph_op subgraph_pass) - list(APPEND subgraph_passes npu_pass) - lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc - DEPS npu_pass mir_passes paddle_api_full paddle_api_light gflags - ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 - --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) - if (WITH_TESTING) - add_dependencies(test_npu_pass extern_lite_download_mobilenet_v1_tar_gz) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) + lite_cc_test(test_subgraph_detector + SRCS subgraph_detector_test.cc + DEPS subgraph_detector mir_passes gflags model_parser cxx_api + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) + add_dependencies(test_subgraph_detector + extern_lite_download_mobilenet_v1_tar_gz + extern_lite_download_mobilenet_v2_relu_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(test_npu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - endif() -endif() - -if(LITE_WITH_XPU) - lite_cc_library(xpu_pass SRCS generate_xpu_program_pass.cc - DEPS mir_pass types context ${mir_fusers} ${xpu_bridges} ${xpu_builder_libs} graph_op subgraph_pass) - list(APPEND subgraph_passes xpu_pass) - lite_cc_test(test_xpu_pass SRCS generate_xpu_program_pass_test.cc - DEPS xpu_pass mir_passes paddle_api_full gflags - ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 - --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) - if (WITH_TESTING) - add_dependencies(test_xpu_pass extern_lite_download_mobilenet_v1_tar_gz) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) + set_target_properties(test_subgraph_detector PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + lite_cc_test(test_subgraph_pass + SRCS subgraph_pass_test.cc + DEPS mir_passes paddle_api_full paddle_api_light gflags + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 + --optimized_model_dir=${LITE_MODEL_DIR}/lite_model_opt SERIAL) + add_dependencies(test_subgraph_pass + extern_lite_download_mobilenet_v1_tar_gz + extern_lite_download_mobilenet_v2_relu_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(test_xpu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - endif() + set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") endif() -set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes") -message(STATUS "----> subgraph_passes: ${subgraph_passes}") +set(mir_subgraphs subgraph_pass CACHE INTERNAL "mir_subgraphs") +message(STATUS "----> mir_subgraphs: ${mir_subgraphs}") diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc deleted file mode 100644 index 65c29aa68f1c8c5f5702ca97d27f9579edc7a951..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/core/mir/subgraph/generate_npu_program_pass.h" -#include -#include -#include -#include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/pass_registry.h" -#include "lite/core/mir/pattern_matcher.h" - -#include "lite/backends/npu/builder.h" -#include "lite/kernels/npu/bridges/paddle_use_npu_bridges.h" -#include "lite/kernels/npu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -std::shared_ptr GenerateNPUProgramPass::CvtVarNode( - lite::mir::Node* var_node, const Scope* scope) { - CHECK(var_node->IsArg()); - const auto& arg = var_node->AsArg(); - VLOG(4) << "[NPU] Convert var node " << arg.name; - - auto* var = scope->FindVar(arg.name); - CHECK(var); - auto* tensor = var->GetMutable(); - CHECK(tensor); - auto dims = tensor->dims(); - if (arg.is_weight) { - auto wgt = std::make_shared(arg.name); - LOG(INFO) << "[NPU] Convert const var node " << arg.name; - VLOG(4) << dims; - wgt->set_attr_value(lite::npu::CvtTensor(tensor)); - return wgt; - } else { - CHECK_EQ(dims.size(), 4); - LOG(INFO) << "[NPU] Convert data var node " << arg.name; - LOG(INFO) << dims; - // TODO(xxx): support more types and dims size - ge::TensorDesc desc(ge::Shape(dims.Vectorize()), - ge::Format::FORMAT_NCHW, - ge::DataType::DT_FLOAT); - - // auto size = desc.GetShape().GetShapeSize(); - // ge::TensorUtils::SetSize(desc, size*sizeof(float)); - // ge::TensorUtils::SetRealDimCnt(desc, 4); - auto data = std::make_shared(arg.name); - data->update_input_desc_x(desc); - return data; - } - return nullptr; -} - -void GenerateNPUProgramPass::CvtAllOpNodes( - const std::vector& nodes2cvt, - lite::kernels::npu::bridges::node_map_type* converted_vars) { - const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); - const auto& cvtfunc_map = bridges.AllFunctions(); - // return record all converted vars - // op node's inputs must be found in converted_vars - for (auto& node : nodes2cvt) { - lite::kernels::npu::bridges::node_map_type node_inputs; - auto& stmt = node->AsStmt(); - for (auto& var_node : node->inlinks) { - auto& arg = var_node->AsArg(); - // weight should be handled in the converter, so skip here - if (arg.is_weight) { - continue; - } - auto var_name = arg.name; - if (!converted_vars->count(var_name)) { - converted_vars->insert( - std::make_pair(var_name, CvtVarNode(var_node, stmt.op()->scope()))); - } - node_inputs.insert(*converted_vars->find(var_name)); - } - auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs); - converted_vars->insert(node_outputs.begin(), node_outputs.end()); - } -} - -std::string GenerateNPUProgramPass::BuildNPUGraph( - const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id) { - auto ordered_nodes = GetTopologicalOrder(op_nodes); - lite::kernels::npu::bridges::node_map_type converted_vars; - CvtAllOpNodes(ordered_nodes, &converted_vars); - - std::vector in_var_names; - std::vector out_var_names; - std::vector inputs; - std::vector outputs; - for (auto i : in_data_vars) { - auto argname = i->AsArg().name; - in_var_names.push_back(argname); - inputs.push_back(*converted_vars.at(argname)); - } - for (auto i : out_data_vars) { - auto argname = i->AsArg().name; - out_var_names.push_back(argname); - outputs.push_back(*converted_vars.at(argname)); - } - - std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights"; - auto any_op = (*op_nodes.begin())->AsStmt().op(); - auto weight = any_op->scope()->Var(weight_var_name)->GetMutable(); - weight->set_persistable(true); - weight->set_precision(PRECISION(kInt8)); - // Compiling IR graph to NPU model and store mode data into weight tensor with - // persistable=true, Sothat the model parser can recognize it and save it to - // param files - if (!lite::npu::BuildModel(inputs, outputs, weight)) { - LOG(FATAL) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")"; - } else { - LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")"; - } - return weight_var_name; -} - -void GenerateNPUProgramPass::GenNPUSubgraph( - const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id) { - std::unordered_set in_data_vars; - std::unordered_set in_wgt_vars; - std::unordered_set out_data_vars; - std::unordered_set out_unused_vars; - FindInputOutputVars( - op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); - - auto weight_var_name = - BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); - - auto any_op = (*op_nodes.begin())->AsStmt().op(); - InsertNewNode(graph, - weight_var_name, - any_op->scope(), - any_op->valid_places(), - in_data_vars, - in_wgt_vars, - out_data_vars, - out_unused_vars); - - auto nodes2rm = GetNode2rm( - op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars}); - - GraphSafeRemoveNodes(graph.get(), nodes2rm); -} - -void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { - LOG(INFO) << "[NPU] Before NPU Pass \n" << Visualize(graph.get()); - const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); - const auto& op_map = bridges.AllFunctions(); - std::vector supported_op_types; - for (auto& i : op_map) { - LOG(INFO) << "[NPU] Supported type: " << i.first; - supported_op_types.push_back(i.first); - } - - int num_subgraph = FuseSubgraph(graph, supported_op_types); - InferOnce(graph); - auto op_nodes_all = ClassifySubgraph(graph); - CHECK_EQ(op_nodes_all.size(), num_subgraph); - int id = 1; - for (auto& op_nodes : op_nodes_all) { - LOG(INFO) << "[NPU] Converting Subgraph " << id; - GenNPUSubgraph(graph, op_nodes.second, id); - LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n" - << Visualize(graph.get()); - id++; - } -} -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle - -REGISTER_MIR_PASS(generate_npu_program_pass, - paddle::lite::mir::subgraph::GenerateNPUProgramPass) - .BindTargets({TARGET(kNPU)}); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h deleted file mode 100644 index 5b1a98c6ed0e10f4fae8832b9ba3c5f98f3d9ed9..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_npu_program_pass.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "lite/backends/npu/builder.h" -#include "lite/core/mir/pass.h" -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include "lite/kernels/npu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -class GenerateNPUProgramPass : public SubgraphProgramPass { - public: - using key2nodes_t = std::map; - - void Apply(const std::unique_ptr& graph) override; - - protected: - // nodes2cvt: op nodes to convert - // return cvted_vars: converted var nodes - void CvtAllOpNodes(const std::vector& nodes2cvt, - lite::kernels::npu::bridges::node_map_type* cvted_vars); - - std::shared_ptr CvtVarNode(lite::mir::Node* var_node, - const Scope* scope); - - std::string BuildNPUGraph(const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id); - - void GenNPUSubgraph(const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id); -}; - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass.cc b/lite/core/mir/subgraph/generate_xpu_program_pass.cc deleted file mode 100644 index 4340cb4ee3cccad32db9bc333b5856386812c62a..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_xpu_program_pass.cc +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/core/mir/subgraph/generate_xpu_program_pass.h" -#include -#include -#include -#include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/pass_registry.h" -#include "lite/core/mir/pattern_matcher.h" - -#include "lite/backends/xpu/builder.h" -#include "lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h" -#include "lite/kernels/xpu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -std::shared_ptr GenerateXPUProgramPass::CvtVarNode( - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::mir::Node* var_node, - const Scope* scope) { - CHECK(var_node->IsArg()); - const auto& arg = var_node->AsArg(); - auto var_name = arg.name; - VLOG(4) << "[XPU] Convert var node " << var_name; - - auto* var = scope->FindVar(var_name); - CHECK(var); - auto* tensor = var->GetMutable(); - CHECK(tensor); - auto dims = tensor->dims(); - auto cvted_var_node = - std::make_shared(graph_ctx->builder->CreateTensor( - var_name, lite::xpu::CvtShape(dims), ::xtcl::Float(32))); - if (arg.is_weight) { - auto cvted_var_tensor = lite::xpu::CvtTensor(tensor); - graph_ctx->params->emplace(std::make_pair(var_name, *cvted_var_tensor)); - } - return cvted_var_node; -} - -void GenerateXPUProgramPass::CvtAllOpNodes( - const std::vector& op_nodes, - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes) { - const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); - const auto& supported_lists = bridges.AllFunctions(); - // return record all converted vars - // op node's inputs must be found in converted_vars - for (auto& node : op_nodes) { - lite::kernels::xpu::bridges::node_map_type input_nodes; - auto& stmt = node->AsStmt(); - for (auto& var_node : node->inlinks) { - auto& arg = var_node->AsArg(); - // weight should be handled in the converter, so skip here - if (arg.is_weight) { - continue; - } - auto var_name = arg.name; - if (!cvted_var_nodes->count(var_name)) { - cvted_var_nodes->insert(std::make_pair( - var_name, CvtVarNode(graph_ctx, var_node, stmt.op()->scope()))); - } - input_nodes.insert(*cvted_var_nodes->find(var_name)); - } - auto output_nodes = - supported_lists.at(stmt.op_type())(stmt.op(), graph_ctx, input_nodes); - cvted_var_nodes->insert(output_nodes.begin(), output_nodes.end()); - } -} - -std::string GenerateXPUProgramPass::BuildXPUGraph( - const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id) { - auto ordered_op_nodes = GetTopologicalOrder(op_nodes); - lite::kernels::xpu::bridges::graph_ctx_type graph_ctx; - graph_ctx.builder = std::make_shared(); - graph_ctx.params = - std::make_shared(); - lite::kernels::xpu::bridges::node_map_type cvted_var_nodes; - CvtAllOpNodes(ordered_op_nodes, &graph_ctx, &cvted_var_nodes); - - std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights"; - auto any_op = (*op_nodes.begin())->AsStmt().op(); - auto weight = any_op->scope()->Var(weight_var_name)->GetMutable(); - weight->set_persistable(true); - weight->set_precision(PRECISION(kInt8)); - // Compiling graph to XPU model and store mode data into weight tensor with - // persistable=true, Sothat the model parser can recognize it and save it to - // param files - std::vector> ordered_cvted_var_nodes; - for (auto out_data_var : out_data_vars) { - auto var_name = out_data_var->AsArg().name; - ordered_cvted_var_nodes.push_back(cvted_var_nodes[var_name]); - } - if (!lite::xpu::BuildModel(graph_ctx.builder, - graph_ctx.params, - &ordered_cvted_var_nodes, - weight)) { - LOG(FATAL) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")"; - } else { - LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")"; - } - return weight_var_name; -} - -void GenerateXPUProgramPass::GenXPUSubgraph( - const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id) { - std::unordered_set in_data_vars; - std::unordered_set in_wgt_vars; - std::unordered_set out_data_vars; - std::unordered_set out_unused_vars; - FindInputOutputVars( - op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); - - auto weight_var_name = - BuildXPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); - - auto any_op = (*op_nodes.begin())->AsStmt().op(); - InsertNewNode(graph, - weight_var_name, - any_op->scope(), - any_op->valid_places(), - in_data_vars, - in_wgt_vars, - out_data_vars, - out_unused_vars); - - auto nodes2rm = GetNode2rm( - op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars}); - - GraphSafeRemoveNodes(graph.get(), nodes2rm); -} - -void GenerateXPUProgramPass::Apply(const std::unique_ptr& graph) { - LOG(INFO) << "[XPU] Before XPU Pass \n" << Visualize(graph.get()); - const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); - const auto& op_map = bridges.AllFunctions(); - std::vector supported_op_types; - for (auto& i : op_map) { - LOG(INFO) << "[XPU] Supported type: " << i.first; - supported_op_types.push_back(i.first); - } - - int num_subgraph = FuseSubgraph(graph, supported_op_types); - InferOnce(graph); - auto op_nodes_all = ClassifySubgraph(graph); - CHECK_EQ(op_nodes_all.size(), num_subgraph); - int id = 1; - for (auto& op_nodes : op_nodes_all) { - LOG(INFO) << "[XPU] Converting Subgraph " << id; - GenXPUSubgraph(graph, op_nodes.second, id); - LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n" - << Visualize(graph.get()); - id++; - } -} -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle - -REGISTER_MIR_PASS(generate_xpu_program_pass, - paddle::lite::mir::subgraph::GenerateXPUProgramPass) - .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass.h b/lite/core/mir/subgraph/generate_xpu_program_pass.h deleted file mode 100644 index 777642cfb6c61671a8aeb119c70664297573d9a7..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_xpu_program_pass.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "lite/backends/xpu/builder.h" -#include "lite/core/mir/pass.h" -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include "lite/kernels/xpu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -class GenerateXPUProgramPass : public SubgraphProgramPass { - public: - using key2nodes_t = std::map; - - void Apply(const std::unique_ptr& graph) override; - - protected: - // nodes2cvt: op nodes to convert - // return cvted_vars: converted var nodes - void CvtAllOpNodes( - const std::vector& op_nodes, - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes); - - std::shared_ptr CvtVarNode( - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::mir::Node* var_node, - const Scope* scope); - - std::string BuildXPUGraph(const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id); - - void GenXPUSubgraph(const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id); -}; - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc b/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc deleted file mode 100644 index 728ecbc6b77666accd432b1ad82a03860588ab40..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "lite/api/paddle_api.h" -#include "lite/api/paddle_use_kernels.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/api/paddle_use_passes.h" -#include "lite/api/test_helper.h" -#include "lite/utils/cp_logging.h" - -DEFINE_string(model_file, "", "model file path of combined protobuf model"); -DEFINE_string(params_file, "", "params file path of combined protobuf model"); -DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); -DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors"); -DEFINE_int32(output_tensor_num, 1, "number of output tensors"); - -namespace paddle { -namespace lite { - -std::vector> ParseShape(std::string txt) { - std::vector> shape; - while (!txt.empty()) { - size_t idx = txt.find_first_of(":"); - std::string dims = txt.substr(0, idx); - std::vector s; - while (!dims.empty()) { - size_t idx = dims.find_first_of(","); - int d = atoi(dims.substr(0, idx).c_str()); - VLOG(3) << d; - s.push_back(d); - if (idx == std::string::npos) { - break; - } else { - dims = dims.substr(idx + 1); - } - } - shape.push_back(s); - if (idx == std::string::npos) { - break; - } else { - txt = txt.substr(idx + 1); - } - } - return shape; -} - -int64_t ShapeProduction(std::vector shape) { - int64_t s = 1; - for (int64_t dim : shape) { - s *= dim; - } - return s; -} - -void FillInputTensor( - const std::shared_ptr& predictor, - const std::vector>& input_tensor_shape, - const float value) { - for (int i = 0; i < input_tensor_shape.size(); i++) { - auto input_tensor = predictor->GetInput(i); - input_tensor->Resize(input_tensor_shape[i]); - auto input_tensor_data = input_tensor->mutable_data(); - auto input_tensor_size = ShapeProduction(input_tensor->shape()); - for (int j = 0; j < input_tensor_size; j++) { - input_tensor_data[j] = value; - } - } -} - -void CompareOutputTensor( - const std::shared_ptr& tar_predictor, - const std::shared_ptr& ref_predictor, - const int output_tensor_num) { - for (int i = 0; i < output_tensor_num; i++) { - auto tar_output_tensor = tar_predictor->GetOutput(i); - auto ref_output_tensor = ref_predictor->GetOutput(i); - auto tar_output_tensor_data = tar_output_tensor->data(); - auto ref_output_tensor_data = ref_output_tensor->data(); - auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape()); - auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape()); - EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size); - for (size_t j = 0; j < ref_output_tensor_size; j++) { - auto diff = - std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) / - (std::fabs(ref_output_tensor_data[j]) + 1e-6); - VLOG(3) << diff; - EXPECT_LT(diff, 0.1); - } - } -} - -std::shared_ptr TestModel( - const std::string& model_dir, - const std::string& model_file, - const std::string& params_file, - const std::vector& valid_places, - const std::vector>& input_tensor_shape, - const std::string& optimized_model_dir) { - // generate optimized model - lite_api::CxxConfig cxx_config; - cxx_config.set_model_dir(model_dir); - cxx_config.set_model_file(model_file); - cxx_config.set_param_file(params_file); - cxx_config.set_valid_places(valid_places); - auto predictor = lite_api::CreatePaddlePredictor(cxx_config); - FillInputTensor(predictor, input_tensor_shape, -1); - predictor->SaveOptimizedModel(optimized_model_dir, - lite_api::LiteModelType::kNaiveBuffer); -#if 0 // TODO(hong19860320) supports light api for XPU - // load optimized model - lite_api::MobileConfig mobile_config; - mobile_config.set_model_dir(optimized_model_dir); - mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); - mobile_config.set_threads(1); - predictor = lite_api::CreatePaddlePredictor(mobile_config); - FillInputTensor(predictor, input_tensor_shape, 1); -#endif - // run optimized model - for (int i = 0; i < FLAGS_warmup; i++) { - predictor->Run(); - } - for (int i = 0; i < FLAGS_repeats; i++) { - auto start = GetCurrentUS(); - predictor->Run(); - LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; - } - return predictor; -} - -TEST(XPUSubgraph, compare) { - // parsing input tensor shape, supported formats: "1,3,224,224" - // "1,3,224,224:1,80" - std::vector> input_tensor_shape = - ParseShape(FLAGS_input_tensor_shape); - // generate and run optimized CPU model - LOG(INFO) << " ================ CPU ================== "; - auto cpu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kX86), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/CPU"); - // generate and run optimized XPU model - LOG(INFO) << " ================ XPU ================== "; - auto xpu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, - lite_api::Place{TARGET(kX86), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/XPU"); - // verify results - CompareOutputTensor(xpu_predictor, cpu_predictor, FLAGS_output_tensor_num); -} - -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d48b053a1a4140252d35e85d2351644d3c216e9 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -0,0 +1,551 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/subgraph_detector.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/dot.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher.h" +#include "lite/operators/subgraph_op.h" + +namespace paddle { +namespace lite { +namespace mir { + +using inference::analysis::Dot; + +std::string SubgraphVisualizer::operator()() { + inference::analysis::Dot dot; + const std::vector subgraph_colors{ + "red", "green", "cyan", "bisque3", + "coral", "darkseagreen1", "goldenrod1", "darkorchid", + "antiquewhite", "aquamarine", "azure", "bisque4", + "blue2", "brown1", "burlywood1", "cadetblue1", + "chartreuse1", "chocolate1", "coral1", "cornsilk", + "crimson", "cyan4", "darkgoldenrod4", "darkolivegreen2", + "darkorange2", "darkorchid2", "darkseagreen3", "darkslategray", + "deeppink2", "deepskyblue2", "dodgerblue", "firebrick", + "floralwhite", "gold1", "skyblue3", "indianred", + "indigo", "lavenderblush2", "lightblue1", "lightsalmon3", + "khaki1", "ivory4", "sandybrown", "olivedrab2", + "turquoise4", "snow3", "sienna4", "salmon2", + }; + std::unordered_map subgraph_indices; + for (int i = 0; i < subgraphs_.size(); i++) { + for (int j = 0; j < subgraphs_[i].size(); j++) { + subgraph_indices[subgraphs_[i][j]] = i; + } + } + std::unordered_map exists_ops; + std::set exists_args; + for (auto &node : graph_->StmtTopologicalOrder()) { + if (!node->IsStmt()) { + continue; + } + auto op_type = node->AsStmt().op_type(); + if (!exists_ops.count(op_type)) { + exists_ops[op_type] = 0; + } else { + exists_ops[op_type]++; + } + auto op_name = op_type + std::to_string(exists_ops[op_type]); + std::string op_color = "white"; + if (subgraph_indices.count(node)) { + auto subgraph_idx = subgraph_indices[node]; + op_name += "_subgraph_" + std::to_string(subgraph_idx); + op_color = subgraph_colors[subgraph_idx % subgraph_colors.size()]; + } + dot.AddNode(op_name, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", op_color)}); + for (auto &in_node : node->inlinks) { + auto arg_name = in_node->AsArg().name; + if (!exists_args.count(arg_name)) { + dot.AddNode(arg_name, {}); + exists_args.insert(arg_name); + } + dot.AddEdge(arg_name, op_name, {}); + } + for (auto &out_node : node->outlinks) { + auto arg_name = out_node->AsArg().name; + if (!exists_args.count(arg_name)) { + dot.AddNode(arg_name, {}); + exists_args.insert(arg_name); + } + dot.AddEdge(op_name, arg_name, {}); + } + } + + auto res = dot.Build(); + std::cout << "subgraphs: " << subgraphs_.size() << "\n" << res << std::endl; + return res; +} + +// Find the ancestor node +SubgraphDetector::node_dat_t * +SubgraphDetector::node_dat_t::UnionFindAncestor() { + node_dat_t *ancestor = this; + while (ancestor->union_find_parent != ancestor) { + ancestor = ancestor->union_find_parent; + } + return ancestor; +} + +// Merge the two adjacent nodes into one node. +// Suppose we have two adjacent nodes src and dst. +// We will perform the following operations: +// 1. add all inputs(except src) of dst to src inlinks. +// 2. add all outputs of dst to src outlinks. +// 3. change all the dst's inputs and outputs +// corresponding inlinks and outlinks to src node. +// 4. delete all dst's inlinks and outlinks. +void SubgraphDetector::node_dat_t::UnionFindCombine(node_dat_t *candidate) { + // Make this two node share the same ancestor. + union_find_parent = UnionFindAncestor(); + node_dat_t *candidate_ancestor = candidate->UnionFindAncestor(); + candidate_ancestor->union_find_parent = union_find_parent; + candidate->union_find_parent = union_find_parent; + + // Obtain the input and output nodes for the combined one + std::unordered_set inputs(inlinks.begin(), inlinks.end()); + std::unordered_set outputs(candidate->outlinks.begin(), + candidate->outlinks.end()); + for (auto *out_node : outlinks) { + if (out_node != candidate) { + outputs.insert(out_node); + } + } + for (auto *in_node : candidate->inlinks) { + if (in_node != this) { + inputs.insert(in_node); + } + } + +// Update the dst and src node's inlinks and outlinks. +#ifdef __clang__ + inlinks = node_set_t(inputs.begin(), inputs.end()); + outlinks = node_set_t(outputs.begin(), outputs.end()); + candidate->inlinks.clear(); + candidate->outlinks.clear(); +#else + inlinks = std::move(node_set_t(inputs.begin(), inputs.end())); + outlinks = std::move(node_set_t(outputs.begin(), outputs.end())); + candidate->inlinks.clear(); + candidate->outlinks.clear(); +#endif + + // Change all the dst inputs and outputs corresponding inlink and + // outlink to the src node. + for (auto *in_node : inlinks) { + for (auto *&out_node : in_node->outlinks) { + if (out_node == candidate) { + out_node = this; + } + } + } + for (auto *out_node : outlinks) { + for (auto *&in_node : out_node->inlinks) { + if (in_node == candidate) { + in_node = this; + } + } + } +} + +// FlexibleDFS +// If reverse is true, do reverse dfs. +// If enter func is not nullptr, calls enter(node) before visiting any children +// of node. +// If leave func not nullptr, calls leave(node) after visiting all parents of +// node. +void SubgraphDetector::FlexibleDFS( + const node_set_t &source, + bool reverse, + const std::function &enter, + const std::function &leave) { + std::vector> stack; // node, leave + for (auto &node : source) { + stack.push_back(std::pair(node, false)); + } + std::unordered_set visited; + while (!stack.empty()) { + auto top = stack.back(); + stack.pop_back(); + + if (top.second) { + if (leave && !leave(top.first)) return; + } + if (visited.count(top.first)) continue; + visited.insert(top.first); + + if (enter && !enter(top.first)) return; + + if (leave) + stack.push_back(std::pair(top.first, true)); + const node_set_t iter_nodes = + reverse == true ? top.first->inlinks : top.first->outlinks; + for (auto *node : iter_nodes) { + if (!visited.count(node)) { + stack.push_back(std::pair(node, false)); + } + } + } +} + +void SubgraphDetector::InitNodes(node_map_t *nodes) { + // Initialize and mark the subgraph detector nodes based on teller. + for (auto &it : *nodes) { + for (auto &in_node : it.first->inlinks) { + it.second->inlinks.push_back((*nodes)[in_node]); + } + for (auto &out_node : it.first->outlinks) { + it.second->outlinks.push_back((*nodes)[out_node]); + } + if (teller_(it.first)) { + it.second->marked = true; + if (it.first->IsStmt()) { + // If a function is inside the subgraph, mark all the output variables + // to be inside too, so that two marked functions will be inside a same + // subgraph, lets take a example: A_function->var->B_function, if + // A_function is marked, var should also be marked, so that B_function + // will be in the same subgraph with A_function if B_function is + // marked. + for (auto &out_node : it.first->outlinks) { + (*nodes)[out_node]->marked = true; + } + } + } + } +} // namespace mir + +std::vector> SubgraphDetector::ExtractSubgraphs( + node_map_t *nodes) { + for (auto &it : *nodes) { + node_dat_t *node = it.second; + if (!node->marked) { + continue; + } + // Our algorithm must guarantee that: + // 1. The graph is always directed acyclic graph(DAG). + // 2. If there is a path in the subgraph from X to Y (X and Y are both + // nodes in the subgraph), then all paths from X to Y are in the + // subgraph. + // + // In order to achieve the above guarantee. + // For adjacent nodes src -> dst. + // 1. Get all dst input nodes except src. + // 2. Reverse DFS from those input nodes + // 3. If there is a path from input nodes to src, + // then the src and dst nodes can not be fused into one node, + // otherwise it can be done. + while (true) { + std::unordered_set contract_nodes; + for (auto *out_node : node->outlinks) { + // must be an candidate + if (!out_node->marked) continue; + // get all dst input nodes except src node. + node_set_t source_nodes; + for (auto *in_node : out_node->inlinks) { + if (in_node != node) { + source_nodes.push_back(in_node); + } + } + + // Reverse DFS from the source_nodes. + bool have_excess_path = false; + FlexibleDFS(source_nodes, + true, + nullptr, + [&have_excess_path, node](const node_dat_t *n) { + if (n == node) { + have_excess_path = true; + return false; + } + return true; + }); + if (have_excess_path) continue; + contract_nodes.insert(out_node); + } + if (contract_nodes.empty()) break; + + for (auto &contract_node : contract_nodes) { + node->UnionFindCombine(contract_node); + } + } + } + + std::unordered_map> clusters; + for (auto &node : graph_->StmtTopologicalOrder()) { + if (!node->IsStmt()) continue; + if ((*nodes)[node]->marked) { + clusters[(*nodes)[node]->UnionFindAncestor()].push_back(node); + } + } + std::vector> subgraphs; + std::for_each(clusters.begin(), + clusters.end(), + [&](const decltype(clusters)::value_type &it) { + subgraphs.push_back(it.second); + }); + return subgraphs; +} + +std::vector> SubgraphDetector::operator()() { + node_map_t nodes; + for (auto &node : graph_->mutable_nodes()) { + nodes[&node] = new node_dat_t(&node); + CHECK(nodes[&node]); + } + // Initialize and mark the subgraph detector nodes based on teller. + InitNodes(&nodes); + // Run the Extract algorithm to find all subgraphs. + std::vector> subgraphs = ExtractSubgraphs(&nodes); + for (auto &it : nodes) { + CHECK(it.second); + delete it.second; + } + return subgraphs; +} + +void SubgraphFuser::InsertNewNode(SSAGraph *graph, + int subgraph_idx, + const std::vector &subgraph_nodes) { + // Create and attach a new subgraph op + cpp::OpDesc subgraph_op_desc; + subgraph_op_desc.SetType("subgraph"); + + // Create a new sub block desc for storing all of Ops an Vars of the target + // subgraph and sub_block_idx is set as a attribute of subgraph op, + // sub_block_idx < 0 means it's a new subgraph op + int sub_block_idx = -(subgraph_idx + 1); + auto sub_block_desc = new cpp::BlockDesc(); + sub_block_desc->ClearOps(); + sub_block_desc->ClearVars(); + for (auto &op_node : subgraph_nodes) { + auto sub_block_op_desc = sub_block_desc->AddOp(); + *sub_block_op_desc = *op_node->AsStmt().op_info(); + sub_block_op_desc->SetAttr( + kKernelTypeAttr, + op_node->AsStmt().picked_kernel().SerializedKernelType()); + } + subgraph_op_desc.SetAttr("sub_block", sub_block_idx); + + // Extract input and output nodes from the target subgraph + std::unordered_set input_var_nodes; + std::unordered_set weight_var_nodes; + std::unordered_set output_var_nodes; + std::unordered_set local_var_nodes; + std::unordered_set unused_var_nodes; + ExtractInputsOutputs(subgraph_nodes, + &input_var_nodes, + &weight_var_nodes, + &output_var_nodes, + &local_var_nodes, + &unused_var_nodes); + + // Set input and output name mapping which stores the real inputs and + // outputs + std::vector input_var_names; + std::vector output_var_names; + for (auto &var_node : input_var_nodes) { + input_var_names.push_back(var_node->AsArg().name); + } + for (auto &var_node : output_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } + subgraph_op_desc.SetAttr>("input_data_names", + input_var_names); + subgraph_op_desc.SetAttr>("output_data_names", + output_var_names); + + // Set all of the inputs and outputs to the target subgraph op + // To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram() + for (auto &var_node : weight_var_nodes) { + input_var_names.push_back(var_node->AsArg().name); + } + for (auto &var_node : local_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } + for (auto &var_node : unused_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } + subgraph_op_desc.SetInput("Inputs", input_var_names); + subgraph_op_desc.SetOutput("Outputs", output_var_names); + auto subgraph_op = LiteOpRegistry::Global().Create("subgraph"); + static_cast(subgraph_op.get()) + ->SetSubBlock(sub_block_desc); + auto any_op = (*subgraph_nodes.begin())->AsStmt().op(); + subgraph_op->Attach(subgraph_op_desc, any_op->scope()); + + // Create and add a new subgraph node into the graph + auto subgraph_op_node = + graph->GraphCreateInstructNode(subgraph_op, any_op->valid_places()); + for (auto &var_node : input_var_nodes) { + IR_NODE_LINK_TO(var_node, subgraph_op_node); + } + for (auto &var_node : weight_var_nodes) { + IR_NODE_LINK_TO(var_node, subgraph_op_node); + } + for (auto &var_node : output_var_nodes) { + IR_OP_VAR_LINK(subgraph_op_node, var_node); + } + for (auto &var_node : local_var_nodes) { + IR_OP_VAR_LINK(subgraph_op_node, var_node); + } + for (auto &var_node : unused_var_nodes) { + IR_OP_VAR_LINK(subgraph_op_node, var_node); + } + + // Create and assign the context to the picked kernel of the new subgraph + // node + auto &inst = subgraph_op_node->AsStmt(); + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); + + // Remove subgraph nodes and unused var nodes + auto nodes2rm = GetNodes2RM(subgraph_nodes, + {input_var_nodes, + weight_var_nodes, + output_var_nodes, + local_var_nodes, + unused_var_nodes}); + GraphSafeRemoveNodes(graph, nodes2rm); +} + +void SubgraphFuser::ReplaceNodesWithSubgraphs(SSAGraph *graph, + const SubgraphTeller &teller, + int min_subgraph_size) { + std::vector> subgraphs = + SubgraphDetector(graph, teller)(); + SubgraphVisualizer(graph, subgraphs)(); + for (int subgraph_idx = 0; subgraph_idx < subgraphs.size(); subgraph_idx++) { + if (subgraphs[subgraph_idx].size() >= min_subgraph_size) { + InsertNewNode(graph, subgraph_idx, subgraphs[subgraph_idx]); + } + } +} + +void SubgraphFuser::operator()() { + ReplaceNodesWithSubgraphs(graph_, teller_, min_subgraph_size_); +} + +void ExtractInputsOutputs(const std::vector &op_nodes, + std::unordered_set *input_var_nodes, + std::unordered_set *weight_var_nodes, + std::unordered_set *output_var_nodes, + std::unordered_set *local_var_nodes, + std::unordered_set *unused_var_nodes) { + for (auto &op_node : op_nodes) { + for (auto &var_node : op_node->inlinks) { + if (var_node->AsArg().is_weight) { + weight_var_nodes->insert(var_node); + continue; + } + if (!var_node->inlinks.empty()) { + // Var can only come from one op node, so use front + auto *prev_op_node = var_node->inlinks.front(); + if (std::find(op_nodes.begin(), op_nodes.end(), prev_op_node) != + op_nodes.end()) { + continue; + } + } + input_var_nodes->insert(var_node); + } + for (auto &var_node : op_node->outlinks) { + if (var_node->outlinks.empty()) { + // The next op is empty so this var is actually unused + unused_var_nodes->insert(var_node); + continue; + } + // Var can have more than one next op node, So, if any one in the + // op_nodes then continue + bool next_op_in_nodes = false; + for (auto &next_op_node : var_node->outlinks) { + if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) != + op_nodes.end()) { + next_op_in_nodes = true; + } + } + if (next_op_in_nodes) { + local_var_nodes->insert(var_node); + continue; + } + output_var_nodes->insert(var_node); + } + } +} + +std::unordered_set GetNodes2RM( + const std::vector &op_nodes, + const std::vector> &excluded_var_nodes) { + std::unordered_set nodes2rm(op_nodes.begin(), op_nodes.end()); + for (auto &op_node : op_nodes) { + for (auto &var_node : op_node->inlinks) { + if (!nodes2rm.count(var_node)) { + nodes2rm.insert(var_node); + } + } + for (auto &var_node : op_node->outlinks) { + if (!nodes2rm.count(var_node)) { + nodes2rm.insert(var_node); + } + } + } + // Excluded nodes should not be removed + for (auto &excluded_var_node : excluded_var_nodes) { + for (auto &var_node : excluded_var_node) { + if (nodes2rm.count(var_node)) { + nodes2rm.erase(var_node); + } + } + } + return nodes2rm; +} + +static void SortHelper(Node *node, + const std::unordered_set &unordered_nodes, + std::unordered_set *visited_nodes, + std::vector *ordered_nodes) { + for (auto &var_node : node->inlinks) { + if (var_node->inlinks.empty()) continue; + auto *op_node = var_node->inlinks.front(); + if (unordered_nodes.count(op_node) && !visited_nodes->count(op_node)) { + SortHelper(op_node, unordered_nodes, visited_nodes, ordered_nodes); + } + } + ordered_nodes->push_back(node); + visited_nodes->insert(node); +} + +std::vector GetTopologicalOrder( + const std::unordered_set &unordered_nodes) { + std::unordered_set visited_nodes; + std::vector ordered_nodes; + for (auto &node : unordered_nodes) { + if (!node->IsStmt()) continue; + if (visited_nodes.count(node)) continue; + SortHelper(node, unordered_nodes, &visited_nodes, &ordered_nodes); + } + return ordered_nodes; +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_detector.h b/lite/core/mir/subgraph/subgraph_detector.h new file mode 100644 index 0000000000000000000000000000000000000000..b6873655e976a785383269972221f001196431f8 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_detector.h @@ -0,0 +1,127 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +using SubgraphTeller = std::function; + +class SubgraphVisualizer { + public: + SubgraphVisualizer(SSAGraph* graph, + const std::vector>& subgraphs) + : graph_(graph), subgraphs_(subgraphs) {} + std::string operator()(); + + protected: + SSAGraph* graph_{nullptr}; + std::vector> subgraphs_; +}; + +/* + * Divide the graph into subgraphs according to the specified conditions. + * Return the divided clusters, a cluster is consisted of the op nodes in the + * subgraph. + */ +class SubgraphDetector { + public: + // This is a simple representation of a graph. The SDNode hold the + // pointer of the Node. This is to avoid changing the original graph in the + // process of graph analysis. + struct node_dat_t; + using node_map_t = std::unordered_map; + using node_set_t = std::vector; + struct node_dat_t { + explicit node_dat_t(Node* _node) : node(_node) {} + Node* node; + bool marked{false}; + node_dat_t* union_find_parent{this}; + node_set_t inlinks{}; + node_set_t outlinks{}; + node_dat_t* UnionFindAncestor(); + void UnionFindCombine(node_dat_t* candidate); + }; + SubgraphDetector(SSAGraph* graph, const SubgraphTeller& teller) + : graph_(graph), teller_(teller) {} + std::vector> operator()(); + + void FlexibleDFS(const node_set_t& source, + bool reverse, + const std::function& enter, + const std::function& leave); + void InitNodes(node_map_t* nodes); + std::vector> ExtractSubgraphs(node_map_t* nodes); + + protected: + SSAGraph* graph_{nullptr}; + SubgraphTeller teller_; +}; + +/* + * Replace all of subgraphs with the subgraph ops, a block desc is added into + * the subgraph op to wrap the original op nodes, keep all of var nodes of the + * original ops nodes as the inputs and outputs of the subgraph op + */ +class SubgraphFuser { + public: + SubgraphFuser(SSAGraph* graph, + const SubgraphTeller& teller, + int min_subgraph_size) + : graph_(graph), teller_(teller), min_subgraph_size_{min_subgraph_size} {} + void operator()(); + + // Remove the op nodes of the subgraphs and replace with the subgraph ops. + void ReplaceNodesWithSubgraphs(SSAGraph* graph, + const SubgraphTeller& teller, + int min_subgraph_size); + // Create a subgraph node with a block desc to wrap the original op nodes of + // the subgraph + void InsertNewNode(SSAGraph* graph, + int subgraph_idx, + const std::vector& subgraph_nodes); + + protected: + SSAGraph* graph_{nullptr}; + SubgraphTeller teller_; + int min_subgraph_size_; +}; + +void ExtractInputsOutputs(const std::vector& op_nodes, + std::unordered_set* input_var_nodes, + std::unordered_set* weight_var_nodes, + std::unordered_set* output_var_nodes, + std::unordered_set* local_var_nodes, + std::unordered_set* unused_var_nodes); + +std::unordered_set GetNodes2RM( + const std::vector& op_nodes, + const std::vector>& excluded_var_nodes); + +std::vector GetTopologicalOrder( + const std::unordered_set& unordered_nodes); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_detector_test.cc similarity index 65% rename from lite/core/mir/subgraph/subgraph_program_pass_test.cc rename to lite/core/mir/subgraph/subgraph_detector_test.cc index 22e20b81d831ff25df090a7565e671b9139122f7..3b0d7c5cd5c8a0d0901750148359f430b6d49894 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_detector_test.cc @@ -12,68 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include "lite/core/mir/subgraph/subgraph_detector.h" #include #include #include #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" -#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/ssa_graph.h" #include "lite/core/program.h" #include "lite/model_parser/cpp/program_desc.h" #include "lite/model_parser/model_parser.h" DEFINE_string(model_dir, "", "model_dir"); +DEFINE_string(model_file, "", "model file path of combined protobuf model"); +DEFINE_string(params_file, "", "params file path of combined protobuf model"); namespace paddle { namespace lite { -TEST(SubgraphTest, models) { - cpp::ProgramDesc program_desc; - auto scope = std::make_shared(); - // LoadModelPb(FLAGS_model_dir, - // FLAGS_model_dir + "/model", - // FLAGS_model_dir + "/params", - // scope.get(), - // &program_desc, - // true); - LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc); - std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, -#ifdef LITE_WITH_ARM - Place{TARGET(kARM), PRECISION(kFloat)}, -#endif -#ifdef LITE_WITH_NPU - Place{TARGET(kNPU), PRECISION(kFloat)}, -#endif -#ifdef LITE_WITH_XPU - Place{TARGET(kXPU), PRECISION(kFloat)}, -#endif - }); - lite::Program program(program_desc, scope, valid_places); - auto graph = std::unique_ptr(new mir::SSAGraph()); - graph->Build(program, valid_places); - - std::vector supported_op_types{"concat", - "conv2d", - "depthwise_conv2d", - "batch_norm", - "scale", - "pool2d", - "mul", - "elementwise_add", - "softmax", - "split", - "relu", - "reshape2", - "transpose2"}; - auto* pass = new mir::subgraph::SubgraphProgramPass; - ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); - LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); -} - -// return output_var_names +// The helper functions for building model manually std::vector AddFCDesc( cpp::BlockDesc* block_desc, const std::shared_ptr& scope, @@ -84,24 +41,23 @@ std::vector AddFCDesc( static int id = 0; std::string prefix = "fc_" + std::to_string(id); auto* op_desc = block_desc->AddOp(); - auto* wgt = block_desc->AddVar(); - auto* bias = block_desc->AddVar(); - auto* out = block_desc->AddVar(); + auto* wgt = block_desc->AddVar(); wgt->SetName(prefix + "_W"); - bias->SetName(prefix + "_Bias"); - out->SetName(prefix + "_Out"); - std::vector out_var_names{prefix + "_Out"}; - - auto* wtensor = scope->Var(prefix + "_W")->GetMutable(); + auto* wtensor = scope->Var(prefix + "_W")->GetMutable(); wtensor->Resize(wshape); wtensor->mutable_data(); - auto* btensor = scope->Var(prefix + "_Bias")->GetMutable(); + auto* bias = block_desc->AddVar(); + bias->SetName(prefix + "_Bias"); + auto* btensor = scope->Var(prefix + "_Bias")->GetMutable(); btensor->Resize({wshape[1]}); btensor->mutable_data(); - scope->Var(prefix + "_Out")->GetMutable(); + auto* out = block_desc->AddVar(); + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("fc"); op_desc->SetInput("Input", input_var_names); @@ -127,7 +83,7 @@ std::vector AddElementwiseAddDesc( out->SetName(prefix + "_Out"); std::vector out_var_names{prefix + "_Out"}; - scope->Var(prefix + "_Out")->GetMutable(); + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("elementwise_add"); op_desc->SetInput("X", input_X_names); @@ -151,7 +107,7 @@ std::vector AddFeedDesc( out->SetName(prefix + "_Out"); std::vector out_var_names{prefix + "_Out"}; - scope->Var(prefix + "_Out")->GetMutable(); + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("feed"); op_desc->SetInput("X", input_X_names); @@ -174,7 +130,7 @@ std::vector AddFetchDesc( out->SetName(prefix + "_Out"); std::vector out_var_names{prefix + "_Out"}; - scope->Var(prefix + "_Out")->GetMutable(); + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("fetch"); op_desc->SetInput("X", input_X_names); @@ -184,41 +140,88 @@ std::vector AddFetchDesc( return out_var_names; } -std::unique_ptr BuildSimpleNet( - cpp::ProgramDesc* program_desc, - const std::shared_ptr& scope, - const std::vector& valid_places) { - program_desc->ClearBlocks(); - auto* block_desc = program_desc->AddBlock(); +TEST(Subgraph, detect_simple_model) { + cpp::ProgramDesc program_desc; + std::vector valid_places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + // Build a simple network + program_desc.ClearBlocks(); + auto* block_desc = program_desc.AddBlock(); block_desc->ClearOps(); block_desc->ClearVars(); - auto* var_desc = block_desc->AddVar(); var_desc->SetName("feed_var"); - auto* feed_var = scope->Var("feed_var")->GetMutable(); + auto* feed_var = scope->Var("feed_var")->GetMutable(); feed_var->Resize({1, 4}); auto fc1_out = AddFCDesc(block_desc, scope, {"feed_var"}, {4, 5}); auto fc2_out = AddFCDesc(block_desc, scope, fc1_out, {5, 2}); - - lite::Program program(*program_desc, scope, valid_places); + Program program(program_desc, scope, valid_places); auto graph = std::unique_ptr(new mir::SSAGraph()); graph->Build(program, valid_places); - - return graph; + // Apply subgraph detector and check results + auto teller = [](mir::Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + auto op_type = stmt.op_type(); + const std::vector supported_types = {"fc"}; + return std::find(supported_types.begin(), supported_types.end(), op_type) != + supported_types.end(); + }; + std::vector> subgraphs = + mir::SubgraphDetector(graph.get(), teller)(); + ASSERT_EQ(subgraphs.size(), 1); + ASSERT_EQ(graph->nodes().size(), 9); + mir::SubgraphVisualizer(graph.get(), subgraphs)(); } -TEST(SubGraphTest, SimpleNet) { +TEST(Subgraph, detect_custom_model) { + if (FLAGS_model_dir.empty() && FLAGS_model_file.empty() && + FLAGS_params_file.empty()) { + LOG(INFO) << "Using --model_dir, or --model_file and --params_file to set " + "the path of model files."; + return; + } cpp::ProgramDesc program_desc; - std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); - auto graph = BuildSimpleNet(&program_desc, scope, places); - - std::vector supported_op_types{"fc"}; - auto* pass = new mir::subgraph::SubgraphProgramPass; - ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); - - ASSERT_EQ(graph->nodes().size(), 9); - // LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); + LoadModelPb(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + scope.get(), + &program_desc, + !FLAGS_model_file.empty() && !FLAGS_params_file.empty(), + false); + std::vector valid_places({ +#ifdef LITE_WITH_ARM + Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_X86 + Place{TARGET(kX86), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_NPU + Place{TARGET(kNPU), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_XPU + Place{TARGET(kXPU), PRECISION(kFloat)}, +#endif + }); + Program program(program_desc, scope, valid_places); + auto graph = std::unique_ptr(new mir::SSAGraph()); + graph->Build(program, valid_places); + // Apply subgraph detector and check results + auto teller = [](mir::Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + auto op_type = stmt.op_type(); + const std::vector unsupported_types = { + "feed", "fetch", "subgraph"}; + return std::find(unsupported_types.begin(), + unsupported_types.end(), + op_type) == unsupported_types.end(); + }; + std::vector> subgraphs = + mir::SubgraphDetector(graph.get(), teller)(); + ASSERT_EQ(subgraphs.size(), 1); + mir::SubgraphVisualizer(graph.get(), subgraphs)(); } } // namespace lite diff --git a/lite/core/mir/subgraph/subgraph_pass.cc b/lite/core/mir/subgraph/subgraph_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..b974ac7043e2fc1c656c4bad69e7ca50fffaff8c --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_pass.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/subgraph_pass.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/subgraph/subgraph_detector.h" + +namespace paddle { +namespace lite { +namespace mir { + +void NPUSubgraphPass::Apply(const std::unique_ptr& graph) { + std::unordered_set supported_lists; +#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type); +#include "lite/kernels/npu/bridges/paddle_use_bridges.h" +#undef USE_SUBGRAPH_BRIDGE + auto teller = [&](Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + return supported_lists.count(stmt.op_type()) != 0; + }; + SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */); + fuser(); +} + +void XPUSubgraphPass::Apply(const std::unique_ptr& graph) { + std::unordered_set supported_lists; +#define USE_SUBGRAPH_BRIDGE(dev_type, op_type) supported_lists.insert(#op_type); +#include "lite/kernels/xpu/bridges/paddle_use_bridges.h" +#undef USE_SUBGRAPH_BRIDGE + auto teller = [&](Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + return supported_lists.count(stmt.op_type()) != 0; + }; + SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */); + fuser(); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(npu_subgraph_pass, paddle::lite::mir::NPUSubgraphPass) + .BindTargets({TARGET(kNPU)}); +REGISTER_MIR_PASS(xpu_subgraph_pass, paddle::lite::mir::XPUSubgraphPass) + .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/subgraph/subgraph_pass.h b/lite/core/mir/subgraph/subgraph_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..554f54304afcd2eac3069c101f2e19ff9391fa66 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class NPUSubgraphPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +class XPUSubgraphPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_pass_test.cc similarity index 68% rename from lite/core/mir/subgraph/generate_npu_program_pass_test.cc rename to lite/core/mir/subgraph/subgraph_pass_test.cc index 1afb54c692592ca42d8b120dcf1a91922e19149c..45c82a4262f16ab180596375cc037cc0e9febec2 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_pass_test.cc @@ -30,7 +30,9 @@ DEFINE_int32(output_tensor_num, 1, "number of output tensors"); namespace paddle { namespace lite { -std::vector> ParseShape(std::string txt) { +// The helper functions for loading and running model from command line and +// verifying output data +std::vector> ShapeParsing(std::string txt) { std::vector> shape; while (!txt.empty()) { size_t idx = txt.find_first_of(":"); @@ -65,7 +67,7 @@ int64_t ShapeProduction(std::vector shape) { return s; } -void FillInputTensor( +void FillInputTensors( const std::shared_ptr& predictor, const std::vector>& input_tensor_shape, const float value) { @@ -80,7 +82,7 @@ void FillInputTensor( } } -void CompareOutputTensor( +void CheckOutputTensors( const std::shared_ptr& tar_predictor, const std::shared_ptr& ref_predictor, const int output_tensor_num) { @@ -96,7 +98,7 @@ void CompareOutputTensor( auto abs_diff = std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]); auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6); - VLOG(3) << "val: " << tar_output_tensor_data[j] + VLOG(5) << "val: " << tar_output_tensor_data[j] << " ref: " << ref_output_tensor_data[j] << " abs_diff: " << abs_diff << " rel_diff: " << rel_diff; EXPECT_LT(rel_diff, 0.1); @@ -111,24 +113,23 @@ std::shared_ptr TestModel( const std::vector& valid_places, const std::vector>& input_tensor_shape, const std::string& optimized_model_dir) { - // generate optimized model + // Generate optimized model lite_api::CxxConfig cxx_config; cxx_config.set_model_dir(model_dir); cxx_config.set_model_file(model_file); cxx_config.set_param_file(params_file); cxx_config.set_valid_places(valid_places); auto predictor = lite_api::CreatePaddlePredictor(cxx_config); - FillInputTensor(predictor, input_tensor_shape, 1); predictor->SaveOptimizedModel(optimized_model_dir, lite_api::LiteModelType::kNaiveBuffer); - // load optimized model + // Load optimized model lite_api::MobileConfig mobile_config; mobile_config.set_model_dir(optimized_model_dir); mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); mobile_config.set_threads(1); predictor = lite_api::CreatePaddlePredictor(mobile_config); - FillInputTensor(predictor, input_tensor_shape, 1); - // run optimized model + FillInputTensors(predictor, input_tensor_shape, 1); + // Run optimized model for (int i = 0; i < FLAGS_warmup; i++) { predictor->Run(); } @@ -140,32 +141,48 @@ std::shared_ptr TestModel( return predictor; } -TEST(NPUSubgraph, compare) { - // parsing input tensor shape, supported formats: "1,3,224,224" - // "1,3,224,224:1,80" +TEST(Subgraph, generate_model_and_check_precision) { + if (FLAGS_model_dir.empty() && FLAGS_model_file.empty() && + FLAGS_params_file.empty()) { + LOG(INFO) << "Using --model_dir, or --model_file and --params_file to set " + "the path of model files."; + return; + } + // Parsing the shapes of input tensors from strings, supported formats: + // "1,3,224,224" and "1,3,224,224:1,80" std::vector> input_tensor_shape = - ParseShape(FLAGS_input_tensor_shape); - // generate and run optimized CPU model - LOG(INFO) << " ================ CPU ================== "; - auto cpu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kARM), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/CPU"); - // generate and run optimized NPU model - LOG(INFO) << " ================ NPU ================== "; - auto npu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}, - lite_api::Place{TARGET(kARM), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/NPU"); - // verify results - CompareOutputTensor(npu_predictor, cpu_predictor, FLAGS_output_tensor_num); + ShapeParsing(FLAGS_input_tensor_shape); + std::vector valid_places({ +#ifdef LITE_WITH_ARM + lite_api::Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_X86 + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, +#endif + }); + // Generate and run optimized model on CPU as the reference predictor + auto ref_predictor = TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + valid_places, + input_tensor_shape, + FLAGS_optimized_model_dir + "/ref_opt_model"); +// Generate and run optimized model on NPU/XPU as the target predictor +#ifdef LITE_WITH_NPU + valid_places.push_back(lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}); +#endif +#ifdef LITE_WITH_XPU + valid_places.push_back(lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}); +#endif + auto tar_predictor = TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + valid_places, + input_tensor_shape, + FLAGS_optimized_model_dir + "/tar_opt_model"); + // Check the difference of the output tensors between reference predictor and + // target predictor + CheckOutputTensors(tar_predictor, ref_predictor, FLAGS_output_tensor_num); } } // namespace lite diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc deleted file mode 100644 index 719a01dfd892f83da5e1d9b1efa6df758612acc7..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include -#include -#include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/pass_registry.h" -#include "lite/core/mir/pattern_matcher.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -std::unordered_map> -SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr& graph) { - std::unordered_map> op_nodes; - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - int sub_id = stmt.subgraph_id(); - if (sub_id < 1) continue; - if (!op_nodes.count(sub_id)) { - op_nodes[sub_id] = std::unordered_set(); - } - op_nodes.at(sub_id).insert(item); - } - return op_nodes; -} - -cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc( - const std::string& weight_var_name, - const std::vector& in_var_names, - const std::vector& out_var_names) { - cpp::OpDesc op_desc; - op_desc.SetType("graph_op"); - op_desc.SetInput("Inputs", in_var_names); - op_desc.SetInput("Weight", {weight_var_name}); - op_desc.SetOutput("Outputs", out_var_names); - return op_desc; -} - -void SubgraphProgramPass::InsertNewNode( - const std::unique_ptr& graph, - const std::string& weight_var_name, - Scope* scope, - const std::vector& valid_places, - std::unordered_set in_data_vars, - std::unordered_set in_wgt_vars, - std::unordered_set out_data_vars, - std::unordered_set out_unused_vars) { - std::vector in_var_names; - std::vector out_var_names; - for (auto i : in_data_vars) { - in_var_names.push_back(i->AsArg().name); - } - for (auto i : out_data_vars) { - out_var_names.push_back(i->AsArg().name); - } - - auto op_desc = GenGraphOpDesc(weight_var_name, in_var_names, out_var_names); - - auto graph_op = LiteOpRegistry::Global().Create("graph_op"); - graph_op->Attach(op_desc, scope); - auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); - - for (auto& in_var : in_data_vars) { - IR_NODE_LINK_TO(in_var, new_op_node); - } - for (auto& in_var : in_wgt_vars) { - IR_NODE_LINK_TO(in_var, new_op_node); - } - for (auto& out_var : out_data_vars) { - IR_OP_VAR_LINK(new_op_node, out_var); - } - for (auto& out_var : out_unused_vars) { - IR_OP_VAR_LINK(new_op_node, out_var); - } - - // add weight node to store pre-compilied NPU model - auto new_weight_node = graph->NewArgumentNode(weight_var_name); - new_weight_node->AsArg().is_weight = true; - new_weight_node->AsArg().is_persist = true; - DirectedLink(new_weight_node, new_op_node); - - // assign context - auto& inst = new_op_node->AsStmt(); - inst.picked_kernel().SetContext( - ContextScheduler::Global().NewContext(inst.picked_kernel().target())); -} - -void SubgraphProgramPass::SortHelper( - Node* node, - const std::unordered_set& nodes_all, - std::unordered_set* visited_nodes, - std::vector* ret) { - for (auto& var_node : node->inlinks) { - if (var_node->inlinks.empty()) continue; - auto* op_node = var_node->inlinks.front(); - if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) { - SortHelper(op_node, nodes_all, visited_nodes, ret); - } - } - ret->push_back(node); - visited_nodes->insert(node); -} - -std::vector SubgraphProgramPass::GetTopologicalOrder( - const std::unordered_set& nodes) { - std::unordered_set visited; - std::vector ret; - for (auto& node : nodes) { - if (!node->IsStmt()) continue; - if (visited.count(node)) continue; - SortHelper(node, nodes, &visited, &ret); - } - return ret; -} - -void SubgraphProgramPass::FindInputOutputVars( - const std::unordered_set& op_nodes, - std::unordered_set* in_data_vars, - std::unordered_set* in_wgt_vars, - std::unordered_set* out_data_vars, - std::unordered_set* out_unused_vars) { - for (auto& op_node : op_nodes) { - for (auto& in_var : op_node->inlinks) { - if (in_var->AsArg().is_weight) { - in_wgt_vars->insert(in_var); - continue; - } - if (!in_var->inlinks.empty()) { - // var can only come from one op node, so use front - auto* pre_op_node = in_var->inlinks.front(); - if (op_nodes.count(pre_op_node)) { - continue; - } - } - in_data_vars->insert(in_var); - } - for (auto& out_var : op_node->outlinks) { - if (out_var->outlinks.empty()) { - // the next op is empty so this var is actually unused - out_unused_vars->insert(out_var); - continue; - } - // var can have more than one next op node - // so, if any one in the op_nodes then continue - bool next_op_in_nodes = false; - for (auto& next_op_node : out_var->outlinks) { - if (op_nodes.count(next_op_node)) { - next_op_in_nodes = true; - } - } - if (next_op_in_nodes) { - continue; - } - - out_data_vars->insert(out_var); - } - } -} - -std::unordered_set SubgraphProgramPass::GetNode2rm( - const std::unordered_set& op_nodes, - const std::vector>& excluded_nodes) { - std::unordered_set nodes2rm(op_nodes.begin(), op_nodes.end()); - for (auto& op_node : op_nodes) { - for (auto& in_var : op_node->inlinks) { - if (!nodes2rm.count(in_var)) { - nodes2rm.insert(in_var); - } - } - for (auto& out_var : op_node->outlinks) { - if (!nodes2rm.count(out_var)) { - nodes2rm.insert(out_var); - } - } - } - // some nodes should not be removed - for (auto& e : excluded_nodes) { - for (auto& i : e) { - if (nodes2rm.count(i)) { - nodes2rm.erase(i); - } - } - } - return nodes2rm; -} - -void SubgraphProgramPass::InferOnce(const std::unique_ptr& graph) { - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - auto& op = stmt.op(); - auto scope = op->scope(); - std::string op_type = op->op_info()->Type(); - // check the dimension of input variables in the scope, must not be empty ! - if (op_type == "feed") { - auto input_var_names = op->op_info()->output_names(); - CHECK_GE(input_var_names.size(), 1); - for (auto input_var_name : input_var_names) { - auto input_var = scope->FindVar(input_var_name); - CHECK(input_var) << "No input variable '" << input_var_name - << "' found in scope " << scope; - auto input = input_var->GetMutable(); - CHECK(!input->dims().empty()) << "The dimension of input variable '" - << input_var_name - << "' can not be empty."; - } - continue; - } - if (op_type == "fetch") { - continue; - } - op->CheckShape(); - op->InferShape(); - -#ifndef LITH_WITH_XPU - // TOOD(xxx): remove Launch() at last - auto& kkks = stmt.kernels(); - if (!kkks.empty()) { - auto& kk = stmt.kernels().front(); - if (kk) { - kk->Launch(); - } - } -#endif - } -} - -void SubgraphProgramPass::InitSubgraphID( - const std::unique_ptr& graph, - const std::vector& supported_op_types) { - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - stmt.ClearSubgraphID(); - if (std::find(supported_op_types.begin(), - supported_op_types.end(), - stmt.op_type()) != supported_op_types.end()) { - stmt.SetSubgraphID(0); - LOG(INFO) << "supported " << stmt.op_type(); - } else { - LOG(INFO) << "======= not supported " << stmt.op_type(); - } - } -} - -// mark current and all output supported nodes -void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node, - int to_id, - int from_id) { - if (!node) return; - if (node->IsStmt()) { - auto& stmt = node->AsStmt(); - if (stmt.subgraph_id() == from_id) { - stmt.SetSubgraphID(to_id); - for (auto& i : node->outlinks) { - ChangeAllOutConnectedID(i, to_id, from_id); - } - } else { - LOG(INFO) << "failed op type:" << stmt.op_type(); - return; - } - } else { - // this it arg node - bool all_out_op_supported = true; - for (auto& i : node->outlinks) { - if (!i->IsStmt()) return; - auto& stmt = i->AsStmt(); - if (stmt.subgraph_id() < from_id) { - all_out_op_supported = false; - } - } - if (!all_out_op_supported) { - return; - } - for (auto& i : node->outlinks) { - CHECK(i->IsStmt()); - auto& stmt = i->AsStmt(); - if (stmt.subgraph_id() == from_id) { - stmt.SetSubgraphID(to_id); - for (auto& o : i->outlinks) { - ChangeAllOutConnectedID(o, to_id, from_id); - } - } - } - } -} - -int SubgraphProgramPass::FuseSubgraphID( - const std::unique_ptr& graph) { - int sub_id = 1; // id start from 1 not 0 - for (auto& item : graph->StmtTopologicalOrder()) { - // bool inputvar = false; - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - /* - if (stmt.subgraph_id() == -1) { - for (auto& i : item->outlinks) { - for (auto& j : i->outlinks) { - if (j->IsStmt()) { - auto& jstmt = j->AsStmt(); - if (jstmt.subgraph_id() == 0) inputvar = true; - } - } - } - } - */ - if (stmt.subgraph_id() != 0) continue; - ChangeAllOutConnectedID(item, sub_id); - sub_id++; - } - return sub_id - 1; -} - -int SubgraphProgramPass::FuseSubgraph( - const std::unique_ptr& graph, - const std::vector& supported_op_types) { - InitSubgraphID(graph, supported_op_types); - return FuseSubgraphID(graph); -} -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle - -REGISTER_MIR_PASS(subgraph_program_pass, - paddle::lite::mir::subgraph::SubgraphProgramPass) - .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/subgraph_program_pass.h b/lite/core/mir/subgraph/subgraph_program_pass.h deleted file mode 100644 index 24c0233bbb428a71fa5645b23573494b5067d8b1..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/subgraph_program_pass.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "lite/core/mir/pass.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -class SubgraphProgramPass : public ProgramPass { - public: - using key2nodes_t = std::map; - - // make all the linked ops in subgraph with same subgraph_id - // return the fused subgraph numbers - int FuseSubgraph(const std::unique_ptr& graph, - const std::vector& supported_op_types); - - void Apply(const std::unique_ptr& graph) override{}; - - protected: - void InferOnce(const std::unique_ptr& graph); - - // clear all subgraph id and mark all ops, which could be fuse, as id zero - void InitSubgraphID(const std::unique_ptr& graph, - const std::vector& supported_op_types); - - // make all the linked ops in subgraph with same subgraph_id - // return the fused subgraph numbers - int FuseSubgraphID(const std::unique_ptr& graph); - - // // GenerateFusedGraph: - // std::unique_ptr GenerateFusedGraph(const - // std::unique_ptr& graph, int sub_num); - void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0); - - // Below function cloud be useful in child classes // - // classify node by subgraph id - std::unordered_map> ClassifySubgraph( - const std::unique_ptr& graph); - - // generate the graph op desc - cpp::OpDesc GenGraphOpDesc(const std::string& weight_var_name, - const std::vector& in_var_names, - const std::vector& out_var_names); - - // insert a new graph op node - void InsertNewNode(const std::unique_ptr& graph, - const std::string& weight_var_name, - Scope* scope, - const std::vector& valid_places, - std::unordered_set in_data_vars, - std::unordered_set in_wgt_vars, - std::unordered_set out_data_vars, - std::unordered_set out_unused_vars); - - // Sort and return the topology order of nodes set - std::vector GetTopologicalOrder( - const std::unordered_set& nodes); - - // find all input data vars, input weight vars, - // output data vars and output vars from the nodes - void FindInputOutputVars(const std::unordered_set& op_nodes, - std::unordered_set* in_data_vars, - std::unordered_set* in_wgt_vars, - std::unordered_set* out_data_vars, - std::unordered_set* out_unused_vars); - - // return the node to remove in the subgraph - std::unordered_set GetNode2rm( - const std::unordered_set& op_nodes, - const std::vector>& excluded_nodes); - - private: - // sort nodes to operational sequence - void SortHelper(Node* node, - const std::unordered_set& nodes_all, - std::unordered_set* visited_nodes, - std::vector* ret); -}; - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index b008faa687474a88988adb9da81c594306298b26..ae74bd8d4d5647139a13509dfda0bb2b41ecc5c7 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include "lite/core/mir/graph_visualize_pass.h" @@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr& graph) { CHECK(!valid_places_.empty()); + // record the copied node. + std::unordered_map copied_nodes; + for (auto& node : nodes) { if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; auto inlinks = node->inlinks; for (auto* in : inlinks) { - ComplementInputs(graph.get(), node, in); + ComplementInputs(graph.get(), node, in, &copied_nodes); } } } -void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, - Node* inst_node, - Node* in) { +void TypeTargetTransformPass::ComplementInputs( + SSAGraph* graph, + Node* inst_node, + Node* in, + std::unordered_map* copied_nodes) { // If this input is out of date. if (inst_node->inlinks.end() == std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) @@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, << " for kernel " << inst.op()->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. - AddIoCopyInst( - *in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_); + AddIoCopyInst(*in->AsArg().type, + *decl_arg_type, + in, + graph, + inst_node, + copied_nodes, + valid_places_); } } @@ -78,128 +89,132 @@ void TypeTargetTransformPass::AddIoCopyInst( Node* in, SSAGraph* graph, Node* inst_node, + std::unordered_map* copied_nodes, const std::vector& valid_places) { CHECK(!valid_places.empty()) << "valid_place should be set"; // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new IoCopy Statement Node. CHECK(in->IsArg()); + // auto node_id = [&] { return graph->nodes().size(); }; auto io_copy_output_name = string_format("%s/target_trans", in->AsArg().name.c_str()); // string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); - // TODO(MyPandaShaoxiang) should set same place with input? - auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); - // Set the place for io_copy_output_arg node, the target should be equal to - // to.target() - // The precision and layout should be equal to from.precision(), from.layout() - io_copy_output_arg->AsArg().type = - LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); - auto* io_copy_inst = graph->NewInstructNode(); - - bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; - std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; - io_copy_output_arg->AsArg().is_persist = in_persist; - // create Op and kernels. - auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); - CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; - // CHECK(io_copy_op); - // Create the new var manually. - inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); - - // Create IoCopy Instruction. - cpp::OpDesc op_desc; - op_desc.SetType(io_copy_type); - op_desc.SetInput("Input", {in->AsArg().name}); - op_desc.SetOutput("Out", {io_copy_output_name}); - - io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); - auto kernels = io_copy_op->CreateKernels(valid_places); - // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type - bool is_found = false; - std::vector> selected_kernels; - for (auto& kernel : kernels) { - const Type* in_arg_ty = kernel->GetInputDeclType("Input"); - const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - - VLOG(4) << "------ kernel info -------"; - VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; - VLOG(4) << "from(last kernel output):" << from; - VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; - VLOG(4) << "to:" << to << "\n"; - - // kernel choose branch for opencl backend - // judge inst's target whether is kOpenCL - // Note: to == *decl_arg_type == in of inst, not output of last inst - // ignore [layout check] for layout between [to] and [from] - // Because all of origin opencl insts in model, are not default layout - // NCHW, - // so skip layout check. - // detailed node info see below: - // [*in->AsArg().type] -> [from]: out of inst's previous kernel - // [*decl_arg_type] -> [to]: input of inst, not output of last - // [in_arg_ty]: in of io_copy - // [out_arg_ty]: out of io_copy - // - // noto: replace LITE_WITH_OPENCL macro with judge input and output target - // of io_copy - if ((in_arg_ty->target() == TARGET(kOpenCL) || - out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first - (TargetCompatibleTo(*in_arg_ty, from) && - PrecisionCompatibleTo(*in_arg_ty, from) && - DeviceCompatibleTo(*in_arg_ty, from) && - TargetCompatibleTo(*out_arg_ty, to))) { - VLOG(4) << "picked, opencl found"; - is_found = true; - } else if (TypeCompatible(*in_arg_ty, from) && - out_arg_ty->target() == to.target()) { - VLOG(4) << "picked"; - is_found = true; - } - if (is_found) { - selected_kernels.emplace_back(std::move(kernel)); - // we pick the kernel - io_copy_inst->AsStmt( - io_copy_type, std::move(selected_kernels), io_copy_op); - break; + if (copied_nodes->count(in->AsArg().name)) { + // Remove the old link + RemoveDirectedLink(in, inst_node); + + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, newarg->inst + DirectedLink(copied_nodes->at(in->AsArg().name), + inst_node); // [io_copy kernel]'s output -> [current kernel] + + UpdateInstNode(in, graph, inst_node, io_copy_output_name); + } else { + // TODO(MyPandaShaoxiang) should set same place with input? + auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); + // Set the place for io_copy_output_arg node, the target should be equal to + // to.target() + // The precision and layout should be equal to from.precision(), + // from.layout() + io_copy_output_arg->AsArg().type = + LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); + auto* io_copy_inst = graph->NewInstructNode(); + + bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; + std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; + io_copy_output_arg->AsArg().is_persist = in_persist; + // create Op and kernels. + auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); + CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; + // CHECK(io_copy_op); + // Create the new var manually. + inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); + + // Create IoCopy Instruction. + cpp::OpDesc op_desc; + op_desc.SetType(io_copy_type); + op_desc.SetInput("Input", {in->AsArg().name}); + op_desc.SetOutput("Out", {io_copy_output_name}); + + io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto kernels = io_copy_op->CreateKernels(valid_places); + // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type + bool is_found = false; + std::vector> selected_kernels; + for (auto& kernel : kernels) { + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + + VLOG(4) << "------ kernel info -------"; + VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; + VLOG(4) << "from(last kernel output):" << from; + VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; + VLOG(4) << "to:" << to << "\n"; + + // kernel choose branch for opencl backend + // judge inst's target whether is kOpenCL + // Note: to == *decl_arg_type == in of inst, not output of last inst + // ignore [layout check] for layout between [to] and [from] + // Because all of origin opencl insts in model, are not default layout + // NCHW, + // so skip layout check. + // detailed node info see below: + // [*in->AsArg().type] -> [from]: out of inst's previous kernel + // [*decl_arg_type] -> [to]: input of inst, not output of last + // [in_arg_ty]: in of io_copy + // [out_arg_ty]: out of io_copy + // + // noto: replace LITE_WITH_OPENCL macro with judge input and output target + // of io_copy + if ((in_arg_ty->target() == TARGET(kOpenCL) || + out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first + (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from) && + TargetCompatibleTo(*out_arg_ty, to))) { + VLOG(4) << "picked, opencl found"; + is_found = true; + } else if (TypeCompatible(*in_arg_ty, from) && + out_arg_ty->target() == to.target()) { + VLOG(4) << "picked"; + is_found = true; + } + + if (is_found) { + selected_kernels.emplace_back(std::move(kernel)); + // we pick the kernel + io_copy_inst->AsStmt( + io_copy_type, std::move(selected_kernels), io_copy_op); + (*copied_nodes)[in->AsArg().name] = io_copy_output_arg; + break; + } + + VLOG(4) << "not picked"; } - VLOG(4) << "not picked"; - } + CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from + << ":" << in->AsArg().name << " -> " << to << ":" + << inst_node->AsStmt().op_info()->Type(); + // Remove the old link + RemoveDirectedLink(in, inst_node); - CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from - << ":" << in->AsArg().name << " -> " << to << ":" - << inst_node->AsStmt().op_info()->Type(); - // Remove the old link - RemoveDirectedLink(in, inst_node); - - // Update the original instruction OpDesc. - // Update its input to the io_copy_output_name - // Add new link, var -> new_inst, new_inst->newarg, newarg->inst - DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel] - DirectedLink( - io_copy_inst, - io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output - DirectedLink(io_copy_output_arg, - inst_node); // [io_copy kernel]'s output -> [current kernel] + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(in, + io_copy_inst); // [last kernel]'s output -> [io_copy kernel] + DirectedLink( + io_copy_inst, + io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output + DirectedLink(io_copy_output_arg, + inst_node); // [io_copy kernel]'s output -> [current kernel] - // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), - in->AsArg().name, - io_copy_output_name); - auto original_selected_kernel = - std::move(inst_node->AsStmt().kernels().front()); - auto update_op_info = *inst_node->AsStmt().op_info(); - // ResetOp() will change the Stmt op_info_ value, - // after that the old op_info_ value will be nullified. - // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. - // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). - // Whenever update the op_info of a stmt, we should call its ResetOp(). - inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); - inst_node->AsStmt().kernels().clear(); - inst_node->AsStmt().kernels().emplace_back( - std::move(original_selected_kernel)); + UpdateInstNode(in, graph, inst_node, io_copy_output_name); + } std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { @@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces( valid_places_ = valid_places; } +void TypeTargetTransformPass::UpdateInstNode(Node* in, + SSAGraph* graph, + Node* inst_node, + std::string io_copy_output_name) { + // reset opdesc and update kernel information + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), + in->AsArg().name, + io_copy_output_name); + auto original_selected_kernel = + std::move(inst_node->AsStmt().kernels().front()); + auto update_op_info = *inst_node->AsStmt().op_info(); + // ResetOp() will change the Stmt op_info_ value, + // after that the old op_info_ value will be nullified. + // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. + // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). + // Whenever update the op_info of a stmt, we should call its ResetOp(). + inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); + inst_node->AsStmt().kernels().clear(); + inst_node->AsStmt().kernels().emplace_back( + std::move(original_selected_kernel)); +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/lite/core/mir/type_target_cast_pass.h b/lite/core/mir/type_target_cast_pass.h index 8a8cfaf9f9282cb477f7b9dd404d6f869333221b..e9a275882f7c2cb813c1c0b8add5cc4ca89b0c8b 100644 --- a/lite/core/mir/type_target_cast_pass.h +++ b/lite/core/mir/type_target_cast_pass.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "lite/core/mir/pass.h" #include "lite/core/op_registry.h" @@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override; - void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + void ComplementInputs(SSAGraph* graph, + Node* inst_node, + Node* in, + std::unordered_map* copied_nodes); void AddIoCopyInst(const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, + std::unordered_map* copied_nodes, const std::vector& valid_places); void SetValidPlaces(const std::vector& valid_places); @@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass { const std::vector& valid_places() const { return valid_places_; } private: + void UpdateInstNode(Node* in, + SSAGraph* graph, + Node* inst_node, + std::string io_copy_output_name); + std::vector valid_places_; }; diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index 3f5d161a56aafa7fd9d058fd404e65cb04572116..875bf23082a24cb6fcae878b46cc9dcdbb2b76f7 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -48,6 +48,10 @@ class VariablePlaceInferencePass : public DebugPass { void CheckAllArgumentTypeDetermined(SSAGraph* graph) { for (auto& node : graph->mutable_nodes()) { if (node.IsArg()) { + if (node.inlinks.size() == 0 && node.outlinks.size() == 0) { + // empty node + continue; + } CHECK(node.AsArg().type) << "node " << node.AsArg().name << " type not determined, " << &node; } @@ -129,6 +133,17 @@ class VariablePlaceInferencePass : public DebugPass { } else { x_in->AsArg().type = type; } + } else if (x_in->AsArg().type->target() == TARGET(kUnk) && + x_in->AsArg().type->precision() != PRECISION(kUnk) && + x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) { + // If is quantization, infer the Int8 type. + if (type->precision() == PRECISION(kInt8)) { + x_in->AsArg().type = type; + } else { + PrecisionType tmp_ptype = x_in->AsArg().type->precision(); + x_in->AsArg().type = LiteType::GetTensorTy( + type->target(), tmp_ptype, type->layout()); + } } } @@ -149,6 +164,17 @@ class VariablePlaceInferencePass : public DebugPass { } else { x_out->AsArg().type = type; } + } else if (x_out->AsArg().type->target() == TARGET(kUnk) && + x_out->AsArg().type->precision() != PRECISION(kUnk) && + x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) { + // If is quantization, infer the Int8 type. + if (type->precision() == PRECISION(kInt8)) { + x_out->AsArg().type = type; + } else { + PrecisionType tmp_ptype = x_out->AsArg().type->precision(); + x_out->AsArg().type = LiteType::GetTensorTy( + type->target(), tmp_ptype, type->layout()); + } } } } diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index c23d3157e0a7ec77ec26afad6092d0be9a63a436..716ce9d6a82b07270b5029f4cddf6a6b808c6c21 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -40,6 +40,18 @@ std::list> KernelRegistry::Create( return Create(op_type); \ + case DATALAYOUT(kImageDefault): \ + return Create(op_type); \ + case DATALAYOUT(kImageFolder): \ + return Create(op_type); \ + case DATALAYOUT(kImageNW): \ + return Create(op_type); \ default: \ LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \ } @@ -54,6 +66,8 @@ std::list> KernelRegistry::Create( CREATE_KERNEL1(target__, kFP16); \ case PRECISION(kAny): \ CREATE_KERNEL1(target__, kAny); \ + case PRECISION(kInt32): \ + CREATE_KERNEL1(target__, kInt32); \ case PRECISION(kInt64): \ CREATE_KERNEL1(target__, kInt64); \ default: \ @@ -136,6 +150,7 @@ KernelRegistry::KernelRegistry() INIT_FOR(kARM, kInt8, kNCHW); INIT_FOR(kARM, kAny, kNCHW); INIT_FOR(kARM, kAny, kAny); + INIT_FOR(kARM, kInt32, kNCHW); INIT_FOR(kOpenCL, kFloat, kNCHW); INIT_FOR(kOpenCL, kFloat, kNHWC); @@ -144,6 +159,17 @@ KernelRegistry::KernelRegistry() INIT_FOR(kOpenCL, kFloat, kAny); INIT_FOR(kOpenCL, kInt8, kNCHW); INIT_FOR(kOpenCL, kAny, kAny); + INIT_FOR(kOpenCL, kFP16, kNCHW); + INIT_FOR(kOpenCL, kFP16, kNHWC); + INIT_FOR(kOpenCL, kFP16, kImageDefault); + INIT_FOR(kOpenCL, kFP16, kImageFolder); + INIT_FOR(kOpenCL, kFP16, kImageNW); + INIT_FOR(kOpenCL, kFloat, kImageDefault); + INIT_FOR(kOpenCL, kFloat, kImageFolder); + INIT_FOR(kOpenCL, kFloat, kImageNW); + INIT_FOR(kOpenCL, kAny, kImageDefault); + INIT_FOR(kOpenCL, kAny, kImageFolder); + INIT_FOR(kOpenCL, kAny, kImageNW); INIT_FOR(kNPU, kFloat, kNCHW); INIT_FOR(kNPU, kInt8, kNCHW); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index d78ae690f9b019dff7728bd3e95c0b1406bea463..0df5cb41ecc4c631e8540f9595c3182122b99f5f 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -145,6 +145,9 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // @@ -173,6 +176,39 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget GenRuntimeProgram() { - // Extra passes are applied for NPU and XPU, they depends on the shapes - // of input tensors. so GenRuntimeProgram() must be called after the shapes - // of input tensors are determined. - std::vector subgraph_passes{"generate_npu_program_pass", - "generate_xpu_program_pass"}; - RunPasses(subgraph_passes); - auto pass = mir::PassManager::Global().LookUp( "generate_program_pass"); pass->Apply(graph_); diff --git a/lite/core/profile/profiler.cc b/lite/core/profile/profiler.cc index a51b769c8f46a5ca8cb9ed74740b93844882cb16..78317f78ac6bf7024c1984c2127434d55b738ad6 100644 --- a/lite/core/profile/profiler.cc +++ b/lite/core/profile/profiler.cc @@ -21,6 +21,13 @@ namespace paddle { namespace lite { namespace profile { +namespace { +auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) { + return (c1.target < c2.target) || (c1.op_type < c2.op_type) || + (c1.kernel_name < c2.kernel_name) || (c1.remark < c2.remark); +}; +} + int Profiler::NewTimer(const OpCharacter& ch) { StatisUnit unit; unit.character = ch; @@ -50,61 +57,66 @@ float Profiler::StopTiming(const int index, KernelContext* ctx) { return units_[index].timer->Stop(ctx); } -std::string Profiler::Summary(bool concise) { +std::string Profiler::Summary(bool concise, size_t w) { + using std::setw; + using std::left; + using std::fixed; STL::stringstream ss; - auto cout_title = [&ss](const std::string& title, const std::string& name) { - // clang-format off - ss << "===== " << title << ": " << name << " =====" << std::endl; - ss << std::setw(25) << std::left << "Operator Type" \ - << std::setw(40) << std::left << "Kernel Name" \ - << std::setw(10) << std::left << "Remark" \ - << std::setw(10) << std::left << "Avg (ms)" \ - << std::setw(10) << std::left << "Min (ms)" \ - << std::setw(10) << std::left << "Max (ms)" \ + std::string title; + // Title. + if (concise) { + ss << "Timing cycle = " << units_.front().timer->LapTimes().Size() << std::endl; - // clang-format on - }; + ss << "===== Concise Profiler Summary: " << name_ << ", Exclude " << w + << " warm-ups =====" << std::endl; + } else { + ss << "===== Detailed Profiler Summary: " << name_ << ", Exclude " << w + << " warm-ups =====" << std::endl; + } + ss << setw(25) << left << "Operator Type" + << " " << setw(40) << left << "Kernel Name" + << " " << setw(12) << left << "Remark" + << " " << setw(12) << left << "Avg (ms)" + << " " << setw(12) << left << "Min (ms)" + << " " << setw(12) << left << "Max (ms)" + << " " << setw(12) << left << "Last (ms)" << std::endl; + // Profile information. if (concise) { - auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) { - return (c1.target < c2.target) || (c1.op_type < c2.op_type) || - (c1.kernel_name < c2.kernel_name) || (c1.remark < c2.remark); - }; std::map summary(op_comp); for (auto& unit : units_) { auto ch = summary.find(unit.character); if (ch != summary.end()) { - ch->second.avg += unit.timer->LapTimes().Avg(); - ch->second.min += unit.timer->LapTimes().Min(); - ch->second.max += unit.timer->LapTimes().Max(); + ch->second.avg += unit.timer->LapTimes().Avg(w); + ch->second.min += unit.timer->LapTimes().Min(w); + ch->second.max += unit.timer->LapTimes().Max(w); } else { - TimeInfo info({unit.timer->LapTimes().Avg(), - unit.timer->LapTimes().Min(), - unit.timer->LapTimes().Max()}); + TimeInfo info({unit.timer->LapTimes().Avg(w), + unit.timer->LapTimes().Min(w), + unit.timer->LapTimes().Max(w)}); summary.insert({unit.character, info}); } } - cout_title("Concise Profiler Summary", name_); for (const auto& item : summary) { // clang-format off - ss << std::setw(25) << std::left << item.first.op_type \ - << std::setw(40) << std::left << item.first.kernel_name \ - << std::setw(10) << std::left << item.first.remark \ - << std::setw(10) << std::left << item.second.avg \ - << std::setw(10) << std::left << item.second.min \ - << std::setw(10) << std::left << item.second.max \ - << std::endl; + ss << setw(25) << left << fixed << item.first.op_type \ + << " " << setw(40) << left << fixed << item.first.kernel_name \ + << " " << setw(12) << left << fixed << item.first.remark \ + << " " << setw(12) << left << fixed << item.second.avg \ + << " " << setw(12) << left << fixed << item.second.min \ + << " " << setw(12) << left << fixed << item.second.max \ + << " " << std::endl; // clang-format on } } else { - cout_title("Detailed Profiler Summary", name_); for (auto& unit : units_) { // clang-format off - ss << std::setw(25) << std::left << unit.character.op_type \ - << std::setw(40) << std::left << unit.character.kernel_name \ - << std::setw(10) << std::left << unit.character.remark \ - << std::setw(10) << std::left << unit.timer->LapTimes().Avg() \ - << std::setw(10) << std::left << unit.timer->LapTimes().Min() \ - << std::setw(10) << std::left << unit.timer->LapTimes().Max() \ + ss << setw(25) << left << fixed << unit.character.op_type \ + << " " << setw(40) << left << fixed << unit.character.kernel_name \ + << " " << setw(12) << left << fixed << unit.character.remark \ + << " " << setw(12) << left << fixed << unit.timer->LapTimes().Avg(w) \ + << " " << setw(12) << left << fixed << unit.timer->LapTimes().Min(w) \ + << " " << setw(12) << left << fixed << unit.timer->LapTimes().Max(w) \ + << " " << setw(12) << left << fixed << unit.timer->LapTimes().Last(w) \ << std::endl; // clang-format on } diff --git a/lite/core/profile/profiler.h b/lite/core/profile/profiler.h index 0fce8167cdd5383c2cc4ae5d641433582f0ee6a7..4e9e9ae31c1a6d7f331eac2e77c4971986bd42a1 100644 --- a/lite/core/profile/profiler.h +++ b/lite/core/profile/profiler.h @@ -47,7 +47,7 @@ class Profiler final { int NewTimer(const OpCharacter& ch); void StartTiming(const int index, KernelContext* ctx); float StopTiming(const int index, KernelContext* ctx); - std::string Summary(bool concise = true); + std::string Summary(bool concise = true, size_t warm_up = 10); private: std::string name_{std::string("N/A")}; diff --git a/lite/core/profile/timer.h b/lite/core/profile/timer.h index 1e86f0d7b9be4914bdf1a6874195276d3c1b61ee..e9bb16bd27d5ec6fd21814c35db52b2467a12b51 100644 --- a/lite/core/profile/timer.h +++ b/lite/core/profile/timer.h @@ -15,7 +15,7 @@ #pragma once #include #include // NOLINT -#include +#include #ifdef LITE_WITH_CUDA #include "lite/backends/cuda/cuda_utils.h" #endif @@ -30,20 +30,44 @@ class TimeList { public: void Clear() { laps_t_.clear(); } void Add(T t) { laps_t_.push_back(t); } - T Max() const { return *std::max_element(laps_t_.begin(), laps_t_.end()); } - T Min() const { return *std::min_element(laps_t_.begin(), laps_t_.end()); } - T Sum() const { return std::accumulate(laps_t_.begin(), laps_t_.end(), 0.0); } - size_t Size() const { return laps_t_.size(); } - T Avg() const { - if (!Size()) { + T Last(size_t offset = 0) const { + if (!Size(offset)) { return 0; } - return Sum() / Size(); + return laps_t_.back(); } - const std::list& Raw() const { return laps_t_; } + T Max(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return *std::max_element((laps_t_.begin() + offset), laps_t_.end()); + } + T Min(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return *std::min_element((laps_t_.begin() + offset), laps_t_.end()); + } + T Sum(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return std::accumulate((laps_t_.begin() + offset), laps_t_.end(), 0.0); + } + size_t Size(size_t offset = 0) const { + size_t size = (laps_t_.size() <= offset) ? 0 : (laps_t_.size() - offset); + return size; + } + T Avg(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return Sum(offset) / Size(offset); + } + const std::vector& Raw() const { return laps_t_; } private: - std::list laps_t_; + std::vector laps_t_; }; class Timer { @@ -69,8 +93,10 @@ class Timer { const TimeList& LapTimes() const { return laps_t_; } protected: - std::chrono::time_point t_start_, t_stop_; TimeList laps_t_; + + private: + std::chrono::time_point t_start_, t_stop_; }; template diff --git a/lite/core/program.cc b/lite/core/program.cc index 45796a478b3f2309912e6382b3380bf0734bd6ae..b0c61bf00ed29e2fa71072b64f11f6ba30f77691 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -17,6 +17,8 @@ #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/op_desc.h" #include "lite/model_parser/cpp/var_desc.h" +#include "lite/operators/conditional_block_op.h" +#include "lite/operators/subgraph_op.h" #include "lite/operators/while_op.h" #ifdef LITE_WITH_PROFILE #include "lite/core/profile/precision_profiler.h" @@ -30,10 +32,32 @@ void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) { // NOTE: RuntimeProgram do not has all meta info, so save model just update // upon origin model CHECK(desc->BlocksSize()); - auto& main_block = *desc->GetBlock(0); - main_block.ClearOps(); + auto main_block = desc->GetBlock(0); + main_block->ClearOps(); for (auto& node : instructions_) { - auto* op = main_block.AddOp(); + auto op_type = node.op()->op_info()->Type(); + if (op_type == "subgraph") { + auto subgraph_op = const_cast( + static_cast(node.op())); + int sub_block_idx = subgraph_op->op_info()->GetAttr("sub_block"); + if (sub_block_idx < 0) { + // It's a new subgraph op when its sub_block_idx < 0, Now we add its + // subblock desc to the program desc, Then update its sub_block_idx to + // the index of block desc of the program desc. + sub_block_idx = desc->BlocksSize(); + auto sub_block_desc = subgraph_op->GetSubBlock(); + CHECK(sub_block_desc); + auto new_block_desc = desc->AddBlock(); + *new_block_desc = *sub_block_desc; + delete sub_block_desc; + subgraph_op->mutable_op_info()->SetAttr("sub_block", + sub_block_idx); + subgraph_op->SetSubBlock(new_block_desc); + // Update main block desc after a new subblock desc is added + main_block = desc->GetBlock(0); + } + } + auto op = main_block->AddOp(); *op = *node.op()->op_info(); op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); } @@ -123,7 +147,7 @@ void RuntimeProgram::Run() { #endif // LITE_WITH_PROFILE } #ifdef LITE_WITH_PROFILE - LOG(INFO) << "\n" << profiler_.Summary(); + LOG(INFO) << "\n" << profiler_.Summary(false, 0); #endif // LITE_WITH_PROFILE } @@ -141,12 +165,26 @@ void Program::Build(const cpp::ProgramDesc& prog) { VLOG(4) << "create Op [" << op_type << "]"; auto op = LiteOpRegistry::Global().Create(op_type); CHECK(op) << "no Op found for " << op_type; - if (op_type == "while") { + if (op_type == "while" || op_type == "conditional_block" || + op_type == "subgraph") { auto sub_block_idx = op_desc.GetAttr("sub_block"); - auto sub_block = + CHECK(sub_block_idx >= 0 && sub_block_idx < program.BlocksSize()) + << "Invalid attribute sub_block(" << sub_block_idx << ") for " + << op_type; + auto sub_block_desc = const_cast(prog).GetBlock( sub_block_idx); - static_cast(op.get())->SetSubBlock(sub_block); + CHECK(sub_block_desc); + if (op_type == "while") { + static_cast(op.get())->SetSubBlock( + sub_block_desc); + } else if (op_type == "conditional_block") { + static_cast(op.get())->SetSubBlock( + sub_block_desc); + } else if (op_type == "subgraph") { + static_cast(op.get())->SetSubBlock( + sub_block_desc); + } } ops_.emplace_back(std::move(op)); ops_.back()->Attach(op_desc, exec_scope_); @@ -162,6 +200,27 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { tmp_vars_.push_back("feed"); tmp_vars_.push_back("fetch"); + auto VarPrecision2KernlPrecision = + [](const lite::VarDescAPI::Type& type) -> PrecisionType { + switch (type) { + case lite::VarDescAPI::Type::FP32: + return PRECISION(kFloat); + case lite::VarDescAPI::Type::FP16: + return PRECISION(kFP16); + case lite::VarDescAPI::Type::INT8: + return PRECISION(kInt8); + case lite::VarDescAPI::Type::INT16: + return PRECISION(kInt16); + case lite::VarDescAPI::Type::INT32: + return PRECISION(kInt32); + case lite::VarDescAPI::Type::INT64: + return PRECISION(kInt64); + default: + // LOG(FATAL) << "not supported type: " << static_cast(type); + return PRECISION(kUnk); + } + }; + auto program = prog; CHECK(program.BlocksSize()); for (size_t b = 0; b < program.BlocksSize(); ++b) { @@ -169,7 +228,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { for (size_t i = 0; i < main_block.VarsSize(); ++i) { auto& var_desc = *main_block.GetVar(i); if (!var_desc.Persistable()) { + if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR && + VarPrecision2KernlPrecision(var_desc.GetDataType()) != + PRECISION(kUnk)) { + var_data_type_[var_desc.Name()] = + VarPrecision2KernlPrecision(var_desc.GetDataType()); + } tmp_vars_.push_back(var_desc.Name()); + VLOG(4) << "var name: " << var_desc.Name() << " type is " + << static_cast(var_desc.GetType()) << " data type is " + << static_cast(var_desc.GetDataType()); exec_scope_->Var(var_desc.Name()); if (b > 0) { VLOG(4) << "var: " << var_desc.Name(); @@ -194,14 +262,10 @@ void Instruction::Run() { if (op_->run_once() && has_run_) { return; } -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "kernel launch"; -#endif + // VLOG(4) << "kernel launch"; op_->InferShape(); -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target " - << TargetToStr(kernel_->target()); -#endif + // VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target " + // << TargetToStr(kernel_->target()); kernel_->Launch(); has_run_ = true; } diff --git a/lite/core/program.h b/lite/core/program.h index 1c1e4975c3a13bcfa9a22999a705f3a78b0fc68e..291252619b396f18576b935a0189f4ecdba7867f 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include "lite/core/kernel.h" @@ -63,6 +64,10 @@ struct Program { lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* scope() { return scope_.get(); } + const std::unordered_map& var_data_type() const { + return var_data_type_; + } + private: // Build from a program and scope. void Build(const cpp::ProgramDesc& program); @@ -70,6 +75,7 @@ struct Program { void PrepareWorkspace(const cpp::ProgramDesc& program); private: + std::unordered_map var_data_type_; std::list tmp_vars_; std::list weights_; std::list> ops_; @@ -135,6 +141,11 @@ class LITE_API RuntimeProgram { set_profiler(); #endif } + ~RuntimeProgram() { +#ifdef LITE_WITH_PROFILE + LOG(INFO) << "\n" << profiler_.Summary(); +#endif // LITE_WITH_PROFILE + } void Run(); diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index 1c7db871c7b525d6e4944fd0d669e81bcaff7f2a..ecfdcf3d1107953f1c41ea57b6f12187b29686c6 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -104,6 +104,12 @@ const cl::Image2D *TensorLite::data() const { if (nullptr == buffer_->data()) return nullptr; return static_cast(buffer_->data()); } + +template <> // use int16_t represent half float +const cl::Image2D *TensorLite::data() const { + if (nullptr == buffer_->data()) return nullptr; + return static_cast(buffer_->data()); +} #endif } // namespace lite diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 8c4fe1604a517332e52b243404828e81af26f419..a1141c613e29326a5f9ffb2fdc1427e3fbe84481 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -147,9 +147,11 @@ class TensorLite { #ifdef LITE_WITH_OPENCL template - R *mutable_data(const size_t img_w, const size_t img_h) { + R *mutable_data(const size_t img_w, + const size_t img_h, + void *host_ptr = nullptr) { target_ = TARGET(kOpenCL); - buffer_->ResetLazyImage2D(target_, img_w, img_h); + buffer_->ResetLazyImage2D(target_, img_w, img_h, host_ptr); return static_cast(buffer_->data()); } #endif @@ -251,6 +253,9 @@ bool TensorCompareWith(const TensorT &a, const TensorT &b) { #ifdef LITE_WITH_OPENCL template <> const cl::Image2D *TensorLite::data() const; + +template <> // use int16_t represent half float +const cl::Image2D *TensorLite::data() const; #endif } // namespace lite diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 0c8866eaf88145d3bb0703b32ffb3eaf80332898..f543c000f8a202d891cd27958fb23dcf38e0240c 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -1,3 +1,10 @@ +# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered +# to the model_optimize_tool. +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +message(STATUS "compile with lite ARM kernels") # 1. basic kernels for basic models # for conv op @@ -41,6 +48,7 @@ add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc D add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) ## 2.other basic kernels: basic kernels that not used in basic models add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -61,11 +69,17 @@ add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${ add_kernel(sequence_pool_compute_arm ARM extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(split_lod_tensor_compute_arm ARM extra SRCS split_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(merge_lod_tensor_compute_arm ARM extra SRCS merge_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) # for OCR specific @@ -87,13 +101,6 @@ add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEP add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) -# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered -# to the model_optimize_tool. -if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) - return() -endif() - -message(STATUS "compile with lite ARM kernels") lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) @@ -107,6 +114,8 @@ lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS tran lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) if(LITE_BUILD_EXTRA) + lite_cc_test(test_split_lod_tensor_compute_arm SRCS split_lod_tensor_compute_test.cc DEPS split_lod_tensor_compute_arm) + lite_cc_test(test_merge_lod_tensor_compute_arm SRCS merge_lod_tensor_compute_test.cc DEPS merge_lod_tensor_compute_arm) lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 1fef52bcb77b7c3efdcd848ee63f8ec46c16d6f8..266ae1fc916af4303aca274c39b9b4923fdbb154 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -74,6 +74,6 @@ void CastCompute::Run() { REGISTER_LITE_KERNEL( cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/collect_fpn_proposals_compute.cc b/lite/kernels/arm/collect_fpn_proposals_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..d54b96348e866bbe16898ddd6fdbd45beb62afa0 --- /dev/null +++ b/lite/kernels/arm/collect_fpn_proposals_compute.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/collect_fpn_proposals_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +struct ScoreWithID { + float score; + int batch_id; + int index; + int level; + ScoreWithID() { + batch_id = -1; + index = -1; + level = -1; + } + ScoreWithID(float score_, int batch_id_, int index_, int level_) { + score = score_; + batch_id = batch_id_; + index = index_; + level = level_; + } +}; + +static inline bool CompareByScore(ScoreWithID a, ScoreWithID b) { + return a.score >= b.score; +} + +static inline bool CompareByBatchid(ScoreWithID a, ScoreWithID b) { + return a.batch_id < b.batch_id; +} + +void CollectFpnProposalsCompute::Run() { + auto& param = Param(); + auto multi_layer_rois = param.multi_level_rois; + auto multi_layer_scores = param.multi_level_scores; + auto* fpn_rois = param.fpn_rois; + int post_nms_topN = param.post_nms_topN; + + if (multi_layer_rois.size() != multi_layer_scores.size()) { + LOG(FATAL) << "multi_layer_rois.size() should be equan to " + "multi_layer_scores.size()"; + } + + size_t num_fpn_level = multi_layer_rois.size(); + std::vector integral_of_all_rois(num_fpn_level + 1, 0); + for (size_t i = 0; i < num_fpn_level; ++i) { + auto cur_rois_lod = multi_layer_rois[i]->lod().back(); + integral_of_all_rois[i + 1] = static_cast( + integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1]); + } + + std::vector scores_of_all_rois( + integral_of_all_rois[num_fpn_level], ScoreWithID()); + for (int i = 0; i < num_fpn_level; ++i) { + const float* cur_level_scores = multi_layer_scores[i]->data(); + int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i]; + auto cur_scores_lod = multi_layer_scores[i]->lod().back(); + int cur_batch_id = 0; + for (int j = 0; j < cur_level_num; ++j) { + if (j >= cur_scores_lod[cur_batch_id + 1]) { + cur_batch_id++; + } + int cur_index = j + integral_of_all_rois[i]; + scores_of_all_rois[cur_index].score = cur_level_scores[j]; + scores_of_all_rois[cur_index].index = j; + scores_of_all_rois[cur_index].level = i; + scores_of_all_rois[cur_index].batch_id = cur_batch_id; + } + } + + // keep top post_nms_topN rois, sort the rois by the score + if (post_nms_topN > integral_of_all_rois[num_fpn_level]) { + post_nms_topN = integral_of_all_rois[num_fpn_level]; + } + std::stable_sort( + scores_of_all_rois.begin(), scores_of_all_rois.end(), CompareByScore); + scores_of_all_rois.resize(post_nms_topN); + // sort by batch id + std::stable_sort( + scores_of_all_rois.begin(), scores_of_all_rois.end(), CompareByBatchid); + // create a pointer array + std::vector multi_fpn_rois_data(num_fpn_level); + for (int i = 0; i < num_fpn_level; ++i) { + multi_fpn_rois_data[i] = multi_layer_rois[i]->data(); + } + + // initialize the outputs + const int kBoxDim = 4; + auto fpn_rois_data = fpn_rois->mutable_data(); + std::vector lod0(1, 0); + int cur_batch_id = 0; + for (int i = 0; i < post_nms_topN; ++i) { + int cur_fpn_level = scores_of_all_rois[i].level; + int cur_level_index = scores_of_all_rois[i].index; + std::memcpy(fpn_rois_data, + multi_fpn_rois_data[cur_fpn_level] + cur_level_index * kBoxDim, + kBoxDim * sizeof(float)); + fpn_rois_data += kBoxDim; + if (scores_of_all_rois[i].batch_id != cur_batch_id) { + cur_batch_id = scores_of_all_rois[i].batch_id; + lod0.emplace_back(i); + } + } + lod0.emplace_back(post_nms_topN); + lite::LoD lod; + lod.emplace_back(lod0); + fpn_rois->set_lod(lod); + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(collect_fpn_proposals, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::CollectFpnProposalsCompute, + def) + .BindInput("MultiLevelRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("MultiLevelScores", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/xpu/bridges/registry.cc b/lite/kernels/arm/collect_fpn_proposals_compute.h similarity index 62% rename from lite/kernels/xpu/bridges/registry.cc rename to lite/kernels/arm/collect_fpn_proposals_compute.h index 4ab1b69a25a29aeb1c1ceaff25525459ef2e94cd..f1e7448a07aee4f9c2b57a1c6d2223f4262c59b4 100644 --- a/lite/kernels/xpu/bridges/registry.cc +++ b/lite/kernels/arm/collect_fpn_proposals_compute.h @@ -12,30 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/xpu/bridges/registry.h" -#include +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/axpy_op.h" namespace paddle { namespace lite { namespace kernels { -namespace xpu { -namespace bridges { +namespace arm { -Factory& Factory::Instance() { - static Factory g_xpu_bridge; - return g_xpu_bridge; -} +class CollectFpnProposalsCompute + : public KernelLite { + public: + using param_t = operators::CollectFpnProposalsParam; -bool Factory::HasType(const std::string& op_type) const { - return map_.count(op_type); -} + void Run() override; -void Factory::Insert(const std::string& op_type, const func_type& func_name) { - map_.insert(std::make_pair(op_type, func_name)); -} + virtual ~CollectFpnProposalsCompute() = default; +}; -} // namespace bridges -} // namespace xpu +} // namespace arm } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/arm/compare_compute.cc b/lite/kernels/arm/compare_compute.cc index 95014b4ccd427e152dfe919643afa5ff5eb3011d..6118cbc6e403645cada84d2434497b084636a4a3 100644 --- a/lite/kernels/arm/compare_compute.cc +++ b/lite/kernels/arm/compare_compute.cc @@ -112,6 +112,42 @@ void CompareCompute::Run() { } } +template