提交 62ed0122 编写于 作者: C chenjiaoAngel

test=develop

Merge branch 'conv_dw_5x5' of https://github.com/chenjiaoAngel/Paddle-Lite into conv_dw_5x5
...@@ -23,6 +23,9 @@ if(NOT DEFINED BM_SDK_ROOT) ...@@ -23,6 +23,9 @@ if(NOT DEFINED BM_SDK_ROOT)
endif() endif()
endif() endif()
set(BM_SDK_CPLIB_RPATH ${BM_SDK_ROOT}/lib/bmcompiler)
set(BM_SDK_LIB_RPATH ${BM_SDK_ROOT}/lib/bmnn/pcie)
message(STATUS "BM_SDK_ROOT: ${BM_SDK_ROOT}") message(STATUS "BM_SDK_ROOT: ${BM_SDK_ROOT}")
find_path(BM_SDK_INC NAMES bmruntime_interface.h find_path(BM_SDK_INC NAMES bmruntime_interface.h
PATHS ${BM_SDK_ROOT}/include/bmruntime NO_DEFAULT_PATH) PATHS ${BM_SDK_ROOT}/include/bmruntime NO_DEFAULT_PATH)
...@@ -37,43 +40,35 @@ include_directories("${BM_SDK_ROOT}/include/bmcpu") ...@@ -37,43 +40,35 @@ include_directories("${BM_SDK_ROOT}/include/bmcpu")
include_directories("${BM_SDK_ROOT}/include/bmlog") include_directories("${BM_SDK_ROOT}/include/bmlog")
find_library(BM_SDK_RT_LIB NAMES bmrt find_library(BM_SDK_RT_LIB NAMES bmrt
PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie) PATHS ${BM_SDK_LIB_RPATH})
if(NOT BM_SDK_RT_LIB) if(NOT BM_SDK_RT_LIB)
message(FATAL_ERROR "Can not find bmrt Library in ${BM_SDK_ROOT}") message(FATAL_ERROR "Can not find bmrt Library in ${BM_SDK_ROOT}")
else() else()
message(STATUS "Found bmrt Library: ${BM_SDK_RT_LIB}") message(STATUS "Found bmrt Library: ${BM_SDK_RT_LIB}")
add_library(bmrt SHARED IMPORTED GLOBAL)
set_property(TARGET bmrt PROPERTY IMPORTED_LOCATION ${BM_SDK_RT_LIB})
endif() endif()
find_library(BM_SDK_BM_LIB NAMES bmlib find_library(BM_SDK_BM_LIB NAMES bmlib
PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie) PATHS ${BM_SDK_LIB_RPATH})
if(NOT BM_SDK_BM_LIB) if(NOT BM_SDK_BM_LIB)
message(FATAL_ERROR "Can not find bmlib Library in ${BM_SDK_ROOT}") message(FATAL_ERROR "Can not find bmlib Library in ${BM_SDK_ROOT}")
else() else()
message(STATUS "Found bmlib Library: ${BM_SDK_BM_LIB}") message(STATUS "Found bmlib Library: ${BM_SDK_BM_LIB}")
add_library(bmlib SHARED IMPORTED GLOBAL)
set_property(TARGET bmlib PROPERTY IMPORTED_LOCATION ${BM_SDK_BM_LIB})
endif() endif()
find_library(BM_SDK_COMPILER_LIB NAMES bmcompiler find_library(BM_SDK_COMPILER_LIB NAMES bmcompiler
PATHS ${BM_SDK_ROOT}/lib/bmcompiler) PATHS ${BM_SDK_CPLIB_RPATH})
if(NOT BM_SDK_COMPILER_LIB) if(NOT BM_SDK_COMPILER_LIB)
message(FATAL_ERROR "Can not find bmcompiler Library in ${BM_SDK_ROOT}") message(FATAL_ERROR "Can not find bmcompiler Library in ${BM_SDK_ROOT}")
else() else()
message(STATUS "Found bmcompiler Library: ${BM_SDK_COMPILER_LIB}") message(STATUS "Found bmcompiler Library: ${BM_SDK_COMPILER_LIB}")
add_library(bmcompiler SHARED IMPORTED GLOBAL)
set_property(TARGET bmcompiler PROPERTY IMPORTED_LOCATION ${BM_SDK_COMPILER_LIB})
endif() endif()
find_library(BM_SDK_CPU_LIB NAMES bmcpu find_library(BM_SDK_CPU_LIB NAMES bmcpu
PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie) PATHS ${BM_SDK_LIB_RPATH})
if(NOT BM_SDK_CPU_LIB) if(NOT BM_SDK_CPU_LIB)
message(FATAL_ERROR "Can not find bmcpu Library in ${BM_SDK_ROOT}") message(FATAL_ERROR "Can not find bmcpu Library in ${BM_SDK_ROOT}")
else() else()
message(STATUS "Found bmcpu Library: ${BM_SDK_CPU_LIB}") message(STATUS "Found bmcpu Library: ${BM_SDK_CPU_LIB}")
add_library(bmcpu SHARED IMPORTED GLOBAL)
set_property(TARGET bmcpu PROPERTY IMPORTED_LOCATION ${BM_SDK_CPU_LIB})
endif() endif()
set(bm_runtime_libs bmrt bmlib bmcompiler bmcpu CACHE INTERNAL "bm runtime libs") set(bm_runtime_libs bmrt bmlib bmcompiler bmcpu CACHE INTERNAL "bm runtime libs")
......
...@@ -44,6 +44,8 @@ sh run_benchmark.sh ...@@ -44,6 +44,8 @@ sh run_benchmark.sh
3. 自动执行另一个脚本`benchmark.sh`(多台手机连接USB,请在`benchmark.sh`脚本中对`adb`命令后加上测试手机的`serial number`); 3. 自动执行另一个脚本`benchmark.sh`(多台手机连接USB,请在`benchmark.sh`脚本中对`adb`命令后加上测试手机的`serial number`);
4. 从手机下载benchmark结果`result_armv7.txt``result_armv8.txt`,到当前目录,并显示Benchmark结果。 4. 从手机下载benchmark结果`result_armv7.txt``result_armv8.txt`,到当前目录,并显示Benchmark结果。
> **注意:** 如果运行中遇到`Operation not permitted`的问题,请使用`sudo +sh run_benchmark.sh`给予授权,并尝试重新关闭/打开手机**USB调试**和**文件传输模式**,或者通过USB重新连接手机之后再次运行脚本。
## 二. 逐步Benchmark ## 二. 逐步Benchmark
### 1. 编译benchmark可执行文件 ### 1. 编译benchmark可执行文件
......
...@@ -36,9 +36,11 @@ ...@@ -36,9 +36,11 @@
**需要的环境**: Android Studio、Android手机(开启USB调试模式)、下载到本地的[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)工程 **需要的环境**: Android Studio、Android手机(开启USB调试模式)、下载到本地的[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)工程
**预先要求**:如果您的Android Studio尚未配置NDK,请根据Android Studio用户指南中的[安装及配置NDK和CMake](https://developer.android.com/studio/projects/install-ndk)内容,预先配置好NDK。您可以选择最新的NDK版本,或者与[Android编译环境配置](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#android)中的NDK版本保持一致。
**部署步骤** **部署步骤**
1、 目标检测的Android示例位于 `Paddle-Lite-Demo\PaddleLite-android-demo\object_detection_demo` 1、目标检测的Android示例位于 `Paddle-Lite-Demo\PaddleLite-android-demo\object_detection_demo`
2、用Android Studio 打开object_detection_demo工程 (本步骤需要联网)。 2、用Android Studio 打开object_detection_demo工程 (本步骤需要联网)。
...@@ -46,12 +48,17 @@ ...@@ -46,12 +48,17 @@
![Android_studio](https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/Android_studio.png) ![Android_studio](https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/Android_studio.png)
**注意:** 如果您在导入项目、编译或者运行过程中遇到NDK配置错误的提示,请打开 File > Project Structure > SDK Location,修改 "Andriod NDK location" 为您本机配置的NDK所在路径。如果您是通过Andriod Studio的SDK Tools下载的NDK (见本章节"预先要求"),可以直接点击下拉框选择默认路径。如果以上步骤仍旧无法解决NDK配置错误,请尝试根据Andriod Studio官方文档中的[更新 Android Gradle 插件](https://developer.android.com/studio/releases/gradle-plugin?hl=zh-cn#updating-plugin)章节,尝试更新Android Gradle plugin版本。
<p align="center"><img width="600" height="450" src="https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/Andriod_Studio_NDK.png"/>
4、按下 Run按钮,自动编译APP并安装到手机。(该过程会自动下载Paddle-Lite预测库和模型,需要联网) 4、按下 Run按钮,自动编译APP并安装到手机。(该过程会自动下载Paddle-Lite预测库和模型,需要联网)
成功后效果如下,图一:APP安装到手机 图二: APP打开后的效果,会自动识别图片中的物体并标记 成功后效果如下,图一:APP安装到手机 图二: APP打开后的效果,会自动识别图片中的物体并标记
<p align="center"><img width="300" height="450" src="https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/AndroidApp0.png"/>&#8194;&#8194;&#8194;&#8194;&#8194;<img width="300" height="450" src="https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/AndroidApp1.jpg"/></p> <p align="center"><img width="300" height="450" src="https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/AndroidApp0.png"/>&#8194;&#8194;&#8194;&#8194;&#8194;<img width="300" height="450" src="https://paddlelite-data.bj.bcebos.com/doc_images/Android_iOS_demo/android/AndroidApp1.jpg"/></p>
## Android demo结构讲解 ## Android demo结构讲解
Android 示例的代码结构如下图所示: Android 示例的代码结构如下图所示:
......
# PaddleLite使用华为NPU(Kirin SoC)预测部署
Paddle Lite是首款支持华为自研达芬奇架构NPU(Kirin 810/990 SoC搭载的NPU)的预测框架。
原理是在线分析Paddle模型,将Paddle算子转成HiAI IR后,调用HiAI IR/Builder/Runtime APIs生成并执行HiAI模型。
## 支持现状
### 已支持的芯片
- Kirin 810/990/9000。
### 已支持的设备
- HUAWEI nova5、nova5i Pro、mate30、mate30 pro、mate30 5G、荣耀v30、p40、p40 pro和即将推出的mate40。
### 已支持的Paddle模型
- [MobileNetV1](https://paddlelite-demo.bj.bcebos.com/models/mobilenet_v1_fp32_224_fluid.tar.gz)
- [MobileNetV2](https://paddlelite-demo.bj.bcebos.com/models/mobilenet_v2_fp32_224_fluid.tar.gz)
- ResNet系列(例如[ResNet18](https://paddlelite-demo.bj.bcebos.com/models/resnet18_fp32_224_fluid.tar.gz)[ResNet50](https://paddlelite-demo.bj.bcebos.com/models/resnet50_fp32_224_fluid.tar.gz)
- [SqueezeNet](https://paddlelite-demo.bj.bcebos.com/models/squeezenet_fp32_224_fluid.tar.gz)
- [MnasNet](https://paddlelite-demo.bj.bcebos.com/models/mnasnet_fp32_224_fluid.tar.gz)
- [MobileNet-SSD](https://paddlelite-demo.bj.bcebos.com/models/ssd_mobilenet_v1_pascalvoc_fp32_300_fluid.tar.gz) *
- YOLOv3系列(例如[YOLOv3-MobileNetV3](https://paddlelite-demo.bj.bcebos.com/models/yolov3_mobilenet_v3_prune86_FPGM_320_fp32_fluid.tar.gz)*
- [Transformer](https://github.com/PaddlePaddle/models/tree/release/1.8/PaddleNLP/machine_translation/transformer) *
- CycleGAN
- 百度内部业务模型(由于涉密,不方便透露具体细节)
*表示该模型的部分算子不支持NPU加速,而是采用CPU+NPU异构计算方式获得支持。
### 已支持(或部分支持)的Paddle算子
| | | | |
|-|-|-|-|
|sigmoid|relu|tanh|relu_clipped|
|leaky_relu|softsign|hard_sigmoid|log|
|sqrt|square|thresholded_relu|batch_norm|
|less_than|concat|conv2d|depthwise_conv2d|
|conv2d_transpose|dropout|elementwise_add|elementwise_sub|
|elementwise_mul|elementwise_div|expand|fusion_elementwise_add_activation|
|fusion_elementwise_sub_activation|fusion_elementwise_mul_activation|fusion_elementwise_div_activation|increment|
|instance_norm (需要HiAI DDK330)|layer_norm (需要HiAI DDK330)|fc|bilinear_interp|
|nearest_interp|matmul|mul|pad2d|
|pool2d|reduce_mean|reshape|reshape2|
|scale|shuffle_channel|softmax|split|
|transpose|transpose2|unsqueeze|unsqueeze2|
可以通过访问[https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/kernels/npu/bridges/paddle_use_bridges.h](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/kernels/npu/bridges/paddle_use_bridges.h)获得最新的算子支持列表。
## 参考示例演示
### 测试设备(HUAWEI Mate30 5G)
![huwei_mate30_5g](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/huawei_mate30_5g.jpg)
### 准备设备环境
- 由于HiAI DDK可能依赖特定版本的ROM,建议用户更新至最新版EMUI系统,具体参考华为官方[手机升级指南](https://consumer.huawei.com/cn/support/update/)
### 准备交叉编译环境
- 为了保证编译环境一致,建议参考[源码编译](../user_guides/source_compile)中的Docker开发环境进行配置。
### 运行图像分类示例程序
-[https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/PaddleLite-android-demo.tar.gz](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/PaddleLite-android-demo.tar.gz)下载示例程序,解压后清单如下:
```shell
- PaddleLite-android-demo
- image_classification_demo # 基于MobileNetV1的图像分类示例程序
- assets
- images
- tabby_cat.jpg # 测试图片
- labels
- synset_words.txt # 1000分类label文件
- models
- mobilenet_v1_fp32_224_fluid # Paddle fluid non-combined格式的mobilenetv1 float32模型
- mobilenet_v1_fp32_224_for_cpu
- model.nb # 已通过opt转好的、适合ARM CPU的mobilenetv1模型
- mobilenet_v1_fp32_224_for_npu
- model.nb # 已通过opt转好的、适合华为 NPU的mobilenetv1模型
- shell # android shell端的示例程序,注意:HiAI存在限制,拥有ROOT权限才能正常运行shell端程序
- CMakeLists.txt # android shell端的示例程序CMake脚本
- build
- image_classification_demo # 已编译好的android shell端的示例程序
- image_classification_demo.cc # 示例程序源码
- build.sh # android shell端的示例程序编译脚本
- run.sh # android shell端的示例程序运行脚本
- apk # 常规android应用程序,无需ROOT
- app
- src
- main
- java # java层代码
- cpp # 自定义的jni实现
- app.iml
- build.gradle
- gradle
...
- libs
- PaddleLite
- bin
- opt # 适合Ubuntu x86平台、预编译的模型优化工具
- armeabi-v7a # 适合armv7架构的PaddleLite预编译库以及HiAI运行时库
- include # PaddleLite头文件,每次版本更新时记得替换掉,否则可能会出现segmentation fault或精度无法对齐的问题
- lib
- libc++_shared.so # HiAI DDK中的so库是基于c++_shared编译生成的,部署时记得带上它
- libpaddle_light_api_shared.so # 用于最终移动端部署的预编译PaddleLite库(tiny publish模式下编译生成的库)
- libpaddle_full_api_shared.so # 用于直接加载Paddle模型进行测试和Debug的预编译PaddleLite库(full publish模式下编译生成的库)
- libhiai.so # HiAI runtime库函数,主要实现模型加载、执行和Tensor的操作
- libhiai_ir.so # HiAI IR/Graph的定义
- libhiai_ir_build.so # HiAI IRGraph转om模型的接口
- libhcl.so # HiAI NPU高性能算子库
- libcpucl.so # HiAI的CPU算子库,PaddleLite中没有用到,理论上可以删掉
- arm64-v8a # 适合armv8架构的PaddleLite预编译库以及HiAI运行时库
- OpenCV # OpenCV 4.2 for android
- object_detection_demo # 基于YOLOv3_MobileNetV3的目标检测示例程序(手动子图划分章节会详细介绍)
```
- Android shell端的示例程序
- 按照以下命令分别运行转换后的ARM CPU模型和华为NPU模型,比较它们的性能和结果;
```shell
注意:
1)run.sh只能在连接设备的系统上运行,不能在docker环境执行(可能无法找到设备),也不能在设备上运行;
2)build.sh需要在docker环境中执行,否则,需要将build.sh的ANDROID_NDK修改为当前环境下的NDK路径;
3)以下执行结果均由armeabi-v7a库生成,如果需要测试arm64-v8a库,可将build.sh的ANDROID_ABI修改成arm64-v8a后重新生成image_classification_demo,同时将run.sh的ANDROID_ABI也修改成arm64-v8a即可)。
运行适用于华为NPU的mobilenetv1模型
$ cd PaddleLite-android-demo/image_classification_demo/assets/models
$ cp mobilenet_v1_fp32_224_for_npu/model.nb mobilenet_v1_fp32_224_fluid.nb
$ cd ../../shell
$ ./run.sh
...
iter 0 cost: 2.426000 ms
iter 1 cost: 2.428000 ms
iter 2 cost: 2.465000 ms
iter 3 cost: 2.401000 ms
iter 4 cost: 2.406000 ms
iter 5 cost: 2.492000 ms
iter 6 cost: 2.411000 ms
iter 7 cost: 2.397000 ms
iter 8 cost: 2.441000 ms
iter 9 cost: 2.402000 ms
warmup: 5 repeat: 10, average: 2.426900 ms, max: 2.492000 ms, min: 2.397000 ms
results: 3
Top0 tabby, tabby cat - 0.477539
Top1 Egyptian cat - 0.408447
Top2 tiger cat - 0.094788
Preprocess time: 1.724000 ms
Prediction time: 2.426900 ms
Postprocess time: 0.127000 ms
运行适用于ARM CPU的mobilenetv1模型
$ cd PaddleLite-android-demo/image_classification_demo/assets/models
$ cp mobilenet_v1_fp32_224_for_cpu/model.nb mobilenet_v1_fp32_224_fluid.nb
$ cd ../../shell
$ ./run.sh
...
iter 0 cost: 34.467999 ms
iter 1 cost: 34.514999 ms
iter 2 cost: 34.646000 ms
iter 3 cost: 34.713001 ms
iter 4 cost: 34.612000 ms
iter 5 cost: 34.551998 ms
iter 6 cost: 34.741001 ms
iter 7 cost: 34.655998 ms
iter 8 cost: 35.035000 ms
iter 9 cost: 34.661999 ms
warmup: 5 repeat: 10, average: 34.659999 ms, max: 35.035000 ms, min: 34.467999 ms
results: 3
Top0 tabby, tabby cat - 0.475009
Top1 Egyptian cat - 0.409486
Top2 tiger cat - 0.095744
Preprocess time: 1.714000 ms
Prediction time: 34.659999 ms
Postprocess time: 0.082000 ms
```
- 如果需要更改测试图片,可将图片拷贝到PaddleLite-android-demo/image_classification_demo/assets/images目录下,然后将run.sh的IMAGE_NAME设置成指定文件名即可;
- 如果需要重新编译示例程序,直接运行./build.sh即可。
- 常规Android应用程序
(如果不想按照以下步骤编译Android应用程序,可以直接在Android设备上通过浏览器访问[https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/image_classification_demo.apk](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/image_classification_demo.apk)下载和安装已编译好的apk)
- 访问[https://developer.android.google.cn/studio](https://developer.android.google.cn/studio/)下载安装Android Studio(当前Android demo app是基于Android Studio3.4开发的),如果无法访问,可以从[http://www.android-studio.org](http://www.android-studio.org/)下载;
- 打开Android Studio,在"Welcome to Android Studio"窗口点击"Open an existing Android Studio project",在弹出的路径选择窗口中进入"PaddleLite-android-demo/image_classification_demo/apk"目录,然后点击右下角的"Open"按钮即可导入工程;
- 通过USB连接Android手机、平板或开发板;
- 待工程加载完成后,首先,点击菜单栏的File->Sync Project with Gradle Files手动同步项目构建;然后,点击菜单栏的Build->Rebuild Project按钮,如果提示CMake版本不匹配,请点击错误提示中的'Install CMake xxx.xxx.xx'按钮,重新安装CMake,再次点击菜单栏的Build->Rebuild Project按钮;
- 待工程编译完成后,点击菜单栏的Run->Run 'App'按钮,在弹出的"Select Deployment Target"窗口选择已经连接的Android设备,然后点击"OK"按钮;
- 等待大约1分钟后(第一次时间比较长,需要耐心等待),app已经安装到设备上。默认使用ARM CPU模型进行推理,如下图所示,推理耗时34.8ms,整个流程(含预处理和后处理)的帧率约22fps;
![huawei_mate30_5g_mobilenet_v1_cpu](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/huawei_mate30_5g_mobilenet_v1_cpu.jpg)
- 点击app界面右下角的设置按钮,在弹出的设置页面点击"Choose pre-installed models",选择"mobilenet_v1_fp32_for_npu",点击返回按钮后,app将切换到华为NPU模型,如下图所示,推理耗时下降到3.4ms,帧率提高到29fps(由于代码中帧率统计限制在30fps以内,因此实际帧率会更高,具体地,您可以手动计算截图中Read GLFBO time、Write GLTexture time、Predict time和Postprocess time的总耗时)。
![huaewi_mate30_5g_mobilenet_v1_npu](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/huawei_mate30_5g_mobilenet_v1_npu.jpg)
### 更新模型
- 通过Paddle Fluid训练,或X2Paddle转换得到MobileNetv1 foat32模型[mobilenet_v1_fp32_224_fluid](https://paddlelite-demo.bj.bcebos.com/models/mobilenet_v1_fp32_224_fluid.tar.gz)
- 参考[模型转化方法](../user_guides/model_optimize_tool),利用opt工具转换生成华为NPU模型,仅需将valid_targets设置为npu,arm即可。
```shell
注意:为了保证opt工具和库版本一致,使用了PaddleLite-android-demo.tar.gz自带的opt程序(需要在Ubuntu x86平台执行)演示NPU模型生成的过程。
$ cd PaddleLite-android-demo/image_classification_demo/assets/models
$ GLOG_v=5 ../../../libs/PaddleLite/bin/opt --model_dir=mobilenet_v1_fp32_224_fluid \
--optimize_out_type=naive_buffer \
--optimize_out=opt_model \
--valid_targets=npu,arm
...
[I 8/12 6:56:25.460 ...elease/Paddle-Lite/lite/core/optimizer.h:229 RunPasses] == Running pass: memory_optimize_pass
[I 8/12 6:56:25.460 ...elease/Paddle-Lite/lite/core/optimizer.h:242 RunPasses] - Skip memory_optimize_pass because the target or kernel does not match.
[I 8/12 6:56:25.461 ...te/lite/core/mir/generate_program_pass.h:37 GenProgram] insts.size 1
[I 8/12 6:56:25.683 ...e-Lite/lite/model_parser/model_parser.cc:593 SaveModelNaive] Save naive buffer model in 'opt_model.nb' successfully
替换自带的NPU模型
$ cp opt_model.nb mobilenet_v1_fp32_224_for_npu/model.nb
```
- 注意:opt生成的模型只是标记了华为NPU支持的Paddle算子,并没有真正生成华为NPU模型,只有在执行时才会将标记的Paddle算子转成HiAI IR并组网得到HiAI IRGraph,然后生成并执行华为NPU模型(具体原理请参考Pull Request[#2576](https://github.com/PaddlePaddle/Paddle-Lite/pull/2576));
- 不同模型,不同型号(ROM版本)的华为手机,在执行阶段,由于某些Paddle算子无法完全转成HiAI IR,或目标手机的HiAI版本过低等原因,可能导致HiAI模型无法成功生成,在这种情况下,Paddle Lite会调用CPU版算子进行运算完成整个预测任务。
### 更新支持华为NPU的PaddleLite库
- 下载PaddleLite源码和最新版HiAI DDK
```shell
$ git clone https://github.com/PaddlePaddle/Paddle-Lite.git
$ cd Paddle-Lite
$ git checkout <release-version-tag>
$ wget https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/hiai_ddk_lib_330.tar.gz
$ tar -xvf hiai_ddk_lib_330.tar.gz
```
- 编译并生成PaddleLite+NPU for armv8 and armv7的部署库
```shell
For armv8
tiny_publish
$ ./lite/tools/build_android.sh --android_stl=c++_shared --with_extra=ON --with_log=ON --with_huawei_kirin_npu=ON --huawei_kirin_npu_sdk_root=./hiai_ddk_lib_330
full_publish
$ ./lite/tools/build_android.sh --android_stl=c++_shared --with_extra=ON --with_log=ON --with_huawei_kirin_npu=ON --huawei_kirin_npu_sdk_root=./hiai_ddk_lib_330 full_publish
For armv7
tiny_publish
$ ./lite/tools/build_android.sh --arch=armv7 --android_stl=c++_shared --with_extra=ON --with_log=ON --with_huawei_kirin_npu=ON --huawei_kirin_npu_sdk_root=./hiai_ddk_lib_330
full_publish
$ ./lite/tools/build_android.sh --arch=armv7 --android_stl=c++_shared --with_extra=ON --with_log=ON --with_huawei_kirin_npu=ON --huawei_kirin_npu_sdk_root=./hiai_ddk_lib_330 full_publish
备注:由于HiAI DDK的so库均基于c++_shared构建,建议将android stl设置为c++_shared,更多选项还可以通过 "./lite/tools/build_android.sh help" 查看。
```
- 将编译生成的build.lite.android.armv8.gcc/inference_lite_lib.android.armv8.npu/cxx/include替换PaddleLite-android-demo/libs/PaddleLite/arm64-v8a/include目录;
- 将tiny_publish模式下编译生成的build.lite.android.armv8.gcc/inference_lite_lib.android.armv8.npu/cxx/lib/libpaddle_light_api_shared.so替换PaddleLite-android-demo/libs/PaddleLite/arm64-v8a/lib/libpaddle_light_api_shared.so文件;
- 将full_publish模式下编译生成的build.lite.android.armv8.gcc/inference_lite_lib.android.armv8.npu/cxx/lib/libpaddle_full_api_shared.so替换PaddleLite-android-demo/libs/PaddleLite/arm64-v8a/lib/libpaddle_full_api_shared.so文件;
- 将编译生成的build.lite.android.armv7.gcc/inference_lite_lib.android.armv7.npu/cxx/include替换PaddleLite-android-demo/libs/PaddleLite/armeabi-v7a/include目录;
- 将tiny_publish模式下编译生成的build.lite.android.armv7.gcc/inference_lite_lib.android.armv7.npu/cxx/lib/libpaddle_light_api_shared.so替换PaddleLite-android-demo/libs/PaddleLite/armeabi-v7a/lib/libpaddle_light_api_shared.so文件;
- 将full_publish模式下编译生成的build.lite.android.armv7.gcc/inference_lite_lib.android.armv7.npu/cxx/lib/libpaddle_full_api_shared.so替换PaddleLite-android-demo/libs/PaddleLite/armeabi-v7a/lib/libpaddle_full_api_shared.so文件。
## 如何支持CPU+NPU异构计算?
- 上述示例中所使用的MobileNetv1 foat32模型[mobilenet_v1_fp32_224_fluid](https://paddlelite-demo.bj.bcebos.com/models/mobilenet_v1_fp32_224_fluid.tar.gz),它的所有算子均能成功转成华为NPU的HiAI IR,因此,能够获得非常好的NPU加速效果;
- 而实际情况是,你的模型中可能存在NPU不支持的算子,尽管opt工具可以成功生成CPU+NPU的异构模型,但可能因为一些限制等原因,模型最终执行失败或性能不够理想;
- 我们首先用一个简单的目标检测示例程序让你直观感受到CPU+NPU异构模型带来的性能提升;然后,简要说明一下华为NPU接入PaddleLite的原理;最后,详细介绍如何使用『自定义子图分割』功能生成正常运行的CPU+NPU异构模型。
### 运行目标检测示例程序
- 『运行图像分类示例程序』章节中的PaddleLite-android-demo.tar.gz同样包含基于[YOLOv3_MobileNetV3](https://paddlelite-demo.bj.bcebos.com/models/yolov3_mobilenet_v3_prune86_FPGM_320_fp32_fluid.tar.gz)的目标检测示例程序;
```shell
- PaddleLite-android-demo
- image_classification_demo # 基于MobileNetV1的图像分类示例程序
- libs # PaddleLite和OpenCV预编译库
- object_detection_demo # 基于YOLOv3_MobileNetV3的目标检测示例程序
- assets
- images
- kite.jpg # 测试图片
- labels
- coco-labels-2014_2017.txt # coco数据集的label文件
- models
- yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid # Paddle fluid combined格式的、剪枝后的YOLOv3_MobileNetV3 float32模型
- yolov3_mobilenet_v3_prune86_FPGM_fp32_320_for_cpu
- model.nb # 已通过opt转好的、适合CPU的YOLOv3_MobileNetV3模型
- yolov3_mobilenet_v3_prune86_FPGM_fp32_320_for_hybrid_cpu_npu
- model.nb # 已通过opt转好的、适合NPU+CPU的YOLOv3_MobileNetV3异构模型
- subgraph_custom_partition_config_file.txt # YOLOv3_MobileNetV3自定义子图分割配置文件
- shell # android shell端的示例程序,注意:HiAI存在限制,拥有ROOT权限才能正常运行shell端程序
- CMakeLists.txt # android shell端的示例程序CMake脚本
- build
- object_detection_demo # 已编译好的android shell端的示例程序
- object_detection_demo.cc.cc # 示例程序源码
- build.sh # android shell端的示例程序编译脚本
- run.sh # android shell端的示例程序运行脚本
- apk # 常规android应用程序,无需ROOT
```
- 运行Android shell端的示例程序
- 参考『运行图像分类示例程序』章节的类似步骤,通过以下命令比较CPU模型、CPU+NPU异构模型的性能和结果;
```shell
运行YOLOv3_MobileNetV3 CPU模型
$ cd PaddleLite-android-demo/object_detection_demo/assets/models
$ cp yolov3_mobilenet_v3_prune86_FPGM_fp32_320_for_cpu/model.nb yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid.nb
$ cd ../../shell
$ ./run.sh
...
warmup: 5 repeat: 10, average: 53.963000 ms, max: 54.161999 ms, min: 53.562000 ms
results: 24
[0] person - 0.986361 211.407288,334.633301,51.627228,133.759537
[1] person - 0.879052 261.493347,342.849823,40.597961,120.775108
...
[22] kite - 0.272905 362.982941,119.011330,14.060059,11.157372
[23] kite - 0.254866 216.051910,175.607956,70.241974,23.265827
Preprocess time: 4.882000 ms
Prediction time: 53.963000 ms
Postprocess time: 0.548000 ms
运行YOLOv3_MobileNetV3 CPU+NPU异构模型
$ cd PaddleLite-android-demo/object_detection_demo/assets/models
$ cp yolov3_mobilenet_v3_prune86_FPGM_fp32_320_for_hybrid_cpu_npu/model.nb yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid.nb
$ cd ../../shell
$ ./run.sh
...
warmup: 5 repeat: 10, average: 23.767200 ms, max: 25.287001 ms, min: 22.292000 ms
results: 24
[0] person - 0.986164 211.420929,334.705780,51.559906,133.627930
[1] person - 0.879287 261.553680,342.857300,40.531372,120.751106
...
[22] kite - 0.271422 362.977722,119.014709,14.053833,11.162636
[23] kite - 0.257437 216.123276,175.631500,70.095078,23.248249
Preprocess time: 4.951000 ms
Prediction time: 23.767200 ms
Postprocess time: 1.015000 ms
```
- 运行常规Android应用程序
(如果不想按照以下步骤编译Android应用程序,可以直接在Android设备上通过浏览器访问[https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/object_detection_demo.apk](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/object_detection_demo.apk)下载和安装已编译好的apk)
- 参考『运行图像分类示例程序』章节的类似步骤,通过Android Studio导入"PaddleLite-android-demo/object_detection_demo/apk"工程,生成和运行常规Android应用程序;
- 默认使用ARM CPU模型进行推理,如下图所示,推理耗时55.1ms,整个流程(含预处理和后处理)的帧率约15fps;
![huawei_mate30_5g_yolov3_mobilenet_v3_cpu](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/huawei_mate30_5g_yolov3_mobilenet_v3_cpu.jpg)
- 选择"yolov3_mobilenet_v3_for_hybrid_cpu_npu"后,如下图所示,推理耗时下降到26.9ms,帧率提高到28fps
![huawei_mate30_5g_yolov3_mobilenet_v3_hybrid_cpu_npu](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/huawei_mate30_5g_yolov3_mobilenet_v3_hybrid_cpu_npu.jpg)
### PaddleLite是如何支持华为NPU的?
- PaddleLite是如何加载Paddle模型并执行一次推理的?
- 如下图左半部分所示,Paddle模型的读取和执行,经历了Paddle推理模型文件的加载和解析、计算图的转化、图分析和优化、运行时程序的生成和执行等步骤:
![how_to_intergrate_hiai_to_paddlelite](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/how_to_intergrate_hiai_to_paddlelite.png)
- Paddle推理模型文件的加载和解析:基于ProtoBuf协议对Paddle推理模型文件进行反序列化,解析生成网络结构(描述算子和张量的关系)和参数信息(包括算子属性和权重张量);
- 计算图的转化:为了更好的描述网络拓扑结构和方便后续的优化,依据算子的输入、出张量关系,构建一个由算子节点、张量节点组成的有向无环图;
- 图分析和优化:由一些列pass(优化器)组成,pass是用于描述一个计算图优化生成另一个计算图的过程;例如conv2d_bn_fuse_pass,它用于将模型中每一个conv2d、batch_norm相连的算子对融合成一个conv2d算子以便获得性能上的提升;
- 运行时程序的生成和执行:按照拓扑顺序遍历最终优化后的计算图,生成算子kernel列表,依次执行每一个算子kernel后即完成一次模型的推理。
- PaddleLite是如何支持华为NPU呢?
- 为了支持华为NPU,我们额外增加了(如上图标黄的区域):Subgraph detection pass、NPU subgraph op kernel和Paddle2HiAI op/tensor bridges。其中Subgraph detection pass是后续自定义子图划分涉及的关键步骤;
- Subgraph detection pass:该pass的作用是遍历计算图中所有的算子节点,标记能够转成HiAI IR的算子节点,然后通过图分割算法,将那些支持转为HiAI IR的、相邻的算子节点融合成一个subgraph(子图)算子节点(需要注意的是,这个阶段算子节点并没有真正转为HiAI IR,更没有生成HiAI模型);
- NPU subgraph op kernel:根据Subgraph detection pass的分割结果,在生成的算子kernel列表中,可能存在多个subgraph算子kernel;每个subgraph算子kernel,都会将它所包裹的、能够转成HiAI IR的、所有Paddle算子,如上图右半部所示,依次调用对应的op bridge,组网生成一个HiAI Graph,最终,调用HiAI Runtime APIs生成并执行NPU模型;
- Paddle2HiAI op/tensor bridges:Paddle算子/张量转HiAI IR/tensor的桥接器,其目的是将Paddle算子、输入、输出张量转为HiAI组网IR和常量张量。
### 编写配置文件完成自定义子图分割,生成华为NPU与ARM CPU的异构模型
- 为什么需要进行手动子图划分?如果模型中存在不支持转HiAI IR的算子,Subgraph detection pass会在没有人工干预的情况下,可能将计算图分割为许多小的子图,而出现如下问题:
- 过多的子图会产生频繁的CPU<->NPU数据传输和NPU任务调度,影响整体性能;
- 由于NPU模型暂时不支持dynamic shape,因此,如果模型中存在输入和输出不定长的算子(例如一些检测类算子,NLP类算子),在模型推理过程中,可能会因输入、输出shape变化而不断生成NPU模型,从而导致性能变差,更有可能使得NPU模型生成失败。
- 实现原理
- Subgraph detection pass在执行分割任务前,通过读取指定配置文件的方式获得禁用NPU的算子列表,实现人为干预分割结果的目的。
- 具体步骤(以YOLOv3_MobileNetV3目标检测示例程序为例)
- 步骤1:查看[YOLOv3_MobileNetV3](https://paddlelite-demo.bj.bcebos.com/models/yolov3_mobilenet_v3_prune86_FPGM_320_fp32_fluid.tar.gz)的模型结构,具体是将PaddleLite-android-demo/object_detection_demo/assets/models/yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid目录下的model复制并重名为__model__后,拖入[Netron页面](https://lutzroeder.github.io/netron/)即得到如下图所示的网络结构(部分):
![yolov3_mobilenet_v3_netron](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/yolov3_mobilenet_v3_netron.jpg)
- 步骤2:访问[https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/kernels/npu/bridges/paddle_use_bridges.h](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/kernels/npu/bridges/paddle_use_bridges.h)查看已支持的算子列表,发现NPU不支持yolo_box、multiclass_nms这两个算子;
- 步骤3:如果直接使用opt工具生成NPU模型,会发现整个网络被分割成3个子图(即3个subgraph op),subgraph1为MobileNetV3 backbone,subgraph2为1个transpose2和1个concat,subgraph3为2个transpose2和1个concat,它们都将运行在NPU上;
```shell
$ cd PaddleLite-android-demo/object_detection_demo/assets/models
$ GLOG_v=5 ../../../libs/PaddleLite/bin/opt --model_file=yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid/model \
--param_file=yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid/params \
--optimize_out_type=protobuf \
--optimize_out=opt_model \
--valid_targets=npu,arm
...
[4 8/12 14:12:50.559 ...e/Paddle-Lite/lite/core/mir/ssa_graph.cc:27 CheckBidirectionalConnection] node count 398
[4 8/12 14:12:50.560 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement feed host/any/any
[4 8/12 14:12:50.560 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement feed host/any/any
[4 8/12 14:12:50.560 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement subgraph npu/any/NCHW
[4 8/12 14:12:50.560 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement yolo_box arm/float/NCHW
[4 8/12 14:12:50.561 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement yolo_box arm/float/NCHW
[4 8/12 14:12:50.561 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement yolo_box arm/float/NCHW
[4 8/12 14:12:50.561 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement subgraph npu/any/NCHW
[4 8/12 14:12:50.561 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement subgraph npu/any/NCHW
[4 8/12 14:12:50.561 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement multiclass_nms host/float/NCHW
[4 8/12 14:12:50.561 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement fetch host/any/any
[I 8/12 14:12:50.561 ...te/lite/core/mir/generate_program_pass.h:37 GenProgram] insts.size 1
[4 8/12 14:12:50.836 ...e-Lite/lite/model_parser/model_parser.cc:308 SaveModelPb] Save protobuf model in 'opt_model' successfully
注意:为了方便查看优化后的模型,上述命令将`optimize_out_type`参数设置为protobuf,执行成功后将opt_model目录下的model文件复制为__model__并拖入Netron页面进行可视化。
```
![yolov3_mobilenet_v3_hybrid_cpu_npu_auto_split_netron](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/yolov3_mobilenet_v3_hybrid_cpu_npu_auto_split_netron.jpg)
- 步骤4:为了防止CPU与NPU频繁切换,去除subgraph2和subgraph3,强制让transpose2和concat运行在CPU上。那么,我们就需要通过环境变量SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE设置『自定义子图分割配置文件』,实现人为干预分割结果;
```shell
$ cd PaddleLite-android-demo/object_detection_demo/assets/models
$ cat ./subgraph_custom_partition_config_file.txt
transpose2:yolo_box0.tmp_1:transpose_0.tmp_0,transpose_0.tmp_1
transpose2:yolo_box1.tmp_1:transpose_1.tmp_0,transpose_1.tmp_1
transpose2:yolo_box2.tmp_1:transpose_2.tmp_0,transpose_2.tmp_1
concat:yolo_box0.tmp_0,yolo_box1.tmp_0,yolo_box2.tmp_0:concat_2.tmp_0
concat:transpose_0.tmp_0,transpose_1.tmp_0,transpose_2.tmp_0:concat_3.tmp_0
$ export SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE=./subgraph_custom_partition_config_file.txt
$ GLOG_v=5 ../../../libs/PaddleLite/bin/opt --model_file=yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid/model \
--param_file=yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid/params \
--optimize_out_type=protobuf \
--optimize_out=opt_model \
--valid_targets=npu,arm
...
[4 8/12 14:15:37.609 ...e/Paddle-Lite/lite/core/mir/ssa_graph.cc:27 CheckBidirectionalConnection] node count 401
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement feed host/any/any
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement feed host/any/any
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement subgraph npu/any/NCHW
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement yolo_box arm/float/NCHW
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement transpose2 arm/float/NCHW
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement yolo_box arm/float/NCHW
[4 8/12 14:15:37.610 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement transpose2 arm/float/NCHW
[4 8/12 14:15:37.611 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement yolo_box arm/float/NCHW
[4 8/12 14:15:37.611 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement transpose2 arm/float/NCHW
[4 8/12 14:15:37.611 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement concat arm/any/NCHW
[4 8/12 14:15:37.611 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement concat arm/any/NCHW
[4 8/12 14:15:37.611 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement multiclass_nms host/float/NCHW
[4 8/12 14:15:37.611 ...e/lite/core/mir/generate_program_pass.cc:46 Apply] Statement fetch host/any/any
[I 8/12 14:15:37.611 ...te/lite/core/mir/generate_program_pass.h:37 GenProgram] insts.size 1
[4 8/12 14:15:37.998 ...e-Lite/lite/model_parser/model_parser.cc:308 SaveModelPb] Save protobuf model in 'opt_model'' successfully
```
![yolov3_mobilenet_v3_hybrid_cpu_npu_manual_split_netron](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/yolov3_mobilenet_v3_hybrid_cpu_npu_manual_split_netron.jpg)
- 步骤5:上述步骤中,PaddleLite-android-demo/object_detection_demo/assets/models/subgraph_custom_partition_config_file.txt是示例自带的『自定义子图分割配置文件』,它的格式是什么样的呢?
- 每行记录由『算子类型:输入张量名列表:输出张量名列表』组成(即以分号分隔算子类型、输入和输出张量名列表),以逗号分隔输入、输出张量名列表中的每个张量名;
- 可省略输入、输出张量名列表中的部分张量名(如果不设置任何输入、输出张量列表,则代表计算图中该类型的所有算子节点均被强制运行在CPU上);
- 示例说明:
```
op_type0:var_name0,var_name1:var_name2 表示将算子类型为op_type0、输入张量为var_name0和var_name1、输出张量为var_name2的节点强制运行在CPU上
op_type1::var_name3 表示将算子类型为op_type1、任意输入张量、输出张量为var_name3的节点强制运行在CPU上
op_type2:var_name4 表示将算子类型为op_type2、输入张量为var_name4、任意输出张量的节点强制运行在CPU上
op_type3 表示任意算子类型为op_type3的节点均被强制运行在CPU上
```
- 步骤6:对于YOLOv3_MobileNetV3的模型,我们如何得到PaddleLite-android-demo/object_detection_demo/assets/models/subgraph_custom_partition_config_file.txt的配置呢?
- 重新在Netron打开PaddleLite-android-demo/object_detection_demo/assets/models/yolov3_mobilenet_v3_prune86_FPGM_fp32_320_fluid模型,如下图所示,1~5号节点需要强制放在CPU上运行。
![yolov3_mobilenet_v3_hybrid_cpu_npu_manual_split_step1_netron](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/yolov3_mobilenet_v3_hybrid_cpu_npu_manual_split_step1_netron.jpg)
- 在Netron中依次点击1~5号节点,右侧将显示每个节点的输入、输出张量名称,如下图所示,1号节点为transpose2类型算子,它的输入为yolo_box0.tmp1、输出为transpose_0.tmp_0,transpose_0.tmp_1,即可得到配置文件的第一条记录"transpose2:yolo_box0.tmp_1:transpose_0.tmp_0,transpose_0.tmp_1";
![yolov3_mobilenet_v3_hybrid_cpu_npu_manual_split_step2_netron](https://paddlelite-demo.bj.bcebos.com/devices/huawei/kirin/yolov3_mobilenet_v3_hybrid_cpu_npu_manual_split_step2_netron.jpg)
- 步骤7:将步骤4中的"optimize_out_type"修改为naive_buffer,重新执行步骤4即可以生成用于部署的CPU+NPU异构模型。
## 其它说明
- 华为达芬奇架构的NPU内部大量采用float16进行运算,因此,预测结果会存在偏差,但大部分情况下精度不会有较大损失,可参考[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)中Image Classification Demo for Android对同一张图片CPU与NPU的预测结果。
- 华为Kirin 810/990 Soc搭载的自研达芬奇架构的NPU,与Kirin 970/980 Soc搭载的寒武纪NPU不一样,同样的,与Hi3559A、Hi3519A使用的NNIE也不一样,Paddle Lite只支持华为自研达芬奇架构NPU。
- 我们正在持续增加能够适配HiAI IR的Paddle算子bridge/converter,以便适配更多Paddle模型,同时华为研发同学也在持续对HiAI IR性能进行优化。
# PaddleLite使用NPU(华为)预测部署
Paddle Lite是首款支持华为自研达芬奇架构NPU(Kirin 810/990 SoC搭载的NPU)的预测框架。
原理是在线分析Paddle模型,将Paddle算子转成HiAI IR后,调用HiAI IR/Builder/Runtime APIs生成并执行HiAI模型。
## 已支持的设备
- 华为nova5、nova5i pro、mate30、mate30 pro、mate30 5G、荣耀v30、p40、p40 pro,以及即将推出的mate40、。据华为透露,今后上市的大部分手机都会搭载其自研达芬奇架构NPU。
## 已支持的模型
- MobileNetV1
- MobileNetV2
- ResNet-18/50
- ShuffleNetV2
- squeezenet
- mnasnet
- yolov3
- CycleGAN (暂时需要华为内部rom的支持)
- 百度内部业务模型(由于涉密,不方便透露具体细节)
*CPU/NPU混合调度在部分模型可以获得更佳的性能*
## 已支持(或部分支持)的Paddle算子
- sigmoid
- relu
- tanh
- relu_clipped
- leaky_relu
- softsign
- hard_sigmoid
- batch_norm
- concat
- conv2d
- depthwise_conv2d
- conv2d_transpose
- dropout
- elementwise_add
- elementwise_sub
- elementwise_mul
- elementwise_div
- fusion_elementwise_add_activation
- fusion_elementwise_sub_activation
- fusion_elementwise_mul_activation
- fusion_elementwise_div_activation
- fc
- bilinear_interp
- nearest_interp
- matmul
- mul
- pad2d
- pool2d
- reduce_mean
- reshape
- reshape2
- scale
- shuffle_channel
- softmax
- split
- sqrt
- square
- transpose
- transpose2
- unsqueeze
- unsqueeze2
- instance_norm (暂时需要华为内部rom的支持)
- layer_norm (暂时需要华为内部rom的支持)
## 编译支持NPU的Paddle Lite库
-[华为HiAI平台](https://developer.huawei.com/consumer/cn/hiai)下载华为HiAI DDK后解压到任意路径(注意:华为提供了多个版本的DDK,我们需要下载针对麒麟810/990芯片HiAI Foundation开发套件,例如[DDK V310版本](https://obs.cn-north-2.myhwclouds.com/hms-ds-wf/sdk/hwhiai-ddk-100.310.011.010.zip))。
- 将HiAI DDK中的ai_ddk_lib目录拷贝至Paddle Lite源码根目录后,使用[编译脚本](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/tools/build_android.sh)编译 (需要指定NPU相关选项)。
注意:以下是HiAI DDK V310版解压后的目录结构,需要将ai_ddk_lib目录拷贝至Paddle Lite源码根目录。
```shell
- app_sample
- ddk
- ai_ddk_lib
- include
- lib # for armv7
- lib64 # for armv8
- document
- tools
```
- 推荐编译命令。由于HiAI DDK的so库均基于c++_shared构建,因此,建议使用c++_shared编译Paddle Lite。
```shell
# huawei_kirin_npu_sdk_root 需要指向 ai_ddk_lib 的路径
$ ./lite/tools/build_android.sh --android_stl=c++_shared --with_huawei_kirin_npu=ON --huawei_kirin_npu_sdk_root=<path-to-ai_ddk_lib>
# 其它选项可以通过 "./lite/tools/build_android.sh help" 查看,例如arm版本等
```
注意:为了保证编译环境一致,建议参考[源码编译](../user_guides/source_compile)中的Docker开发环境进行配置,然后再执行上述命令。
## 优化生成NPU模型
- model_optimize_tool工具已经支持生成NPU模型,仅需要将valid_targets设置为npu,arm即可,具体参考[模型转化方法](../user_guides/model_optimize_tool)
```shell
./model_optimize_tool --model_dir=<model_param_dir> \
--model_file=<model_path> \
--param_file=<param_path> \
--optimize_out_type=(protobuf|naive_buffer) \
--optimize_out=<output_optimize_model_dir> \
--valid_targets=npu,arm \
--record_tailoring_info =(true|false)
```
- model_optimize_tool生成的模型只是标记了NPU支持的Paddle算子,并没有真正生成NPU HiAI模型,只有在执行时才会将标记的Paddle算子转成HiAI IR,最终生成并执行HiAI模型,具体实现参考PR[2576](https://github.com/PaddlePaddle/Paddle-Lite/pull/2576)
- 不同模型,不同型号(ROM版本)的华为手机,在执行阶段,由于某些Paddle算子无法完全转成HiAI IR,或目标手机的HiAI版本过低等原因,可能导致HiAI模型无法成功生成,在这种情况下,Paddle Lite会调用CPU版算子进行运算完成整个预测任务。
## 通过JAVA接口加载并执行NPU模型
**注意:由于华为手机root权限限制,现在仅支持JAVA接口加载和执行NPU模型**
- 使用方法和[Java实例](java_demo)一致,无需额外设置任何参数,只需将模型换成NPU模型即可。[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)中的Image Classification Demo for Android是同时支持CPU和NPU两种模型的图像分类Demo。
注意:在拷贝libpaddle_lite_jni.so的时候,由于依赖HiAI DDK so和libc++_shared.so库,需要将HiAI DDK中ai_ddk_lib/lib或ai_ddk_lib/lib64目录下的所有so和libc++_shared.so,拷到libpaddle_lite_jni.so同级目录下。
## 其它说明
- 华为达芬奇架构的NPU内部大量采用float16进行运算,因此,预测结果会存在偏差,但大部分情况下精度不会有较大损失,可参考[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)中Image Classification Demo for Android对同一张图片CPU与NPU的预测结果。
- 华为Kirin 810/990 Soc搭载的自研达芬奇架构的NPU,与Kirin 970/980 Soc搭载的寒武纪NPU不一样,同样的,与Hi3559A、Hi3519A使用的NNIE也不一样,Paddle Lite只支持华为自研达芬奇架构NPU。
- 我们正在持续增加能够适配HiAI IR的Paddle算子bridge/converter,以便适配更多Paddle模型,同时华为研发同学也在持续对HiAI IR性能进行优化。
## 手动分割子图
### 背景
- Paddle-Lite已经支持了大量的华为NPU的算子,但是仍然不能满足所有模型的需求。对于一个有部分算子不支持的模型,Paddle-Lite会将模型划分为可以跑在NPU上的子图和跑在CPU上的子图,实现NPU和CPU自动调度的功能,通常情况下可以获得比较好的性能。在一些特殊情况下,模型会被自动划分为比较多的子图,导致CPU和NPU的切换开销很大,从而导致整体性能变差。因此,需要手动分割子图的功能来指定一些算子跑在CPU上,避免子图过多。
### 功能
- 通过配置文件来指定需要强制跑在CPU上的算子
### 使用方法
- 1、通过netron打开paddle模型文件,可以查看模型结构,获得算子的类型、输入名称。输出名称。
- 注意:Paddle-Lite会对模型进行优化,模型算子可以改变,需要以优化后的模型算子为准。后面会举例说明。
- 2、生成配置文件 ```split_cfg.txt```,记录需要跑在CPU上的算子信息。
- 每行一条OP记录信息,以冒号":"分隔"op名称","op输入名","op输出名",以逗号","分隔"op输入名"和"op输出名"中的不同var名。
- 可以部分省略输入或者输出名。比如:```op3:in3_var0```表示,指定类型为"op3",输入为"in3_var0"的算子;```op4```表示所有类型为"op4"的算子
- 例子1:
```
op0:in0_var0,in0_var1:out0_var0,out0_var1
op1:in1_var0,in1_var1:out1_var0
op2::out2_var0
op3:in3_var0
op4
```
- 例子2:
```
transpose:conv2d_22.tmp_1:transpose_0.tmp_0
```
![image](https://user-images.githubusercontent.com/50474132/80475316-4a5fda80-897b-11ea-910a-6aee13243387.png)
- 3、使用环境变量```SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE```指定配置文件的位置。
- 例如:
```
export SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE=/data/local/tmp/split_sfg.txt
```
- 4、以上步骤完成后,运行的模型中符合条件的算子将被强制跑在CPU上。
### 举例
- 以模型[image](https://paddlelite-demo.bj.bcebos.com/models/ssd_mobilenet_v1_pascalvoc_fp32_300_fluid.tar.gz)为例
- 1、可以使用netron查看模型
- 2、初步分析
- 下图是ssd_mobilenet_v1中的部分结构。其中红色部分暂时不支持在NPU上运行,蓝色部分可能NPU上的性能不理想。此时,如果直接让预测库自动调度的话,可能会分成多个子图,而且整体性能不佳。因此,可以将蓝色部分和绿色部分整体指定在CPU上运行,让其他部分自动运行在NPU上(红色部分会自动在CPU上运行)。
![](https://user-images.githubusercontent.com/50474132/80453173-525b5280-895a-11ea-847f-c7dd5b5799de.png)
- 3、使用opt转换模型
- opt转换过程中会打印log信息。在log中搜索```digraph G```和```// end G```可以找到优化后的模型图。
![](https://user-images.githubusercontent.com/50474132/80454098-145f2e00-895c-11ea-9f16-dde1483a9beb.png)
![](https://user-images.githubusercontent.com/50474132/80454123-1de89600-895c-11ea-86b9-a62d78a6616d.png)
- 将从```digraph G```开始的,到```// end G```结束的整段模型图信息,保存到```.dot```格式的文件中。可以用```graphviz```打开查看,或者在[网页版](http://dreampuf.github.io/GraphvizOnline/)查看。
![](https://user-images.githubusercontent.com/50474132/80454841-47ee8800-895d-11ea-9531-5689c5560fcb.png)
- 在此处确认需要被指定的算子是否被优化了。(期望是被指定的算子都还独立存在,如果被融合为了一个算子,需要指定此时融合后的算子)。
- 4、写配置文件
- 在配置文件中指定可以支持NPU但是需要指定在CPU上运行的算子。
```
reshape
transpose
concat
softmax
```
- 由于这些算子都指定在CPU上运行,因此不需要特意配置算子的输入输出名称。
- 5、指定配置文件路径
- 通过```export SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE=your_split_config_file```的方式实现。
- 6、性能测试
- 设备:华为mate30 5G
- HIAI ddk版本:320
- 性能:CPU约71.8ms,NPU约16.6ms。
...@@ -55,7 +55,7 @@ Welcome to Paddle-Lite's documentation! ...@@ -55,7 +55,7 @@ Welcome to Paddle-Lite's documentation!
demo_guides/cuda demo_guides/cuda
demo_guides/opencl demo_guides/opencl
demo_guides/fpga demo_guides/fpga
demo_guides/npu demo_guides/huawei_kirin_npu
demo_guides/baidu_xpu demo_guides/baidu_xpu
demo_guides/rockchip_npu demo_guides/rockchip_npu
demo_guides/mediatek_apu demo_guides/mediatek_apu
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
**注意:本编译方法只适用于release/v2.6.0之后版本(包括 v2.6.0)** **注意:本编译方法只适用于release/v2.6.0之后版本(包括 v2.6.0)**
安装了Android的编译环境,可以下载并编译 Paddle-Lite源码 如果您还没有配置好Andriod交叉编译环境,请先根据[环境准备](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#id2)中的内容,根据您的开发环境安装编译Android预测库所需的编译环境。运行编译脚本之前,请先检查环变量`NDK_ROOT`指向正确的Andriod NDK安装路径,之后可以下载并编译 Paddle-Lite源码。
```shell ```shell
# 1. 下载Paddle-Lite源码 并切换到release分支 # 1. 下载Paddle-Lite源码 并切换到release分支
...@@ -14,6 +14,7 @@ cd Paddle-Lite && git checkout release/v2.3 ...@@ -14,6 +14,7 @@ cd Paddle-Lite && git checkout release/v2.3
./lite/tools/build_android.sh ./lite/tools/build_android.sh
``` ```
> **提示:** 编译过程中,如果程序在下载第三方库时花费较多时间,请尝试删除Paddle-Lite下的`<lite-repo>/third-party`目录之后再次运行编译脚本,脚本会自动下载存储于百度云的第三方库代码包,节省从git repo下载第三方库代码的时间。
### 编译结果 ### 编译结果
......
...@@ -3,10 +3,14 @@ ...@@ -3,10 +3,14 @@
opt是 x86 平台上的可执行文件,需要在PC端运行:支持Linux终端和Mac终端。 opt是 x86 平台上的可执行文件,需要在PC端运行:支持Linux终端和Mac终端。
### 帮助信息 ### 帮助信息
执行opt时不加入任何输入选项,会输出帮助信息,提示当前支持的选项:
执行opt时不加入任何输入选项,会输出帮助信息,提示当前支持的选项:
```bash ```bash
./opt ./opt
``` ```
> **注意:** 如果您是通过[准备opt](https://paddle-lite.readthedocs.io/zh/latest/user_guides/model_optimize_tool.html#id1)页面中,"方法二:下载opt可执行文件" 中提供的链接下载得到的opt可执行文件,请先通过`chmod +x ./opt`命令为下载的opt文件添加可执行权限。
![](https://paddlelite-data.bj.bcebos.com/doc_images/1.png) ![](https://paddlelite-data.bj.bcebos.com/doc_images/1.png)
### 功能一:转化模型为Paddle-Lite格式 ### 功能一:转化模型为Paddle-Lite格式
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
### 2.3 配置校准数据生成器 ### 2.3 配置校准数据生成器
静态离线量化内部使用异步数据读取的方式读取校准数据,大家只需要根据模型的输入,配置读取数据的sample_generator。sample_generator是Python生成器,**必须每次返回单个样本数据**,会用作`DataLoader.set_sample_generator()`的数据源。 静态离线量化内部使用异步数据读取的方式读取校准数据,大家只需要根据模型的输入,配置读取数据的sample_generator。sample_generator是Python生成器,**必须每次返回单个样本数据**,会用作`DataLoader.set_sample_generator()`的数据源。
建议参考[异步数据读取文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/data_preparing/use_py_reader.html)和本文示例,学习如何配置校准数据生成器。 建议参考[异步数据读取文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/data_preparing/static_mode/use_py_reader.html)和本文示例,学习如何配置校准数据生成器。
### 2.4 调用静态离线量化 ### 2.4 调用静态离线量化
......
# 预编译库 # 预编译库下载
## 编译版本介绍 ## 编译版本介绍
......
# 模型转换工具 X2Paddle # 模型转换工具 X2Paddle
X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。 X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。目前支持版本为caffe 1.0;tensorflow 1.x,推荐1.4.0;ONNX 1.6.0,OpSet支持 9, 10, 11版本。
[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。 [X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。
支持的模型可参考**X2Paddle模型测试库:** 支持的模型可参考**X2Paddle模型测试库:**
......
...@@ -39,12 +39,16 @@ USE_MIR_PASS(identity_dropout_eliminate_pass); ...@@ -39,12 +39,16 @@ USE_MIR_PASS(identity_dropout_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_match_matrix_activation_fuse_pass);
USE_MIR_PASS(lite_scales_fuse_pass);
USE_MIR_PASS(lite_sequence_reverse_embedding_fuse_pass);
USE_MIR_PASS(lite_elementwise_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(type_layout_cast_pass);
USE_MIR_PASS(type_layout_cast_preprocess_pass); USE_MIR_PASS(type_layout_cast_preprocess_pass);
USE_MIR_PASS(memory_optimize_pass); USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(lite_reshape_fuse_pass);
USE_MIR_PASS(multi_stream_analysis_pass); USE_MIR_PASS(multi_stream_analysis_pass);
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(npu_subgraph_pass);
......
...@@ -9,6 +9,7 @@ if(WIN32) ...@@ -9,6 +9,7 @@ if(WIN32)
target_link_libraries(lite_pybind ${os_dependency_modules}) target_link_libraries(lite_pybind ${os_dependency_modules})
else() else()
lite_cc_library(lite_pybind SHARED SRCS pybind.cc DEPS ${PYBIND_DEPS}) lite_cc_library(lite_pybind SHARED SRCS pybind.cc DEPS ${PYBIND_DEPS})
target_sources(lite_pybind PUBLIC ${__lite_cc_files})
endif(WIN32) endif(WIN32)
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -20,61 +20,117 @@ namespace lite { ...@@ -20,61 +20,117 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_depthwise_3x3s1p0_bias(float *dout, void conv_depthwise_3x3s1p1_bias_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_s(float *dout, void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias(float *dout, void conv_depthwise_3x3s1p0_bias_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_s(float *dout, void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1_fp32(const float *din, void conv_depthwise_3x3s1_fp32(const float *din,
float *dout, float *dout,
...@@ -92,138 +148,270 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -92,138 +148,270 @@ void conv_depthwise_3x3s1_fp32(const float *din,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
bool has_active = act_param.has_active; bool has_active = act_param.has_active;
bool flag_relu = false; auto act_type = act_param.active_type;
bool relu6 = false; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
if (has_active) { if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) { switch (act_type) {
flag_relu = true; case lite_api::ActivationType::kRelu:
} else { if (pad == 0) {
relu6 = true; if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kRelu6:
if (pad == 0) {
if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kLeakyRelu:
if (pad == 0) {
if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_type)
<< " fuse not support";
} }
} } else {
if (pad == 0) { if (pad == 0) {
if (w_in > 5) { if (w_in > 5) {
if (relu6) { conv_depthwise_3x3s1p0_bias_no_relu(dout,
conv_depthwise_3x3s1p0_bias(dout, din,
din, weights,
weights, bias,
bias, flag_bias,
flag_bias, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param, ctx);
ctx);
} else {
conv_depthwise_3x3s1p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s1p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s1p0_bias_s_relu(dout, conv_depthwise_3x3s1p0_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
} }
} }
} if (pad == 1) {
if (pad == 1) { if (w_in > 4) {
if (w_in > 4) { conv_depthwise_3x3s1p1_bias_no_relu(dout,
if (relu6) { din,
conv_depthwise_3x3s1p1_bias(dout, weights,
din, bias,
weights, flag_bias,
bias, false,
flag_bias, num,
num, ch_in,
ch_in, h_in,
h_in, w_in,
w_in, h_out,
h_out, w_out,
w_out, ctx);
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s1p1_bias_relu(dout, conv_depthwise_3x3s1p1_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s1p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} }
} }
} }
...@@ -1978,338 +2166,19 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1978,338 +2166,19 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#endif #endif
#ifdef __aarch64__ void conv_depthwise_3x3s1p1_bias_relu6(float *dout,
void act_switch_3x3s1p1(const float *din_ptr0, const float *din,
const float *din_ptr1, const float *weights,
const float *din_ptr2, const float *bias,
const float *din_ptr3, const float *six,
const float *din_ptr4, bool flag_bias,
const float *din_ptr5, const int num,
float *doutr0, const int ch_in,
float *doutr1, const int h_in,
float *doutr2, const int w_in,
float *doutr3, const int h_out,
float32x4_t wr0, const int w_out,
float32x4_t wr1, ARMContext *ctx) {
float32x4_t wr2,
unsigned int *vmask,
unsigned int *rmask,
float32x4_t vzero,
float *vbias,
int cnt,
const operators::ActivationParam act_param) {
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#else
void act_switch_3x3s1p1(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask_ptr,
unsigned int *rmask_ptr,
float32x4_t vzero,
float bias_val,
int cnt,
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[six_ptr] "r"(vsix),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[scale_ptr] "r"(vscale),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p1_bias(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! pad is done implicit //! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window //! for 4x6 convolution window
...@@ -2355,7 +2224,9 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2355,7 +2224,9 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
vst1q_u32(rmask, vmask_result); vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel; const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel; float *dout_batch = dout + n * ch_in * size_out_channel;
...@@ -2458,25 +2329,56 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2458,25 +2329,56 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
act_switch_3x3s1p1(din_ptr0, asm volatile(
din_ptr1, INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
din_ptr2, MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
din_ptr3, : [cnt] "+r"(cnt),
din_ptr4, [din_ptr0] "+r"(din_ptr0),
din_ptr5, [din_ptr1] "+r"(din_ptr1),
doutr0, [din_ptr2] "+r"(din_ptr2),
doutr1, [din_ptr3] "+r"(din_ptr3),
doutr2, [din_ptr4] "+r"(din_ptr4),
doutr3, [din_ptr5] "+r"(din_ptr5),
wr0, [doutr0] "+r"(doutr0),
wr1, [doutr1] "+r"(doutr1),
wr2, [doutr2] "+r"(doutr2),
vmask, [doutr3] "+r"(doutr3)
rmask, : [w0] "w"(wr0),
vzero, [w1] "w"(wr1),
vbias, [w2] "w"(wr2),
cnt, [vsix] "w"(vsix),
act_param); [bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
...@@ -2525,759 +2427,58 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2525,759 +2427,58 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
int cnt = cnt_col; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p1(din_ptr0, asm volatile(
din_ptr1, INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
din_ptr2, MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
din_ptr3, : [dout_ptr1] "+r"(doutr0),
doutr0, [dout_ptr2] "+r"(doutr1),
doutr1, [din0_ptr] "+r"(din_ptr0),
wr0, [din1_ptr] "+r"(din_ptr1),
wr1, [din2_ptr] "+r"(din_ptr2),
wr2, [din3_ptr] "+r"(din_ptr3),
vmask_ptr, [cnt] "+r"(cnt),
rmask_ptr, [rmask] "+r"(rmask_ptr),
vzero, [vmask] "+r"(vmask_ptr)
bias_val, : [wr0] "w"(wr0),
cnt, [wr1] "w"(wr1),
act_param); [wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[six_ptr] "r"(six),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
} }
} }
} }
void act_switch_3x3s1p1_s(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp,
float32x4_t vzero,
float32x4_t wbias,
const operators::ActivationParam act_param) {
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[six_ptr] "r"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[scale_ptr] "r"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
for (int j = 0; j < h_out; j += 2) {
const float *dr0_ptr = dr0;
const float *dr1_ptr = dr1;
const float *dr2_ptr = dr2;
const float *dr3_ptr = dr3;
if (j == 0) {
dr0_ptr = zero;
dr1_ptr = dr0;
dr2_ptr = dr1;
dr3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
dr1_ptr = zero;
case 2:
dr2_ptr = zero;
case 1:
dr3_ptr = zero;
default:
break;
}
}
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
act_switch_3x3s1p1_s(dr0_ptr,
dr1_ptr,
dr2_ptr,
dr3_ptr,
out_buf1,
out_buf2,
wr0,
wr1,
wr2,
vmask_rp,
vzero,
wbias,
act_param);
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
#ifdef __aarch64__
void act_switch_3x3s1p0(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
const float *din_ptr4,
const float *din_ptr5,
float *doutr0,
float *doutr1,
float *doutr2,
float *doutr3,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask,
unsigned int *rmask,
float32x4_t vzero,
float *vbias,
int cnt,
int remain,
const operators::ActivationParam act_param) {
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU6
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#else
void act_switch_3x3s1p0(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask_ptr,
unsigned int *rmask_ptr,
float32x4_t vzero,
float bias_val,
int cnt,
int remain,
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) { void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout,
case lite_api::ActivationType::kRelu: const float *din,
asm volatile(INIT_S1 const float *weights,
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" const float *bias,
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" const float *scale,
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" bool flag_bias,
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" const int num,
"vext.32 q6, q8, q9, #1 @ 0012\n" const int ch_in,
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 const int h_in,
MID_RESULT_S1_RELU const int w_in,
"cmp %[remain], #1 \n" const int h_out,
"blt 0f \n" RIGHT_COMPUTE_S1 const int w_out,
RIGHT_RESULT_S1_RELU "0: \n" ARMContext *ctx) {
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(vsix),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! pad is done implicit //! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window //! for 4x6 convolution window
...@@ -3293,14 +2494,19 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3293,14 +2494,19 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
int tile_w = w_out >> 2; int tile_w = w_out >> 2;
int remain = w_out % 4; int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3}; const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 if (remain == 0 && size_pad_right == 5) {
tile_w -= 1; size_pad_right = 1;
cnt_col -= 1;
remain = 4; remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2; size_pad_right = 2;
cnt_col -= 1;
remain = 4;
} }
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp1 =
...@@ -3308,7 +2514,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3308,7 +2514,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result = uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
...@@ -3318,7 +2524,9 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3318,7 +2524,9 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
vst1q_u32(rmask, vmask_result); vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel; const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel; float *dout_batch = dout + n * ch_in * size_out_channel;
...@@ -3355,7 +2563,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3355,7 +2563,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const float *din_ptr3 = dr3; const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4; const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5; const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero); float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__ #ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) { for (int i = 0; i < h_out; i += 4) {
...@@ -3371,26 +2578,37 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3371,26 +2578,37 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
doutr1 = doutr0 + w_out; doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out; doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out; doutr3 = doutr2 + w_out;
if (i == 0) {
dr0 = dr4; din_ptr0 = zero_ptr;
dr1 = dr5; din_ptr1 = dr0;
dr2 = dr1 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in; dr3 = dr2 + w_in;
dr4 = dr3 + w_in; dr4 = dr3 + w_in;
dr5 = dr4 + w_in; dr5 = dr4 + w_in;
//! process bottom pad //! process bottom pad
if (i + 5 >= h_in) { if (i + 5 > h_in) {
switch (i + 5 - h_in) { switch (i + 5 - h_in) {
case 4: case 5:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
case 3: case 4:
din_ptr2 = zero_ptr; din_ptr2 = zero_ptr;
case 2: case 3:
din_ptr3 = zero_ptr; din_ptr3 = zero_ptr;
case 1: case 2:
din_ptr4 = zero_ptr; din_ptr4 = zero_ptr;
case 0: case 1:
din_ptr5 = zero_ptr; din_ptr5 = zero_ptr;
default: default:
break; break;
...@@ -3410,31 +2628,62 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3410,31 +2628,62 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
} }
} }
int cnt = tile_w; int cnt = cnt_col;
act_switch_3x3s1p0(din_ptr0, asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
din_ptr1, MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
din_ptr2, RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU
din_ptr3, : [cnt] "+r"(cnt),
din_ptr4, [din_ptr0] "+r"(din_ptr0),
din_ptr5, [din_ptr1] "+r"(din_ptr1),
doutr0, [din_ptr2] "+r"(din_ptr2),
doutr1, [din_ptr3] "+r"(din_ptr3),
doutr2, [din_ptr4] "+r"(din_ptr4),
doutr3, [din_ptr5] "+r"(din_ptr5),
wr0, [doutr0] "+r"(doutr0),
wr1, [doutr1] "+r"(doutr1),
wr2, [doutr2] "+r"(doutr2),
vmask, [doutr3] "+r"(doutr3)
rmask, : [w0] "w"(wr0),
vzero, [w1] "w"(wr1),
vbias, [w2] "w"(wr2),
cnt, [vscale] "w"(vscale),
remain, [bias_val] "r"(vbias),
act_param); [vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
for (int i = 0; i < h_out; i += 2) { for (int i = 0; i < h_out; i += 2) {
//! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
din_ptr2 = dr2; din_ptr2 = dr2;
...@@ -3443,13 +2692,24 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3443,13 +2692,24 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
doutr0 = dout_ptr; doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out; doutr1 = dout_ptr + w_out;
dr0 = dr2; if (i == 0) {
dr1 = dr3; din_ptr0 = zero_ptr;
dr2 = dr1 + w_in; din_ptr1 = dr0;
dr3 = dr2 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad //! process bottom pad
if (i + 4 > h_in) { if (i + 3 > h_in) {
switch (i + 4 - h_in) { switch (i + 3 - h_in) {
case 3: case 3:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
case 2: case 2:
...@@ -3464,292 +2724,1140 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3464,292 +2724,1140 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
if (i + 2 > h_out) { if (i + 2 > h_out) {
doutr1 = write_ptr; doutr1 = write_ptr;
} }
int cnt = tile_w; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0(din_ptr0, asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
din_ptr1, MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
din_ptr2, RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU
din_ptr3, : [dout_ptr1] "+r"(doutr0),
doutr0, [dout_ptr2] "+r"(doutr1),
doutr1, [din0_ptr] "+r"(din_ptr0),
wr0, [din1_ptr] "+r"(din_ptr1),
wr1, [din2_ptr] "+r"(din_ptr2),
wr2, [din3_ptr] "+r"(din_ptr3),
vmask_ptr, [cnt] "+r"(cnt),
rmask_ptr, [rmask] "+r"(rmask_ptr),
vzero, [vmask] "+r"(vmask_ptr)
bias_val, : [wr0] "w"(wr0),
cnt, [wr1] "w"(wr1),
remain, [wr2] "w"(wr2),
act_param); [bias_val] "r"(bias_val),
[scale_ptr] "r"(scale),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
} }
} }
} }
void act_switch_3x3s1p0_s(const float *din_ptr0,
const float *din_ptr1, void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout,
const float *din_ptr2, const float *din,
const float *din_ptr3, const float *weights,
float *doutr0, const float *bias,
float *doutr1, const float *six,
float32x4_t wr0, bool flag_bias,
float32x4_t wr1, const int num,
float32x4_t wr2, const int ch_in,
uint32x4_t vmask_rp1, const int h_in,
uint32x4_t vmask_rp2, const int w_in,
float32x4_t vzero, const int h_out,
float32x4_t wbias, const int w_out,
unsigned int *vmask_ptr, ARMContext *ctx) {
float bias_val, const int right_pad_idx[4] = {3, 2, 1, 0};
const operators::ActivationParam act_param) { const float zero[4] = {0.f, 0.f, 0.f, 0.f};
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[six_ptr] "r"(vsix),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
for (int j = 0; j < h_out; j += 2) {
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
if (j == 0) {
din_ptr0 = zero;
din_ptr1 = dr0;
din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
din_ptr1 = zero;
case 2:
din_ptr2 = zero;
case 1:
din_ptr3 = zero;
default:
break;
}
}
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[six_ptr] "r"(six),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
for (int j = 0; j < h_out; j += 2) {
const float *dr0_ptr = dr0;
const float *dr1_ptr = dr1;
const float *dr2_ptr = dr2;
const float *dr3_ptr = dr3;
if (j == 0) {
dr0_ptr = zero;
dr1_ptr = dr0;
dr2_ptr = dr1;
dr3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
dr1_ptr = zero;
case 2:
dr2_ptr = zero;
case 1:
dr3_ptr = zero;
default:
break;
}
}
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(dr0_ptr),
[din1] "+r"(dr1_ptr),
[din2] "+r"(dr2_ptr),
[din3] "+r"(dr3_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(dr0_ptr),
[din1] "+r"(dr1_ptr),
[din2] "+r"(dr2_ptr),
[din3] "+r"(dr3_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[scale_ptr] "r"(scale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p0_bias_relu6(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0
tile_w -= 1;
remain = 4;
size_pad_right = 2;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU6
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(six),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
float bias_val = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_val = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 4 > h_in) {
switch (j + 4 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
default:
break;
}
}
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
unsigned int *vmask_ptr = vmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[six_ptr] "r"(six),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0
tile_w -= 1;
remain = 4;
size_pad_right = 2;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[scale_ptr] "r"(scale),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2); vst1q_u32(vmask + 4, vmask_rp2);
...@@ -3808,22 +3916,70 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -3808,22 +3916,70 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
doutr1 = trash_buf; doutr1 = trash_buf;
} }
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0_s(dr0, #ifdef __aarch64__
dr1, asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
dr2, : [din0] "+r"(dr0),
dr3, [din1] "+r"(dr1),
out_buf1, [din2] "+r"(dr2),
out_buf2, [din3] "+r"(dr3)
wr0, : [wr0] "w"(wr0),
wr1, [wr1] "w"(wr1),
wr2, [wr2] "w"(wr2),
vmask_rp1, [vbias] "w"(wbias),
vmask_rp2, [mask1] "w"(vmask_rp1),
vzero, [mask2] "w"(vmask_rp2),
wbias, [vzero] "w"(vzero),
vmask_ptr, [vscale] "w"(vscale),
bias_val, [out1] "r"(doutr0),
act_param); [out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[scale_ptr] "r"(scale),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w]; *doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w]; *doutr1++ = out_buf2[w];
......
...@@ -1202,19 +1202,19 @@ namespace math { ...@@ -1202,19 +1202,19 @@ namespace math {
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4 * width > 4
*/ */
void conv_depthwise_3x3s1p1_bias_relu(float *dout, void conv_depthwise_3x3s1p1_bias_no_relu(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
ARMContext *ctx) { ARMContext *ctx) {
//! pad is done implicit //! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window //! for 4x6 convolution window
...@@ -1363,106 +1363,54 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout, ...@@ -1363,106 +1363,54 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
if (flag_relu) { asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
asm volatile( MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 : [cnt] "+r"(cnt),
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU [din_ptr0] "+r"(din_ptr0),
: [cnt] "+r"(cnt), [din_ptr1] "+r"(din_ptr1),
[din_ptr0] "+r"(din_ptr0), [din_ptr2] "+r"(din_ptr2),
[din_ptr1] "+r"(din_ptr1), [din_ptr3] "+r"(din_ptr3),
[din_ptr2] "+r"(din_ptr2), [din_ptr4] "+r"(din_ptr4),
[din_ptr3] "+r"(din_ptr3), [din_ptr5] "+r"(din_ptr5),
[din_ptr4] "+r"(din_ptr4), [doutr0] "+r"(doutr0),
[din_ptr5] "+r"(din_ptr5), [doutr1] "+r"(doutr1),
[doutr0] "+r"(doutr0), [doutr2] "+r"(doutr2),
[doutr1] "+r"(doutr1), [doutr3] "+r"(doutr3)
[doutr2] "+r"(doutr2), : [w0] "w"(wr0),
[doutr3] "+r"(doutr3) [w1] "w"(wr1),
: [w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [bias_val] "r"(vbias),
[w2] "w"(wr2), [vmask] "r"(vmask),
[bias_val] "r"(vbias), [rmask] "r"(rmask),
[vmask] "r"(vmask), [vzero] "w"(vzero)
[rmask] "r"(rmask), : "cc",
[vzero] "w"(vzero) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22",
"v21", "v23",
"v22", "v24",
"v23", "v25");
"v24",
"v25");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
...@@ -1512,70 +1460,36 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout, ...@@ -1512,70 +1460,36 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout,
int cnt = cnt_col; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
if (flag_relu) { asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
asm volatile( MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 : [dout_ptr1] "+r"(doutr0),
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU [dout_ptr2] "+r"(doutr1),
: [dout_ptr1] "+r"(doutr0), [din0_ptr] "+r"(din_ptr0),
[dout_ptr2] "+r"(doutr1), [din1_ptr] "+r"(din_ptr1),
[din0_ptr] "+r"(din_ptr0), [din2_ptr] "+r"(din_ptr2),
[din1_ptr] "+r"(din_ptr1), [din3_ptr] "+r"(din_ptr3),
[din2_ptr] "+r"(din_ptr2), [cnt] "+r"(cnt),
[din3_ptr] "+r"(din_ptr3), [rmask] "+r"(rmask_ptr),
[cnt] "+r"(cnt), [vmask] "+r"(vmask_ptr)
[rmask] "+r"(rmask_ptr), : [wr0] "w"(wr0),
[vmask] "+r"(vmask_ptr) [wr1] "w"(wr1),
: [wr0] "w"(wr0), [wr2] "w"(wr2),
[wr1] "w"(wr1), [bias_val] "r"(bias_val),
[wr2] "w"(wr2), [vzero] "w"(vzero)
[bias_val] "r"(bias_val), : "cc",
[vzero] "w"(vzero) "memory",
: "cc", "q4",
"memory", "q5",
"q4", "q6",
"q5", "q7",
"q6", "q8",
"q7", "q9",
"q8", "q10",
"q9", "q11",
"q10", "q12",
"q11", "q13",
"q12", "q14",
"q13", "q15");
"q14",
"q15");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
...@@ -1583,221 +1497,7 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout, ...@@ -1583,221 +1497,7 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout,
} }
} }
/** void conv_depthwise_3x3s1p1_bias_relu(float *dout,
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
...@@ -1825,16 +1525,27 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1825,16 +1525,27 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
int tile_w = w_out >> 2; int tile_w = w_out >> 2;
int remain = w_out % 4; int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3}; const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 5) {
size_pad_right = 1;
cnt_col -= 1;
remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2;
cnt_col -= 1;
remain = 4;
}
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result = uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
...@@ -1881,10 +1592,9 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1881,10 +1592,9 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din_ptr3 = dr3; const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4; const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5; const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero); float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__ #ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) { for (int i = 0; i < h_in; i += 4) {
//! process top pad pad_h = 1 //! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
...@@ -1897,26 +1607,37 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1897,26 +1607,37 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
doutr1 = doutr0 + w_out; doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out; doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out; doutr3 = doutr2 + w_out;
if (i == 0) {
dr0 = dr4; din_ptr0 = zero_ptr;
dr1 = dr5; din_ptr1 = dr0;
dr2 = dr1 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in; dr3 = dr2 + w_in;
dr4 = dr3 + w_in; dr4 = dr3 + w_in;
dr5 = dr4 + w_in; dr5 = dr4 + w_in;
//! process bottom pad //! process bottom pad
if (i + 5 >= h_in) { if (i + 5 > h_in) {
switch (i + 5 - h_in) { switch (i + 5 - h_in) {
case 4: case 5:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
case 3: case 4:
din_ptr2 = zero_ptr; din_ptr2 = zero_ptr;
case 2: case 3:
din_ptr3 = zero_ptr; din_ptr3 = zero_ptr;
case 1: case 2:
din_ptr4 = zero_ptr; din_ptr4 = zero_ptr;
case 0: case 1:
din_ptr5 = zero_ptr; din_ptr5 = zero_ptr;
default: default:
break; break;
...@@ -1936,132 +1657,61 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1936,132 +1657,61 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
} }
} }
int cnt = tile_w; int cnt = cnt_col;
if (flag_relu) { asm volatile(
asm volatile( INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
INIT_S1 MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ : [cnt] "+r"(cnt),
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ [din_ptr0] "+r"(din_ptr0),
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ [din_ptr1] "+r"(din_ptr1),
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ [din_ptr2] "+r"(din_ptr2),
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ [din_ptr3] "+r"(din_ptr3),
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ [din_ptr4] "+r"(din_ptr4),
MID_COMPUTE_S1 MID_RESULT_S1_RELU [din_ptr5] "+r"(din_ptr5),
"cmp %w[remain], #1 \n" [doutr0] "+r"(doutr0),
"blt 0f \n" RIGHT_COMPUTE_S1 [doutr1] "+r"(doutr1),
RIGHT_RESULT_S1_RELU "0: \n" [doutr2] "+r"(doutr2),
: [cnt] "+r"(cnt), [doutr3] "+r"(doutr3)
[din_ptr0] "+r"(din_ptr0), : [w0] "w"(wr0),
[din_ptr1] "+r"(din_ptr1), [w1] "w"(wr1),
[din_ptr2] "+r"(din_ptr2), [w2] "w"(wr2),
[din_ptr3] "+r"(din_ptr3), [bias_val] "r"(vbias),
[din_ptr4] "+r"(din_ptr4), [vmask] "r"(vmask),
[din_ptr5] "+r"(din_ptr5), [rmask] "r"(rmask),
[doutr0] "+r"(doutr0), [vzero] "w"(vzero)
[doutr1] "+r"(doutr1), : "cc",
[doutr2] "+r"(doutr2), "memory",
[doutr3] "+r"(doutr3) "v0",
: [w0] "w"(wr0), "v1",
[w1] "w"(wr1), "v2",
[w2] "w"(wr2), "v3",
[bias_val] "r"(vbias), "v4",
[vmask] "r"(vmask), "v5",
[rmask] "r"(rmask), "v6",
[vzero] "w"(vzero), "v7",
[remain] "r"(remain) "v8",
: "cc", "v9",
"memory", "v10",
"v0", "v11",
"v1", "v12",
"v2", "v13",
"v3", "v14",
"v4", "v15",
"v5", "v16",
"v6", "v17",
"v7", "v18",
"v8", "v19",
"v9", "v20",
"v10", "v21",
"v11", "v22",
"v12", "v23",
"v13", "v24",
"v14", "v25");
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
for (int i = 0; i < h_out; i += 2) { for (int i = 0; i < h_in; i += 2) {
//! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
din_ptr2 = dr2; din_ptr2 = dr2;
...@@ -2069,13 +1719,25 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -2069,13 +1719,25 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
doutr0 = dout_ptr; doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out; doutr1 = dout_ptr + w_out;
// unsigned int* rst_mask = rmask;
dr0 = dr2; if (i == 0) {
dr1 = dr3; din_ptr0 = zero_ptr;
dr2 = dr1 + w_in; din_ptr1 = dr0;
dr3 = dr2 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad //! process bottom pad
if (i + 3 >= h_in) { if (i + 3 > h_in) {
switch (i + 3 - h_in) { switch (i + 3 - h_in) {
case 3: case 3:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
...@@ -2083,8 +1745,6 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -2083,8 +1745,6 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
din_ptr2 = zero_ptr; din_ptr2 = zero_ptr;
case 1: case 1:
din_ptr3 = zero_ptr; din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default: default:
break; break;
} }
...@@ -2093,131 +1753,73 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -2093,131 +1753,73 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
if (i + 2 > h_out) { if (i + 2 > h_out) {
doutr1 = write_ptr; doutr1 = write_ptr;
} }
int cnt = tile_w; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
if (flag_relu) { asm volatile(
asm volatile(INIT_S1 INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" : [dout_ptr1] "+r"(doutr0),
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" [dout_ptr2] "+r"(doutr1),
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" [din0_ptr] "+r"(din_ptr0),
"vext.32 q6, q8, q9, #1 @ 0012\n" [din1_ptr] "+r"(din_ptr1),
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 [din2_ptr] "+r"(din_ptr2),
MID_RESULT_S1_RELU [din3_ptr] "+r"(din_ptr3),
"cmp %[remain], #1 \n" [cnt] "+r"(cnt),
"blt 0f \n" RIGHT_COMPUTE_S1 [rmask] "+r"(rmask_ptr),
RIGHT_RESULT_S1_RELU "0: \n" [vmask] "+r"(vmask_ptr)
: [dout_ptr1] "+r"(doutr0), : [wr0] "w"(wr0),
[dout_ptr2] "+r"(doutr1), [wr1] "w"(wr1),
[din0_ptr] "+r"(din_ptr0), [wr2] "w"(wr2),
[din1_ptr] "+r"(din_ptr1), [bias_val] "r"(bias_val),
[din2_ptr] "+r"(din_ptr2), [vzero] "w"(vzero)
[din3_ptr] "+r"(din_ptr3), : "cc",
[cnt] "+r"(cnt), "memory",
[rmask] "+r"(rmask_ptr), "q4",
[vmask] "+r"(vmask_ptr) "q5",
: [wr0] "w"(wr0), "q6",
[wr1] "w"(wr1), "q7",
[wr2] "w"(wr2), "q8",
[bias_val] "r"(bias_val), "q9",
[vzero] "w"(vzero), "q10",
[remain] "r"(remain) "q11",
: "cc", "q12",
"memory", "q13",
"q4", "q14",
"q5", "q15");
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
} }
} }
} }
/** /**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4 * width <= 4
*/ */
void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, void conv_depthwise_3x3s1p1_bias_s_no_relu(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
ARMContext *ctx) { ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm //! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit //! pad is done implicit
//! for 4x6 convolution window //! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in; int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out; int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
...@@ -2231,38 +1833,907 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, ...@@ -2231,38 +1833,907 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
float32x4_t wr0 = vld1q_f32(weight_ptr); float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3); float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6); float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias; float32x4_t wbias;
if (flag_bias) { if (flag_bias) {
wbias = vdupq_n_f32(bias[i]); wbias = vdupq_n_f32(bias[i]);
} else { } else {
wbias = vdupq_n_f32(0.f); wbias = vdupq_n_f32(0.f);
} }
#endif // __aarch64__
int hs = -1;
int he = 3;
float out_buf1[4]; float out_buf1[4];
float out_buf2[4]; float out_buf2[4];
float trash_buf[4]; float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel; float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out; float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) { for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + j * w_in; const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in; const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in; const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in; const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out; if (hs == -1) {
doutr1 = doutr0 + w_out; dr0 = zero;
}
if (j + 3 >= h_in) { switch (he - h_in) {
switch (j + 3 - h_in) { case 2:
case 3: dr2 = zero;
dr1 = zero_ptr; doutr1 = trash_buf;
case 2: case 1:
dr2 = zero_ptr; dr3 = zero;
default:
break;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p1_bias_s_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias_no_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s_no_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#endif // __aarch64__
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1: case 1:
dr3 = zero_ptr; dr3 = zero_ptr;
doutr1 = trash_buf; doutr1 = trash_buf;
...@@ -2276,133 +2747,227 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, ...@@ -2276,133 +2747,227 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
} }
} }
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU : [din0] "+r"(dr0),
: [din0] "+r"(dr0), [din1] "+r"(dr1),
[din1] "+r"(dr1), [din2] "+r"(dr2),
[din2] "+r"(dr2), [din3] "+r"(dr3)
[din3] "+r"(dr3) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [vbias] "w"(wbias),
[vbias] "w"(wbias), [mask1] "w"(vmask_rp1),
[mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2),
[mask2] "w"(vmask_rp2), [zero] "w"(vzero),
[zero] "w"(vzero), [out1] "r"(out_buf1),
[out1] "r"(out_buf1), [out2] "r"(out_buf2)
[out2] "r"(out_buf2) : "cc",
: "cc", "memory",
"memory", "v0",
"v0", "v1",
"v1", "v2",
"v2", "v3",
"v3", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15");
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else #else
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f; float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) { asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU : [din0] "+r"(dr0),
: [din0] "+r"(dr0), [din1] "+r"(dr1),
[din1] "+r"(dr1), [din2] "+r"(dr2),
[din2] "+r"(dr2), [din3] "+r"(dr3),
[din3] "+r"(dr3), [vmask] "+r"(vmask_ptr)
[vmask] "+r"(vmask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [vzero] "w"(vzero),
[vzero] "w"(vzero), [bias_val] "r"(bias_val),
[bias_val] "r"(bias_val), [out1] "r"(out_buf1),
[out1] "r"(out_buf1), [out2] "r"(out_buf2)
[out2] "r"(out_buf2) : "cc",
: "cc", "memory",
"memory", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); #endif
} else { for (int w = 0; w < w_out; ++w) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 *doutr0++ = out_buf1[w];
: [din0] "+r"(dr0), *doutr1++ = out_buf2[w];
[din1] "+r"(dr1), }
[din2] "+r"(dr2), } // end of processing heights
[din3] "+r"(dr3), } // end of processing channels
[vmask] "+r"(vmask_ptr) } // end of processing batchs
: [wr0] "w"(wr0), }
[wr1] "w"(wr1),
[wr2] "w"(wr2), void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
[vzero] "w"(vzero), const float *din,
[bias_val] "r"(bias_val), const float *weights,
[out1] "r"(out_buf1), const float *bias,
[out2] "r"(out_buf2) bool flag_bias,
: "cc", bool flag_relu,
"memory", const int num,
"q4", const int ch_in,
"q5", const int h_in,
"q6", const int w_in,
"q7", const int h_out,
"q8", const int w_out,
"q9", ARMContext *ctx) {
"q10", //! 3x3s1 convolution, implemented by direct algorithm
"q11", //! pad is done implicit
"q12", //! for 4x6 convolution window
"q13", const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
"q14", const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
"q15");
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#endif // __aarch64__
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
default:
break;
}
} }
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif #endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w]; *doutr0++ = out_buf1[w];
......
...@@ -20,61 +20,117 @@ namespace paddle { ...@@ -20,61 +20,117 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_depthwise_3x3s2p0_bias(float* dout, void conv_depthwise_3x3s2p0_bias_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_leakyRelu(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* scale,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias(float* dout, void conv_depthwise_3x3s2p0_bias_s_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_s_leakyRelu(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* scale,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_relu6(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_relu6(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din, void conv_depthwise_3x3s2_fp32(const float* din,
float* dout, float* dout,
...@@ -92,142 +148,275 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -92,142 +148,275 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
bool has_active = act_param.has_active; bool has_active = act_param.has_active;
bool flag_relu = false; auto act_type = act_param.active_type;
bool relu6 = false; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
if (has_active) { if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) { switch (act_type) {
flag_relu = true; case lite_api::ActivationType::kRelu:
} else { if (pad == 0) {
relu6 = true; if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kRelu6:
if (pad == 0) {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kLeakyRelu:
if (pad == 0) {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_type)
<< " fuse not support";
} }
} } else {
if (pad == 0) { if (pad == 0) {
if (w_in > 8) { if (w_in > 8) {
if (relu6) { conv_depthwise_3x3s2p0_bias_no_relu(dout,
conv_depthwise_3x3s2p0_bias(dout, din,
din, weights,
weights, bias,
bias, flag_bias,
flag_bias, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param, ctx);
ctx);
} else {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s2p0_bias_s_relu(dout, conv_depthwise_3x3s2p0_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
} }
} }
} if (pad == 1) {
if (pad == 1) { if (w_in > 7) {
if (w_in > 7) { conv_depthwise_3x3s2p1_bias_no_relu(dout,
if (relu6) { din,
conv_depthwise_3x3s2p1_bias(dout, weights,
din, bias,
weights, flag_bias,
bias, false,
flag_bias, num,
num, ch_in,
ch_in, h_in,
h_in, w_in,
w_in, h_out,
h_out, w_out,
w_out, ctx);
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s2p1_bias_relu(dout, conv_depthwise_3x3s2p1_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} }
} }
} }
} }
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
#define INIT_S2 \ #define INIT_S2 \
"prfm pldl1keep, [%[inptr0]] \n" \ "prfm pldl1keep, [%[inptr0]] \n" \
...@@ -746,6 +935,18 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -746,6 +935,18 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"fmax v4.4s, v4.4s, v9.4s \n" \ "fmax v4.4s, v4.4s, v9.4s \n" \
\ \
"st1 {v4.4s}, [%[out]] \n" "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_RELU6 \
"fadd v4.4s, v4.4s, %[bias].4s \n" \
"fmax v4.4s, v4.4s, v9.4s \n" \
"fmin v4.4s, v4.4s, %[vsix].4s \n" \
\
"st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_LEAKY_RELU \
"fadd v4.4s, v4.4s, %[bias].4s \n" \
"fcmge v11.4s, v4.4s, %[vzero].4s \n"/* vcgeq_u32 */\
"fmul v12.4s, v4.4s, %[vscale].4s \n"\
"bif v4.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v4.4s}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \ #define COMPUTE_S_S2_P0 \
"movi v9.4s, #0 \n" \ "movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
...@@ -785,6 +986,15 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -785,6 +986,15 @@ void conv_depthwise_3x3s2_fp32(const float* din,
#define RESULT_S_S2_P0_RELU \ #define RESULT_S_S2_P0_RELU \
"fmax v4.4s, v4.4s, v9.4s \n" \ "fmax v4.4s, v4.4s, v9.4s \n" \
"st1 {v4.4s}, [%[out]] \n" "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU6 \
"fmax v4.4s, v4.4s, v9.4s \n" \
"fmin v4.4s, v4.4s, %[vsix].4s \n" \
"st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_LEAKY_RELU \
"fcmge v11.4s, v4.4s, %[vzero].4s \n"/* vcgeq_u32 */\
"fmul v12.4s, v4.4s, %[vscale].4s \n"\
"bif v4.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v4.4s}, [%[out]] \n"
#else #else
#define INIT_S2 \ #define INIT_S2 \
...@@ -822,14 +1032,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -822,14 +1032,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \ "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \ "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \
\ \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define LEFT_RESULT_S2 \ #define LEFT_RESULT_S2 \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vadd.f32 q3, q3, q4 @ add \n"\
"cmp %[cnt], #1 \n" \ "vadd.f32 q3, q3, q5 @ add \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"cmp %[cnt], #1 \n" \
"blt 1f \n" "blt 1f \n"
#define MID_COMPUTE_S2 \ #define MID_COMPUTE_S2 \
...@@ -860,12 +1069,11 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -860,12 +1069,11 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \
\ \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define MID_RESULT_S2 \ #define MID_RESULT_S2 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"subs %[cnt], #1 \n" \ "subs %[cnt], #1 \n" \
\ \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n" \
...@@ -910,36 +1118,104 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -910,36 +1118,104 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RIGHT_RESULT_S2 \ #define RIGHT_RESULT_S2 \
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vbif.f32 q3, q10, q11 @ write mask\n" \ "vbif.f32 q3, q10, q11 @ write mask\n" \
\ \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n" \
"3: \n" "3: \n"
#define LEFT_RESULT_S2_RELU \ #define LEFT_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \ "vadd.f32 q3, q3, q4 @ add \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vadd.f32 q3, q3, q5 @ add \n"\
"cmp %[cnt], #1 \n" \ "vmax.f32 q3, q3, q9 \n"\
"blt 1f \n" "cmp %[cnt], #1 \n"\
#define MID_RESULT_S2_RELU \ "vst1.32 {d6-d7}, [%[outptr]]! \n"\
"vmax.f32 q3, q3, q9 @ relu \n" \ "blt 1f \n"
"subs %[cnt], #1 \n" \ #define LEFT_RESULT_S2_RELU6 \
\ "vadd.f32 q3, q3, q4 @ add \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"bne 2b \n" "vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu \n"\
#define RIGHT_RESULT_S2_RELU \ "cmp %[cnt], #1 \n"\
"vmax.f32 q3, q3, q9 @ relu \n" \ "vmin.f32 q3, q3, q6 @ relu \n"\
"vbif.f32 q3, q10, q11 @ write mask\n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n"\
\ "blt 1f \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ #define LEFT_RESULT_S2_LEAKY_RELU \
"3: \n" "vadd.f32 q3, q3, q4 \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"cmp %[cnt], #1 \n"\
"vbif q3, q8, q7 @ choose \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"blt 1f \n"
#define MID_RESULT_S2_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"subs %[cnt], #1 \n"\
"vmax.f32 q3, q3, q9 @ relu \n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"bne 2b \n"
#define MID_RESULT_S2_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu \n"\
"subs %[cnt], #1 \n"\
"vmin.f32 q3, q3, q6 @ relu \n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"bne 2b \n"
#define MID_RESULT_S2_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"subs %[cnt], #1 \n"\
"vbif q3, q8, q7 @ choose \n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
"vbif.f32 q3, q10, q11 @ write mask\n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"3: \n"
#define RIGHT_RESULT_S2_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
"vmin.f32 q3, q3, q6 @ relu \n"\
\
"vbif.f32 q3, q10, q11 @ write mask\n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"3: \n"
#define RIGHT_RESULT_S2_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"vbif q3, q8, q7 @ choose \n"\
"vbif.f32 q3, q10, q11 @ write mask\n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"3: \n"
#define COMPUTE_S_S2 \ #define COMPUTE_S_S2 \
"vmov.u32 q9, #0 \n" \ "vmov.u32 q9, #0 \n" \
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \
...@@ -976,17 +1252,36 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -976,17 +1252,36 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" #define RESULT_S_S2 \
#define RESULT_S_S2_RELU \ "vadd.f32 q3, q3, q4 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n" \ "vadd.f32 q3, q3, q5 @ add \n"\
\ "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
\
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
"vmin.f32 q3, q3, q6 @ relu\n"\
\
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"vbif q3, q8, q7 @ choose \n"\
\
"vst1.32 {d6-d7}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \ #define COMPUTE_S_S2_P0 \
"vmov.u32 q9, #0 \n" \ "vmov.u32 q9, #0 \n" \
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \ "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \
...@@ -1023,207 +1318,309 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -1023,207 +1318,309 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" #define RESULT_S_S2_P0 \
#define RESULT_S_S2_P0_RELU \ "vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vmax.f32 q3, q3, q9 @ relu \n" \ "vmax.f32 q3, q3, q9 @ relu \n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n" \
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vmax.f32 q3, q3, q9 @ relu\n" \
"vmin.f32 q3, q3, q6 @ relu\n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n" \
"vld1.f32 {d12-d13}, [%[scale_ptr]] @ load six \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vcge.f32 q7, q3, q9 \n" \
"vmul.f32 q8, q3, q6 \n" \
"vbif q3, q8, q7 @ choose \n" \
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#endif #endif
#ifdef __aarch64__ // clang-format on
void act_switch_3x3s2p1(const float* din0_ptr,
const float* din1_ptr,
const float* din2_ptr,
const float* din3_ptr,
const float* din4_ptr,
float* doutr0_ptr,
float* doutr1_ptr,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
uint32x4_t wmask,
float32x4_t wbias,
float32x4_t vzero,
int cnt,
int cnt_remain,
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#endif
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
* w_in > 7 * w_in > 7
*/ */
void conv_depthwise_3x3s2p1_bias(float* dout, void conv_depthwise_3x3s2p1_bias_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
int cnt_col = tile_w - 1;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 4 > h_in) {
switch (i * 2 + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(six),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i * 2 + 2 > h_in) {
switch (i * 2 + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(six),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
void conv_depthwise_3x3s2p1_bias_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in; int size_pad_bottom = h_out * 2 - h_in;
...@@ -1350,24 +1747,52 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1350,24 +1747,52 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = cnt_col; int cnt = cnt_col;
act_switch_3x3s2p1(din0_ptr, asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
din1_ptr, MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
din2_ptr, RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
din3_ptr, : [inptr0] "+r"(din0_ptr),
din4_ptr, [inptr1] "+r"(din1_ptr),
doutr0_ptr, [inptr2] "+r"(din2_ptr),
doutr1_ptr, [inptr3] "+r"(din3_ptr),
wr0, [inptr4] "+r"(din4_ptr),
wr1, [outptr0] "+r"(doutr0_ptr),
wr2, [outptr1] "+r"(doutr1_ptr),
vmask_rp1, [cnt] "+r"(cnt)
vmask_rp2, : [vzero] "w"(vzero),
wmask, [w0] "w"(wr0),
wbias, [w1] "w"(wr1),
vzero, [w2] "w"(wr2),
cnt, [remain] "r"(cnt_remain),
cnt_remain, [scale_ptr] "r"(scale),
act_param); [mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1404,8 +1829,9 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1404,8 +1829,9 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1416,6 +1842,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1416,6 +1842,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
[wr0] "w"(wr0), [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[scale_ptr] "r"(scale),
[bias] "r"(bias_c) [bias] "r"(bias_c)
: "cc", : "cc",
"memory", "memory",
...@@ -1432,10 +1859,6 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1432,10 +1859,6 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
"q13", "q13",
"q14", "q14",
"q15"); "q15");
// do act
if (act_param.has_active) {
act_switch_process(doutr0, doutr0, w_out, &act_param);
}
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
...@@ -1446,19 +1869,19 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1446,19 +1869,19 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
/** /**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4 * \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/ */
void conv_depthwise_3x3s2p1_bias_s(float* dout, void conv_depthwise_3x3s2p1_bias_s_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f}; float zeros[8] = {0.0f};
...@@ -1474,7 +1897,9 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1474,7 +1897,9 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
unsigned int dmask[8]; unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1); vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2); vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel; const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel; float* dout_batch = dout + n * ch_in * size_out_channel;
...@@ -1513,7 +1938,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1513,7 +1938,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S2 RESULT_S_S2 asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1522,6 +1947,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1522,6 +1947,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "w"(vbias), [bias] "w"(vbias),
[vsix] "w"(vsix),
[out] "r"(out_buf) [out] "r"(out_buf)
: "v4", : "v4",
"v5", "v5",
...@@ -1536,7 +1962,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1536,7 +1962,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
"v14", "v14",
"v15"); "v15");
#else #else
asm volatile(COMPUTE_S_S2 RESULT_S_S2 asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1545,6 +1971,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1545,6 +1971,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "r"(bias_c), [bias] "r"(bias_c),
[six_ptr] "r"(six),
[out] "r"(out_buf) [out] "r"(out_buf)
: "cc", : "cc",
"memory", "memory",
...@@ -1562,10 +1989,6 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1562,10 +1989,6 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
// do act
if (act_param.has_active) {
act_switch_process(out_buf, out_buf, w_out, &act_param);
}
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
} }
...@@ -1575,231 +1998,154 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1575,231 +1998,154 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
} }
} }
} }
void conv_depthwise_3x3s2p1_bias_s_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__ #ifdef __aarch64__
void act_switch_3x3s2p0(const float* din0_ptr, float32x4_t vscale = vld1q_f32(scale);
const float* din1_ptr, float32x4_t vzero = vdupq_n_f32(0.f);
const float* din2_ptr, #endif
const float* din3_ptr, for (int n = 0; n < num; ++n) {
const float* din4_ptr, const float* din_batch = din + n * ch_in * size_in_channel;
float* doutr0_ptr, float* dout_batch = dout + n * ch_in * size_out_channel;
float* doutr1_ptr, #pragma omp parallel for
float32x4_t wr0, for (int i = 0; i < ch_in; ++i) {
float32x4_t wr1, const float* din_channel = din_batch + i * size_in_channel;
float32x4_t wr2, float* dout_channel = dout_batch + i * size_out_channel;
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2, const float* weight_ptr = weights + i * 9;
uint32x4_t wmask, float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wbias, float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t vzero, float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
int cnt,
int cnt_remain, float bias_c = 0.f;
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef; if (flag_bias) {
float ss = act_param.Leaky_relu_alpha; bias_c = bias[i];
float vsix[4] = {tmp, tmp, tmp, tmp}; }
float vscale[4] = {ss, ss, ss, ss}; float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
switch (act_param.active_type) { unsigned int* mask_ptr = dmask;
case lite_api::ActivationType::kRelu: #ifdef __aarch64__
asm volatile( asm volatile(COMPUTE_S_S2 RESULT_S_S2_LEAKY_RELU
INIT_S2 : [din0_ptr] "+r"(din0_ptr),
"ld1 {v15.4s}, [%[inptr0]] \n" [din1_ptr] "+r"(din1_ptr),
"ld1 {v18.4s}, [%[inptr1]] \n" [din2_ptr] "+r"(din2_ptr),
"ld1 {v19.4s}, [%[inptr2]] \n" [mask_ptr] "+r"(mask_ptr)
"ld1 {v20.4s}, [%[inptr3]] \n" : [wr0] "w"(wr0),
"ld1 {v21.4s}, [%[inptr4]] \n" [wr1] "w"(wr1),
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} [wr2] "w"(wr2),
MID_COMPUTE_S2 MID_RESULT_S2_RELU [bias] "w"(vbias),
"cmp %w[remain], #1 \n" [vzero] "w"(vzero),
"blt 4f \n" RIGHT_COMPUTE_S2 [vscale] "w"(vscale),
RIGHT_RESULT_S2_RELU [out] "r"(out_buf)
"4: \n" : "v4",
: [inptr0] "+r"(din0_ptr), "v5",
[inptr1] "+r"(din1_ptr), "v6",
[inptr2] "+r"(din2_ptr), "v7",
[inptr3] "+r"(din3_ptr), "v8",
[inptr4] "+r"(din4_ptr), "v9",
[outptr0] "+r"(doutr0_ptr), "v10",
[outptr1] "+r"(doutr1_ptr), "v11",
[cnt] "+r"(cnt) "v12",
: [vzero] "w"(vzero), "v13",
[w0] "w"(wr0), "v14",
[w1] "w"(wr1), "v15");
[w2] "w"(wr2), #else
[remain] "r"(cnt_remain), asm volatile(COMPUTE_S_S2 RESULT_S_S2_LEAKY_RELU
[mask1] "w"(vmask_rp1), : [din0_ptr] "+r"(din0_ptr),
[mask2] "w"(vmask_rp2), [din1_ptr] "+r"(din1_ptr),
[wmask] "w"(wmask), [din2_ptr] "+r"(din2_ptr),
[vbias] "w"(wbias) [mask_ptr] "+r"(mask_ptr)
: "cc", : [wr0] "w"(wr0),
"memory", [wr1] "w"(wr1),
"v0", [wr2] "w"(wr2),
"v1", [bias] "r"(bias_c),
"v2", [scale_ptr] "r"(scale),
"v3", [out] "r"(out_buf)
"v4", : "cc",
"v5", "memory",
"v6", "q3",
"v7", "q4",
"v8", "q5",
"v9", "q6",
"v10", "q7",
"v11", "q8",
"v12", "q9",
"v13", "q10",
"v14", "q11",
"v15", "q12",
"v16", "q13",
"v17", "q14",
"v18", "q15");
"v19", #endif
"v20", for (int w = 0; w < w_out; ++w) {
"v21"); *dout_channel++ = out_buf[w];
break; }
case lite_api::ActivationType::kRelu6: hs += 2;
/* 0 <= din <= 6 */ he += 2;
asm volatile( }
INIT_S2 }
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_RELU6
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU6
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
} }
} }
#endif
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
*/ */
// w_in > 7 // w_in > 7
void conv_depthwise_3x3s2p0_bias(float* dout, void conv_depthwise_3x3s2p0_bias_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
...@@ -1918,24 +2264,63 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1918,24 +2264,63 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = tile_w; int cnt = tile_w;
act_switch_3x3s2p0(din0_ptr, asm volatile(
din1_ptr, INIT_S2
din2_ptr, "ld1 {v15.4s}, [%[inptr0]] \n"
din3_ptr, "ld1 {v18.4s}, [%[inptr1]] \n"
din4_ptr, "ld1 {v19.4s}, [%[inptr2]] \n"
doutr0_ptr, "ld1 {v20.4s}, [%[inptr3]] \n"
doutr1_ptr, "ld1 {v21.4s}, [%[inptr4]] \n"
wr0, "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
wr1, "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
wr2, MID_RESULT_S2_RELU6
vmask_rp1, "cmp %w[remain], #1 \n"
vmask_rp2, "blt 4f \n" RIGHT_COMPUTE_S2
wmask, RIGHT_RESULT_S2_RELU6
wbias, "4: \n"
vzero, : [inptr0] "+r"(din0_ptr),
cnt, [inptr1] "+r"(din1_ptr),
cnt_remain, [inptr2] "+r"(din2_ptr),
act_param); [inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(six),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1963,8 +2348,8 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1963,8 +2348,8 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
} }
int cnt = tile_w; int cnt = tile_w;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 RIGHT_RESULT_S2_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1972,6 +2357,7 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1972,6 +2357,7 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
[cnt] "+r"(cnt), [cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr) [mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain), : [remain] "r"(cnt_remain),
[six_ptr] "r"(six),
[wr0] "w"(wr0), [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
...@@ -1991,9 +2377,257 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1991,9 +2377,257 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
"q13", "q13",
"q14", "q14",
"q15"); "q15");
if (act_param.has_active) { doutr0 = doutr0 + w_out;
act_switch_process(doutr0, doutr0, w_out, &act_param); }
#endif
}
}
}
void conv_depthwise_3x3s2p0_bias_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 5 > h_in) {
switch (i * 2 + 5 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
} }
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(scale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i * 2 + 3 > h_in) {
switch (i * 2 + 3 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[scale_ptr] "r"(scale),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
...@@ -2004,19 +2638,19 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -2004,19 +2638,19 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
/** /**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4 * \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/ */
void conv_depthwise_3x3s2p0_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_s_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f}; float zeros[8] = {0.0f};
...@@ -2033,6 +2667,10 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2033,6 +2667,10 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
unsigned int dmask[8]; unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1); vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2); vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel; const float* din_batch = din + n * ch_in * size_in_channel;
...@@ -2077,7 +2715,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2077,7 +2715,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -2086,6 +2724,8 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2086,6 +2724,8 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "w"(vbias), [bias] "w"(vbias),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out] "r"(out_buf) [out] "r"(out_buf)
: "cc", : "cc",
"memory", "memory",
...@@ -2104,7 +2744,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2104,7 +2744,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
"v16"); "v16");
#else #else
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr) [din2_ptr] "+r"(din2_ptr)
...@@ -2113,6 +2753,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2113,6 +2753,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "r"(bias_c), [bias] "r"(bias_c),
[out] "r"(out_buf), [out] "r"(out_buf),
[six_ptr] "r"(six),
[mask_ptr] "r"(dmask) [mask_ptr] "r"(dmask)
: "cc", : "cc",
"memory", "memory",
...@@ -2130,9 +2771,145 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2130,9 +2771,145 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
if (act_param.has_active) { for (int w = 0; w < w_out; ++w) {
act_switch_process(out_buf, out_buf, w_out, &act_param); *dout_channel++ = out_buf[w];
}
}
}
}
}
void conv_depthwise_3x3s2p0_bias_s_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; j++) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
if (j * 2 + 2 >= h_in) {
switch (j + 2 - h_in) {
case 1:
din1_ptr = zero_ptr;
case 0:
din2_ptr = zero_ptr;
default:
break;
}
} }
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[scale_ptr] "r"(scale),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
} }
......
...@@ -20,6 +20,7 @@ namespace lite { ...@@ -20,6 +20,7 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
#define INIT_S2 \ #define INIT_S2 \
"prfm pldl1keep, [%[inptr0]] \n" \ "prfm pldl1keep, [%[inptr0]] \n" \
...@@ -683,6 +684,7 @@ namespace math { ...@@ -683,6 +684,7 @@ namespace math {
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#endif #endif
// clang-format on
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
...@@ -825,96 +827,50 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout, ...@@ -825,96 +827,50 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = cnt_col; int cnt = cnt_col;
if (flag_relu) { asm volatile(
asm volatile( INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [inptr0] "+r"(din0_ptr),
: [inptr0] "+r"(din0_ptr), [inptr1] "+r"(din1_ptr),
[inptr1] "+r"(din1_ptr), [inptr2] "+r"(din2_ptr),
[inptr2] "+r"(din2_ptr), [inptr3] "+r"(din3_ptr),
[inptr3] "+r"(din3_ptr), [inptr4] "+r"(din4_ptr),
[inptr4] "+r"(din4_ptr), [outptr0] "+r"(doutr0_ptr),
[outptr0] "+r"(doutr0_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr1] "+r"(doutr1_ptr), [cnt] "+r"(cnt)
[cnt] "+r"(cnt) : [vzero] "w"(vzero),
: [vzero] "w"(vzero), [w0] "w"(wr0),
[w0] "w"(wr0), [w1] "w"(wr1),
[w1] "w"(wr1), [w2] "w"(wr2),
[w2] "w"(wr2), [remain] "r"(cnt_remain),
[remain] "r"(cnt_remain), [mask1] "w"(vmask_rp1),
[mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2),
[mask2] "w"(vmask_rp2), [wmask] "w"(wmask),
[wmask] "w"(wmask), [vbias] "w"(wbias)
[vbias] "w"(wbias) : "cc",
: "cc", "memory",
"memory", "v0",
"v0", "v1",
"v1", "v2",
"v2", "v3",
"v3", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16",
"v16", "v17",
"v17", "v18",
"v18", "v19",
"v19", "v20",
"v20", "v21");
"v21");
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -951,66 +907,286 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout, ...@@ -951,66 +907,286 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
if (flag_relu) { asm volatile(
asm volatile( INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [outptr] "+r"(doutr0_ptr),
[outptr] "+r"(doutr0_ptr), [cnt] "+r"(cnt),
[cnt] "+r"(cnt), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [remain] "r"(cnt_remain),
: [remain] "r"(cnt_remain), [wr0] "w"(wr0),
[wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c)
[bias] "r"(bias_c) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); doutr0 = doutr0 + w_out;
}
#endif
}
}
}
void conv_depthwise_3x3s2p1_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); //
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else { } else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 dr0 = dr4;
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 dr1 = dr0 + w_in;
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
...@@ -1088,107 +1264,179 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, ...@@ -1088,107 +1264,179 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "w"(vbias),
[bias] "w"(vbias), [out] "r"(out_buf)
[out] "r"(out_buf) : "v4",
: "v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15");
"v15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else #else
if (flag_relu) { asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c),
[bias] "r"(bias_c), [out] "r"(out_buf)
[out] "r"(out_buf) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); #endif
} else { for (int w = 0; w < w_out; ++w) {
asm volatile(COMPUTE_S_S2 RESULT_S_S2 *dout_channel++ = out_buf[w];
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
hs += 2;
he += 2;
}
}
}
}
void conv_depthwise_3x3s2p1_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif #endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
...@@ -1334,117 +1582,60 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout, ...@@ -1334,117 +1582,60 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = tile_w; int cnt = tile_w;
if (flag_relu) { asm volatile(
asm volatile( INIT_S2
INIT_S2 "ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v15.4s}, [%[inptr0]] \n" "ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n" "ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} MID_COMPUTE_S2 MID_RESULT_S2_RELU
MID_COMPUTE_S2 MID_RESULT_S2_RELU "cmp %w[remain], #1 \n"
"cmp %w[remain], #1 \n" "blt 4f \n" RIGHT_COMPUTE_S2
"blt 4f \n" RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
RIGHT_RESULT_S2_RELU "4: \n"
"4: \n" : [inptr0] "+r"(din0_ptr),
: [inptr0] "+r"(din0_ptr), [inptr1] "+r"(din1_ptr),
[inptr1] "+r"(din1_ptr), [inptr2] "+r"(din2_ptr),
[inptr2] "+r"(din2_ptr), [inptr3] "+r"(din3_ptr),
[inptr3] "+r"(din3_ptr), [inptr4] "+r"(din4_ptr),
[inptr4] "+r"(din4_ptr), [outptr0] "+r"(doutr0_ptr),
[outptr0] "+r"(doutr0_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr1] "+r"(doutr1_ptr), [cnt] "+r"(cnt)
[cnt] "+r"(cnt) : [vzero] "w"(vzero),
: [vzero] "w"(vzero), [w0] "w"(wr0),
[w0] "w"(wr0), [w1] "w"(wr1),
[w1] "w"(wr1), [w2] "w"(wr2),
[w2] "w"(wr2), [remain] "r"(cnt_remain),
[remain] "r"(cnt_remain), [mask1] "w"(vmask_rp1),
[mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2),
[mask2] "w"(vmask_rp2), [wmask] "w"(wmask),
[wmask] "w"(wmask), [vbias] "w"(wbias)
[vbias] "w"(wbias) : "cc",
: "cc", "memory",
"memory", "v0",
"v0", "v1",
"v1", "v2",
"v2", "v3",
"v3", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16",
"v16", "v17",
"v17", "v18",
"v18", "v19",
"v19", "v20",
"v20", "v21");
"v21");
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1472,72 +1663,284 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout, ...@@ -1472,72 +1663,284 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout,
} }
int cnt = tile_w; int cnt = tile_w;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
if (flag_relu) { asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_COMPUTE_S2
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_RESULT_S2_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [outptr] "+r"(doutr0_ptr),
[outptr] "+r"(doutr0_ptr), [cnt] "+r"(cnt),
[cnt] "+r"(cnt), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [remain] "r"(cnt_remain),
: [remain] "r"(cnt_remain), [wr0] "w"(wr0),
[wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c)
[bias] "r"(bias_c) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15");
} else {
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
} }
} }
} }
void conv_depthwise_3x3s2p0_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 5 > h_in) {
switch (i * 2 + 5 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 "4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i * 2 + 3 > h_in) {
switch (i * 2 + 3 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
/** /**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4 * \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/ */
...@@ -1614,113 +2017,189 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, ...@@ -1614,113 +2017,189 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "w"(vbias),
[bias] "w"(vbias), [out] "r"(out_buf)
[out] "r"(out_buf) : "cc",
: "cc", "memory",
"memory", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16");
"v16");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
}
#else #else
if (flag_relu) { asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr)
[din2_ptr] "+r"(din2_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c),
[bias] "r"(bias_c), [out] "r"(out_buf),
[out] "r"(out_buf), [mask_ptr] "r"(dmask)
[mask_ptr] "r"(dmask) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); #endif
} else { for (int w = 0; w < w_out; ++w) {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 *dout_channel++ = out_buf[w];
: [din0_ptr] "+r"(din0_ptr), }
[din1_ptr] "+r"(din1_ptr), }
[din2_ptr] "+r"(din2_ptr) }
: [wr0] "w"(wr0), }
[wr1] "w"(wr1), }
[wr2] "w"(wr2), void conv_depthwise_3x3s2p0_bias_s_no_relu(float* dout,
[bias] "r"(bias_c), const float* din,
[out] "r"(out_buf), const float* weights,
[mask_ptr] "r"(dmask) const float* bias,
: "cc", bool flag_bias,
"memory", bool flag_relu,
"q3", const int num,
"q4", const int ch_in,
"q5", const int h_in,
"q6", const int w_in,
"q7", const int h_out,
"q8", const int w_out,
"q9", ARMContext* ctx) {
"q10", int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
"q11", int out_pad_idx[4] = {0, 1, 2, 3};
"q12", float zeros[8] = {0.0f};
"q13", const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
"q14",
"q15"); uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; j++) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
if (j * 2 + 2 >= h_in) {
switch (j + 2 - h_in) {
case 1:
din1_ptr = zero_ptr;
case 0:
din2_ptr = zero_ptr;
default:
break;
}
} }
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif #endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
......
...@@ -106,6 +106,42 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -106,6 +106,42 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int padh, int padh,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1_int8_int8_impl(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
void conv_depthwise_3x3s1_int8_float_impl(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
void conv_depthwise_3x3s2_int8(Dtype* dout, void conv_depthwise_3x3s2_int8(Dtype* dout,
const int8_t* din, const int8_t* din,
...@@ -340,6 +376,118 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, ...@@ -340,6 +376,118 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const int w_out, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -841,24 +841,52 @@ void conv_depthwise_3x3_int8_fp32(const void* din, ...@@ -841,24 +841,52 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
alpha[3] = local_alpha; alpha[3] = local_alpha;
} }
} }
bool support_act_type = flag_act <= 1;
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (param.strides[0] == 1 && param.strides[1] == 1);
bool support_width_type = w_in > 9 ? true : false;
if (stride == 1) { if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout), if (!support_act_type || !support_pad_type || !support_stride_type ||
reinterpret_cast<const int8_t*>(din), !support_width_type) {
reinterpret_cast<const int8_t*>(weights), conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout),
scale, reinterpret_cast<const int8_t*>(din),
bias, reinterpret_cast<const int8_t*>(weights),
flag_bias, scale,
flag_act, bias,
alpha, flag_bias,
num, flag_act,
ch_in, alpha,
h_in, num,
w_in, ch_in,
h_out, h_in,
w_out, w_in,
pad_w, h_out,
pad_h, w_out,
ctx); pad_w,
pad_h,
ctx);
} else {
conv_depthwise_3x3s1_int8_float_impl(
reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_act,
alpha,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
}
} else if (stride == 2) { } else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout), conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
...@@ -924,24 +952,52 @@ void conv_depthwise_3x3_int8_int8(const void* din, ...@@ -924,24 +952,52 @@ void conv_depthwise_3x3_int8_int8(const void* din,
alpha[3] = local_alpha; alpha[3] = local_alpha;
} }
} }
bool support_act_type = flag_act <= 1;
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (param.strides[0] == 1 && param.strides[1] == 1);
bool support_width_type = w_in > 9 ? true : false;
if (stride == 1) { if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout), if (!support_act_type || !support_pad_type || !support_stride_type ||
reinterpret_cast<const int8_t*>(din), !support_width_type) {
reinterpret_cast<const int8_t*>(weights), conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout),
scale, reinterpret_cast<const int8_t*>(din),
bias, reinterpret_cast<const int8_t*>(weights),
flag_bias, scale,
flag_act, bias,
alpha, flag_bias,
num, flag_act,
ch_in, alpha,
h_in, num,
w_in, ch_in,
h_out, h_in,
w_out, w_in,
pad_w, h_out,
pad_h, w_out,
ctx); pad_w,
pad_h,
ctx);
} else {
conv_depthwise_3x3s1_int8_int8_impl(
reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_act,
alpha,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
}
} else if (stride == 2) { } else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout), conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
......
...@@ -300,13 +300,15 @@ void fill_bias_act<float>(float* tensor, ...@@ -300,13 +300,15 @@ void fill_bias_act<float>(float* tensor,
switch (act_param->active_type) { switch (act_param->active_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
*dst = *src >= 0.f ? *src : 0.f; float tmp = (*src + bias_data);
*dst = tmp >= 0.f ? tmp : 0.f;
src++; src++;
dst++; dst++;
} }
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f; float tmp = (*src + bias_data);
tmp = tmp >= 0.f ? tmp : 0.f;
*dst = tmp <= act_param->Relu_clipped_coef *dst = tmp <= act_param->Relu_clipped_coef
? tmp ? tmp
: act_param->Relu_clipped_coef; : act_param->Relu_clipped_coef;
...@@ -315,10 +317,11 @@ void fill_bias_act<float>(float* tensor, ...@@ -315,10 +317,11 @@ void fill_bias_act<float>(float* tensor,
} }
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) { for (int i = 0; i < remain; i++) {
if (*src >= 0.f) { float tmp = (*src + bias_data);
*dst = *src; if (tmp >= 0.f) {
*dst = tmp;
} else { } else {
*dst = *src * act_param->Leaky_relu_alpha; *dst = tmp * act_param->Leaky_relu_alpha;
} }
src++; src++;
dst++; dst++;
...@@ -336,17 +339,24 @@ void fill_bias_act<float>(float* tensor, ...@@ -336,17 +339,24 @@ void fill_bias_act<float>(float* tensor,
float32x4_t vbias = vdupq_n_f32(bias_data); float32x4_t vbias = vdupq_n_f32(bias_data);
float* src = data + j * channel_size; float* src = data + j * channel_size;
float* dst = data + j * channel_size; float* dst = data + j * channel_size;
if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(FILL_BIAS FILL_STORE asm volatile(FILL_BIAS FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) :
: [vbias] "w"(vbias) [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: "memory", "cc", "v0", "v1", "v2", "v3"); : [vbias] "w"(vbias)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else #else
asm volatile(FILL_BIAS FILL_STORE asm volatile(FILL_BIAS FILL_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) :
: [vbias] "w"(vbias) [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: "memory", "cc", "q3", "q4", "q5", "q6"); : [vbias] "w"(vbias)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif #endif
}
for (int i = 0; i < remain; i++) {
*dst = *src + bias_data;
}
} }
} }
} }
......
...@@ -2,4 +2,5 @@ if (NOT LITE_WITH_BM) ...@@ -2,4 +2,5 @@ if (NOT LITE_WITH_BM)
return() return()
endif() endif()
lite_cc_library(target_wrapper_bm SRCS target_wrapper.cc DEPS ${bm_runtime_libs}) add_library(target_wrapper_bm STATIC target_wrapper.cc)
target_link_libraries(target_wrapper_bm -Wl,-rpath,${BM_SDK_CPLIB_RPATH}:${BM_SDK_LIB_RPATH} -L${BM_SDK_CPLIB_RPATH} -L${BM_SDK_LIB_RPATH} -lbmcompiler -lbmcpu -lbmlib -lbmrt)
...@@ -23,12 +23,16 @@ lite_cc_library(mir_passes ...@@ -23,12 +23,16 @@ lite_cc_library(mir_passes
fusion/quant_dequant_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc fusion/scale_activation_fuse_pass.cc
fusion/reshape_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__resnet_cbam_fuse_pass.cc fusion/__xpu__resnet_cbam_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc fusion/__xpu__mmdnn_fuse_pass.cc
fusion/match_matrix_activation_fuse_pass.cc
fusion/scales_fuse_pass.cc
fusion/sequence_reverse_embedding_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
......
...@@ -37,6 +37,18 @@ lite_cc_library(fuse_sequence_pool_concat ...@@ -37,6 +37,18 @@ lite_cc_library(fuse_sequence_pool_concat
lite_cc_library(fuse_scale_activation lite_cc_library(fuse_scale_activation
SRCS scale_activation_fuser.cc SRCS scale_activation_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
lite_cc_library(fuse_reshape
SRCS reshape_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_match_matrix_activation
SRCS match_matrix_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_scales
SRCS scales_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_sequence_reverse_embedding
SRCS sequence_reverse_embedding_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
...@@ -52,6 +64,10 @@ set(mir_fusers ...@@ -52,6 +64,10 @@ set(mir_fusers
fuse_interpolate fuse_interpolate
fuse_sequence_pool_concat fuse_sequence_pool_concat
fuse_scale_activation fuse_scale_activation
fuse_reshape
fuse_match_matrix_activation
fuse_scales
fuse_sequence_reverse_embedding
CACHE INTERNAL "fusers") CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
...@@ -104,9 +104,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -104,9 +104,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_weight_t = auto conv_weight_t =
scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>(); scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>();
auto groups = conv_op_desc->GetAttr<int>("groups"); auto groups = conv_op_desc->GetAttr<int>("groups");
bool depthwise = false;
if (conv_type_ == "conv2d_transpose") { if (conv_type_ == "conv2d_transpose") {
depthwise = (conv_weight_t->dims()[0] == conv_weight_t->dims()[1] * groups);
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()), CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[1] * groups)) static_cast<size_t>(conv_weight_t->dims()[1] * groups))
<< "The BN bias's size should be equal to the size of the first " << "The BN bias's size should be equal to the size of the first "
...@@ -120,7 +118,6 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -120,7 +118,6 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
size_t weight_num = conv_weight_t->data_size(); size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool is_weight_quantization = conv_op_desc->HasAttr("quantize_weight_bits"); bool is_weight_quantization = conv_op_desc->HasAttr("quantize_weight_bits");
// comupte BN alpha and beta // comupte BN alpha and beta
Tensor alpha_tensor, beta_tensor; Tensor alpha_tensor, beta_tensor;
alpha_tensor.CopyDataFrom(*bn_bias_t); alpha_tensor.CopyDataFrom(*bn_bias_t);
...@@ -162,12 +159,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -162,12 +159,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>(); auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8 // compute new conv_weight for int8
auto weight_scale = conv_op_desc->GetInputScale(weight_name); auto weight_scale = conv_op_desc->GetInputScale(weight_name);
if (conv_type_ == "conv2d_transpose" && !depthwise) { if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int cout = conv_weight_t->dims()[1] * groups;
conv_weight_t->dims()[3]; int cin_group = conv_weight_t->dims()[0] / groups;
int c_size = cout * conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (int k = 0; k < conv_weight_t->dims()[0]; ++k) { for (int k = 0; k < cin_group; ++k) {
for (int i = 0; i < h; ++i) { for (int i = 0; i < cout; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]); weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) { if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + k * c_size + i * hw; auto ptr_row = conv_weight_d + k * c_size + i * hw;
...@@ -203,12 +201,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -203,12 +201,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else { } else {
// compute new conv_weight // compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>(); auto conv_weight_d = conv_weight_t->mutable_data<float>();
if (conv_type_ == "conv2d_transpose" && !depthwise) { if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int cout = conv_weight_t->dims()[1] * groups;
conv_weight_t->dims()[3]; int cin_group = conv_weight_t->dims()[0] / groups;
int c_size = cout * conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (int k = 0; k < conv_weight_t->dims()[0]; ++k) { for (int k = 0; k < cin_group; ++k) {
for (int i = 0; i < h; ++i) { for (int i = 0; i < cout; ++i) {
auto ptr_row = conv_weight_d + k * c_size + i * hw; auto ptr_row = conv_weight_d + k * c_size + i * hw;
for (int j = 0; j < hw; ++j) { for (int j = 0; j < hw; ++j) {
ptr_row[j] *= alpha_data[i]; ptr_row[j] *= alpha_data[i];
......
...@@ -23,7 +23,7 @@ namespace lite { ...@@ -23,7 +23,7 @@ namespace lite {
namespace mir { namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_X86 #if defined(LITE_WITH_X86) || defined(LITE_WITH_CUDA)
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
fusion::FcFuser fuser(false); fusion::FcFuser fuser(false);
fuser(graph.get()); fuser(graph.get());
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/match_matrix_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/match_matrix_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void MatchMatrixActFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::MatchMatrixActFuser fuser("relu");
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_match_matrix_activation_fuse_pass,
paddle::lite::mir::MatchMatrixActFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class MatchMatrixActFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/match_matrix_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void MatchMatrixActFuser::BuildPattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("match_matrix_tensor", "X");
auto* W = VarNode("W")->assert_is_op_input("match_matrix_tensor", "W");
auto* y = VarNode("y")->assert_is_op_input("match_matrix_tensor", "Y");
auto* mm = OpNode("match_matrix_tensor", "match_matrix_tensor");
auto* mm_out =
VarNode("mm_out")->assert_is_op_output("match_matrix_tensor", "Out");
auto* mm_tmp =
VarNode("mm_tmp")->assert_is_op_output("match_matrix_tensor", "Tmp");
auto* act = OpNode("act", activation_);
auto* out = VarNode("Out")->assert_is_op_output(activation_, "Out");
// create topology.
std::vector<PMNode*> mm_inputs{x, W, y};
std::vector<PMNode*> mm_ouputs{mm_out, mm_tmp};
mm_inputs >> *mm >> mm_ouputs;
// Some op specialities.
mm_out->AsIntermediate();
mm->AsIntermediate();
act->AsIntermediate();
*mm_out >> *act >> *out;
}
void MatchMatrixActFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto mm_op = LiteOpRegistry::Global().Create("match_matrix_tensor");
auto mm = matched.at("match_matrix_tensor")->stmt()->op();
auto* scope = mm->scope();
auto& valid_places = mm->valid_places();
mm_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(mm_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("y"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
cpp::OpDesc MatchMatrixActFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("match_matrix_tensor")->stmt()->op_info();
int dim_t = matched.at("match_matrix_tensor")
->stmt()
->op_info()
->GetAttr<int>("dim_t");
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("match_matrix_tensor");
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetOutput("Tmp", {matched.at("mm_tmp")->arg()->name});
op_desc.SetAttr("dim_t", dim_t);
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class MatchMatrixActFuser : public FuseBase {
public:
explicit MatchMatrixActFuser(std::string activation)
: activation_(activation) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string activation_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -175,6 +175,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -175,6 +175,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) { for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
op_desc.SetInputScale(weight_name, weight_scale); op_desc.SetInputScale(weight_name, weight_scale);
...@@ -280,9 +281,8 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -280,9 +281,8 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
} }
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
}
op_desc.SetInputScale(weight_name, weight_scale); op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/reshape_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/reshape_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ReshapeFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> reshape_type_cases{"reshape", "reshape2"};
for (auto type_ : reshape_type_cases) {
fusion::ReshapeFuser reshape_fuser(type_);
reshape_fuser(graph.get());
}
for (auto type_ : reshape_type_cases) {
fusion::Reshape2OutFuser reshape2Out_fuser(type_);
reshape2Out_fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_reshape_fuse_pass, paddle::lite::mir::ReshapeFusePass)
.BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ReshapeFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/reshape_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ReshapeFuser::BuildPattern() {
auto* x = VarNode("x");
auto* reshape = OpNode("reshape", type_);
auto* reshape_out = VarNode("Out");
auto* out1 = OpNode("out1");
*x >> *reshape >> *reshape_out >> *out1;
}
void ReshapeFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = const_cast<OpInfo*>(matched.at("reshape")->stmt()->op_info());
op_desc->SetAttr<bool>("inplace", true);
}
void Reshape2OutFuser::BuildPattern() {
auto* x = VarNode("x");
auto* reshape =
OpNode("reshape", type_)->assert_op_attr<bool>("inplace", true);
auto* reshape_out = VarNode("Out");
auto* out1 = OpNode("out1");
auto* out2 = OpNode("out2");
*x >> *reshape >> *reshape_out >> *out1;
*reshape_out >> *out2;
}
void Reshape2OutFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = const_cast<OpInfo*>(matched.at("reshape")->stmt()->op_info());
op_desc->SetAttr<bool>("inplace", false);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ReshapeFuser : public FuseBase {
public:
explicit ReshapeFuser(const std::string& type) : type_(type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
std::string type_;
};
class Reshape2OutFuser : public FuseBase {
public:
explicit Reshape2OutFuser(const std::string& type) : type_(type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
std::string type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scales_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/scales_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ScalesFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ScalesFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_scales_fuse_pass, paddle::lite::mir::ScalesFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ScalesFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scales_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ScalesFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
auto scales_teller = [](const Node* node) -> bool {
bool bias_after_scale =
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<bool>(
"bias_after_scale");
return bias_after_scale;
};
// create op nodes
auto* scale1 = OpNode("scale1", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scales_teller)
->AsIntermediate();
auto* scale2 = OpNode("scale2", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scales_teller)
->AsIntermediate();
// create intermediate nodes
auto* scale1_out = VarNode("scale1_out")
->assert_is_op_output("scale", "Out")
->assert_is_op_input("scale", "X")
->AsIntermediate();
// create output node
auto* out = VarNode("out")->assert_is_op_output("scale", "Out")->AsOutput();
// create topology.
*x >> *scale1 >> *scale1_out >> *scale2 >> *out;
}
void ScalesFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto scale_op = LiteOpRegistry::Global().Create("scale");
auto scale = matched.at("scale1")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();
scale_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(scale_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc ScalesFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("scale1")->stmt()->op_info();
float scale1 = op_desc.GetAttr<float>("scale");
float bias1 = op_desc.GetAttr<float>("bias");
float scale2 =
matched.at("scale2")->stmt()->op_info()->GetAttr<float>("scale");
float bias2 = matched.at("scale2")->stmt()->op_info()->GetAttr<float>("bias");
op_desc.SetAttr<float>("scale", scale1 * scale2);
op_desc.SetAttr<float>("bias", bias1 * scale2 + bias2);
auto& out_name = matched.at("out")->arg()->name;
op_desc.SetOutput("Out", {out_name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ScalesFuser : public FuseBase {
public:
ScalesFuser() {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -23,8 +23,11 @@ namespace lite { ...@@ -23,8 +23,11 @@ namespace lite {
namespace mir { namespace mir {
void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::SequencePoolConcatFuser fuser; fusion::SequencePool7ConcatFuser fuser;
fuser(graph.get()); fuser(graph.get());
fusion::SequencePool2ConcatFuser fuser2;
fuser2(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -21,22 +21,6 @@ namespace lite { ...@@ -21,22 +21,6 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePoolConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
#define STR1(R) #R #define STR1(R) #R
#define STR2(R) STR1(R) #define STR2(R) STR1(R)
...@@ -58,6 +42,22 @@ void SequencePoolConcatFuser::BuildPattern() { ...@@ -58,6 +42,22 @@ void SequencePoolConcatFuser::BuildPattern() {
*sequence_pool_##num >> *sequence_pool_##num##_idx; \ *sequence_pool_##num >> *sequence_pool_##num##_idx; \
*x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat; *x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat;
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePool7ConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out = auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out"); VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out; *concat >> *concat_out;
...@@ -69,14 +69,10 @@ void SequencePoolConcatFuser::BuildPattern() { ...@@ -69,14 +69,10 @@ void SequencePoolConcatFuser::BuildPattern() {
POOL_CONCAT_PATTERN(5); POOL_CONCAT_PATTERN(5);
POOL_CONCAT_PATTERN(6); POOL_CONCAT_PATTERN(6);
POOL_CONCAT_PATTERN(7); POOL_CONCAT_PATTERN(7);
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
} }
void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph, void SequencePool7ConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op = auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat"); LiteOpRegistry::Global().Create("sequence_pool_concat");
...@@ -99,7 +95,7 @@ void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph, ...@@ -99,7 +95,7 @@ void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out")); IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
} }
cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc SequencePool7ConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info(); cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat"); op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X", op_desc.SetInput("X",
...@@ -147,6 +143,64 @@ cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -147,6 +143,64 @@ cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc; return op_desc;
} }
void SequencePool2ConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out;
POOL_CONCAT_PATTERN(1);
POOL_CONCAT_PATTERN(2);
}
void SequencePool2ConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat");
auto concat = matched.at("concat")->stmt()->op();
auto* scope = concat->scope();
auto& valid_places = concat->valid_places();
sequence_pool_concat_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(sequence_pool_concat_op, valid_places);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_1"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_2"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
}
cpp::OpDesc SequencePool2ConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X",
{matched.at("sequence_pool_x_1")->arg()->name,
matched.at("sequence_pool_x_2")->arg()->name});
std::vector<std::string> pooltypes;
pooltypes.push_back(matched.at("sequence_pool_1")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_2")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
op_desc.SetAttr("pooltype", pooltypes);
op_desc.SetOutput("Out", {matched.at("concat_out")->arg()->name});
return op_desc;
}
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -23,7 +23,16 @@ namespace lite { ...@@ -23,7 +23,16 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
class SequencePoolConcatFuser : public FuseBase { class SequencePool7ConcatFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
class SequencePool2ConcatFuser : public FuseBase {
public: public:
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_reverse_embedding_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/sequence_reverse_embedding_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void SequenceReverseEmbeddingFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::SequenceReverseEmbeddingFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_sequence_reverse_embedding_fuse_pass,
paddle::lite::mir::SequenceReverseEmbeddingFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class SequenceReverseEmbeddingFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_reverse_embedding_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void SequenceReverseEmbeddingFuser::BuildPattern() {
// create input nodes.
auto* x =
VarNode("x")->assert_is_op_input("sequence_reverse", "X")->AsInput();
auto* w = VarNode("w")->assert_is_op_input("lookup_table", "W")->AsInput();
// create op nodes
auto* sequence_reverse = OpNode("sequence_reverse", "sequence_reverse")
->assert_is_op("sequence_reverse")
->AsIntermediate();
auto* lookup_table = OpNode("lookup_table", "lookup_table")
->assert_is_op("lookup_table")
->AsIntermediate();
// create intermediate nodes
auto* sequence_reverse_out =
VarNode("sequence_reverse_out")
->assert_is_op_output("sequence_reverse", "Y")
->assert_is_op_input("lookup_table", "Ids")
->AsIntermediate();
// create output node
auto* out =
VarNode("out")->assert_is_op_output("lookup_table", "Out")->AsOutput();
// create topology.
*x >> *sequence_reverse >> *sequence_reverse_out >> *lookup_table >> *out;
*w >> *lookup_table;
}
void SequenceReverseEmbeddingFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fuse_op = LiteOpRegistry::Global().Create("sequence_reverse_embedding");
auto lookup_table = matched.at("lookup_table")->stmt()->op();
auto* scope = lookup_table->scope();
auto& valid_places = lookup_table->valid_places();
fuse_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fuse_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("w"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc SequenceReverseEmbeddingFuser::GenOpDesc(
const key2nodes_t& matched) {
auto op_desc = *matched.at("lookup_table")->stmt()->op_info();
op_desc.SetType("sequence_reverse_embedding");
auto& in_name = matched.at("x")->arg()->name;
auto& w_name = matched.at("w")->arg()->name;
auto& out_name = matched.at("out")->arg()->name;
op_desc.SetInput("Ids", {in_name});
op_desc.SetInput("W", {w_name});
op_desc.SetOutput("Out", {out_name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class SequenceReverseEmbeddingFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -25,6 +25,9 @@ void VarConvActivationFuser::BuildPattern() { ...@@ -25,6 +25,9 @@ void VarConvActivationFuser::BuildPattern() {
// create nodes. // create nodes.
auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput(); auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput();
auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput(); auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput();
auto* column =
VarNode("COLUMN")->assert_is_op_input(conv_type_, "COLUMN")->AsInput();
auto* row = VarNode("ROW")->assert_is_op_input(conv_type_, "ROW")->AsInput();
auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate(); auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate();
...@@ -42,7 +45,7 @@ void VarConvActivationFuser::BuildPattern() { ...@@ -42,7 +45,7 @@ void VarConvActivationFuser::BuildPattern() {
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology. // create topology.
std::vector<PMNode*> conv2d_inputs{filter, input}; std::vector<PMNode*> conv2d_inputs{filter, input, column, row};
conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out;
*conv2d >> *conv2d_out_1; *conv2d >> *conv2d_out_1;
} }
...@@ -60,6 +63,8 @@ void VarConvActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -60,6 +63,8 @@ void VarConvActivationFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(matched.at("X"), new_op_node); IR_NODE_LINK_TO(matched.at("X"), new_op_node);
IR_NODE_LINK_TO(matched.at("W"), new_op_node); IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("COLUMN"), new_op_node);
IR_NODE_LINK_TO(matched.at("ROW"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output")); IR_NODE_LINK_TO(new_op_node, matched.at("output"));
} }
......
...@@ -91,11 +91,14 @@ class Optimizer { ...@@ -91,11 +91,14 @@ class Optimizer {
// kernels for devices automatically. // kernels for devices automatically.
"lite_conv_activation_fuse_pass", // "lite_conv_activation_fuse_pass", //
"lite_var_conv_2d_activation_fuse_pass", // "lite_var_conv_2d_activation_fuse_pass", //
"lite_match_matrix_activation_fuse_pass", //
"lite_fc_fuse_pass", // "lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", // "lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", // "lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", // "identity_scale_eliminate_pass", //
"lite_scales_fuse_pass", //
"lite_sequence_reverse_embedding_fuse_pass", //
"elementwise_mul_constant_eliminate_pass", // "elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", // "lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", // "lite_scale_activation_fuse_pass", //
...@@ -161,6 +164,7 @@ class Optimizer { ...@@ -161,6 +164,7 @@ class Optimizer {
"runtime_context_assign_pass", "runtime_context_assign_pass",
"argument_type_display_pass", "argument_type_display_pass",
"lite_reshape_fuse_pass",
"memory_optimize_pass"}}; "memory_optimize_pass"}};
......
...@@ -159,7 +159,9 @@ RuntimeProgram::RuntimeProgram( ...@@ -159,7 +159,9 @@ RuntimeProgram::RuntimeProgram(
int block_idx) int block_idx)
: exec_scope_(exec_scope) { : exec_scope_(exec_scope) {
#ifdef LITE_WITH_OPENCL #ifdef LITE_WITH_OPENCL
bool opencl_valid = CLRuntime::Global()->OpenCLAvaliableForDevice(); bool opencl_valid = paddle::lite::CLWrapper::Global()->OpenclLibFound() &&
paddle::lite::CLWrapper::Global()->DlsymSuccess() &&
CLRuntime::Global()->OpenCLAvaliableForDevice();
using OpenCLContext = Context<TargetType::kOpenCL>; using OpenCLContext = Context<TargetType::kOpenCL>;
std::unique_ptr<KernelContext> unique_opencl_ctx(new KernelContext()); std::unique_ptr<KernelContext> unique_opencl_ctx(new KernelContext());
if (opencl_valid) { if (opencl_valid) {
......
...@@ -35,7 +35,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -35,7 +35,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool ch_four = channel <= 4 * win; bool ch_four = channel <= 4 * win;
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
// VLOG(5) << "invoke 3x3 dw conv fp32";
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (ch_four && pads_less && paddings[0] == paddings[2] && if (ch_four && pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) { (paddings[0] == 0 || paddings[0] == 1)) {
...@@ -116,23 +115,44 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -116,23 +115,44 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
w_scale_[i] = scale[i] * in_scale; w_scale_[i] = scale[i] * in_scale;
} }
} }
auto paddings = *param.paddings;
auto strides = param.strides;
auto x_dims = param.x->dims();
int iw = x_dims[3];
int ih = x_dims[2];
auto act_param = param.activation_param;
bool has_act = act_param.has_active;
lite_api::ActivationType act_type = act_param.active_type;
// no activation and relu activation is supported now
bool support_act_type =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (strides[0] == 1 && strides[1] == 1);
bool support_width_type = iw > 9 ? true : false;
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_fp32"; kernel_func_name_ = "conv_depthwise_3x3_int8_fp32";
#endif #endif
int cround = ROUNDUP(w_dims[0], 8); if (!support_act_type || !support_pad_type || !support_stride_type ||
weights_.Resize({cround / 8, 1, kh * kw, 8}); !support_width_type) {
auto wptr = param.filter->data<int8_t>(); int cround = ROUNDUP(w_dims[0], 8);
auto wptr_new = weights_.mutable_data<int8_t>(); weights_.Resize({cround / 8, 1, kh * kw, 8});
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); auto wptr = param.filter->data<int8_t>();
flag_trans_weights_ = true; auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_fp32"; kernel_func_name_ = "conv_depthwise_5x5_int8_fp32";
...@@ -187,23 +207,45 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -187,23 +207,45 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
param.activation_param.Relu_clipped_coef = param.activation_param.Relu_clipped_coef =
param.activation_param.Relu_clipped_coef / param.output_scale; param.activation_param.Relu_clipped_coef / param.output_scale;
} }
auto paddings = *param.paddings;
auto strides = param.strides;
auto x_dims = param.x->dims();
int iw = x_dims[3];
int ih = x_dims[2];
auto act_param = param.activation_param;
bool has_act = act_param.has_active;
lite_api::ActivationType act_type = act_param.active_type;
// no activation and relu activation is supported now
bool support_act_type =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (strides[0] == 1 && strides[1] == 1);
bool support_width_type = iw > 9 ? true : false;
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_int8"; kernel_func_name_ = "conv_depthwise_3x3_int8_int8";
#endif #endif
int cround = ROUNDUP(w_dims[0], 8); if (!support_act_type || !support_pad_type || !support_stride_type ||
weights_.Resize({cround / 8, 1, kh * kw, 8}); !support_width_type) {
auto wptr = param.filter->data<int8_t>(); int cround = ROUNDUP(w_dims[0], 8);
auto wptr_new = weights_.mutable_data<int8_t>(); weights_.Resize({cround / 8, 1, kh * kw, 8});
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); auto wptr = param.filter->data<int8_t>();
flag_trans_weights_ = true; auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8; impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_int8"; kernel_func_name_ = "conv_depthwise_5x5_int8_int8";
...@@ -295,7 +337,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -295,7 +337,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw int iw = x_dims[3];
int ih = x_dims[2]; int ih = x_dims[2];
int ic = x_dims[1]; int ic = x_dims[1];
int bs = x_dims[0]; int bs = x_dims[0];
...@@ -345,7 +387,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -345,7 +387,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw int iw = x_dims[3];
int ih = x_dims[2]; int ih = x_dims[2];
int ic = x_dims[1]; int ic = x_dims[1];
int bs = x_dims[0]; int bs = x_dims[0];
......
...@@ -73,7 +73,6 @@ void Conv2DTransposeCompute::Run() { ...@@ -73,7 +73,6 @@ void Conv2DTransposeCompute::Run() {
int kw = w_dims[3]; // oihw int kw = w_dims[3]; // oihw
int kh = w_dims[2]; int kh = w_dims[2];
int group = param.groups; int group = param.groups;
bool fuse_relu = param.fuse_relu;
bool flag_bias = (param.bias != nullptr); bool flag_bias = (param.bias != nullptr);
auto paddings = *param.paddings; auto paddings = *param.paddings;
...@@ -104,6 +103,7 @@ void Conv2DTransposeCompute::Run() { ...@@ -104,6 +103,7 @@ void Conv2DTransposeCompute::Run() {
auto dout = param.output->mutable_data<float>(); auto dout = param.output->mutable_data<float>();
auto weights = param.filter->data<float>(); auto weights = param.filter->data<float>();
auto act_param = param.activation_param; auto act_param = param.activation_param;
bool has_act = act_param.has_active;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
const float* din_batch = din + i * chin * hin * win; const float* din_batch = din + i * chin * hin * win;
float* dout_batch = dout + i * chout * hout * wout; float* dout_batch = dout + i * chout * hout * wout;
...@@ -152,13 +152,14 @@ void Conv2DTransposeCompute::Run() { ...@@ -152,13 +152,14 @@ void Conv2DTransposeCompute::Run() {
dout_batch); dout_batch);
} }
if (flag_bias) { if (flag_bias) {
lite::arm::math::fill_bias_relu<float>( act_param.has_active = has_act;
lite::arm::math::fill_bias_act<float>(
dout_batch, dout_batch,
static_cast<const float*>(param.bias->data<float>()), static_cast<const float*>(param.bias->data<float>()),
chout, chout,
wout * hout, wout * hout,
flag_bias, flag_bias,
fuse_relu); &act_param);
} }
} }
} }
......
...@@ -78,6 +78,9 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { ...@@ -78,6 +78,9 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic); void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<float>(); auto weights_data_ = weights_.mutable_data<float>();
memset(reinterpret_cast<char*>(weights_data_),
0,
weights_.numel() * sizeof(float));
if (!choose_small_) { if (!choose_small_) {
lite::arm::math::weight_trans_c4_8x8( lite::arm::math::weight_trans_c4_8x8(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr); weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
...@@ -251,6 +254,9 @@ void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() { ...@@ -251,6 +254,9 @@ void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic); void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<int16_t>(); auto weights_data_ = weights_.mutable_data<int16_t>();
memset(reinterpret_cast<char*>(weights_data_),
0,
weights_.numel() * sizeof(int16_t));
if (!choose_small_) { if (!choose_small_) {
} else { } else {
lite::arm::math::weight_trans_c8_4x4_int8( lite::arm::math::weight_trans_c8_4x4_int8(
......
...@@ -38,6 +38,7 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute. ...@@ -38,6 +38,7 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda}) add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda})
add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_embedding_compute_cuda CUDA extra SRCS sequence_reverse_embedding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
......
...@@ -52,6 +52,7 @@ __global__ void padding_out(const dtype* src, ...@@ -52,6 +52,7 @@ __global__ void padding_out(const dtype* src,
const int max_len_r, const int max_len_r,
const int tl, const int tl,
const int count, const int count,
const bool fuse_relu,
dtype* dst) { dtype* dst) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
int thread_num = blockDim.x * gridDim.x; int thread_num = blockDim.x * gridDim.x;
...@@ -62,7 +63,13 @@ __global__ void padding_out(const dtype* src, ...@@ -62,7 +63,13 @@ __global__ void padding_out(const dtype* src,
int r_id = tid % max_len_r; int r_id = tid % max_len_r;
int cur_len = offset[seq_id + 1] - offset[seq_id]; int cur_len = offset[seq_id + 1] - offset[seq_id];
if (r_id < cur_len) { if (r_id < cur_len) {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id]; if (fuse_relu) {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id] > 0
? src[(offset[seq_id] + r_id) * tl + tl_id]
: 0;
} else {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id];
}
} else { } else {
dst[tid] = 0.f; dst[tid] = 0.f;
} }
...@@ -86,6 +93,7 @@ void MatchMatrixTensorCompute::Run() { ...@@ -86,6 +93,7 @@ void MatchMatrixTensorCompute::Run() {
auto* tmp = param.tmp; auto* tmp = param.tmp;
int dim_t = param.dim_t; int dim_t = param.dim_t;
int dim_in = x->dims()[1]; int dim_in = x->dims()[1];
bool fuse_relu = param.fuse_relu;
const auto& offset_l = x->lod()[0]; const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0]; const auto& offset_r = y->lod()[0];
...@@ -155,6 +163,7 @@ void MatchMatrixTensorCompute::Run() { ...@@ -155,6 +163,7 @@ void MatchMatrixTensorCompute::Run() {
max_len_r, max_len_r,
dim_t * len_l, dim_t * len_l,
count, count,
fuse_relu,
out_data); out_data);
out->set_lod(y->lod()); out->set_lod(y->lod());
} }
......
...@@ -37,6 +37,40 @@ __global__ void SequenceMaskKernel(T* dst, ...@@ -37,6 +37,40 @@ __global__ void SequenceMaskKernel(T* dst,
} }
} }
template <typename T>
__global__ void VecMaxKernel(const T* in_data, T* out, const int count) {
extern __shared__ T cache[];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int cache_index = threadIdx.x;
T tmp = -1;
while (i < count) {
if (in_data[i] > tmp) {
tmp = in_data[i];
}
i += blockDim.x * gridDim.x;
}
cache[cache_index] = tmp;
__syncthreads();
// perform parallel reduction, blockDim.x must be 2^n
int ib = blockDim.x / 2;
while (ib != 0) {
if (cache_index < ib && cache[cache_index + ib] > cache[cache_index]) {
cache[cache_index] = cache[cache_index + ib];
}
__syncthreads();
ib /= 2;
}
if (cache_index == 0) {
out[blockIdx.x] = cache[0];
}
}
template <typename T, PrecisionType Ptype> template <typename T, PrecisionType Ptype>
void SequenceMaskCompute<T, Ptype>::Run() { void SequenceMaskCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
...@@ -57,11 +91,34 @@ void SequenceMaskCompute<T, Ptype>::Run() { ...@@ -57,11 +91,34 @@ void SequenceMaskCompute<T, Ptype>::Run() {
} }
if (maxlen < 0) { if (maxlen < 0) {
maxlen = static_cast<int>( // choose algorithm according to magic_num.
thrust::reduce(thrust::device_pointer_cast(x_data), const int magic_num = 256;
thrust::device_pointer_cast(x_data) + x->numel(), std::vector<int64_t> h_max_data;
static_cast<int64_t>(0), if (x->numel() < magic_num) {
thrust::maximum<int64_t>())); h_max_data.resize(x->numel());
TargetWrapperCuda::MemcpySync(h_max_data.data(),
x_data,
x->numel() * sizeof(int64_t),
IoDirection::DtoH);
} else {
const int threads = 256;
const int blocks = (x->numel() + threads - 1) / threads;
max_tensor_.Resize({blocks});
auto* max_data = max_tensor_.mutable_data<int64_t>(TARGET(kCUDA));
VecMaxKernel<
int64_t><<<blocks, threads, threads * sizeof(int64_t), stream>>>(
x_data, max_data, x->numel());
h_max_data.resize(blocks);
TargetWrapperCuda::MemcpyAsync(h_max_data.data(),
max_data,
sizeof(int64_t) * blocks,
IoDirection::DtoH,
stream);
TargetWrapperCuda::StreamSync(stream);
}
auto maxlen_iterator =
std::max_element(h_max_data.begin(), h_max_data.end());
maxlen = h_max_data[std::distance(h_max_data.begin(), maxlen_iterator)];
} }
auto y_dim = x->dims().Vectorize(); auto y_dim = x->dims().Vectorize();
......
...@@ -28,6 +28,9 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> { ...@@ -28,6 +28,9 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> {
void Run() override; void Run() override;
virtual ~SequenceMaskCompute() = default; virtual ~SequenceMaskCompute() = default;
private:
lite::Tensor max_tensor_;
}; };
} // namespace cuda } // namespace cuda
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_reverse_embedding_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__host__ __device__ inline size_t UpperBound(const T* x,
const int num,
const T& val) {
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto* first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
auto step = (count >> 1);
auto* it = first + step;
if (val < *it) {
count = step;
} else {
first = ++it;
count -= (step + 1);
}
}
return static_cast<size_t>(first - x);
}
template <typename T>
__global__ void SequenceReverseEmbeddingKernel(const int64_t* ids,
const T* table,
T* out,
const int64_t* lod,
const int lod_count,
const int width,
const int count,
const bool padding_flag,
const int64_t padding_idx) {
CUDA_KERNEL_LOOP(tid, count) {
int64_t row = tid / width;
int col = tid % width;
auto lod_idx = UpperBound(lod, lod_count, row);
auto reverse_row = lod[lod_idx - 1] + lod[lod_idx] - 1 - row;
if (padding_flag) {
if (ids[reverse_row] == padding_idx)
out[tid] = 0;
else
out[tid] = table[ids[reverse_row] * width + col];
} else {
out[tid] = table[ids[reverse_row] * width + col];
}
}
}
template <typename T, PrecisionType Ptype>
void SequenceReverseEmbeddingCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto io_stream = ctx.io_stream();
auto* table_data = param.W->template data<T>();
auto* out_data = param.Out->template mutable_data<T>(TARGET(kCUDA));
auto* ids_data = param.Ids->template data<int64_t>();
const auto lod = param.Ids->lod()[param.Ids->lod().size() - 1];
const int lod_count = lod.size();
const int width = param.W->dims()[1];
const int count = param.Out->numel();
lod_info_.Resize({static_cast<int64_t>(lod.size())});
int64_t* lod_data = lod_info_.mutable_data<int64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(lod_data,
lod.data(),
sizeof(int64_t) * lod.size(),
IoDirection::HtoD,
stream);
int64_t padding_idx = param.padding_idx;
bool padding_flag = padding_idx != -1;
SequenceReverseEmbeddingKernel<
T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(ids_data,
table_data,
out_data,
lod_data,
lod_count,
width,
count,
padding_flag,
padding_idx);
CUDA_POST_KERNEL_CHECK;
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SeqReverseEmbFp32 = paddle::lite::kernels::cuda::
SequenceReverseEmbeddingCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
sequence_reverse_embedding, kCUDA, kFloat, kNCHW, SeqReverseEmbFp32, def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SequenceReverseEmbeddingCompute
: public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::LookupTableParam;
void Run() override;
virtual ~SequenceReverseEmbeddingCompute() = default;
private:
lite::Tensor lod_info_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -55,7 +55,9 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( ...@@ -55,7 +55,9 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
output_data + output_data +
(gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size + (gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size +
blockIdx.y * topk_size; blockIdx.y * topk_size;
for (int i = 0; i < topk_size; ++i) {
fm_row_out_data[i] = 0;
}
Dtype *smem_start_col = smem + idx * col_max; Dtype *smem_start_col = smem + idx * col_max;
int counter = max_k; // topk_size; int counter = max_k; // topk_size;
...@@ -151,6 +153,9 @@ __global__ void topk_avg_pooling_kernel_for_big_data( ...@@ -151,6 +153,9 @@ __global__ void topk_avg_pooling_kernel_for_big_data(
blockIdx.z * actual_row_in_shared_mem + idx) * blockIdx.z * actual_row_in_shared_mem + idx) *
feat_map_num * topk_size + feat_map_num * topk_size +
blockIdx.y * topk_size; blockIdx.y * topk_size;
for (int i = 0; i < topk_size; ++i) {
fm_row_out_data[i] = 0;
}
Dtype *smem_start_col = smem + idx * col_max; Dtype *smem_start_col = smem + idx * col_max;
...@@ -239,8 +244,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() { ...@@ -239,8 +244,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
Tensor *out_tensor = param.Out; Tensor *out_tensor = param.Out;
const T *in_data = x_tensor->data<T>(); const T *in_data = x_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA)); T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA));
TargetWrapperCuda::MemsetAsync(
out_data, 0, sizeof(T) * param.Out->numel(), cuda_stream);
int topk_num = param.topks.size(); int topk_num = param.topks.size();
lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1}); lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1});
......
...@@ -23,7 +23,9 @@ add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kerne ...@@ -23,7 +23,9 @@ add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kerne
add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program) add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program) add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps}) add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host) lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
lite_cc_test(test_one_hot_compute_host SRCS one_hot_compute_test.cc DEPS one_hot_compute_host)
endif() endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/one_hot_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename T>
void OneHotKernelFunctor(const Tensor* in,
Tensor* out,
int depth,
bool allow_out_of_range = false) {
auto* p_in_data = in->data<T>();
auto numel = in->numel();
auto* p_out_data = out->mutable_data<T>();
memset(p_out_data, 0, out->numel() * sizeof(T));
if (allow_out_of_range) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth) {
p_out_data[i * depth + static_cast<int>(p_in_data[i])] = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
CHECK_GE(p_in_data[i], 0) << "Illegal index value, Input(input) value "
"should be at least 0, but received input ("
<< p_in_data[i] << ") less than 0";
CHECK_LE(p_in_data[i], depth)
<< "Illegal index value, Input(input) value should be less than "
"Input(depth), but received input ("
<< p_in_data[i] << ") not less than depth (" << depth << ")";
p_out_data[i * depth + static_cast<int>(p_in_data[i])] = 1.0;
}
}
}
void OneHotCompute::Run() {
auto& param = this->template Param<param_t>();
switch (param.dtype) {
case static_cast<int>(lite::core::FluidType::INT64):
OneHotKernelFunctor<int64_t>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
case static_cast<int>(lite::core::FluidType::INT32):
OneHotKernelFunctor<int32_t>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
case static_cast<int>(lite::core::FluidType::FP32):
OneHotKernelFunctor<float>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
default:
LOG(ERROR) << "Unsupported data type for one_hot op:" << param.dtype;
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("depth_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class OneHotCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::OneHotParam;
void Run() override;
virtual ~OneHotCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/host/one_hot_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
/* note:
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]] */
TEST(one_hot, test) {
using T = float;
lite::Tensor x, out;
x.Resize({4, 1});
out.Resize({4, 4});
auto* x_data = x.mutable_data<T>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
x_data[3] = 0;
auto* out_data = out.mutable_data<T>();
float out_ref[4][4] = {
{0, 1, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 1}, {1, 0, 0, 0}};
OneHotCompute one_hot;
operators::OneHotParam param;
param.X = &x;
param.Out = &out;
param.depth = 4;
// static_cast<int>(lite::core::FluidType::FP32) = 5;
param.dtype = 5;
one_hot.SetParam(param);
one_hot.PrepareForRun();
one_hot.Run();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref[i], 1e-5);
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def);
...@@ -17,6 +17,15 @@ lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_ ...@@ -17,6 +17,15 @@ lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_
lite_cc_library(subgraph_bridge_softmax_op_huawei_ascend_npu SRCS softmax_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_softmax_op_huawei_ascend_npu SRCS softmax_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_huawei_ascend_npu SRCS dropout_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_huawei_ascend_npu SRCS dropout_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fc_op_huawei_ascend_npu SRCS fc_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_fc_op_huawei_ascend_npu SRCS fc_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reshape_op_huawei_ascend_npu SRCS reshape_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_transpose_op_huawei_ascend_npu SRCS transpose_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_flatten_op_huawei_ascend_npu SRCS flatten_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_layer_norm_op_huawei_ascend_npu SRCS layer_norm_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_huawei_ascend_npu SRCS matmul_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_huawei_ascend_npu SRCS cast_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_scale_op_huawei_ascend_npu SRCS scale_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_slice_op_huawei_ascend_npu SRCS slice_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_gather_op_huawei_ascend_npu SRCS gather_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
set(huawei_ascend_npu_subgraph_bridges set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -32,4 +41,13 @@ set(huawei_ascend_npu_subgraph_bridges ...@@ -32,4 +41,13 @@ set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_softmax_op_huawei_ascend_npu subgraph_bridge_softmax_op_huawei_ascend_npu
subgraph_bridge_dropout_op_huawei_ascend_npu subgraph_bridge_dropout_op_huawei_ascend_npu
subgraph_bridge_fc_op_huawei_ascend_npu subgraph_bridge_fc_op_huawei_ascend_npu
subgraph_bridge_reshape_op_huawei_ascend_npu
subgraph_bridge_transpose_op_huawei_ascend_npu
subgraph_bridge_flatten_op_huawei_ascend_npu
subgraph_bridge_layer_norm_op_huawei_ascend_npu
subgraph_bridge_matmul_op_huawei_ascend_npu
subgraph_bridge_cast_op_huawei_ascend_npu
subgraph_bridge_scale_op_huawei_ascend_npu
subgraph_bridge_slice_op_huawei_ascend_npu
subgraph_bridge_gather_op_huawei_ascend_npu
CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges") CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges")
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
// auto in_dtype = op_info->GetAttr<int>("in_dtype");
auto out_dtype = op_info->GetAttr<int>("out_dtype");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
PrecisionType ptype = PRECISION(kFloat);
ge::DataType otype = ge::DT_FLOAT;
switch (out_dtype) {
case 0: // BOOL = 0;
ptype = PRECISION(kBool);
otype = ge::DT_BOOL;
break;
case 1: // INT16 = 1
ptype = PRECISION(kInt16);
otype = ge::DT_INT16;
break;
case 2: // INT32 = 2
ptype = PRECISION(kInt32);
otype = ge::DT_INT32;
break;
case 3: // INT64 = 3
ptype = PRECISION(kInt64);
otype = ge::DT_INT64;
break;
case 4: // FP16 = 4
ptype = PRECISION(kFP16);
otype = ge::DT_FLOAT16;
break;
case 5: // FP32 = 5
ptype = PRECISION(kFloat);
otype = ge::DT_FLOAT;
break;
case 21: // INT8 = 21
ptype = PRECISION(kInt8);
otype = ge::DT_INT8;
break;
default:
LOG(FATAL) << "unsupported data type: " << out_dtype;
break;
}
// Cast node
auto cast_node = graph->Add<ge::op::Cast>(out_name, ptype);
auto cast_op = cast_node->data<ge::op::Cast>();
cast_op->set_input_x(*x_node->data());
cast_op->set_attr_dst_type(otype);
INPUT_UPDATE(cast_op, x, x_node);
OUTPUT_UPDATE(cast_op, y, cast_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
cast,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::CastConverter);
...@@ -132,19 +132,22 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -132,19 +132,22 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
return FAILED; return FAILED;
} }
// Filter node
std::shared_ptr<Node> filter_node = nullptr;
// Check depthwise mode, and decide whether use DepthwiseConv2D Op // Check depthwise mode, and decide whether use DepthwiseConv2D Op
bool use_depthwise_conv = false; bool use_depthwise_conv = false;
bool is_depthwise_mode = (ic == groups && oc == groups); bool is_depthwise_mode = (ic == groups && oc == groups);
if (is_depthwise_mode && dilations[0] == 1 && dilations[1] == 1) { if (is_depthwise_mode && dilations[0] == 1 && dilations[1] == 1) {
use_depthwise_conv = true; use_depthwise_conv = true;
// Change filter shape {oc, ic/groups = 1, kh, kw} => { K=1, oc, kh, hw} // Change filter shape {oc, ic/groups = 1, kh, kw} => { K=1, oc, kh, hw}
filter->Resize({1L, oc, filter_dims[2], filter_dims[3]}); filter_node = graph->Add(
filter_name, *filter, {1L, oc, filter_dims[2], filter_dims[3]});
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] DepthwiseConv2D op is used."; LOG(WARNING) << "[HUAWEI_ASCEND_NPU] DepthwiseConv2D op is used.";
} else {
filter_node = graph->Add(filter_name, *filter);
} }
// Filter node
auto filter_node = graph->Add(filter_name, *filter);
// Add bias node if exists bias // Add bias node if exists bias
// Supports the bias nodes with the following dimensions // Supports the bias nodes with the following dimensions
// 0: {oc} => 1D tensor of foramt ND // 0: {oc} => 1D tensor of foramt ND
......
...@@ -138,7 +138,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -138,7 +138,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::shared_ptr<Node> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) { if (graph->Has(x_name)) {
x_node = graph->Get(x_name); x_node = graph->Get(x_name);
auto shape_node = graph->Add<int64_t>(x_name + "/shape", x_new_shape); auto shape_node = graph->Add<int64_t>(x_name + "/x_shape", x_new_shape);
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape"); auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>(); auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_x(*x_node->data()); reshaped_x_op->set_input_x(*x_node->data());
...@@ -156,7 +156,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -156,7 +156,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::shared_ptr<Node> y_node = nullptr; std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) { if (graph->Has(y_name)) {
y_node = graph->Get(y_name); y_node = graph->Get(y_name);
auto shape_node = graph->Add<int64_t>(y_name + "/shape", y_new_shape); auto shape_node = graph->Add<int64_t>(y_name + "/y_shape", y_new_shape);
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape"); auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>(); auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_x(*y_node->data()); reshaped_y_op->set_input_x(*y_node->data());
...@@ -224,7 +224,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -224,7 +224,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_shape = out_dims.Vectorize(); auto out_shape = out_dims.Vectorize();
if (out_shape != x_new_shape) { if (out_shape != x_new_shape) {
auto shape_node = graph->Add<int64_t>(out_name + "/shape", out_shape); auto shape_node = graph->Add<int64_t>(out_name + "/out_shape", out_shape);
auto reshaped_elt_node = graph->Add<ge::op::Reshape>(out_name); auto reshaped_elt_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_elt_op = reshaped_elt_node->data<ge::op::Reshape>(); auto reshaped_elt_op = reshaped_elt_node->data<ge::op::Reshape>();
reshaped_elt_op->set_input_x(*elt_node->data()); reshaped_elt_op->set_input_x(*elt_node->data());
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
VLOG(3) << "output shape is: " << out_dims.repr();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Const Shape node
auto shape_node =
graph->Add<int64_t>(x_name + "/shape", out_dims.Vectorize());
// Reshape node
auto reshaped_x_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_x(*x_node->data());
reshaped_x_op->set_input_shape(*shape_node->data());
reshaped_x_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_x_op, x, x_node);
INPUT_UPDATE(reshaped_x_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_x_op, y, reshaped_x_node);
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
flatten,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::FlattenConverter);
REGISTER_SUBGRAPH_BRIDGE(
flatten2,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::FlattenConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto index_name = op_info->Input("Index").front();
auto index = scope->FindTensor(index_name);
auto index_dims = index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1))
<< "index dims unmatch";
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Index node
std::shared_ptr<Node> index_node = nullptr;
if (graph->Has(index_name)) {
index_node = graph->Get(index_name);
} else {
index_node = graph->Add(index_name, *index);
}
// Gather node
auto gather_node = graph->Add<ge::op::Gather>(out_name);
auto gather_op = gather_node->data<ge::op::Gather>();
gather_op->set_input_x(*x_node->data());
gather_op->set_input_indices(*index_node->data());
INPUT_UPDATE(gather_op, x, x_node);
INPUT_UPDATE(gather_op, indices, index_node);
OUTPUT_UPDATE(gather_op, y, gather_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
gather,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::GatherConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto x_rank = static_cast<int>(x_dims.size());
CHECK(x_rank >= 2 && x_rank <= 4);
bool has_bias = op_info->HasInput("Bias");
bool has_scale = op_info->HasInput("Scale");
auto y_name = op_info->Output("Y").front();
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto mean_name = op_info->Output("Mean").front();
auto mean = scope->FindMutableTensor(mean_name);
auto mean_dims = mean->dims();
CHECK_EQ(mean_dims.size(), 1);
auto var_name = op_info->Output("Variance").front();
auto var = scope->FindMutableTensor(var_name);
auto var_dims = var->dims();
CHECK_EQ(var_dims.size(), 1);
// Get op attributes
auto epsilon = op_info->GetAttr<float>("epsilon");
auto begin_norm_axis = op_info->GetAttr<int>("begin_norm_axis");
if (begin_norm_axis < 0) {
begin_norm_axis += x_rank;
}
CHECK_GT(begin_norm_axis, 0);
CHECK_LT(begin_norm_axis, x_rank);
CHECK(begin_norm_axis >= 1 && begin_norm_axis < x_rank);
auto matrix_dim = x_dims.Flatten2D(begin_norm_axis);
int batch_size = matrix_dim[0];
int feature_size = matrix_dim[1];
CHECK_EQ(mean_dims.production(), batch_size);
CHECK_EQ(var_dims.production(), batch_size);
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Get shape of bias and scale
DDim scale_bias_dims = x_dims.Slice(begin_norm_axis, x_dims.size());
CHECK_EQ(scale_bias_dims.production(), feature_size);
// auto scale_bias_dims = DDim({x_dims[x_dims.size()-1]});
// Bias node
std::shared_ptr<Node> bias_node = nullptr;
if (has_bias) {
auto bias_name = op_info->Input("Bias").front();
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
CHECK_EQ(bias_dims.size(), 1);
CHECK_EQ(bias_dims.production(), feature_size);
bias_node = graph->Add(bias_name, *bias, scale_bias_dims);
} else {
bias_node = graph->Add<float>(y_name + "/bias", 0.f, scale_bias_dims);
}
// Scale node
std::shared_ptr<Node> scale_node = nullptr;
if (has_scale) {
auto scale_name = op_info->Input("Scale").front();
auto scale = scope->FindMutableTensor(scale_name);
auto scale_dims = scale->dims();
CHECK_EQ(scale_dims.size(), 1);
CHECK_EQ(scale_dims.production(), feature_size);
scale_node = graph->Add(scale_name, *scale, scale_bias_dims);
} else {
scale_node = graph->Add<float>(y_name + "/scale", 1.f, scale_bias_dims);
}
// LayerNorm node
auto layer_norm_node = graph->Add<ge::op::LayerNorm>(y_name + "/layer_norm");
auto layer_norm_op = layer_norm_node->data<ge::op::LayerNorm>();
layer_norm_op->set_input_x(*x_node->data());
layer_norm_op->set_input_gamma(*scale_node->data());
layer_norm_op->set_input_beta(*bias_node->data());
layer_norm_op->set_attr_begin_norm_axis(begin_norm_axis);
layer_norm_op->set_attr_begin_params_axis(begin_norm_axis);
layer_norm_op->set_attr_epsilon(epsilon);
INPUT_UPDATE(layer_norm_op, x, x_node);
INPUT_UPDATE(layer_norm_op, gamma, scale_node);
INPUT_UPDATE(layer_norm_op, beta, bias_node);
OUTPUT_UPDATE(layer_norm_op, y, layer_norm_node);
OUTPUT_UPDATE(layer_norm_op, mean, layer_norm_node);
OUTPUT_UPDATE(layer_norm_op, variance, layer_norm_node);
// Get output of Y
auto out_y_node = graph->Add<ge::op::Identity>(y_name);
auto out_y_op = out_y_node->data<ge::op::Identity>();
out_y_op->set_input_x(*layer_norm_node->data(), "y");
INPUT_UPDATE(out_y_op, x, layer_norm_node);
OUTPUT_UPDATE(out_y_op, y, out_y_node);
// Get output of Mean
auto out_mean_node = graph->Add<ge::op::Identity>(mean_name);
auto out_mean_op = out_mean_node->data<ge::op::Identity>();
out_mean_op->set_input_x(*layer_norm_node->data(), "mean");
INPUT_UPDATE(out_mean_op, x, layer_norm_node);
OUTPUT_UPDATE(out_mean_op, y, out_mean_node);
// Get output of Variance
auto out_var_node = graph->Add<ge::op::Identity>(var_name);
auto out_var_op = out_var_node->data<ge::op::Identity>();
out_var_op->set_input_x(*layer_norm_node->data(), "variance");
INPUT_UPDATE(out_var_op, x, layer_norm_node);
OUTPUT_UPDATE(out_var_op, y, out_var_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
layer_norm,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::LayerNormConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int MatMulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto x_dims = x->dims();
if (x_dims.size() < 2) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Input dims should be equal or large "
"than 2 in Huawei Ascend NPU DDK.";
return FAILED;
}
auto y_name = op_info->Input("Y").front();
auto y = scope->FindTensor(y_name);
auto y_dims = y->dims();
if (y_dims.size() < 2) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] Input dims should be equal or large "
"than 2 in Huawei Ascend NPU DDK.";
return FAILED;
}
if (x_dims.size() != y_dims.size()) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] dims size of input x1 and x2 must be "
"same in Huawei Ascend NPU DDK.";
return FAILED;
}
auto out_name = op_info->Output("Out").front();
auto out = scope->FindTensor(out_name);
auto out_dims = out->dims();
bool transpose_x = op_info->GetAttr<bool>("transpose_X");
bool transpose_y = op_info->GetAttr<bool>("transpose_Y");
float alpha = op_info->GetAttr<float>("alpha");
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
} else {
y_node = graph->Add(y_name, *y);
}
// Matmul node
std::shared_ptr<Node> matmul_node = nullptr;
if (x_dims.size() == 2) {
matmul_node = graph->Add<ge::op::MatMul>(out_name);
auto matmul_op = matmul_node->data<ge::op::MatMul>();
matmul_op->set_input_x1(*x_node->data());
matmul_op->set_input_x2(*y_node->data());
matmul_op->set_attr_transpose_x1(transpose_x);
matmul_op->set_attr_transpose_x2(transpose_y);
INPUT_UPDATE(matmul_op, x1, x_node);
INPUT_UPDATE(matmul_op, x2, y_node);
OUTPUT_UPDATE(matmul_op, y, matmul_node);
} else {
matmul_node = graph->Add<ge::op::BatchMatMul>(out_name);
auto matmul_op = matmul_node->data<ge::op::BatchMatMul>();
matmul_op->set_input_x1(*x_node->data());
matmul_op->set_input_x2(*y_node->data());
matmul_op->set_attr_adj_x1(transpose_x);
matmul_op->set_attr_adj_x2(transpose_y);
INPUT_UPDATE(matmul_op, x1, x_node);
INPUT_UPDATE(matmul_op, x2, y_node);
OUTPUT_UPDATE(matmul_op, y, matmul_node);
}
if (fabs(alpha - 1.f) > 1e-6f) {
auto scale_node = graph->Add<ge::op::Muls>(out_name);
auto scale_op = scale_node->data<ge::op::Muls>();
scale_op->set_input_x(*matmul_node->data());
scale_op->set_attr_value(alpha);
INPUT_UPDATE(scale_op, x, matmul_node);
OUTPUT_UPDATE(scale_op, y, scale_node);
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
matmul,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::MatMulConverter);
...@@ -42,3 +42,15 @@ USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU); ...@@ -42,3 +42,15 @@ USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softmax, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(softmax, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(dropout, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(dropout, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fc, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(fc, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(reshape, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(reshape2, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(transpose, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(transpose2, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(flatten, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(flatten2, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(layer_norm, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(matmul, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(cast, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(scale, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(slice, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(gather, kHuaweiAscendNPU);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Shape Const node
if (op_info->HasInput("ShapeTensor")) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] not support \"Shape\" from more than "
"one Tensor.";
return FAILED;
}
std::shared_ptr<Node> actual_shape_node = nullptr;
if (op_info->HasInput("Shape")) {
auto actual_shape_name = op_info->Input("Shape").front();
if (graph->Has(actual_shape_name)) {
actual_shape_node = graph->Get(actual_shape_name);
} else {
auto actual_shape = scope->FindMutableTensor(actual_shape_name);
auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape =
std::vector<int>(actual_shape_data,
actual_shape_data + actual_shape_dims.production());
auto out_shape = lite::operators::ValidateShape(shape, x_dims);
actual_shape_node =
graph->Add<int>(actual_shape_name,
std::vector<int>(out_shape.begin(), out_shape.end()));
}
} else if (op_info->HasAttr("shape")) {
auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_shape = lite::operators::ValidateShape(shape, x_dims);
out_shape = CvtShape(out_shape);
actual_shape_node = graph->Add<int64_t>(
out_name + "/shape",
std::vector<int64_t>(out_shape.begin(), out_shape.end()));
}
// actual_shape_node should not be nullptr
CHECK(actual_shape_node);
// Reshape node
auto reshape_node = graph->Add<ge::op::Reshape>(out_name);
auto reshape_op = reshape_node->data<ge::op::Reshape>();
reshape_op->set_input_x(*x_node->data());
reshape_op->set_input_shape(*actual_shape_node->data());
INPUT_UPDATE(reshape_op, x, x_node);
INPUT_UPDATE(reshape_op, shape, actual_shape_node);
OUTPUT_UPDATE(reshape_op, y, reshape_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
reshape,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(
reshape2,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ReshapeConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
float scale = op_info->GetAttr<float>("scale");
float bias = op_info->GetAttr<float>("bias");
bool bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
if (!bias_after_scale) {
bias *= scale;
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// const node
auto input_scale_node =
graph->Add<float>(out_name + "/scale", scale, x_dims.Vectorize());
// scale node
auto scale_node = graph->Add<ge::op::Scale>(out_name);
auto scale_op = scale_node->data<ge::op::Scale>();
scale_op->set_input_x(*x_node->data());
scale_op->set_input_scale(*input_scale_node->data());
scale_op->set_attr_axis(0);
scale_op->set_attr_num_axes(-1);
scale_op->set_attr_scale_from_blob(true);
INPUT_UPDATE(scale_op, x, x_node);
INPUT_UPDATE(scale_op, scale, input_scale_node);
OUTPUT_UPDATE(scale_op, y, scale_node);
// Add bias node(fill with bias)
if (fabs(bias) > 1e-6f) {
auto bias_node = graph->Add(out_name + "/bias", bias, x_dims.Vectorize());
scale_op->set_input_bias(*bias_node->data());
INPUT_UPDATE(scale_op, bias, input_scale_node);
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
scale,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ScaleConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto input_name = op_info->Input("Input").front();
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
auto input_rank = static_cast<int>(input_dims.size());
std::vector<int64_t> input_shape = input_dims.Vectorize();
auto out_name = op_info->Output("Out").front();
auto axes = op_info->GetAttr<std::vector<int>>("axes");
auto starts = op_info->GetAttr<std::vector<int>>("starts");
auto ends = op_info->GetAttr<std::vector<int>>("ends");
CHECK_EQ(axes.size(), starts.size());
CHECK_EQ(axes.size(), ends.size());
// X node
std::shared_ptr<Node> input_node = nullptr;
if (graph->Has(input_name)) {
input_node = graph->Get(input_name);
} else {
input_node = graph->Add(input_name, *input);
}
// Get begin/offset based on axes and starts
std::vector<int> offset_vec(input_rank, 0);
std::vector<int> size_vec(input_shape.begin(), input_shape.end());
// Get begin/offset based on axes and starts
for (int i = 0; i < axes.size(); i++) {
auto axis = axes[i];
CHECK_LE(axis, input_rank)
<< "[HUAWEI_ASCEND_NPU] axes value should less than input rank.";
offset_vec[axis] = starts[i];
size_vec[axis] = ends[i] - starts[i];
}
// Cast node
auto slice_node = graph->Add<ge::op::SliceD>(out_name);
auto slice_op = slice_node->data<ge::op::SliceD>();
slice_op->set_input_x(*input_node->data());
slice_op->set_attr_offsets(
ge::Operator::OpListInt(offset_vec.begin(), offset_vec.end()));
slice_op->set_attr_size(
ge::Operator::OpListInt(size_vec.begin(), size_vec.end()));
INPUT_UPDATE(slice_op, x, input_node);
OUTPUT_UPDATE(slice_op, y, slice_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
slice,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::SliceConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Transpose node
auto transpose_node = graph->Add<ge::op::TransposeD>(out_name);
auto transpose_op = transpose_node->data<ge::op::TransposeD>();
transpose_op->set_input_x(*x_node->data());
transpose_op->set_attr_perm(
ge::Operator::OpListInt(axis.begin(), axis.end()));
INPUT_UPDATE(transpose_op, x, x_node);
OUTPUT_UPDATE(transpose_op, y, transpose_node);
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
transpose,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(
transpose2,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::TransposeConverter);
...@@ -66,7 +66,7 @@ ge::DataType CvtPrecisionType(PrecisionType itype); ...@@ -66,7 +66,7 @@ ge::DataType CvtPrecisionType(PrecisionType itype);
ge::Format CvtDataLayoutType(DataLayoutType itype); ge::Format CvtDataLayoutType(DataLayoutType itype);
// Padding the shape to 4-dimensions(NCHW) for HiAI // Padding the shape to 4-dimensions(NCHW) for Huawei Ascend NPU
std::vector<int64_t> CvtShape(const std::vector<int64_t>& in_shape); std::vector<int64_t> CvtShape(const std::vector<int64_t>& in_shape);
std::vector<int64_t> CvtShape(const DDim& in_dims); std::vector<int64_t> CvtShape(const DDim& in_dims);
......
...@@ -19,9 +19,9 @@ endif() ...@@ -19,9 +19,9 @@ endif()
if (NOT LITE_ON_TINY_PUBLISH) if (NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(compatible_pb SRCS compatible_pb.cc lite_cc_library(compatible_pb SRCS compatible_pb.cc
DEPS ${cpp_wrapper} ${naive_wrapper} ${pb_wrapper} framework_proto) DEPS ${cpp_wrapper} ${naive_wrapper} ${pb_wrapper} framework_proto fbs_io)
else() else()
lite_cc_library(compatible_pb SRCS compatible_pb.cc DEPS ${cpp_wrapper} ${naive_wrapper}) lite_cc_library(compatible_pb SRCS compatible_pb.cc DEPS ${cpp_wrapper} ${naive_wrapper} fbs_io)
endif() endif()
lite_cc_library(model_parser SRCS model_parser.cc DEPS lite_cc_library(model_parser SRCS model_parser.cc DEPS
......
...@@ -83,9 +83,9 @@ class VectorView { ...@@ -83,9 +83,9 @@ class VectorView {
operator std::vector<T>() const { operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp; std::vector<T> tmp;
tmp.reserve(size()); tmp.resize(size());
for (size_t i = 0; i < size(); ++i) { for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i)); tmp[i] = cvec_->operator[](i);
} }
return tmp; return tmp;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "lite/model_parser/compatible_pb.h" #include "lite/model_parser/compatible_pb.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/model_parser/flatbuffers/program_desc.h"
#include "lite/model_parser/naive_buffer/block_desc.h" #include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h" #include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h"
...@@ -73,6 +74,18 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>( ...@@ -73,6 +74,18 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
}*/ }*/
} }
template <>
void TransformVarDescAnyToCpp<fbs::VarDesc>(const fbs::VarDesc &any_desc,
cpp::VarDesc *cpp_desc) {
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}
}
/// For OpDesc transform /// For OpDesc transform
template <typename OpDescType> template <typename OpDescType>
void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
...@@ -219,100 +232,102 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { ...@@ -219,100 +232,102 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
} }
/// For BlockDesc transform /// For BlockDesc transform
#define TRANS_BLOCK_ANY_WITH_CPP_IMPL(T, NT, PNT) \ #define TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpT, VarT, NT, PNT) \
template <> \ template <> \
void TransformBlockDescAnyToCpp<NT::T>(const NT::T &any_desc, \ void TransformBlockDescAnyToCpp<NT::BlockDesc>( \
cpp::BlockDesc *cpp_desc) { \ const NT::BlockDesc &any_desc, cpp::BlockDesc *cpp_desc) { \
NT::T desc = any_desc; \ NT::BlockDesc &desc = const_cast<NT::BlockDesc &>(any_desc); \
cpp_desc->SetIdx(desc.Idx()); \ cpp_desc->SetIdx(desc.Idx()); \
cpp_desc->SetParentIdx(desc.ParentIdx()); \ cpp_desc->SetParentIdx(desc.ParentIdx()); \
cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
\ \
cpp_desc->ClearOps(); \ cpp_desc->ClearOps(); \
for (size_t i = 0; i < desc.OpsSize(); ++i) { \ for (size_t i = 0; i < desc.OpsSize(); ++i) { \
auto any_op_desc = NT::OpDesc(desc.GetOp<PNT::proto::OpDesc>(i)); \ auto any_op_desc = NT::OpDesc(desc.GetOp<PNT::proto::OpT>(i)); \
auto *cpp_op_desc = cpp_desc->AddOp<cpp::OpDesc>(); \ auto *cpp_op_desc = cpp_desc->AddOp<cpp::OpDesc>(); \
TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \ TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \
} \ } \
\ \
cpp_desc->ClearVars(); \ cpp_desc->ClearVars(); \
for (size_t i = 0; i < desc.VarsSize(); ++i) { \ for (size_t i = 0; i < desc.VarsSize(); ++i) { \
auto any_var_desc = NT::VarDesc(desc.GetVar<PNT::proto::VarDesc>(i)); \ auto any_var_desc = NT::VarDesc(desc.GetVar<PNT::proto::VarT>(i)); \
auto *cpp_var_desc = cpp_desc->AddVar<cpp::VarDesc>(); \ auto *cpp_var_desc = cpp_desc->AddVar<cpp::VarDesc>(); \
TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \ TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \
} \ } \
} \ } \
\ \
template <> \ template <> \
void TransformBlockDescCppToAny<NT::T>(const cpp::T &cpp_desc, \ void TransformBlockDescCppToAny<NT::BlockDesc>( \
NT::T *any_desc) { \ const cpp::BlockDesc &cpp_desc, NT::BlockDesc *any_desc) { \
const cpp::T &desc = cpp_desc; \ const cpp::BlockDesc &desc = cpp_desc; \
any_desc->SetIdx(desc.Idx()); \ any_desc->SetIdx(desc.Idx()); \
any_desc->SetParentIdx(desc.ParentIdx()); \ any_desc->SetParentIdx(desc.ParentIdx()); \
any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
\ \
any_desc->ClearOps(); \ any_desc->ClearOps(); \
for (size_t i = 0; i < desc.OpsSize(); ++i) { \ for (size_t i = 0; i < desc.OpsSize(); ++i) { \
auto *cpp_op_desc = desc.GetOp<cpp::OpDesc>(i); \ auto *cpp_op_desc = desc.GetOp<cpp::OpDesc>(i); \
auto any_op_desc = NT::OpDesc(any_desc->AddOp<PNT::proto::OpDesc>()); \ auto any_op_desc = NT::OpDesc(any_desc->AddOp<PNT::proto::OpT>()); \
TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \ TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \
} \ } \
\ \
any_desc->ClearVars(); \ any_desc->ClearVars(); \
for (size_t i = 0; i < desc.VarsSize(); ++i) { \ for (size_t i = 0; i < desc.VarsSize(); ++i) { \
auto *cpp_var_desc = desc.GetVar<cpp::VarDesc>(i); \ auto *cpp_var_desc = desc.GetVar<cpp::VarDesc>(i); \
auto any_var_desc = \ auto any_var_desc = NT::VarDesc(any_desc->AddVar<PNT::proto::VarT>()); \
NT::VarDesc(any_desc->AddVar<PNT::proto::VarDesc>()); \ TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \
TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \ } \
} \
} }
/// For ProgramDesc transform /// For ProgramDesc transform
#define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(T, NT, PNT) \ #define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockT, NT, PNT) \
template <> \ template <> \
void TransformProgramDescAnyToCpp<NT::T>(const NT::T &any_desc, \ void TransformProgramDescAnyToCpp<NT::ProgramDesc>( \
cpp::ProgramDesc *cpp_desc) { \ const NT::ProgramDesc &any_desc, cpp::ProgramDesc *cpp_desc) { \
NT::T desc = any_desc; \ NT::ProgramDesc &desc = const_cast<NT::ProgramDesc &>(any_desc); \
if (desc.HasVersion()) { \ if (desc.HasVersion()) { \
cpp_desc->SetVersion(desc.Version()); \ cpp_desc->SetVersion(desc.Version()); \
} \ } \
\ \
cpp_desc->ClearBlocks(); \ cpp_desc->ClearBlocks(); \
for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ for (size_t i = 0; i < desc.BlocksSize(); ++i) { \
auto any_block_desc = \ NT::BlockDesc any_block_desc(desc.GetBlock<PNT::proto::BlockT>(i)); \
NT::BlockDesc(desc.GetBlock<PNT::proto::BlockDesc>(i)); \ auto *cpp_block_desc = cpp_desc->AddBlock<cpp::BlockDesc>(); \
auto *cpp_block_desc = cpp_desc->AddBlock<cpp::BlockDesc>(); \ TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \
TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \ } \
} \ } \
} \ \
\ template <> \
template <> \ void TransformProgramDescCppToAny<NT::ProgramDesc>( \
void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \ const cpp::ProgramDesc &cpp_desc, NT::ProgramDesc *any_desc) { \
NT::T *any_desc) { \ auto &desc = cpp_desc; \
auto &desc = cpp_desc; \ if (desc.HasVersion()) { \
if (desc.HasVersion()) { \ any_desc->SetVersion(desc.Version()); \
any_desc->SetVersion(desc.Version()); \ } \
} \ \
\ any_desc->ClearBlocks(); \
any_desc->ClearBlocks(); \ for (size_t i = 0; i < desc.BlocksSize(); ++i) { \
for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ auto *cpp_block_desc = desc.GetBlock<cpp::BlockDesc>(i); \
auto *cpp_block_desc = desc.GetBlock<cpp::BlockDesc>(i); \ NT::BlockDesc any_block_desc(any_desc->AddBlock<PNT::proto::BlockT>()); \
auto any_block_desc = \ TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \
NT::BlockDesc(any_desc->AddBlock<PNT::proto::BlockDesc>()); \ } \
TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \
} \
} }
TRANS_VAR_ANY_WITH_CPP_IMPL(naive_buffer::VarDesc); TRANS_VAR_ANY_WITH_CPP_IMPL(naive_buffer::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(naive_buffer::OpDesc); TRANS_OP_ANY_WITH_CPP_IMPL(naive_buffer::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer); TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDesc, VarDesc, naive_buffer, naive_buffer);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, naive_buffer, naive_buffer); TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer);
TRANS_VAR_ANY_WITH_CPP_IMPL(fbs::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(fbs::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDescT, VarDescT, fbs, fbs);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDescT, fbs, fbs);
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc); TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(pb::OpDesc); TRANS_OP_ANY_WITH_CPP_IMPL(pb::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework); TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDesc, VarDesc, pb, framework);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, pb, framework); TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework);
#endif #endif
#undef TRANS_VAR_ANY_WITH_CPP_IMPL #undef TRANS_VAR_ANY_WITH_CPP_IMPL
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "lite/model_parser/compatible_pb.h" #include "lite/model_parser/compatible_pb.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "lite/model_parser/cpp_desc.h" #include "lite/model_parser/cpp_desc.h"
#include "lite/model_parser/flatbuffers/program_desc.h"
#include "lite/model_parser/flatbuffers/test_helper.h"
#include "lite/model_parser/naive_buffer/block_desc.h" #include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h" #include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h"
...@@ -430,5 +432,14 @@ TEST(ProgramDesc, AnyToCpp) { ...@@ -430,5 +432,14 @@ TEST(ProgramDesc, AnyToCpp) {
TestProgramAnyToCpp<naive_buffer::ProgramDesc>(&nb_desc); TestProgramAnyToCpp<naive_buffer::ProgramDesc>(&nb_desc);
} }
TEST(ProgramDesc, FbsCpp) {
fbs::ProgramDesc fbs_program(fbs::test::GenerateProgramCache());
cpp::ProgramDesc cpp_program;
TransformProgramDescAnyToCpp(fbs_program, &cpp_program);
fbs::ProgramDesc fbs_program_2;
TransformProgramDescCppToAny(cpp_program, &fbs_program_2);
fbs::test::CheckProgramCache(&fbs_program_2);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -21,52 +21,52 @@ namespace fbs { ...@@ -21,52 +21,52 @@ namespace fbs {
template <> template <>
proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const { proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return desc_->vars()->Get(idx); return desc_->vars()->Get(idx);
} }
template <> template <>
proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const { proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
return desc_->ops()->Get(idx); return desc_->ops()->Get(idx);
} }
template <> template <>
VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const { VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return &vars_[idx]; return &vars_[idx];
} }
template <> template <>
OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const { OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
return &ops_[idx]; return &ops_[idx];
} }
template <> template <>
proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) { proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return vars_[idx].raw_desc(); return vars_[idx]->raw_desc();
} }
template <> template <>
proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() { proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() {
desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT)); desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT));
SyncVars(); SyncVars();
return vars_.back().raw_desc(); return vars_.back()->raw_desc();
} }
template <> template <>
proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) { proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= vars.size()";
return ops_[idx].raw_desc(); return ops_[idx]->raw_desc();
} }
template <> template <>
proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() { proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() {
desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT)); desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT));
SyncOps(); SyncOps();
return ops_.back().raw_desc(); return ops_.back()->raw_desc();
} }
} // namespace fbs } // namespace fbs
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/model_parser/flatbuffers/framework_generated.h"
...@@ -29,13 +30,13 @@ class BlockDescView : public BlockDescAPI { ...@@ -29,13 +30,13 @@ class BlockDescView : public BlockDescAPI {
public: public:
explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) { explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_); CHECK(desc_);
vars_.reserve(VarsSize()); vars_.resize(VarsSize());
ops_.reserve(OpsSize()); ops_.resize(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) { for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDescView(desc_->vars()->Get(idx))); vars_[idx] = VarDescView(desc_->vars()->Get(idx));
} }
for (size_t idx = 0; idx < OpsSize(); ++idx) { for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDescView(desc_->ops()->Get(idx))); ops_[idx] = OpDescView(desc_->ops()->Get(idx));
} }
} }
...@@ -75,7 +76,7 @@ class BlockDescView : public BlockDescAPI { ...@@ -75,7 +76,7 @@ class BlockDescView : public BlockDescAPI {
return desc_->forward_block_idx(); return desc_->forward_block_idx();
} }
BlockDescView() { NotImplemented(); } BlockDescView() = default;
private: private:
proto::BlockDesc const* desc_; // not_own proto::BlockDesc const* desc_; // not_own
...@@ -150,24 +151,24 @@ class BlockDesc : public BlockDescAPI { ...@@ -150,24 +151,24 @@ class BlockDesc : public BlockDescAPI {
void SyncVars() { void SyncVars() {
vars_.resize(desc_->vars.size()); vars_.resize(desc_->vars.size());
for (size_t i = 0; i < desc_->vars.size(); ++i) { for (size_t i = 0; i < desc_->vars.size(); ++i) {
if (vars_[i].raw_desc() != desc_->vars[i].get()) { if (!vars_[i] || vars_[i]->raw_desc() != desc_->vars[i].get()) {
vars_[i] = VarDesc(desc_->vars[i].get()); vars_[i].reset(new VarDesc(desc_->vars[i].get()));
} }
} }
} }
void SyncOps() { void SyncOps() {
ops_.resize(desc_->ops.size()); ops_.resize(desc_->ops.size());
for (size_t i = 0; i < desc_->ops.size(); ++i) { for (size_t i = 0; i < desc_->ops.size(); ++i) {
if (ops_[i].raw_desc() != desc_->ops[i].get()) { if (!ops_[i] || ops_[i]->raw_desc() != desc_->ops[i].get()) {
ops_[i] = OpDesc(desc_->ops[i].get()); ops_[i].reset(new OpDesc(desc_->ops[i].get()));
} }
} }
} }
bool owned_{false}; bool owned_{false};
proto::BlockDescT* desc_{nullptr}; proto::BlockDescT* desc_{nullptr};
std::vector<VarDesc> vars_; std::vector<std::unique_ptr<VarDesc>> vars_;
std::vector<OpDesc> ops_; std::vector<std::unique_ptr<OpDesc>> ops_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -25,11 +25,12 @@ namespace fbs { ...@@ -25,11 +25,12 @@ namespace fbs {
std::vector<char> LoadFile(const std::string& path) { std::vector<char> LoadFile(const std::string& path) {
FILE* file = fopen(path.c_str(), "rb"); FILE* file = fopen(path.c_str(), "rb");
CHECK(file);
fseek(file, 0, SEEK_END); fseek(file, 0, SEEK_END);
int64_t length = ftell(file); uint64_t length = ftell(file);
rewind(file); rewind(file);
std::vector<char> buf(length); std::vector<char> buf(length);
CHECK(fread(buf.data(), 1, length, file) == length); CHECK_EQ(fread(buf.data(), 1, length, file), length);
fclose(file); fclose(file);
return buf; return buf;
} }
...@@ -37,6 +38,7 @@ std::vector<char> LoadFile(const std::string& path) { ...@@ -37,6 +38,7 @@ std::vector<char> LoadFile(const std::string& path) {
void SaveFile(const std::string& path, const void* src, size_t byte_size) { void SaveFile(const std::string& path, const void* src, size_t byte_size) {
CHECK(src); CHECK(src);
FILE* file = fopen(path.c_str(), "wb"); FILE* file = fopen(path.c_str(), "wb");
CHECK(file);
CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size); CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size);
fclose(file); fclose(file);
} }
...@@ -60,7 +62,7 @@ void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) { ...@@ -60,7 +62,7 @@ void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) {
} }
void SetCombinedParamsWithScope(const lite::Scope& scope, void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name, const std::set<std::string>& params_name,
CombinedParamsDescWriteAPI* params) { CombinedParamsDescWriteAPI* params) {
for (const auto& name : params_name) { for (const auto& name : params_name) {
auto* param = params->AddParamDesc(); auto* param = params->AddParamDesc();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/scope.h" #include "lite/core/scope.h"
...@@ -30,8 +31,9 @@ void SaveFile(const std::string& path, const void* src, size_t byte_size); ...@@ -30,8 +31,9 @@ void SaveFile(const std::string& path, const void* src, size_t byte_size);
void SetScopeWithCombinedParams(lite::Scope* scope, void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params); const CombinedParamsDescReadAPI& params);
void SetCombinedParamsWithScope(const lite::Scope& scope, void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name, const std::set<std::string>& params_name,
CombinedParamsDescWriteAPI* params); CombinedParamsDescWriteAPI* params);
} // namespace fbs } // namespace fbs
......
...@@ -32,7 +32,7 @@ void set_tensor(paddle::lite::Tensor* tensor, ...@@ -32,7 +32,7 @@ void set_tensor(paddle::lite::Tensor* tensor,
tensor->Resize(dims); tensor->Resize(dims);
std::vector<T> data; std::vector<T> data;
data.resize(production); data.resize(production);
for (size_t i = 0; i < production; ++i) { for (int i = 0; i < production; ++i) {
data[i] = i / 2.f; data[i] = i / 2.f;
} }
std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size()); std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size());
...@@ -53,7 +53,8 @@ TEST(CombinedParamsDesc, Scope) { ...@@ -53,7 +53,8 @@ TEST(CombinedParamsDesc, Scope) {
set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1})); set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1}));
// Set combined parameters // Set combined parameters
fbs::CombinedParamsDesc combined_param; fbs::CombinedParamsDesc combined_param;
SetCombinedParamsWithScope(scope, params_name, &combined_param); std::set<std::string> params_set(params_name.begin(), params_name.end());
SetCombinedParamsWithScope(scope, params_set, &combined_param);
/* --------- Check scope ---------- */ /* --------- Check scope ---------- */
auto check_params = [&](const CombinedParamsDescReadAPI& desc) { auto check_params = [&](const CombinedParamsDescReadAPI& desc) {
......
...@@ -19,8 +19,8 @@ namespace lite { ...@@ -19,8 +19,8 @@ namespace lite {
namespace fbs { namespace fbs {
template <> template <>
std::string OpDescView::GetAttr<std::string>(const std::string& name) const { std::string OpDescView::GetAttr<std::string>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); const auto& it = desc_->attrs()->LookupByKey(name);
if (!it->s()) { if (!it->s()) {
return std::string(); return std::string();
} }
...@@ -28,56 +28,48 @@ std::string OpDescView::GetAttr<std::string>(const std::string& name) const { ...@@ -28,56 +28,48 @@ std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
} }
template <> template <>
std::string OpDescView::GetAttr<std::string>(size_t idx) const { std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->Get(idx); return GetAttr<std::string>(name.c_str());
if (!it->s()) {
return std::string();
}
return it->s()->str();
} }
template <> template <>
lite::VectorView<std::string, Flatbuffers> lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const { OpDescView::GetAttr<std::vector<std::string>>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); const auto& it = desc_->attrs()->LookupByKey(name);
CHECK(it) << "Attr " << name << "does not exist."; CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings()); return VectorView<std::string>(it->strings());
} }
template <> template <>
VectorView<std::string, Flatbuffers> lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const { OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->Get(idx); return GetAttr<std::vector<std::string>>(name.c_str());
CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings());
} }
#define GET_ATTR_IMPL(T, fb_f__) \ #define GET_ATTR_IMPL(T, fb_f__) \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \ const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ const auto& it = desc_->attrs()->LookupByKey(name); \
return it->fb_f__(); \ return it->fb_f__(); \
} \ } \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \ const std::string& name) const { \
const auto& it = desc_->attrs()->Get(idx); \ return GetAttr<T>(name.c_str()); \
return it->fb_f__(); \
} }
#define GET_ATTRS_IMPL(T, fb_f__) \ #define GET_ATTRS_IMPL(T, fb_f__) \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \ const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ const auto& it = desc_->attrs()->LookupByKey(name); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \ return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \ } \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \ const std::string& name) const { \
const auto& it = desc_->attrs()->Get(idx); \ return GetAttr<T>(name.c_str()); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} }
GET_ATTR_IMPL(int32_t, i); GET_ATTR_IMPL(int32_t, i);
...@@ -103,6 +95,7 @@ GET_ATTRS_IMPL(std::vector<int64_t>, longs); ...@@ -103,6 +95,7 @@ GET_ATTRS_IMPL(std::vector<int64_t>, longs);
new proto::OpDesc_::AttrT())), \ new proto::OpDesc_::AttrT())), \
&(desc_->attrs)); \ &(desc_->attrs)); \
p->fb_f__ = v; \ p->fb_f__ = v; \
p->type = ConvertAttrType(OpDataTypeTrait<T>::AT); \
SetKey(name, &p); \ SetKey(name, &p); \
} }
ATTR_IMPL(int32_t, i); ATTR_IMPL(int32_t, i);
......
...@@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI { ...@@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI {
std::string Type() const override { return desc_->type()->str(); } std::string Type() const override { return desc_->type()->str(); }
// Get the arguments of parameter called `param` std::vector<std::string> Input(const char* param) const {
std::vector<std::string> Input(const std::string& param) const override { const auto& var = desc_->inputs()->LookupByKey(param);
const auto& var = desc_->inputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec; std::vector<std::string> args_vec;
if (var->arguments()) { if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size()); args_vec.resize(var->arguments()->size());
for (const auto& in : *var->arguments()) { for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec.push_back(in->str()); args_vec[i] = (*var->arguments())[i]->str();
} }
} }
return args_vec; return args_vec;
} }
std::vector<std::string> Input(const std::string& param) const override {
return Input(param.c_str());
}
std::vector<std::string> InputArgumentNames() const override { std::vector<std::string> InputArgumentNames() const override {
const auto& vars = desc_->inputs(); const auto& vars = desc_->inputs();
std::vector<std::string> input_names_vec; std::vector<std::string> input_names_vec;
if (vars) { if (vars) {
input_names_vec.reserve(vars->size()); input_names_vec.resize(vars->size());
for (const auto& in : *vars) { for (size_t i = 0; i < vars->size(); ++i) {
input_names_vec.push_back(in->parameter()->str()); input_names_vec[i] = (*vars)[i]->parameter()->str();
} }
} }
return input_names_vec; return input_names_vec;
} }
std::vector<std::string> Output(const std::string& param) const override { std::vector<std::string> Output(const char* param) const {
const auto& var = desc_->outputs()->LookupByKey(param.c_str()); const auto& var = desc_->outputs()->LookupByKey(param);
std::vector<std::string> args_vec; std::vector<std::string> args_vec;
if (var && var->arguments()) { if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size()); args_vec.resize(var->arguments()->size());
for (const auto& out : *var->arguments()) { for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec.push_back(out->str()); args_vec[i] = (*var->arguments())[i]->str();
} }
} }
return args_vec; return args_vec;
} }
std::vector<std::string> Output(const std::string& param) const override {
return Output(param.c_str());
}
std::vector<std::string> OutputArgumentNames() const override { std::vector<std::string> OutputArgumentNames() const override {
const auto& vars = desc_->outputs(); const auto& vars = desc_->outputs();
std::vector<std::string> output_names_vec; std::vector<std::string> output_names_vec;
if (vars) { if (vars) {
output_names_vec.reserve(vars->size()); output_names_vec.resize(vars->size());
for (const auto& out : *vars) { for (size_t i = 0; i < vars->size(); ++i) {
output_names_vec.push_back(out->parameter()->str()); output_names_vec[i] = (*vars)[i]->parameter()->str();
} }
} }
return output_names_vec; return output_names_vec;
} }
bool HasAttr(const char* name) const {
return desc_->attrs()->LookupByKey(name) != nullptr;
}
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) != nullptr; return HasAttr(name.c_str());
} }
size_t AttrsSize() const { return desc_->attrs()->size(); } size_t AttrsSize() const { return desc_->attrs()->size(); }
...@@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI { ...@@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI {
return desc_->attrs()->Get(idx)->name()->str(); return desc_->attrs()->Get(idx)->name()->str();
} }
OpDescAPI::AttrType GetAttrType(const std::string& name) const override { OpDescAPI::AttrType GetAttrType(const char* name) const {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); const auto& attr = desc_->attrs()->LookupByKey(name);
CHECK(attr) << "Can not find attr: " << name; CHECK(attr) << "Can not find attr: " << name;
return ConvertAttrType(attr->type()); return ConvertAttrType(attr->type());
} }
OpDescAPI::AttrType GetAttrType(size_t idx) const { OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->Get(idx); return GetAttrType(name.c_str());
CHECK(attr);
return ConvertAttrType(attr->type());
} }
std::vector<std::string> AttrNames() const override { std::vector<std::string> AttrNames() const override {
const auto& attrs = desc_->attrs(); const auto& attrs = desc_->attrs();
std::vector<std::string> attr_names_vec; std::vector<std::string> attr_names_vec;
if (attrs) { if (attrs) {
attr_names_vec.reserve(attrs->size()); attr_names_vec.resize(attrs->size());
for (const auto& attr : *attrs) { for (size_t i = 0; i < attrs->size(); ++i) {
attr_names_vec.push_back(attr->name()->str()); attr_names_vec[i] = (*attrs)[i]->name()->str();
} }
} }
return attr_names_vec; return attr_names_vec;
...@@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI { ...@@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI {
template <typename T> template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr( typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const; const char* name) const;
template <typename T> template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const; typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const;
private: private:
proto::OpDesc const* desc_; proto::OpDesc const* desc_;
...@@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI { ...@@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI {
// caused by different building options. // caused by different building options.
public: public:
OpDescView() { NotImplemented(); } OpDescView() = default;
bool HasInput(const std::string& param) const { bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr; return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
} }
......
...@@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI { ...@@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI {
std::vector<int64_t> Dim() const override { std::vector<int64_t> Dim() const override {
const auto& dims = tensor_desc_->dim(); const auto& dims = tensor_desc_->dim();
std::vector<int64_t> dims_vec; std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size()); dims_vec.resize(dims->size());
for (const auto& dim : *dims) { for (size_t i = 0; i < dims->size(); ++i) {
dims_vec.push_back(dim); dims_vec[i] = dims->operator[](i);
} }
return dims_vec; return dims_vec;
} }
...@@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI { ...@@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI {
size_t byte_size() const override { return tensor_desc_->data()->size(); } size_t byte_size() const override { return tensor_desc_->data()->size(); }
ParamDescView() = delete; ParamDescView() = default;
private: private:
proto::ParamDesc const* desc_; proto::ParamDesc const* desc_;
...@@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI { ...@@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI {
void InitParams() { void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data()); desc_ = proto::GetCombinedParamsDesc(buf_.data());
size_t params_size = desc_->params()->size(); size_t params_size = desc_->params()->size();
params_.reserve(params_size); params_.resize(params_size);
for (size_t idx = 0; idx < params_size; ++idx) { for (size_t idx = 0; idx < params_size; ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx))); params_[idx] = ParamDescView(desc_->params()->Get(idx));
} }
} }
...@@ -115,7 +115,11 @@ class ParamDesc : public ParamDescAPI { ...@@ -115,7 +115,11 @@ class ParamDesc : public ParamDescAPI {
} }
explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) { explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT()); if (desc_->variable.type == proto::ParamDesc_::VariableDesc_NONE) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT());
}
CHECK(desc_->variable.type ==
proto::ParamDesc_::VariableDesc_LoDTensorDesc);
lod_tensor_ = desc_->variable.AsLoDTensorDesc(); lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_); CHECK(lod_tensor_);
} }
...@@ -169,7 +173,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -169,7 +173,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
} }
const ParamDescReadAPI* GetParamDesc(size_t idx) const override { const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
return &params_[idx]; return params_[idx].get();
} }
size_t GetParamsSize() const override { return desc_.params.size(); } size_t GetParamsSize() const override { return desc_.params.size(); }
...@@ -178,7 +182,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -178,7 +182,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
desc_.params.push_back( desc_.params.push_back(
std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT)); std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT));
SyncParams(); SyncParams();
return &params_[params_.size() - 1]; return params_[params_.size() - 1].get();
} }
const void* data() { const void* data() {
...@@ -195,8 +199,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -195,8 +199,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
void SyncParams() { void SyncParams() {
params_.resize(GetParamsSize()); params_.resize(GetParamsSize());
for (size_t i = 0; i < GetParamsSize(); ++i) { for (size_t i = 0; i < GetParamsSize(); ++i) {
if (params_[i].raw_desc() != desc_.params[i].get()) { if (!params_[i] || params_[i]->raw_desc() != desc_.params[i].get()) {
params_[i] = ParamDesc(desc_.params[i].get()); params_[i].reset(new ParamDesc(desc_.params[i].get()));
} }
} }
} }
...@@ -212,7 +216,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -212,7 +216,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
flatbuffers::DetachedBuffer buf_; flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_; flatbuffers::FlatBufferBuilder fbb_;
proto::CombinedParamsDescT desc_; proto::CombinedParamsDescT desc_;
std::vector<ParamDesc> params_; std::vector<std::unique_ptr<ParamDesc>> params_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -21,21 +21,21 @@ namespace fbs { ...@@ -21,21 +21,21 @@ namespace fbs {
template <> template <>
proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>( proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>(
int32_t idx) const { int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
return desc_->blocks()->Get(idx); return desc_->blocks()->Get(idx);
} }
template <> template <>
BlockDescView const* ProgramDescView::GetBlock<BlockDescView>( BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
int32_t idx) const { int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
return &blocks_[idx]; return &blocks_[idx];
} }
template <> template <>
proto::BlockDescT* ProgramDesc::GetBlock<proto::BlockDescT>(int32_t idx) { proto::BlockDescT* ProgramDesc::GetBlock<proto::BlockDescT>(int32_t idx) {
CHECK_LT(idx, BlocksSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= vars.size()";
return blocks_[idx].raw_desc(); return blocks_[idx]->raw_desc();
} }
template <> template <>
...@@ -43,7 +43,7 @@ proto::BlockDescT* ProgramDesc::AddBlock<proto::BlockDescT>() { ...@@ -43,7 +43,7 @@ proto::BlockDescT* ProgramDesc::AddBlock<proto::BlockDescT>() {
desc_.blocks.push_back( desc_.blocks.push_back(
std::unique_ptr<proto::BlockDescT>(new proto::BlockDescT)); std::unique_ptr<proto::BlockDescT>(new proto::BlockDescT));
SyncBlocks(); SyncBlocks();
return blocks_.back().raw_desc(); return blocks_.back()->raw_desc();
} }
} // namespace fbs } // namespace fbs
......
...@@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI { ...@@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI {
void InitProgramDesc() { void InitProgramDesc() {
desc_ = proto::GetProgramDesc(buf_.data()); desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize()); blocks_.resize(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) { for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx))); blocks_[idx] = BlockDescView(desc_->blocks()->Get(idx));
} }
} }
...@@ -150,8 +150,8 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -150,8 +150,8 @@ class ProgramDesc : public ProgramDescAPI {
void SyncBlocks() { void SyncBlocks() {
blocks_.resize(desc_.blocks.size()); blocks_.resize(desc_.blocks.size());
for (size_t i = 0; i < desc_.blocks.size(); ++i) { for (size_t i = 0; i < desc_.blocks.size(); ++i) {
if (blocks_[i].raw_desc() != desc_.blocks[i].get()) { if (!blocks_[i] || blocks_[i]->raw_desc() != desc_.blocks[i].get()) {
blocks_[i] = BlockDesc(desc_.blocks[i].get()); blocks_[i].reset(new BlockDesc(desc_.blocks[i].get()));
} }
} }
} }
...@@ -167,7 +167,7 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -167,7 +167,7 @@ class ProgramDesc : public ProgramDescAPI {
flatbuffers::DetachedBuffer buf_; flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_; flatbuffers::FlatBufferBuilder fbb_;
proto::ProgramDescT desc_; proto::ProgramDescT desc_;
std::vector<BlockDesc> blocks_; std::vector<std::unique_ptr<BlockDesc>> blocks_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -15,136 +15,22 @@ ...@@ -15,136 +15,22 @@
#include "lite/model_parser/flatbuffers/program_desc.h" #include "lite/model_parser/flatbuffers/program_desc.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
#include "lite/model_parser/flatbuffers/test_helper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
namespace {
std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
} // namespace
TEST(ProgramDesc, LoadTest) { TEST(ProgramDesc, LoadTest) {
ProgramDesc program(GenerateProgramCache()); ProgramDesc program(test::GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600); test::CheckProgramCache(&program);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
auto block_a = BlockDesc(program.GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
auto block_b = BlockDesc(program.GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
} }
TEST(ProgramDescView, LoadTest) { TEST(ProgramDescView, LoadTest) {
const ProgramDescView program(GenerateProgramCache()); const ProgramDescView program(test::GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600); test::CheckProgramCache(program);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
} }
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
namespace test {
inline std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
inline void CheckProgramCache(ProgramDesc* program) {
CHECK_EQ(program->Version(), 1000600);
CHECK_EQ(program->BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
BlockDesc block_a(program->GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), static_cast<size_t>(1));
CHECK_EQ(block_a.VarsSize(), static_cast<size_t>(2));
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
BlockDesc block_b(program->GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), static_cast<size_t>(2));
CHECK_EQ(block_b.VarsSize(), static_cast<size_t>(1));
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
inline void CheckProgramCache(const ProgramDescView& program) {
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), static_cast<size_t>(1));
CHECK_EQ(block_a.VarsSize(), static_cast<size_t>(2));
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), static_cast<size_t>(2));
CHECK_EQ(block_b.VarsSize(), static_cast<size_t>(1));
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
} // namespace test
} // namespace fbs
} // namespace lite
} // namespace paddle
...@@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI { ...@@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR); CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
const auto& dims = desc_->type()->lod_tensor()->tensor()->dims(); const auto& dims = desc_->type()->lod_tensor()->tensor()->dims();
std::vector<int64_t> dims_vec; std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size()); dims_vec.resize(dims->size());
for (const auto& dim : *dims) { for (size_t i = 0; i < dims->size(); ++i) {
dims_vec.push_back(dim); dims_vec[i] = dims->operator[](i);
} }
return dims_vec; return dims_vec;
} }
...@@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI { ...@@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI {
// caused by different building options. // caused by different building options.
public: public:
VarDescView() { NotImplemented(); } VarDescView() = default;
void SetDataType(Type data_type) { NotImplemented(); } void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); } void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
...@@ -93,9 +93,14 @@ class VarDesc : public VarDescAPI { ...@@ -93,9 +93,14 @@ class VarDesc : public VarDescAPI {
Type GetType() const override { return ConvertVarType(type_->type); } Type GetType() const override { return ConvertVarType(type_->type); }
void SetType(Type type) override { void SetType(Type type) override { type_->type = ConvertVarType(type); }
CHECK(type == VarDescAPI::Type::LOD_TENSOR);
type_->type = ConvertVarType(type); void SetDataType(Type type) {
type_->lod_tensor->tensor->data_type = ConvertVarType(type);
}
Type GetDataType() const {
return ConvertVarType(type_->lod_tensor->tensor->data_type);
} }
bool Persistable() const override { return desc_->persistable; } bool Persistable() const override { return desc_->persistable; }
......
...@@ -127,9 +127,9 @@ class VectorView<std::string, Flatbuffers> { ...@@ -127,9 +127,9 @@ class VectorView<std::string, Flatbuffers> {
operator std::vector<std::string>() const { operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp; std::vector<std::string> tmp;
tmp.reserve(size()); tmp.resize(size());
for (size_t i = 0; i < size(); ++i) { for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i)->str()); tmp[i] = cvec_->operator[](i)->str();
} }
return tmp; return tmp;
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <fstream> #include <fstream>
#include <limits> #include <limits>
#include <set> #include <set>
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/core/variable.h" #include "lite/core/variable.h"
...@@ -27,6 +28,7 @@ ...@@ -27,6 +28,7 @@
#include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h"
#include "lite/model_parser/naive_buffer/var_desc.h" #include "lite/model_parser/naive_buffer/var_desc.h"
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
#include "lite/model_parser/flatbuffers/io.h"
#include "lite/model_parser/pb/program_desc.h" #include "lite/model_parser/pb/program_desc.h"
#include "lite/model_parser/pb/var_desc.h" #include "lite/model_parser/pb/var_desc.h"
#endif #endif
...@@ -592,7 +594,54 @@ void SaveModelNaive(const std::string &model_dir, ...@@ -592,7 +594,54 @@ void SaveModelNaive(const std::string &model_dir,
LOG(INFO) << "Save naive buffer model in '" << model_dir LOG(INFO) << "Save naive buffer model in '" << model_dir
<< ".nb' successfully"; << ".nb' successfully";
} }
#endif
/* ---------- Flatbuffers ---------- */
void SaveModelFbs(const std::string &model_dir,
const Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
/* 1. Save model to model.fbs */
const std::string prog_path = model_dir + "/model.fbs";
fbs::ProgramDesc fbs_prog;
TransformProgramDescCppToAny(cpp_prog, &fbs_prog);
fbs::SaveFile(prog_path, fbs_prog.data(), fbs_prog.buf_size());
/* 2. Get param names from cpp::ProgramDesc */
auto &main_block_desc = *cpp_prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly
std::set<std::string> unique_var_names;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable() ||
unique_var_names.count(var.Name()) > 0)
continue;
unique_var_names.emplace(var.Name());
}
/* 3. Save combined params to params.fbs */
const std::string params_path = model_dir + "/params.fbs";
fbs::CombinedParamsDesc params_prog;
fbs::SetCombinedParamsWithScope(exec_scope, unique_var_names, &params_prog);
fbs::SaveFile(params_path, params_prog.data(), params_prog.buf_size());
}
void LoadModelFbsFromFile(const std::string &filename,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
CHECK(scope);
/* 1. Save cpp::ProgramDesc with model.fbs */
const std::string prog_path = filename + "/model.fbs";
fbs::ProgramDesc program(fbs::LoadFile(prog_path));
TransformProgramDescAnyToCpp(program, cpp_prog);
/* 2. Save scope with params.fbs */
const std::string params_path = filename + "/params.fbs";
fbs::CombinedParamsDesc params(fbs::LoadFile(params_path));
fbs::SetScopeWithCombinedParams(scope, params);
}
#endif // LITE_ON_TINY_PUBLISH
template <typename T> template <typename T>
void SetTensorDataNaive(T *out, size_t size, const std::vector<T> &src) { void SetTensorDataNaive(T *out, size_t size, const std::vector<T> &src) {
......
...@@ -88,7 +88,15 @@ void SaveModelNaive(const std::string& model_dir, ...@@ -88,7 +88,15 @@ void SaveModelNaive(const std::string& model_dir,
const Scope& exec_scope, const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog, const cpp::ProgramDesc& cpp_prog,
bool combined = true); bool combined = true);
#endif
void SaveModelFbs(const std::string& model_dir,
const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog);
void LoadModelFbsFromFile(const std::string& filename,
Scope* scope,
cpp::ProgramDesc* cpp_prog);
#endif // LITE_ON_TINY_PUBLISH
void LoadParamNaive(const std::string& path, void LoadParamNaive(const std::string& path,
lite::Scope* scope, lite::Scope* scope,
......
...@@ -101,6 +101,7 @@ add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) ...@@ -101,6 +101,7 @@ add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool_grad extra SRCS sequence_pool_grad_op.cc DEPS ${op_DEPS}) add_operator(sequence_pool_grad extra SRCS sequence_pool_grad_op.cc DEPS ${op_DEPS})
add_operator(sequence_conv extra SRCS sequence_conv_op.cc DEPS ${op_DEPS}) add_operator(sequence_conv extra SRCS sequence_conv_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS}) add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_embedding_op_lite extra SRCS sequence_reverse_embedding_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS}) add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS})
add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS}) add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS})
...@@ -148,6 +149,7 @@ add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) ...@@ -148,6 +149,7 @@ add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS}) add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS})
add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS}) add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS})
add_operator(one_hot_op extra SRCS one_hot_op.cc DEPS ${op_DEPS})
# for content-dnn specific # for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS}) add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
...@@ -176,7 +178,7 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) ...@@ -176,7 +178,7 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS}) add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS})
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS}) add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS})
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS}) add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS})
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory DEPS fc_op memory
......
...@@ -141,9 +141,25 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -141,9 +141,25 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
} }
} }
} }
if (op_desc.HasAttr("fuse_relu")) { if (op_desc.HasAttr("with_act") && op_desc.GetAttr<bool>("with_act")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu"); param_.activation_param.has_active = true;
param_.activation_param.active_type = lite_api::ActivationType::kRelu; auto act_type = op_desc.GetAttr<std::string>("act_type");
if (act_type == "relu") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu;
param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu6;
param_.activation_param.Relu_clipped_coef =
op_desc.GetAttr<float>("fuse_brelu_threshold"); // 6.f
} else if (act_type == "leaky_relu") {
param_.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu;
param_.activation_param.Leaky_relu_alpha =
op_desc.GetAttr<float>("leaky_relu_alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
} }
if (op_desc.HasAttr("output_size")) { if (op_desc.HasAttr("output_size")) {
param_.output_size = op_desc.GetAttr<std::vector<int>>("output_size"); param_.output_size = op_desc.GetAttr<std::vector<int>>("output_size");
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/operators/dropout_op.h" #include "lite/operators/dropout_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -66,8 +68,10 @@ bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -66,8 +68,10 @@ bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.fix_seed = op_desc.GetAttr<bool>("fix_seed"); param_.fix_seed = op_desc.GetAttr<bool>("fix_seed");
param_.seed = op_desc.GetAttr<int>("seed"); param_.seed = op_desc.GetAttr<int>("seed");
param_.dropout_implementation = if (op_desc.HasAttr("dropout_implementation")) {
op_desc.GetAttr<std::string>("dropout_implementation"); param_.dropout_implementation =
op_desc.GetAttr<std::string>("dropout_implementation");
}
return true; return true;
} }
......
...@@ -50,9 +50,15 @@ bool FcOpLite::CheckShape() const { ...@@ -50,9 +50,15 @@ bool FcOpLite::CheckShape() const {
bool FcOpLite::InferShapeImpl() const { bool FcOpLite::InferShapeImpl() const {
const auto& input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims(); int64_t w_dims_1;
if (param_.w_dims.empty()) {
const auto& w_dims = param_.w->dims();
w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
} else {
const auto& w_dims = param_.w_dims;
w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
}
int in_num_col_dims = param_.in_num_col_dims; int in_num_col_dims = param_.in_num_col_dims;
int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
// Set output dims // Set output dims
std::vector<DDim::value_type> output_dims(in_num_col_dims + 1); std::vector<DDim::value_type> output_dims(in_num_col_dims + 1);
...@@ -77,6 +83,7 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -77,6 +83,7 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.input = scope->FindVar(input)->GetMutable<lite::Tensor>(); param_.input = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.w = scope->FindVar(W)->GetMutable<lite::Tensor>(); param_.w = scope->FindVar(W)->GetMutable<lite::Tensor>();
param_.w_dims = param_.w->dims();
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames(); std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) { input_arg_names.end()) {
......
...@@ -97,7 +97,9 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -97,7 +97,9 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.gate_activation = op_desc.GetAttr<std::string>("gate_activation"); param_.gate_activation = op_desc.GetAttr<std::string>("gate_activation");
param_.activation = op_desc.GetAttr<std::string>("activation"); param_.activation = op_desc.GetAttr<std::string>("activation");
param_.is_reverse = op_desc.GetAttr<bool>("is_reverse"); param_.is_reverse = op_desc.GetAttr<bool>("is_reverse");
param_.origin_mode = op_desc.GetAttr<bool>("origin_mode"); if (op_desc.HasAttr("origin_mode")) {
param_.origin_mode = op_desc.GetAttr<bool>("origin_mode");
}
return true; return true;
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool OneHotOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool OneHotOp::InferShapeImpl() const {
auto out_dims = param_.X->dims();
CHECK_GE(out_dims.size(), 2);
int depth = param_.depth_tensor ? param_.depth
: param_.depth_tensor->data<int32_t>()[0];
out_dims[out_dims.size() - 1] = depth;
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod());
return true;
}
bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
param_.X = scope->FindVar(x)->GetMutable<Tensor>();
param_.Out = scope->FindMutableTensor(out);
if (op_desc.HasInput("depth_tensor") &&
!op_desc.Input("depth_tensor").empty()) {
auto depth_tensor = op_desc.Input("depth_tensor").front();
param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable<Tensor>();
}
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = op_desc.GetAttr<bool>("allow_out_of_range");
}
param_.dtype = op_desc.GetAttr<int>("dtype");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
/* note:
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]] */
class OneHotOp : public OpLite {
public:
OneHotOp() {}
explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "one_hot"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->macs = param_.X->numel() * 1.f;
}
#endif
private:
mutable OneHotParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(one_hot_op_lite, TestHost) {
// prepare variables
Scope scope;
auto* x = scope.Var("X")->GetMutable<Tensor>();
auto* depth_tensor = scope.Var("depth_tensor")->GetMutable<Tensor>();
auto* output = scope.Var("Out")->GetMutable<Tensor>();
depth_tensor->dims();
output->dims();
// set data
x->Resize(DDim(std::vector<int64_t>({4, 1})));
auto* x_data = x->mutable_data<int32_t>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
x_data[3] = 0;
// prepare op desc
cpp::OpDesc desc;
desc.SetType("one_hot");
desc.SetInput("X", {"X"});
desc.SetInput("depth_tensor", {"depth_tensor"});
desc.SetOutput("Out", {"Out"});
desc.SetAttr("depth", static_cast<int>(4));
desc.SetAttr("dtype", static_cast<int>(1));
desc.SetAttr("allow_out_of_range", static_cast<bool>(0));
OneHotOp one_hot("one_hot");
one_hot.SetValidPlaces({Place{TARGET(kHost), PRECISION(kAny)}});
one_hot.Attach(desc, &scope);
auto kernels = one_hot.CreateKernels({Place{TARGET(kHost), PRECISION(kAny)}});
ASSERT_FALSE(kernels.empty());
}
} // namespace operators
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def);
...@@ -103,6 +103,8 @@ struct FcParam : ParamBase { ...@@ -103,6 +103,8 @@ struct FcParam : ParamBase {
lite::Tensor* bias{nullptr}; lite::Tensor* bias{nullptr};
lite::Tensor* output{nullptr}; lite::Tensor* output{nullptr};
lite::DDim in_mat_dims; lite::DDim in_mat_dims;
// original dims of input weight
lite::DDim w_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
std::string activation_type{""}; std::string activation_type{""};
bool padding_weights{false}; bool padding_weights{false};
...@@ -1824,6 +1826,15 @@ struct PrintParam : ParamBase { ...@@ -1824,6 +1826,15 @@ struct PrintParam : ParamBase {
bool is_forward{true}; bool is_forward{true};
}; };
struct OneHotParam : ParamBase {
const lite::Tensor* X{};
const lite::Tensor* depth_tensor{nullptr};
lite::Tensor* Out{};
int depth;
int dtype;
bool allow_out_of_range;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -43,6 +43,8 @@ bool SequencePoolOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -43,6 +43,8 @@ bool SequencePoolOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.Out = param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.MaxIndex = scope->FindVar(opdesc.Output("MaxIndex").front())
->GetMutable<lite::Tensor>();
param_.pool_type = opdesc.GetAttr<std::string>("pooltype"); param_.pool_type = opdesc.GetAttr<std::string>("pooltype");
CHECK(param_.X); CHECK(param_.X);
CHECK(param_.Out); CHECK(param_.Out);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_reverse_embedding_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceReverseEmbeddingOp::CheckShape() const {
CHECK_OR_FALSE(param_.W)
CHECK_OR_FALSE(param_.Ids)
CHECK_OR_FALSE(param_.Out)
CHECK_EQ(param_.Ids->lod().empty(), false)
<< "Input(Ids) Tensor of SequenceReverseEmbeddingOp does not contain "
"LoD information.";
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
int ids_rank = ids_dims.size();
CHECK_EQ_OR_FALSE(table_dims.size(), 2)
CHECK_EQ_OR_FALSE(ids_dims[ids_rank - 1], 1)
return true;
}
bool SequenceReverseEmbeddingOp::InferShapeImpl() const {
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = table_dims[1];
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids->lod());
return true;
}
bool SequenceReverseEmbeddingOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto input = op_desc.Input("W").front();
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindTensor(input);
param_.Ids = scope->FindTensor(ids);
param_.Out = scope->FindMutableTensor(out);
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_reverse_embedding,
paddle::lite::operators::SequenceReverseEmbeddingOp);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceReverseEmbeddingOp : public OpLite {
public:
SequenceReverseEmbeddingOp() {}
explicit SequenceReverseEmbeddingOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "sequence_reverse_embedding";
}
private:
mutable LookupTableParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -14,7 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT ...@@ -14,7 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_group_norm_compute SRCS group_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_group_norm_compute SRCS group_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
...@@ -40,21 +40,21 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT ...@@ -40,21 +40,21 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_fill_constant_batch_size_like_compute SRCS fill_constant_batch_size_like_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_fill_constant_batch_size_like_compute SRCS fill_constant_batch_size_like_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
...@@ -73,7 +73,7 @@ if(LITE_BUILD_EXTRA) ...@@ -73,7 +73,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif() endif()
endif() endif()
...@@ -90,5 +90,6 @@ endif() ...@@ -90,5 +90,6 @@ endif()
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_as_compute SRCS expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_as_compute SRCS expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_flatten_compute SRCS flatten_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif() endif()
...@@ -137,17 +137,20 @@ TEST(Cast, precision) { ...@@ -137,17 +137,20 @@ TEST(Cast, precision) {
place = TARGET(kARM); place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU); place = TARGET(kXPU);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else #else
return; return;
#endif #endif
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21; // SIZE_T = 19;UINT8 = 20;INT8 = 21;
#ifndef LITE_WITH_XPU #if !defined(LITE_WITH_XPU) && !defined(LITE_WITH_HUAWEI_ASCEND_NPU)
TestCast(place, abs_error, 20, 5); TestCast(place, abs_error, 20, 5);
#endif #endif
TestCast(place, abs_error, 2, 5); TestCast(place, abs_error, 2, 5);
#ifdef LITE_WITH_XPU #if defined(LITE_WITH_XPU) || defined(LITE_WITH_HUAWEI_ASCEND_NPU)
TestCast(place, abs_error, 3, 5); TestCast(place, abs_error, 3, 5);
TestCast(place, abs_error, 5, 3); TestCast(place, abs_error, 5, 3);
#endif #endif
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class FlattenComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "flatten";
std::string input_ = "x";
std::string output_ = "out";
std::string xshape_ = "xshape";
DDim dims_;
int axis_;
public:
FlattenComputeTester(const Place& place,
const std::string& alias,
DDim dims,
int axis)
: TestCase(place, alias), dims_(dims), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
auto* x = scope->FindTensor(input_);
int64_t outer = 1, inner = 1;
for (size_t i = 0; i < dims_.size(); ++i) {
if (i < axis_) {
outer *= dims_[i];
} else {
inner *= dims_[i];
}
}
std::vector<int64_t> out_shape(2);
out_shape[0] = outer;
out_shape[1] = inner;
out->Resize(out_shape);
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
memcpy(out_data, x_data, sizeof(float) * dims_.production());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
if (op_type_ == "flatten2") {
op_desc->SetOutput("XShape", {xshape_});
}
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
}
};
void TestFlatten(Place place, float abs_error) {
DDim dims{{2, 3, 4, 5}};
std::vector<int> axes{0, 1, 2, 3};
for (auto axis : axes) {
std::unique_ptr<arena::TestCase> tester(
new FlattenComputeTester(place, "def", dims, axis));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision({"xshape"});
}
}
TEST(flatten, precision) {
LOG(INFO) << "test flatten op";
Place place;
float abs_error = 1e-5;
#if defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else
return;
#endif
TestFlatten(place, abs_error);
}
} // namespace lite
} // namespace paddle
...@@ -96,6 +96,9 @@ TEST(Gather, precision) { ...@@ -96,6 +96,9 @@ TEST(Gather, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
......
...@@ -152,6 +152,9 @@ TEST(LayerNorm, precision) { ...@@ -152,6 +152,9 @@ TEST(LayerNorm, precision) {
#elif defined(LITE_WITH_NPU) #elif defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; abs_error = 1e-2;
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
abs_error = 6e-5; abs_error = 6e-5;
......
...@@ -118,6 +118,9 @@ TEST(LookupTable, precision) { ...@@ -118,6 +118,9 @@ TEST(LookupTable, precision) {
place = TARGET(kARM); place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU); place = TARGET(kXPU);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else #else
return; return;
#endif #endif
......
...@@ -455,6 +455,9 @@ TEST(Matmul2x2, precision) { ...@@ -455,6 +455,9 @@ TEST(Matmul2x2, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
...@@ -472,6 +475,9 @@ TEST(Matmul2x2_x_transpose, precision) { ...@@ -472,6 +475,9 @@ TEST(Matmul2x2_x_transpose, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#else #else
...@@ -487,6 +493,9 @@ TEST(Matmul2x2_y_transpose, precision) { ...@@ -487,6 +493,9 @@ TEST(Matmul2x2_y_transpose, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
...@@ -559,6 +568,9 @@ TEST(Matmulnxn, precision) { ...@@ -559,6 +568,9 @@ TEST(Matmulnxn, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_ARM) #elif defined(LITE_WITH_ARM)
place = TARGET(kARM); place = TARGET(kARM);
#else #else
......
...@@ -208,6 +208,9 @@ TEST(Reshape, precision) { ...@@ -208,6 +208,9 @@ TEST(Reshape, precision) {
place = TARGET(kHost); place = TARGET(kHost);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU); place = TARGET(kXPU);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else #else
return; return;
#endif #endif
......
...@@ -168,6 +168,9 @@ TEST(Scale, precision) { ...@@ -168,6 +168,9 @@ TEST(Scale, precision) {
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU); place = TARGET(kXPU);
abs_error = 3e-4; // Some operations use fp16 in XPU abs_error = 3e-4; // Some operations use fp16 in XPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#elif defined(LITE_WITH_X86) #elif defined(LITE_WITH_X86)
place = TARGET(kX86); place = TARGET(kX86);
#else #else
......
...@@ -103,6 +103,11 @@ class SearchAlignedMatMulComputeTester : public arena::TestCase { ...@@ -103,6 +103,11 @@ class SearchAlignedMatMulComputeTester : public arena::TestCase {
out->Resize(out_dims); out->Resize(out_dims);
auto out_data = out->mutable_data<float>(); auto out_data = out->mutable_data<float>();
// Prevent 0*nan=nan in basic_gemm
int64_t out_num = out_dims.production();
for (int64_t i = 0; i < out_num; i++) {
out_data[i] = 0;
}
for (int i = 0; i < seq_num; i++) { for (int i = 0; i < seq_num; i++) {
basic_gemm<float, float>(x_transpose_, basic_gemm<float, float>(x_transpose_,
y_transpose_, y_transpose_,
......
...@@ -87,12 +87,18 @@ class SearchSeqFcOPTest : public arena::TestCase { ...@@ -87,12 +87,18 @@ class SearchSeqFcOPTest : public arena::TestCase {
} }
out->set_lod(x_lod); out->set_lod(x_lod);
out->Resize({x_dims[0], w_dims[0]}); DDim out_dims({x_dims[0], w_dims[0]});
out->Resize(out_dims);
int M = x_dims[0]; int M = x_dims[0];
int K = x_dims[1]; int K = x_dims[1];
int N = w_dims[0]; int N = w_dims[0];
auto out_data = out->mutable_data<float>(); auto out_data = out->mutable_data<float>();
// Prevent 0*nan=nan in basic_gemm
int64_t out_num = out_dims.production();
for (int64_t i = 0; i < out_num; i++) {
out_data[i] = 0;
}
basic_gemm<float, float>(false, basic_gemm<float, float>(false,
true, true,
M, M,
......
...@@ -271,6 +271,9 @@ TEST(Slice, precision) { ...@@ -271,6 +271,9 @@ TEST(Slice, precision) {
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
Place place(TARGET(kXPU)); Place place(TARGET(kXPU));
test_slice(place); test_slice(place);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
Place place = TARGET(kHuaweiAscendNPU);
test_slice(place);
#endif #endif
} }
......
...@@ -169,6 +169,9 @@ TEST(Transpose, precision) { ...@@ -169,6 +169,9 @@ TEST(Transpose, precision) {
#elif defined(LITE_WITH_NPU) #elif defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else #else
return; return;
#endif #endif
......
...@@ -125,7 +125,7 @@ void release_param(ConvParam* param) { ...@@ -125,7 +125,7 @@ void release_param(ConvParam* param) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
void test_conv_int8(const std::vector<DDim>& input_dims, void test_conv_int8(const DDim& dim_in,
const DDim& weight_dim, const DDim& weight_dim,
int group, int group,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -237,241 +237,234 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -237,241 +237,234 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
conv_int8_fp32.SetContext(std::move(ctx2)); conv_int8_fp32.SetContext(std::move(ctx2));
/// set param and context /// set param and context
for (auto& dim_in : input_dims) { param_int8_out.x->Resize(dim_in);
param_int8_out.x->Resize(dim_in); DDim out_tmp_dims = compute_out_dim(dim_in, param_int8_out);
DDim out_tmp_dims = compute_out_dim(dim_in, param_int8_out); if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) { return;
continue;
}
param_fp32_out.x->Resize(dim_in);
param_int8_out.output->Resize(out_tmp_dims);
param_fp32_out.output->Resize(out_tmp_dims);
break;
} }
param_fp32_out.x->Resize(dim_in);
param_int8_out.output->Resize(out_tmp_dims);
param_fp32_out.output->Resize(out_tmp_dims);
conv_int8_int8.SetParam(param_int8_out); conv_int8_int8.SetParam(param_int8_out);
conv_int8_fp32.SetParam(param_fp32_out); conv_int8_fp32.SetParam(param_fp32_out);
/// prepare for run /// prepare for run
conv_int8_int8.PrepareForRun(); conv_int8_int8.PrepareForRun();
conv_int8_fp32.PrepareForRun(); conv_int8_fp32.PrepareForRun();
for (auto& dim_in : input_dims) { CHECK_EQ(weight_dim[1] * group, dim_in[1])
CHECK_EQ(weight_dim[1] * group, dim_in[1]) << "input channel must equal to weights channel";
<< "input channel must equal to weights channel"; DDim dim_out = compute_out_dim(dim_in, param_int8_out);
DDim dim_out = compute_out_dim(dim_in, param_int8_out); if (dim_out[2] < 1 || dim_out[3] < 1) {
if (dim_out[2] < 1 || dim_out[3] < 1) { continue;
continue; }
} delete param_fp32_out.output;
delete param_fp32_out.output; param_fp32_out.output = new Tensor;
param_fp32_out.output = new Tensor; param_fp32_out.output->set_precision(PRECISION(kFloat));
param_fp32_out.output->set_precision(PRECISION(kFloat)); delete param_int8_out.output;
delete param_int8_out.output; param_int8_out.output = new Tensor;
param_int8_out.output = new Tensor; param_int8_out.output->set_precision(PRECISION(kInt8));
param_int8_out.output->set_precision(PRECISION(kInt8));
param_int8_out.x->Resize(dim_in);
param_int8_out.x->Resize(dim_in); param_int8_out.output->Resize(dim_out);
param_int8_out.output->Resize(dim_out); param_fp32_out.x->Resize(dim_in);
param_fp32_out.x->Resize(dim_in); param_fp32_out.output->Resize(dim_out);
param_fp32_out.output->Resize(dim_out);
Tensor tin_fp32;
Tensor tin_fp32; tin_fp32.Resize(dim_in);
tin_fp32.Resize(dim_in); tin_fp32.set_precision(PRECISION(kFloat));
tin_fp32.set_precision(PRECISION(kFloat)); Tensor tout_basic_fp32;
Tensor tout_basic_fp32; Tensor tout_basic_int8;
Tensor tout_basic_int8;
paddle::lite::fill_tensor_rand(*param_int8_out.x, -127, 127);
paddle::lite::fill_tensor_rand(*param_int8_out.x, -127, 127); param_fp32_out.x->CopyDataFrom(*param_int8_out.x);
param_fp32_out.x->CopyDataFrom(*param_int8_out.x);
auto din_fp32 = tin_fp32.mutable_data<float>();
auto din_fp32 = tin_fp32.mutable_data<float>(); paddle::lite::arm::math::int8_to_fp32(param_int8_out.x->data<int8_t>(),
paddle::lite::arm::math::int8_to_fp32(param_int8_out.x->data<int8_t>(), din_fp32,
din_fp32, scale_in.data(),
scale_in.data(), 1,
1,
dim_in.production());
if (FLAGS_check_result) {
tout_basic_fp32.set_precision(PRECISION(kFloat));
tout_basic_fp32.Resize(dim_out);
tout_basic_int8.set_precision(PRECISION(kInt8));
tout_basic_int8.Resize(dim_out);
fill_tensor_const(tout_basic_fp32, 0.f);
auto dout_basic_fp32 = tout_basic_fp32.mutable_data<float>();
auto dout_basic_int8 = tout_basic_int8.mutable_data<int8_t>();
conv_basic<float, float>(din_fp32,
dout_basic_fp32,
dim_in[0],
dim_out[1],
dim_out[2],
dim_out[3],
dim_in[1],
dim_in[2],
dim_in[3],
wptr_fp32,
bptr_fp32,
group,
weight_dim[3],
weight_dim[2],
strides[1],
strides[0],
dilas[1],
dilas[0],
pads[2],
pads[0],
flag_bias,
flag_act,
six,
alpha);
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32,
dout_basic_int8,
scale_out.data(),
1, 1,
1, 1,
dim_in.production()); dim_out.production());
}
if (FLAGS_check_result) { double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
tout_basic_fp32.set_precision(PRECISION(kFloat)); weight_dim[3] / group;
tout_basic_fp32.Resize(dim_out); /// warm up
tout_basic_int8.set_precision(PRECISION(kInt8)); for (int i = 0; i < FLAGS_warmup; ++i) {
tout_basic_int8.Resize(dim_out); conv_int8_fp32.Launch();
fill_tensor_const(tout_basic_fp32, 0.f); }
auto dout_basic_fp32 = tout_basic_fp32.mutable_data<float>(); /// compute fp32 output
auto dout_basic_int8 = tout_basic_int8.mutable_data<int8_t>(); Timer t0;
conv_basic<float, float>(din_fp32, for (int i = 0; i < FLAGS_repeats; ++i) {
dout_basic_fp32, t0.Start();
dim_in[0], conv_int8_fp32.Launch();
dim_out[1], t0.Stop();
dim_out[2], }
dim_out[3], LOG(INFO) << "int8 conv, fp32 output: output shape" << dim_out
dim_in[1], << ",running time, avg: " << t0.LapTimes().Avg() << " ms"
dim_in[2], << ", min time: " << t0.LapTimes().Min() << " ms"
dim_in[3], << ", total GOPS: " << 1e-9 * gops
wptr_fp32, << " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
bptr_fp32, << " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min();
group,
weight_dim[3], // compute int8 output
weight_dim[2], t0.Reset();
strides[1], for (int i = 0; i < FLAGS_repeats; ++i) {
strides[0], t0.Start();
dilas[1], conv_int8_int8.Launch();
dilas[0], t0.Stop();
pads[2], }
pads[0], LOG(INFO) << "int8 conv, int8 output: output shape" << dim_out
flag_bias, << ",running time, avg: " << t0.LapTimes().Avg()
flag_act, << ", min time: " << t0.LapTimes().Min()
six, << ", total GOPS: " << 1e-9 * gops
alpha); << " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32, << " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min();
dout_basic_int8,
scale_out.data(), /// compare result fp32 output
1, if (FLAGS_check_result) {
1, double max_ratio = 0;
dim_out.production()); double max_diff = 0;
} tensor_cmp_host(
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] * tout_basic_fp32, *param_fp32_out.output, max_ratio, max_diff);
weight_dim[3] / group; LOG(INFO) << "FP32 compare result, max diff: " << max_diff
/// warm up << ", max ratio: " << max_ratio;
for (int i = 0; i < FLAGS_warmup; ++i) { if (std::abs(max_ratio) > 1e-5f) {
conv_int8_int8.Launch(); if (max_diff > 5e-5f) {
} LOG(WARNING) << "basic result";
/// compute fp32 output print_tensor(tout_basic_fp32);
Timer t0; LOG(WARNING) << "lite result";
for (int i = 0; i < FLAGS_repeats; ++i) { print_tensor(*param_fp32_out.output);
t0.Start(); Tensor tdiff;
conv_int8_fp32.Launch(); tdiff.Resize(tout_basic_fp32.dims());
t0.Stop(); tdiff.set_precision(PRECISION(kFloat));
} tensor_diff(tout_basic_fp32, *param_fp32_out.output, tdiff);
LOG(INFO) << "int8 conv, fp32 output: output shape" << dim_out print_tensor(tdiff);
<< ",running time, avg: " << t0.LapTimes().Avg() release_param(&param_int8_out);
<< ", min time: " << t0.LapTimes().Min() release_param(&param_fp32_out);
<< ", total GOPS: " << 1e-9 * gops LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg() << ", output: " << dim_out
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min(); << ", weight dim: " << weight_dim << ", pad: " << pads[0]
<< ", " << pads[1] << ", " << pads[2] << ", " << pads[3]
/// compute int8 output << ", stride: " << strides[0] << ", " << strides[1]
t0.Reset(); << ", dila_: " << dilas[0] << ", " << dilas[1]
for (int i = 0; i < FLAGS_repeats; ++i) { << ", group: " << group
t0.Start(); << ", bias: " << (flag_bias ? "true" : "false")
conv_int8_int8.Launch(); << ", act: " << flag_act << ", threads: " << th
t0.Stop(); << ", power_mode: " << cls << " failed!!\n";
}
LOG(INFO) << "int8 conv, int8 output: output shape" << dim_out
<< ",running time, avg: " << t0.LapTimes().Avg()
<< ", min time: " << t0.LapTimes().Min()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min();
/// compare result fp32 output
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(
tout_basic_fp32, *param_fp32_out.output, max_ratio, max_diff);
LOG(INFO) << "FP32 compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-5f) {
if (max_diff > 5e-5f) {
LOG(WARNING) << "basic result";
print_tensor(tout_basic_fp32);
LOG(WARNING) << "lite result";
print_tensor(*param_fp32_out.output);
Tensor tdiff;
tdiff.Resize(tout_basic_fp32.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic_fp32, *param_fp32_out.output, tdiff);
print_tensor(tdiff);
release_param(&param_int8_out);
release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", "
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " failed!!\n";
}
} }
} }
/// compare result int8 output }
if (FLAGS_check_result) { // compare result int8 output
double max_ratio = 0; if (FLAGS_check_result) {
double max_diff = 0; double max_ratio = 0;
// ! int8 double max_diff = 0;
tensor_cmp_host( // ! int8
tout_basic_int8, *param_int8_out.output, max_ratio, max_diff); tensor_cmp_host(
LOG(INFO) << "int8 compare result, max diff: " << max_diff tout_basic_int8, *param_int8_out.output, max_ratio, max_diff);
<< ", max ratio: " << max_ratio; LOG(INFO) << "int8 compare result, max diff: " << max_diff
if (fabs(max_diff) > 0) { << ", max ratio: " << max_ratio;
Tensor tdiff; if (fabs(max_diff) > 0) {
tdiff.Resize(tout_basic_int8.dims()); Tensor tdiff;
tdiff.set_precision(PRECISION(kInt8)); tdiff.Resize(tout_basic_int8.dims());
tensor_diff(tout_basic_int8, *param_int8_out.output, tdiff); tdiff.set_precision(PRECISION(kInt8));
auto ptr = tdiff.data<int8_t>(); tensor_diff(tout_basic_int8, *param_int8_out.output, tdiff);
auto ptr_basic_fp32 = tout_basic_fp32.data<float>(); auto ptr = tdiff.data<int8_t>();
float count = 0; auto ptr_basic_fp32 = tout_basic_fp32.data<float>();
bool check = true; float count = 0;
for (int i = 0; i < tdiff.numel(); ++i) { bool check = true;
if (abs(ptr[i]) > 1) { for (int i = 0; i < tdiff.numel(); ++i) {
check = false; if (abs(ptr[i]) > 1) {
LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i] check = false;
<< ", after scale: " LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
<< ptr_basic_fp32[i] / scale_out[0]; << ", after scale: "
break; << ptr_basic_fp32[i] / scale_out[0];
} break;
if (ptr[i] != 0) {
LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
<< ", after scale: "
<< ptr_basic_fp32[i] / scale_out[0];
count += 1;
}
} }
check = if (ptr[i] != 0) {
check && LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
count < std::max(10, static_cast<int>(0.01 * tdiff.numel())); << ", after scale: "
if (!check) { << ptr_basic_fp32[i] / scale_out[0];
LOG(WARNING) << "int8 basic result"; count += 1;
print_tensor(tout_basic_int8);
LOG(WARNING) << "int8 lite result";
print_tensor(*param_int8_out.output);
LOG(WARNING) << "int8 diff tensor";
print_tensor(tdiff);
release_param(&param_int8_out);
release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", "
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " failed!!\n";
} }
} }
check = check &&
count < std::max(10, static_cast<int>(0.01 * tdiff.numel()));
if (!check) {
LOG(WARNING) << "int8 basic result";
print_tensor(tout_basic_int8);
LOG(WARNING) << "int8 lite result";
print_tensor(*param_int8_out.output);
LOG(WARNING) << "int8 diff tensor";
print_tensor(tdiff);
release_param(&param_int8_out);
release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim << ", pad: " << pads[0]
<< ", " << pads[1] << ", " << pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " failed!!\n";
}
} }
LOG(INFO) << "test int8 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " successed!!\n";
} }
LOG(INFO) << "test int8 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " successed!!\n";
} }
} }
release_param(&param_int8_out); release_param(&param_int8_out);
release_param(&param_fp32_out); release_param(&param_fp32_out);
} }
#else #else
void test_conv_int8(const std::vector<DDim>& input_dims, void test_conv_int8(const DDim& dims_in,
const DDim& weight_dim, const DDim& weight_dim,
int group, int group,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -493,25 +486,24 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -493,25 +486,24 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) { for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3}); DDim weights_dim({c, 1, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 33}) { for (auto& h : {1, 3, 15, 33}) {
dims.push_back(DDim({batch, c, h, h})); DDim dims({batch, c, h, h});
test_conv_int8(dims,
weights_dim,
c,
{stride, stride},
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_act,
{FLAGS_threads},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
test_conv_int8(dims,
weights_dim,
c,
{stride, stride},
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -529,25 +521,24 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -529,25 +521,24 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 5, 15, 33}) { for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 33, 112, 224}) { for (auto& h : {1, 3, 15, 33, 112, 224}) {
dims.push_back(DDim({batch, c, h, h})); DDim dims({batch, c, h, h});
test_conv_int8(dims,
weights_dim,
c,
{stride, stride},
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_act,
{1, 4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
test_conv_int8(dims,
weights_dim,
c,
{stride, stride},
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_act,
{1, 4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -565,28 +556,27 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -565,28 +556,27 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
continue; continue;
} }
DDim weights_dim({cout, cin / g, 1, 1}); DDim weights_dim({cout, cin / g, 1, 1});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 9, 16, 33}) { for (auto& h : {1, 9, 16, 33}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
test_conv_int8(dims,
weights_dim,
g,
{1, 1},
{0, 0, 0, 0},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
test_conv_int8(dims,
weights_dim,
g,
{1, 1},
{0, 0, 0, 0},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -606,29 +596,29 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -606,29 +596,29 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 17, 33}) { for (auto& h : {1, 7, 17, 33}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
if (cin == 1 && cout == 1) {
continue;
}
test_conv_int8(
dims,
weights_dim,
1,
{1, 1},
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
if (cin == 1 && cout == 1) {
continue;
}
test_conv_int8(dims,
weights_dim,
1,
{1, 1},
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -652,25 +642,25 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -652,25 +642,25 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 33}) { for (auto& h : {1, 7, 19, 33}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
test_conv_int8(
dims,
weights_dim,
1,
{2, 2},
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
test_conv_int8(dims,
weights_dim,
1,
{2, 2},
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -702,26 +692,27 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -702,26 +692,27 @@ TEST(TestConvRandInt8, test_conv_rand) {
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
break; break;
} }
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw}); DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 5, 19}) { for (auto& h : {1, 3, 5, 19}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
test_conv_int8(dims,
weights_dim,
g,
{stride, stride},
{pad_top,
pad_bottom,
pad_left,
pad_right},
{dila, dila},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
test_conv_int8(
dims,
weights_dim,
g,
{stride, stride},
{pad_top, pad_bottom, pad_left, pad_right},
{dila, dila},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
......
...@@ -146,6 +146,10 @@ function make_tiny_publish_so { ...@@ -146,6 +146,10 @@ function make_tiny_publish_so {
prepare_opencl_source_code $workspace $build_dir prepare_opencl_source_code $workspace $build_dir
fi fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
local cmake_mutable_options=" local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_BUILD_EXTRA=$WITH_EXTRA \
...@@ -199,6 +203,10 @@ function make_full_publish_so { ...@@ -199,6 +203,10 @@ function make_full_publish_so {
prepare_opencl_source_code $workspace $build_dir prepare_opencl_source_code $workspace $build_dir
fi fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
local cmake_mutable_options=" local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
......
...@@ -49,6 +49,10 @@ function make_ios { ...@@ -49,6 +49,10 @@ function make_ios {
exit 1 exit 1
fi fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
build_dir=$workspace/build.ios.${os}.${arch} build_dir=$workspace/build.ios.${os}.${arch}
if [ -d $build_dir ] if [ -d $build_dir ]
then then
...@@ -61,7 +65,6 @@ function make_ios { ...@@ -61,7 +65,6 @@ function make_ios {
GEN_CODE_PATH_PREFIX=lite/gen_code GEN_CODE_PATH_PREFIX=lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX} mkdir -p ./${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
cmake $workspace \ cmake $workspace \
-DWITH_LITE=ON \ -DWITH_LITE=ON \
-DLITE_WITH_ARM=ON \ -DLITE_WITH_ARM=ON \
......
...@@ -173,6 +173,9 @@ function make_tiny_publish_so { ...@@ -173,6 +173,9 @@ function make_tiny_publish_so {
if [ "${WITH_OPENCL}" = "ON" ]; then if [ "${WITH_OPENCL}" = "ON" ]; then
prepare_opencl_source_code $workspace $build_dir prepare_opencl_source_code $workspace $build_dir
fi fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
init_cmake_mutable_options init_cmake_mutable_options
cmake $workspace \ cmake $workspace \
......
...@@ -69,7 +69,7 @@ function CheckLibSizeDiff() { ...@@ -69,7 +69,7 @@ function CheckLibSizeDiff() {
# step3: if diff_size > 10485, special approval is needed # step3: if diff_size > 10485, special approval is needed
diff_size=$[$current_size - $develop_size] diff_size=$[$current_size - $develop_size]
if [ $diff_size -gt 10485 ]; then if [ $diff_size -gt 10485 ]; then
echo_line="Your PR has increased basic inference lib for $diff_size Byte, exceeding maximum requirement of 10485 Byte (0.01M). You need Superjomn's (Yunchunwei) approval or you can contact DannyIsFunny(HuZhiqiang).\n" echo_line="Your PR has increased basic inference lib for $diff_size Byte, exceeding maximum requirement of 10485 Byte (0.01M). You need Superjomn's (Yunchunwei) approval or you can contact DannyIsFunny(HuZhiqiang).\n Library size in develop branch: $develop_size byte, library size after merging your code: $current_size byte.\n Compiling method: ./lite/tools/build_android.sh --with_log=OFF\n"
echo "****************" echo "****************"
echo -e "${echo_line[@]}" echo -e "${echo_line[@]}"
echo "There is an approved errors." echo "There is an approved errors."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册