diff --git a/.gitignore b/.gitignore
index 894a44cc066a027465cd26d634948d56d13af9af..7c549619732ca4ce247b629414cf2bfe4d6c8a96 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,7 +4,6 @@ __pycache__/
*$py.class
# C extensions
-*.so
# Distribution / packaging
.Python
diff --git a/README.md b/README.md
index a17a7cf8f31341aa49b9ca30a27d1551b326766b..af7ba67a8c2280a51580815762ed8d5e306c567f 100644
--- a/README.md
+++ b/README.md
@@ -64,12 +64,15 @@ $ pip install -r requirements.txt
* [如何训练U-Net](./turtorial/finetune_unet.md)
* [如何训练ICNet](./turtorial/finetune_icnet.md)
* [如何训练PSPNet](./turtorial/finetune_pspnet.md)
+* [如何训练HRNet](./turtorial/finetune_hrnet.md)
### 预测部署
* [模型导出](./docs/model_export.md)
* [使用Python预测](./deploy/python/)
* [使用C++预测](./deploy/cpp/)
+* [移动端预测部署](./deploy/lite/)
+
### 高级功能
diff --git a/configs/hrnet_w18_pet.yaml b/configs/hrnet_w18_pet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b1bfb9215e7f204444613fd9f6c78eba9c1c1432
--- /dev/null
+++ b/configs/hrnet_w18_pet.yaml
@@ -0,0 +1,49 @@
+TRAIN_CROP_SIZE: (512, 512) # (width, height), for unpadding rangescaling and stepscaling
+EVAL_CROP_SIZE: (512, 512) # (width, height), for unpadding rangescaling and stepscaling
+AUG:
+ AUG_METHOD: "unpadding" # choice unpadding rangescaling and stepscaling
+ FIX_RESIZE_SIZE: (512, 512) # (width, height), for unpadding
+
+ INF_RESIZE_VALUE: 500 # for rangescaling
+ MAX_RESIZE_VALUE: 600 # for rangescaling
+ MIN_RESIZE_VALUE: 400 # for rangescaling
+
+ MAX_SCALE_FACTOR: 1.25 # for stepscaling
+ MIN_SCALE_FACTOR: 0.75 # for stepscaling
+ SCALE_STEP_SIZE: 0.25 # for stepscaling
+ MIRROR: True
+BATCH_SIZE: 4
+DATASET:
+ DATA_DIR: "./dataset/mini_pet/"
+ IMAGE_TYPE: "rgb" # choice rgb or rgba
+ NUM_CLASSES: 3
+ TEST_FILE_LIST: "./dataset/mini_pet/file_list/test_list.txt"
+ TRAIN_FILE_LIST: "./dataset/mini_pet/file_list/train_list.txt"
+ VAL_FILE_LIST: "./dataset/mini_pet/file_list/val_list.txt"
+ VIS_FILE_LIST: "./dataset/mini_pet/file_list/test_list.txt"
+ IGNORE_INDEX: 255
+ SEPARATOR: " "
+FREEZE:
+ MODEL_FILENAME: "__model__"
+ PARAMS_FILENAME: "__params__"
+MODEL:
+ MODEL_NAME: "hrnet"
+ DEFAULT_NORM_TYPE: "bn"
+ HRNET:
+ STAGE2:
+ NUM_CHANNELS: [18, 36]
+ STAGE3:
+ NUM_CHANNELS: [18, 36, 72]
+ STAGE4:
+ NUM_CHANNELS: [18, 36, 72, 144]
+TRAIN:
+ PRETRAINED_MODEL_DIR: "./pretrained_model/hrnet_w18_bn_cityscapes/"
+ MODEL_SAVE_DIR: "./saved_model/hrnet_w18_bn_pet/"
+ SNAPSHOT_EPOCH: 10
+TEST:
+ TEST_MODEL: "./saved_model/hrnet_w18_bn_pet/final"
+SOLVER:
+ NUM_EPOCHS: 100
+ LR: 0.005
+ LR_POLICY: "poly"
+ OPTIMIZER: "sgd"
diff --git a/deploy/README.md b/deploy/README.md
index fc06306b7e99817b9aed6b69e9f5d70fc95c61b1..16fe75b9dfa230c290dfd88fa2b57cc924ec4387 100644
--- a/deploy/README.md
+++ b/deploy/README.md
@@ -8,3 +8,5 @@
[3. 服务化部署(仅支持 Linux)](./serving)
+[4. 移动端部署(仅支持Android)](./lite)
+
diff --git a/deploy/lite/README.md b/deploy/lite/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f4ec50be28e75d79ce2f61453737930bccf52cf4
--- /dev/null
+++ b/deploy/lite/README.md
@@ -0,0 +1,70 @@
+# 人像分割在移动端的部署
+
+## 1.介绍
+以人像分割在安卓端的部署为例,介绍如何使用[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite)对分割模型进行移动端的部署。文档第二节介绍如何使用人像分割安卓端的demo,后面几章节介绍如何将PaddleSeg的Model部署到安卓设备。
+
+## 2.安卓Demo使用
+
+### 2.1 要求
+* Android Studio 3.4;
+* Android手机或开发板;
+
+### 2.2 安装
+* git clone https://github.com/PaddlePaddle/PaddleSeg.git ;
+* 打开Android Studio,在"Welcome to Android Studio"窗口点击"Open an existing Android Studio project",在弹出的路径选择窗口中进入"/PaddleSeg/lite/humanseg-android-demo/"目录,然后点击右下角的"Open"按钮即可导入工程
+* 通过USB连接Android手机或开发板;
+* 载入工程后,点击菜单栏的Run->Run 'App'按钮,在弹出的"Select Deployment Target"窗口选择已经连接的Android设备,然后点击"OK"按钮;
+* 手机上会出现Demo的主界面,选择"Image Segmentation"图标,进入的人像分割示例程序;
+* 在人像分割Demo中,默认会载入一张人像图像,并会在图像下方给出CPU的预测结果;
+* 在人像分割Demo中,你还可以通过上方的"Gallery"和"Take Photo"按钮从相册或相机中加载测试图像;
+
+### 2.3 其他
+此安卓demo基于[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)开发,更多的细节请参考该repo。
+*注意:demo中拍照时照片会自动压缩,想测试拍照原图效果,可使用手机相机拍照后从相册中打开进行预测。*
+
+### 2.4 效果展示
+
+## 3.模型导出
+此demo的人像分割模型为[下载链接](https://paddleseg.bj.bcebos.com/models/humanseg_mobilenetv2_1_0_bn_freeze_model_pr_po.zip),是基于Deeplab_v3+mobileNet_v2的humanseg模型,关于humanseg的介绍移步[特色垂类分割模型](./contrib),更多的分割模型导出可参考:[模型导出](https://github.com/PaddlePaddle/PaddleSeg/blob/release/v0.2.0/docs/model_export.md)
+
+## 4.模型转换
+
+### 4.1模型转换工具
+准备好PaddleSeg导出来的模型和参数文件后,需要使用Paddle-Lite提供的model_optimize_tool对模型进行优化,并转换成Paddle-Lite支持的文件格式,这里有两种方式来实现:
+
+* 手动编译model_optimize_tool
+详细的模型转换方法参考paddlelite提供的官方文档:[模型转化方法](https://paddlepaddle.github.io/Paddle-Lite/v2.0.0/model_optimize_tool/),从PaddleSeg里面导出来的模型使用model_optimize_tool即可导出model.nb和param.nb文件。
+
+* 使用预编译版本的model_optimize_tool,最新的预编译文件参考[release](https://github.com/PaddlePaddle/Paddle-Lite/releases/),此demo使用的版本为[model_optimize_tool](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.0.0/model_optimize_tool) ;
+
+ *注意:如果运行失败,请在[Paddle-Lite源码编译](https://paddlepaddle.github.io/Paddle-Lite/v2.0.0/source_compile/)的开发环境中使用model_optimize_tool*
+
+### 4.2 更新模型
+将优化好的model.nb和param.nb文件,替换app/src/main/assets/image_segmentation/
+models/deeplab_mobilenet_for_cpu下面的文件即可。
+
+## 5. 更新预测库
+Paddle-Lite的编译目前支持Docker,Linux和Mac OS开发环境,建议使用Docker开发环境,以免存在各种依赖问题,同时也提供了预编译版本的预测库。准备Paddle-Lite在安卓端的预测库,主要包括三个文件:
+
+* PaddlePredictor.jar;
+* arm64-v8a/libpaddle_lite_jni.so;
+* armeabi-v7a/libpaddle_lite_jni.so;
+
+下面分别介绍两种方法:
+
+* 使用预编译版本的预测库,最新的预编译文件参考:[release](https://github.com/PaddlePaddle/Paddle-Lite/releases/),此demo使用的版本:
+
+ * arm64-v8a: [inference_lite_lib.android.armv8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.0.0/inference_lite_lib.android.armv8.gcc.c++_shared.with_extra.full_publish.tar.gz) ;
+
+ * armeabi-v7a: [inference_lite_lib.android.armv7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.0.0/inference_lite_lib.android.armv7.gcc.c++_shared.with_extra.full_publish.tar.gz) ;
+
+ 解压上面两个文件,PaddlePredictor.jar位于任一文件夹:inference_lite_lib.android.xxx/java/jar/PaddlePredictor.jar;
+
+ 解压上述inference_lite_lib.android.armv8文件,arm64-v8a/libpaddle_lite_jni.so位于:inference_lite_lib.android.armv8/java/so/libpaddle_lite_jni.so;
+
+ 解压上述inference_lite_lib.android.armv7文件,armeabi-v7a/libpaddle_lite_jni.so位于:inference_lite_lib.android.armv7/java/so/libpaddle_lite_jni.so;
+
+* 手动编译Paddle-Lite预测库
+开发环境的准备和编译方法参考:[Paddle-Lite源码编译](https://paddlepaddle.github.io/Paddle-Lite/v2.0.0/source_compile/)。
+
+准备好上述文件,即可参考[java_api](https://paddlepaddle.github.io/Paddle-Lite/v2.0.0/java_api_doc/)在安卓端进行推理。具体使用预测库的方法可参考[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)中更新预测库部分的文档。
diff --git a/deploy/lite/example/human_1.png b/deploy/lite/example/human_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..a167663bf824e0b015b22840414ceb80d6d7923b
Binary files /dev/null and b/deploy/lite/example/human_1.png differ
diff --git a/deploy/lite/example/human_2.png b/deploy/lite/example/human_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..8895e5c0a74e9901404ac0da66de07da986539c4
Binary files /dev/null and b/deploy/lite/example/human_2.png differ
diff --git a/deploy/lite/example/human_3.png b/deploy/lite/example/human_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..31ba12e450bc1dc3e534a703c0b3ff84f84e7cc1
Binary files /dev/null and b/deploy/lite/example/human_3.png differ
diff --git a/deploy/lite/humanseg-android-demo/.gitignore b/deploy/lite/humanseg-android-demo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2b75303ac58f551de0a327638a60b909c6d33ece
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/.gitignore
@@ -0,0 +1,13 @@
+*.iml
+.gradle
+/local.properties
+/.idea/caches
+/.idea/libraries
+/.idea/modules.xml
+/.idea/workspace.xml
+/.idea/navEditor.xml
+/.idea/assetWizardSettings.xml
+.DS_Store
+/build
+/captures
+.externalNativeBuild
diff --git a/deploy/lite/humanseg-android-demo/app/.gitignore b/deploy/lite/humanseg-android-demo/app/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..796b96d1c402326528b4ba3c12ee9d92d0e212e9
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/.gitignore
@@ -0,0 +1 @@
+/build
diff --git a/deploy/lite/humanseg-android-demo/app/build.gradle b/deploy/lite/humanseg-android-demo/app/build.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..087d90ca07b67a94030346989c9b1e8597693f61
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/build.gradle
@@ -0,0 +1,30 @@
+apply plugin: 'com.android.application'
+
+android {
+ compileSdkVersion 28
+ defaultConfig {
+ applicationId "com.baidu.paddle.lite.demo"
+ minSdkVersion 15
+ targetSdkVersion 28
+ versionCode 1
+ versionName "1.0"
+ testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+ }
+ }
+}
+
+dependencies {
+ implementation fileTree(include: ['*.jar'], dir: 'libs')
+ implementation 'com.android.support:appcompat-v7:28.0.0'
+ implementation 'com.android.support.constraint:constraint-layout:1.1.3'
+ implementation 'com.android.support:design:28.0.0'
+ testImplementation 'junit:junit:4.12'
+ androidTestImplementation 'com.android.support.test:runner:1.0.2'
+ androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
+ implementation files('libs/PaddlePredictor.jar')
+}
diff --git a/deploy/lite/humanseg-android-demo/app/gradle/wrapper/gradle-wrapper.jar b/deploy/lite/humanseg-android-demo/app/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000000000000000000000000000000000000..f6b961fd5a86aa5fbfe90f707c3138408be7c718
Binary files /dev/null and b/deploy/lite/humanseg-android-demo/app/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/deploy/lite/humanseg-android-demo/app/gradle/wrapper/gradle-wrapper.properties b/deploy/lite/humanseg-android-demo/app/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000000000000000000000000000000000000..7b5dff50f980af2fe06868600aac6c4db88614f8
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Mon Nov 25 17:01:58 CST 2019
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip
diff --git a/deploy/lite/humanseg-android-demo/app/gradlew b/deploy/lite/humanseg-android-demo/app/gradlew
new file mode 100644
index 0000000000000000000000000000000000000000..cccdd3d517fc5249beaefa600691cf150f2fa3e6
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/gradlew
@@ -0,0 +1,172 @@
+#!/usr/bin/env sh
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+ echo "$*"
+}
+
+die () {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+ NONSTOP* )
+ nonstop=true
+ ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=$((i+1))
+ done
+ case $i in
+ (0) set -- ;;
+ (1) set -- "$args0" ;;
+ (2) set -- "$args0" "$args1" ;;
+ (3) set -- "$args0" "$args1" "$args2" ;;
+ (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Escape application args
+save () {
+ for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+ echo " "
+}
+APP_ARGS=$(save "$@")
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
+if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
+ cd "$(dirname "$0")"
+fi
+
+exec "$JAVACMD" "$@"
diff --git a/deploy/lite/humanseg-android-demo/app/gradlew.bat b/deploy/lite/humanseg-android-demo/app/gradlew.bat
new file mode 100644
index 0000000000000000000000000000000000000000..f9553162f122c71b34635112e717c3e733b5b212
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/gradlew.bat
@@ -0,0 +1,84 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windows variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/deploy/lite/humanseg-android-demo/app/libs/PaddlePredictor.jar b/deploy/lite/humanseg-android-demo/app/libs/PaddlePredictor.jar
new file mode 100644
index 0000000000000000000000000000000000000000..037d569f712578c5cda766b1160654ea491115df
Binary files /dev/null and b/deploy/lite/humanseg-android-demo/app/libs/PaddlePredictor.jar differ
diff --git a/deploy/lite/humanseg-android-demo/app/proguard-rules.pro b/deploy/lite/humanseg-android-demo/app/proguard-rules.pro
new file mode 100644
index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/proguard-rules.pro
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
diff --git a/deploy/lite/humanseg-android-demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ExampleInstrumentedTest.java b/deploy/lite/humanseg-android-demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ExampleInstrumentedTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..353c3677e538eb76e29f22c68232fc68c4240729
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ExampleInstrumentedTest.java
@@ -0,0 +1,26 @@
+package com.baidu.paddle.lite.demo;
+
+import android.content.Context;
+import android.support.test.InstrumentationRegistry;
+import android.support.test.runner.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import static org.junit.Assert.*;
+
+/**
+ * Instrumented test, which will execute on an Android device.
+ *
+ * @see Testing documentation
+ */
+@RunWith(AndroidJUnit4.class)
+public class ExampleInstrumentedTest {
+ @Test
+ public void useAppContext() {
+ // Context of the app under test.
+ Context appContext = InstrumentationRegistry.getTargetContext();
+
+ assertEquals("com.baidu.paddle.lite.demo", appContext.getPackageName());
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/AndroidManifest.xml b/deploy/lite/humanseg-android-demo/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000000000000000000000000000000000..67e06269f4b2764034d4d7c400f1c93c1504fe6a
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/AndroidManifest.xml
@@ -0,0 +1,33 @@
+
+
+ * This technique can be used with an {@link android.app.Activity} class, not just
+ * {@link android.preference.PreferenceActivity}.
+ */
+public abstract class AppCompatPreferenceActivity extends PreferenceActivity {
+ private AppCompatDelegate mDelegate;
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ getDelegate().installViewFactory();
+ getDelegate().onCreate(savedInstanceState);
+ super.onCreate(savedInstanceState);
+ }
+
+ @Override
+ protected void onPostCreate(Bundle savedInstanceState) {
+ super.onPostCreate(savedInstanceState);
+ getDelegate().onPostCreate(savedInstanceState);
+ }
+
+ public ActionBar getSupportActionBar() {
+ return getDelegate().getSupportActionBar();
+ }
+
+ public void setSupportActionBar(@Nullable Toolbar toolbar) {
+ getDelegate().setSupportActionBar(toolbar);
+ }
+
+ @Override
+ public MenuInflater getMenuInflater() {
+ return getDelegate().getMenuInflater();
+ }
+
+ @Override
+ public void setContentView(@LayoutRes int layoutResID) {
+ getDelegate().setContentView(layoutResID);
+ }
+
+ @Override
+ public void setContentView(View view) {
+ getDelegate().setContentView(view);
+ }
+
+ @Override
+ public void setContentView(View view, ViewGroup.LayoutParams params) {
+ getDelegate().setContentView(view, params);
+ }
+
+ @Override
+ public void addContentView(View view, ViewGroup.LayoutParams params) {
+ getDelegate().addContentView(view, params);
+ }
+
+ @Override
+ protected void onPostResume() {
+ super.onPostResume();
+ getDelegate().onPostResume();
+ }
+
+ @Override
+ protected void onTitleChanged(CharSequence title, int color) {
+ super.onTitleChanged(title, color);
+ getDelegate().setTitle(title);
+ }
+
+ @Override
+ public void onConfigurationChanged(Configuration newConfig) {
+ super.onConfigurationChanged(newConfig);
+ getDelegate().onConfigurationChanged(newConfig);
+ }
+
+ @Override
+ protected void onStop() {
+ super.onStop();
+ getDelegate().onStop();
+ }
+
+ @Override
+ protected void onDestroy() {
+ super.onDestroy();
+ getDelegate().onDestroy();
+ }
+
+ public void invalidateOptionsMenu() {
+ getDelegate().invalidateOptionsMenu();
+ }
+
+ private AppCompatDelegate getDelegate() {
+ if (mDelegate == null) {
+ mDelegate = AppCompatDelegate.create(this, null);
+ }
+ return mDelegate;
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/CommonActivity.java b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/CommonActivity.java
new file mode 100644
index 0000000000000000000000000000000000000000..88146b3961e5f2c8ed366816e505144ba3ac9f6b
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/CommonActivity.java
@@ -0,0 +1,265 @@
+package com.baidu.paddle.lite.demo;
+
+import android.Manifest;
+import android.app.ProgressDialog;
+import android.content.ContentResolver;
+import android.content.Intent;
+import android.content.pm.PackageManager;
+import android.database.Cursor;
+import android.graphics.Bitmap;
+import android.graphics.BitmapFactory;
+import android.net.Uri;
+import android.os.Bundle;
+import android.os.Environment;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.Message;
+import android.provider.MediaStore;
+import android.support.annotation.NonNull;
+import android.support.v4.app.ActivityCompat;
+import android.support.v4.content.ContextCompat;
+import android.support.v4.content.FileProvider;
+import android.support.v7.app.ActionBar;
+import android.support.v7.app.AppCompatActivity;
+import android.util.Log;
+import android.view.Menu;
+import android.view.MenuInflater;
+import android.view.MenuItem;
+import android.widget.Toast;
+
+import java.io.File;
+import java.io.IOException;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+
+public class CommonActivity extends AppCompatActivity {
+ private static final String TAG = CommonActivity.class.getSimpleName();
+ public static final int OPEN_GALLERY_REQUEST_CODE = 0;
+ public static final int TAKE_PHOTO_REQUEST_CODE = 1;
+
+ public static final int REQUEST_LOAD_MODEL = 0;
+ public static final int REQUEST_RUN_MODEL = 1;
+ public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0;
+ public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
+ public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2;
+ public static final int RESPONSE_RUN_MODEL_FAILED = 3;
+
+ protected ProgressDialog pbLoadModel = null;
+ protected ProgressDialog pbRunModel = null;
+
+ protected Handler receiver = null; // receive messages from worker thread
+ protected Handler sender = null; // send command to worker thread
+ protected HandlerThread worker = null; // worker thread to load&run model
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ ActionBar supportActionBar = getSupportActionBar();
+ if (supportActionBar != null) {
+ supportActionBar.setDisplayHomeAsUpEnabled(true);
+ }
+
+ receiver = new Handler() {
+ @Override
+ public void handleMessage(Message msg) {
+ switch (msg.what) {
+ case RESPONSE_LOAD_MODEL_SUCCESSED:
+ pbLoadModel.dismiss();
+ onLoadModelSuccessed();
+ break;
+ case RESPONSE_LOAD_MODEL_FAILED:
+ pbLoadModel.dismiss();
+ Toast.makeText(CommonActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
+ onLoadModelFailed();
+ break;
+ case RESPONSE_RUN_MODEL_SUCCESSED:
+ pbRunModel.dismiss();
+ onRunModelSuccessed();
+ break;
+ case RESPONSE_RUN_MODEL_FAILED:
+ pbRunModel.dismiss();
+ Toast.makeText(CommonActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
+ onRunModelFailed();
+ break;
+ default:
+ break;
+ }
+ }
+ };
+
+ worker = new HandlerThread("Predictor Worker");
+ worker.start();
+ sender = new Handler(worker.getLooper()) {
+ public void handleMessage(Message msg) {
+ switch (msg.what) {
+ case REQUEST_LOAD_MODEL:
+ // load model and reload test image
+ if (onLoadModel()) {
+ receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_SUCCESSED);
+ } else {
+ receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_FAILED);
+ }
+ break;
+ case REQUEST_RUN_MODEL:
+ // run model if model is loaded
+ if (onRunModel()) {
+ receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_SUCCESSED);
+ } else {
+ receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_FAILED);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ };
+ }
+
+ public void loadModel() {
+ pbLoadModel = ProgressDialog.show(this, "", "Loading model...", false, false);
+ sender.sendEmptyMessage(REQUEST_LOAD_MODEL);
+ }
+
+ public void runModel() {
+ pbRunModel = ProgressDialog.show(this, "", "Running model...", false, false);
+ sender.sendEmptyMessage(REQUEST_RUN_MODEL);
+ }
+
+ public boolean onLoadModel() {
+ return true;
+ }
+
+ public boolean onRunModel() {
+ return true;
+ }
+
+ public void onLoadModelSuccessed() {
+ }
+
+ public void onLoadModelFailed() {
+ }
+
+ public void onRunModelSuccessed() {
+ }
+
+ public void onRunModelFailed() {
+ }
+
+ public void onImageChanged(Bitmap image) {
+ }
+
+ public void onImageChanged(String path) {
+
+ }
+ public void onSettingsClicked() {
+ }
+
+ @Override
+ public boolean onCreateOptionsMenu(Menu menu) {
+ MenuInflater inflater = getMenuInflater();
+ inflater.inflate(R.menu.menu_action_options, menu);
+ return true;
+ }
+
+ @Override
+ public boolean onOptionsItemSelected(MenuItem item) {
+ switch (item.getItemId()) {
+ case android.R.id.home:
+ finish();
+ break;
+ case R.id.open_gallery:
+ if (requestAllPermissions()) {
+ openGallery();
+ }
+ break;
+ case R.id.take_photo:
+ if (requestAllPermissions()) {
+ takePhoto();
+ }
+ break;
+ case R.id.settings:
+ if (requestAllPermissions()) {
+ // make sure we have SDCard r&w permissions to load model from SDCard
+ onSettingsClicked();
+ }
+ break;
+ }
+ return super.onOptionsItemSelected(item);
+ }
+
+ @Override
+ public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions,
+ @NonNull int[] grantResults) {
+ super.onRequestPermissionsResult(requestCode, permissions, grantResults);
+ if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) {
+ Toast.makeText(this, "Permission Denied", Toast.LENGTH_SHORT).show();
+ }
+ }
+
+ private boolean requestAllPermissions() {
+ if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE)
+ != PackageManager.PERMISSION_GRANTED || ContextCompat.checkSelfPermission(this,
+ Manifest.permission.CAMERA)
+ != PackageManager.PERMISSION_GRANTED) {
+ ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE,
+ Manifest.permission.CAMERA},
+ 0);
+ return false;
+ }
+ return true;
+ }
+
+ private void openGallery() {
+ Intent intent = new Intent(Intent.ACTION_PICK, null);
+ intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
+ startActivityForResult(intent, OPEN_GALLERY_REQUEST_CODE);
+ }
+
+ private void takePhoto() {
+ Intent takePhotoIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
+ if (takePhotoIntent.resolveActivity(getPackageManager()) != null) {
+ startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
+ }
+ }
+
+ @Override
+ protected void onActivityResult(int requestCode, int resultCode, Intent data) {
+ super.onActivityResult(requestCode, resultCode, data);
+ if (resultCode == RESULT_OK && data != null) {
+ switch (requestCode) {
+ case OPEN_GALLERY_REQUEST_CODE:
+ try {
+ ContentResolver resolver = getContentResolver();
+ Uri uri = data.getData();
+ Bitmap image = MediaStore.Images.Media.getBitmap(resolver, uri);
+ String[] proj = {MediaStore.Images.Media.DATA};
+ Cursor cursor = managedQuery(uri, proj, null, null, null);
+ cursor.moveToFirst();
+ onImageChanged(image);
+ } catch (IOException e) {
+ Log.e(TAG, e.toString());
+ }
+ break;
+
+ case TAKE_PHOTO_REQUEST_CODE:
+ Bitmap image = (Bitmap) data.getParcelableExtra("data");
+ onImageChanged(image);
+
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ @Override
+ protected void onResume() {
+ super.onResume();
+ }
+
+ @Override
+ protected void onDestroy() {
+ worker.quit();
+ super.onDestroy();
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/MainActivity.java b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/MainActivity.java
new file mode 100644
index 0000000000000000000000000000000000000000..00728f865a77e601ec60dc30d2f8dc047aa42472
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/MainActivity.java
@@ -0,0 +1,43 @@
+package com.baidu.paddle.lite.demo;
+
+import android.content.Intent;
+import android.content.SharedPreferences;
+import android.os.Bundle;
+import android.preference.PreferenceManager;
+import android.support.v7.app.AppCompatActivity;
+import android.util.Log;
+import android.view.View;
+
+import com.baidu.paddle.lite.demo.segmentation.ImgSegActivity;
+
+public class MainActivity extends AppCompatActivity implements View.OnClickListener {
+ private static final String TAG = MainActivity.class.getSimpleName();
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_main);
+
+ // clear all setting items to avoid app crashing due to the incorrect settings
+ SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this);
+ SharedPreferences.Editor editor = sharedPreferences.edit();
+ editor.clear();
+ editor.commit();
+ }
+
+ @Override
+ public void onClick(View v) {
+ switch (v.getId()) {
+ case R.id.v_img_seg: {
+ Intent intent = new Intent(MainActivity.this, ImgSegActivity.class);
+ startActivity(intent);
+ } break;
+ }
+ }
+
+ @Override
+ protected void onDestroy() {
+ super.onDestroy();
+ System.exit(0);
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/Predictor.java b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/Predictor.java
new file mode 100644
index 0000000000000000000000000000000000000000..27bd971017eba6bb52901a7e2aa1e0a8e3cf5ef0
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/Predictor.java
@@ -0,0 +1,143 @@
+package com.baidu.paddle.lite.demo;
+
+import android.content.Context;
+import android.util.Log;
+import com.baidu.paddle.lite.*;
+
+import java.util.ArrayList;
+import java.util.Date;
+
+public class Predictor {
+ private static final String TAG = Predictor.class.getSimpleName();
+
+ public boolean isLoaded = false;
+ public int warmupIterNum = 0;
+ public int inferIterNum = 1;
+ protected Context appCtx = null;
+ public int cpuThreadNum = 1;
+ public String cpuPowerMode = "LITE_POWER_HIGH";
+ public String modelPath = "";
+ public String modelName = "";
+ protected PaddlePredictor paddlePredictor = null;
+ protected float inferenceTime = 0;
+
+ public Predictor() {
+ }
+
+ public boolean init(Context appCtx, String modelPath, int cpuThreadNum, String cpuPowerMode) {
+ this.appCtx = appCtx;
+ isLoaded = loadModel(modelPath, cpuThreadNum, cpuPowerMode);
+ return isLoaded;
+ }
+
+ protected boolean loadModel(String modelPath, int cpuThreadNum, String cpuPowerMode) {
+ // release model if exists
+ releaseModel();
+
+ // load model
+ if (modelPath.isEmpty()) {
+ return false;
+ }
+ String realPath = modelPath;
+ if (!modelPath.substring(0, 1).equals("/")) {
+ // read model files from custom file_paths if the first character of mode file_paths is '/'
+ // otherwise copy model to cache from assets
+ realPath = appCtx.getCacheDir() + "/" + modelPath;
+ Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath);
+ }
+ if (realPath.isEmpty()) {
+ return false;
+ }
+ MobileConfig config = new MobileConfig();
+ config.setModelDir(realPath);
+ config.setThreads(cpuThreadNum);
+ if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_HIGH")) {
+ config.setPowerMode(PowerMode.LITE_POWER_HIGH);
+ } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_LOW")) {
+ config.setPowerMode(PowerMode.LITE_POWER_LOW);
+ } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_FULL")) {
+ config.setPowerMode(PowerMode.LITE_POWER_FULL);
+ } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_NO_BIND")) {
+ config.setPowerMode(PowerMode.LITE_POWER_NO_BIND);
+ } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_RAND_HIGH")) {
+ config.setPowerMode(PowerMode.LITE_POWER_RAND_HIGH);
+ } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_RAND_LOW")) {
+ config.setPowerMode(PowerMode.LITE_POWER_RAND_LOW);
+ } else {
+ Log.e(TAG, "unknown cpu power mode!");
+ return false;
+ }
+ paddlePredictor = PaddlePredictor.createPaddlePredictor(config);
+
+ this.cpuThreadNum = cpuThreadNum;
+ this.cpuPowerMode = cpuPowerMode;
+ this.modelPath = realPath;
+ this.modelName = realPath.substring(realPath.lastIndexOf("/") + 1);
+ return true;
+ }
+
+ public void releaseModel() {
+ paddlePredictor = null;
+ isLoaded = false;
+ cpuThreadNum = 1;
+ cpuPowerMode = "LITE_POWER_HIGH";
+ modelPath = "";
+ modelName = "";
+ }
+
+ public Tensor getInput(int idx) {
+ if (!isLoaded()) {
+ return null;
+ }
+ return paddlePredictor.getInput(idx);
+ }
+
+ public Tensor getOutput(int idx) {
+ if (!isLoaded()) {
+ return null;
+ }
+ return paddlePredictor.getOutput(idx);
+ }
+
+ public boolean runModel() {
+ if (!isLoaded()) {
+ return false;
+ }
+ // warm up
+ for (int i = 0; i < warmupIterNum; i++){
+ paddlePredictor.run();
+ }
+ // inference
+ Date start = new Date();
+ for (int i = 0; i < inferIterNum; i++) {
+ paddlePredictor.run();
+ }
+ Date end = new Date();
+ inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum;
+ return true;
+ }
+
+ public boolean isLoaded() {
+ return paddlePredictor != null && isLoaded;
+ }
+
+ public String modelPath() {
+ return modelPath;
+ }
+
+ public String modelName() {
+ return modelName;
+ }
+
+ public int cpuThreadNum() {
+ return cpuThreadNum;
+ }
+
+ public String cpuPowerMode() {
+ return cpuPowerMode;
+ }
+
+ public float inferenceTime() {
+ return inferenceTime;
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/Utils.java b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/Utils.java
new file mode 100644
index 0000000000000000000000000000000000000000..a8b252365d05313d847d4ccd491fb44596f31227
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/Utils.java
@@ -0,0 +1,87 @@
+package com.baidu.paddle.lite.demo;
+
+import android.content.Context;
+import android.os.Environment;
+
+import java.io.*;
+
+public class Utils {
+ private static final String TAG = Utils.class.getSimpleName();
+
+ public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
+ if (srcPath.isEmpty() || dstPath.isEmpty()) {
+ return;
+ }
+ InputStream is = null;
+ OutputStream os = null;
+ try {
+ is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
+ os = new BufferedOutputStream(new FileOutputStream(new File(dstPath)));
+ byte[] buffer = new byte[1024];
+ int length = 0;
+ while ((length = is.read(buffer)) != -1) {
+ os.write(buffer, 0, length);
+ }
+ } catch (FileNotFoundException e) {
+ e.printStackTrace();
+ } catch (IOException e) {
+ e.printStackTrace();
+ } finally {
+ try {
+ os.close();
+ is.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
+ if (srcDir.isEmpty() || dstDir.isEmpty()) {
+ return;
+ }
+ try {
+ if (!new File(dstDir).exists()) {
+ new File(dstDir).mkdirs();
+ }
+ for (String fileName : appCtx.getAssets().list(srcDir)) {
+ String srcSubPath = srcDir + File.separator + fileName;
+ String dstSubPath = dstDir + File.separator + fileName;
+ if (new File(srcSubPath).isDirectory()) {
+ copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
+ } else {
+ copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
+ }
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ public static float[] parseFloatsFromString(String string, String delimiter) {
+ String[] pieces = string.trim().toLowerCase().split(delimiter);
+ float[] floats = new float[pieces.length];
+ for (int i = 0; i < pieces.length; i++) {
+ floats[i] = Float.parseFloat(pieces[i].trim());
+ }
+ return floats;
+ }
+
+ public static long[] parseLongsFromString(String string, String delimiter) {
+ String[] pieces = string.trim().toLowerCase().split(delimiter);
+ long[] longs = new long[pieces.length];
+ for (int i = 0; i < pieces.length; i++) {
+ longs[i] = Long.parseLong(pieces[i].trim());
+ }
+ return longs;
+ }
+
+ public static String getSDCardDirectory() {
+ return Environment.getExternalStorageDirectory().getAbsolutePath();
+ }
+
+ public static boolean isSupportedNPU() {
+ String hardware = android.os.Build.HARDWARE;
+ return hardware.equalsIgnoreCase("kirin810") || hardware.equalsIgnoreCase("kirin990");
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/segmentation/ImgSegActivity.java b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/segmentation/ImgSegActivity.java
new file mode 100644
index 0000000000000000000000000000000000000000..d18895aedb892405783c030167cb3e9d1ed2d304
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/segmentation/ImgSegActivity.java
@@ -0,0 +1,210 @@
+package com.baidu.paddle.lite.demo.segmentation;
+
+import android.content.Intent;
+import android.content.SharedPreferences;
+import android.graphics.Bitmap;
+import android.graphics.BitmapFactory;
+import android.os.Bundle;
+import android.preference.PreferenceManager;
+import android.text.method.ScrollingMovementMethod;
+import android.util.Log;
+import android.view.Menu;
+import android.widget.ImageView;
+import android.widget.TextView;
+import android.widget.Toast;
+
+import com.baidu.paddle.lite.demo.CommonActivity;
+import com.baidu.paddle.lite.demo.R;
+import com.baidu.paddle.lite.demo.Utils;
+import com.baidu.paddle.lite.demo.segmentation.config.Config;
+import com.baidu.paddle.lite.demo.segmentation.preprocess.Preprocess;
+import com.baidu.paddle.lite.demo.segmentation.visual.Visualize;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+
+public class ImgSegActivity extends CommonActivity {
+ private static final String TAG = ImgSegActivity.class.getSimpleName();
+
+ protected TextView tvInputSetting;
+ protected ImageView ivInputImage;
+ protected TextView tvOutputResult;
+ protected TextView tvInferenceTime;
+
+ // model config
+ Config config = new Config();
+
+ protected ImgSegPredictor predictor = new ImgSegPredictor();
+
+ Preprocess preprocess = new Preprocess();
+
+ Visualize visualize = new Visualize();
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_img_seg);
+ tvInputSetting = findViewById(R.id.tv_input_setting);
+ ivInputImage = findViewById(R.id.iv_input_image);
+ tvInferenceTime = findViewById(R.id.tv_inference_time);
+ tvOutputResult = findViewById(R.id.tv_output_result);
+ tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance());
+ tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance());
+ }
+
+ @Override
+ public boolean onLoadModel() {
+ return super.onLoadModel() && predictor.init(ImgSegActivity.this, config);
+ }
+
+ @Override
+ public boolean onRunModel() {
+ return super.onRunModel() && predictor.isLoaded() && predictor.runModel(preprocess,visualize);
+ }
+
+ @Override
+ public void onLoadModelSuccessed() {
+ super.onLoadModelSuccessed();
+ // load test image from file_paths and run model
+ try {
+ if (config.imagePath.isEmpty()) {
+ return;
+ }
+ Bitmap image = null;
+ // read test image file from custom file_paths if the first character of mode file_paths is '/', otherwise read test
+ // image file from assets
+ if (!config.imagePath.substring(0, 1).equals("/")) {
+ InputStream imageStream = getAssets().open(config.imagePath);
+ image = BitmapFactory.decodeStream(imageStream);
+ } else {
+ if (!new File(config.imagePath).exists()) {
+ return;
+ }
+ image = BitmapFactory.decodeFile(config.imagePath);
+ }
+ if (image != null && predictor.isLoaded()) {
+ predictor.setInputImage(image);
+ runModel();
+ }
+ } catch (IOException e) {
+ Toast.makeText(ImgSegActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onLoadModelFailed() {
+ super.onLoadModelFailed();
+ }
+
+ @Override
+ public void onRunModelSuccessed() {
+ super.onRunModelSuccessed();
+ // obtain results and update UI
+ tvInferenceTime.setText("Inference time: " + predictor.inferenceTime() + " ms");
+ Bitmap outputImage = predictor.outputImage();
+ if (outputImage != null) {
+ ivInputImage.setImageBitmap(outputImage);
+ }
+ tvOutputResult.setText(predictor.outputResult());
+ tvOutputResult.scrollTo(0, 0);
+ }
+
+ @Override
+ public void onRunModelFailed() {
+ super.onRunModelFailed();
+ }
+
+ @Override
+ public void onImageChanged(Bitmap image) {
+ super.onImageChanged(image);
+ // rerun model if users pick test image from gallery or camera
+ if (image != null && predictor.isLoaded()) {
+// predictor.setConfig(config);
+ predictor.setInputImage(image);
+ runModel();
+ }
+ }
+
+ @Override
+ public void onImageChanged(String path) {
+ super.onImageChanged(path);
+ Bitmap image = BitmapFactory.decodeFile(path);
+ predictor.setInputImage(image);
+ runModel();
+ }
+ public void onSettingsClicked() {
+ super.onSettingsClicked();
+ startActivity(new Intent(ImgSegActivity.this, ImgSegSettingsActivity.class));
+ }
+
+ @Override
+ public boolean onPrepareOptionsMenu(Menu menu) {
+ boolean isLoaded = predictor.isLoaded();
+ menu.findItem(R.id.open_gallery).setEnabled(isLoaded);
+ menu.findItem(R.id.take_photo).setEnabled(isLoaded);
+ return super.onPrepareOptionsMenu(menu);
+ }
+
+ @Override
+ protected void onResume() {
+ Log.i(TAG,"begin onResume");
+ super.onResume();
+
+ SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this);
+ boolean settingsChanged = false;
+ String model_path = sharedPreferences.getString(getString(R.string.ISG_MODEL_PATH_KEY),
+ getString(R.string.ISG_MODEL_PATH_DEFAULT));
+ String label_path = sharedPreferences.getString(getString(R.string.ISG_LABEL_PATH_KEY),
+ getString(R.string.ISG_LABEL_PATH_DEFAULT));
+ String image_path = sharedPreferences.getString(getString(R.string.ISG_IMAGE_PATH_KEY),
+ getString(R.string.ISG_IMAGE_PATH_DEFAULT));
+ settingsChanged |= !model_path.equalsIgnoreCase(config.modelPath);
+ settingsChanged |= !label_path.equalsIgnoreCase(config.labelPath);
+ settingsChanged |= !image_path.equalsIgnoreCase(config.imagePath);
+ int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.ISG_CPU_THREAD_NUM_KEY),
+ getString(R.string.ISG_CPU_THREAD_NUM_DEFAULT)));
+ settingsChanged |= cpu_thread_num != config.cpuThreadNum;
+ String cpu_power_mode =
+ sharedPreferences.getString(getString(R.string.ISG_CPU_POWER_MODE_KEY),
+ getString(R.string.ISG_CPU_POWER_MODE_DEFAULT));
+ settingsChanged |= !cpu_power_mode.equalsIgnoreCase(config.cpuPowerMode);
+ String input_color_format =
+ sharedPreferences.getString(getString(R.string.ISG_INPUT_COLOR_FORMAT_KEY),
+ getString(R.string.ISG_INPUT_COLOR_FORMAT_DEFAULT));
+ settingsChanged |= !input_color_format.equalsIgnoreCase(config.inputColorFormat);
+ long[] input_shape =
+ Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.ISG_INPUT_SHAPE_KEY),
+ getString(R.string.ISG_INPUT_SHAPE_DEFAULT)), ",");
+
+ settingsChanged |= input_shape.length != config.inputShape.length;
+
+ if (!settingsChanged) {
+ for (int i = 0; i < input_shape.length; i++) {
+ settingsChanged |= input_shape[i] != config.inputShape[i];
+ }
+ }
+
+ if (settingsChanged) {
+ config.init(model_path,label_path,image_path,cpu_thread_num,cpu_power_mode,
+ input_color_format,input_shape);
+ preprocess.init(config);
+ // update UI
+ tvInputSetting.setText("Model: " + config.modelPath.substring(config.modelPath.lastIndexOf("/") + 1) + "\n" + "CPU" +
+ " Thread Num: " + Integer.toString(config.cpuThreadNum) + "\n" + "CPU Power Mode: " + config.cpuPowerMode);
+ tvInputSetting.scrollTo(0, 0);
+ // reload model if configure has been changed
+ loadModel();
+ }
+ }
+
+ @Override
+ protected void onDestroy() {
+ if (predictor != null) {
+ predictor.releaseModel();
+ }
+ super.onDestroy();
+ }
+}
diff --git a/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/segmentation/ImgSegPredictor.java b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/segmentation/ImgSegPredictor.java
new file mode 100644
index 0000000000000000000000000000000000000000..717e086adf078a2eea69bf3fc720af8c233fd9a3
--- /dev/null
+++ b/deploy/lite/humanseg-android-demo/app/src/main/java/com/baidu/paddle/lite/demo/segmentation/ImgSegPredictor.java
@@ -0,0 +1,179 @@
+package com.baidu.paddle.lite.demo.segmentation;
+
+import android.content.Context;
+import android.graphics.Bitmap;
+import android.util.Log;
+
+import com.baidu.paddle.lite.Tensor;
+import com.baidu.paddle.lite.demo.Predictor;
+import com.baidu.paddle.lite.demo.segmentation.config.Config;
+import com.baidu.paddle.lite.demo.segmentation.preprocess.Preprocess;
+import com.baidu.paddle.lite.demo.segmentation.visual.Visualize;
+
+import java.io.InputStream;
+import java.util.Date;
+import java.util.Vector;
+
+import static android.graphics.Color.blue;
+import static android.graphics.Color.green;
+import static android.graphics.Color.red;
+
+public class ImgSegPredictor extends Predictor {
+ private static final String TAG = ImgSegPredictor.class.getSimpleName();
+ protected Vector
+
+
+## `STAGE2.NUM_CHANNELS`
+
+HRNet在第二阶段各个分支的通道数
+
+### 默认值
+
+[40, 80]
+
+
+
+
+## `STAGE3.NUM_MODULES`
+
+HRNet在第三阶段执行modularized block的重复次数
+
+### 默认值
+
+4
+
+
+
+
+## `STAGE3.NUM_CHANNELS`
+
+HRNet在第三阶段各个分支的通道数
+
+### 默认值
+
+[40, 80, 160]
+
+
+
+
+## `STAGE4.NUM_MODULES`
+
+HRNet在第四阶段执行modularized block的重复次数
+
+### 默认值
+
+3
+
+
+
+
+## `STAGE4.NUM_CHANNELS`
+
+HRNet在第四阶段各个分支的通道数
+
+### 默认值
+
+[40, 80, 160, 320]
+
+
+
\ No newline at end of file
diff --git a/docs/model_zoo.md b/docs/model_zoo.md
index a591542cc31379b3df75829173d9cf63a4ae69c1..7e625db73a5ae185b8db00e8dd6f04e26d4e11e5 100644
--- a/docs/model_zoo.md
+++ b/docs/model_zoo.md
@@ -22,6 +22,16 @@ PaddleSeg对所有内置的分割模型都提供了公开数据集下的预训
| Xception65 | ImageNet | [Xception65_pretrained.tgz](https://paddleseg.bj.bcebos.com/models/Xception65_pretrained.tgz) | 80.32%/94.47% |
| Xception71 | ImageNet | coming soon | -- |
+| 模型 | 数据集合 | 下载地址 | Accuray Top1/5 Error |
+|---|---|---|---|
+| HRNet_W18 | ImageNet | [hrnet_w18_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w18_imagenet.tar) | 76.92%/93.39% |
+| HRNet_W30 | ImageNet | [hrnet_w30_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar) | 78.04%/94.02% |
+| HRNet_W32 | ImageNet | [hrnet_w32_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar) | 78.28%/94.24% |
+| HRNet_W40 | ImageNet | [hrnet_w40_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar) | 78.77%/94.47% |
+| HRNet_W44 | ImageNet | [hrnet_w44_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w44_imagenet.tar) | 79.00%/94.51% |
+| HRNet_W48 | ImageNet | [hrnet_w48_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w48_imagenet.tar) | 78.95%/94.42% |
+| HRNet_W64 | ImageNet | [hrnet_w64_imagenet.tar](https://paddleseg.bj.bcebos.com/models/hrnet_w64_imagenet.tar) | 79.30%/94.61% |
+
## COCO预训练模型
数据集为COCO实例分割数据集合转换成的语义分割数据集合
@@ -46,3 +56,4 @@ train数据集合为Cityscapes训练集合,测试为Cityscapes的验证集合
| ICNet/bn | Cityscapes |[icnet_cityscapes.tgz](https://paddleseg.bj.bcebos.com/models/icnet_cityscapes.tar.gz) |16|false| 0.6831 |
| PSPNet/bn | Cityscapes |[pspnet50_cityscapes.tgz](https://paddleseg.bj.bcebos.com/models/pspnet50_cityscapes.tgz) |16|false| 0.7013 |
| PSPNet/bn | Cityscapes |[pspnet101_cityscapes.tgz](https://paddleseg.bj.bcebos.com/models/pspnet101_cityscapes.tgz) |16|false| 0.7734 |
+| HRNet_W18/bn | Cityscapes |[hrnet_w18_bn_cityscapes.tgz](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 4 | false | 0.7936 |
diff --git a/pdseg/models/model_builder.py b/pdseg/models/model_builder.py
index 12805164832422189a8b33be1d8b24b983767dcc..56a959a9c20e1322ee5e906008e8bdcf392bd044 100644
--- a/pdseg/models/model_builder.py
+++ b/pdseg/models/model_builder.py
@@ -112,6 +112,7 @@ def softmax(logit):
logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
return logit
+
def sigmoid_to_softmax(logit):
"""
one channel to two channel
@@ -143,19 +144,23 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
# 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程
# 预测部署时只须对输入图像增加batch_size维度即可
if ModelPhase.is_predict(phase):
- origin_image = fluid.layers.data(name='image',
- shape=[ -1, 1, 1, cfg.DATASET.DATA_DIM],
- dtype='float32',
- append_batch_size=False)
+ origin_image = fluid.layers.data(
+ name='image',
+ shape=[-1, 1, 1, cfg.DATASET.DATA_DIM],
+ dtype='float32',
+ append_batch_size=False)
image = fluid.layers.transpose(origin_image, [0, 3, 1, 2])
origin_shape = fluid.layers.shape(image)[-2:]
mean = np.array(cfg.MEAN).reshape(1, len(cfg.MEAN), 1, 1)
mean = fluid.layers.assign(mean.astype('float32'))
std = np.array(cfg.STD).reshape(1, len(cfg.STD), 1, 1)
std = fluid.layers.assign(std.astype('float32'))
- image = (image/255 - mean)/std
- image = fluid.layers.resize_bilinear(image,
- out_shape=[height, width], align_corners=False, align_mode=0)
+ image = fluid.layers.resize_bilinear(
+ image,
+ out_shape=[height, width],
+ align_corners=False,
+ align_mode=0)
+ image = (image / 255 - mean) / std
else:
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
@@ -180,15 +185,20 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
loss_type = list(loss_type)
# dice_loss或bce_loss只适用两类分割中
- if class_num > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)):
- raise Exception("dice loss and bce loss is only applicable to binary classfication")
-
+ if class_num > 2 and (("dice_loss" in loss_type) or
+ ("bce_loss" in loss_type)):
+ raise Exception(
+ "dice loss and bce loss is only applicable to binary classfication"
+ )
+
# 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
if ("dice_loss" in loss_type) or ("bce_loss" in loss_type):
class_num = 1
if "softmax_loss" in loss_type:
- raise Exception("softmax loss can not combine with dice loss or bce loss")
-
+ raise Exception(
+ "softmax loss can not combine with dice loss or bce loss"
+ )
+
logits = model_func(image, class_num)
# 根据选择的loss函数计算相应的损失函数
@@ -196,9 +206,9 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
loss_valid = False
avg_loss_list = []
valid_loss = []
- if "softmax_loss" in loss_type:
- avg_loss_list.append(multi_softmax_with_loss(logits,
- label, mask,class_num))
+ if "softmax_loss" in loss_type:
+ avg_loss_list.append(
+ multi_softmax_with_loss(logits, label, mask, class_num))
loss_valid = True
valid_loss.append("softmax_loss")
if "dice_loss" in loss_type:
@@ -210,13 +220,17 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
loss_valid = True
valid_loss.append("bce_loss")
if not loss_valid:
- raise Exception("SOLVER.LOSS: {} is set wrong. it should "
- "include one of (softmax_loss, bce_loss, dice_loss) at least"
- " example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss']".format(cfg.SOLVER.LOSS))
-
+ raise Exception(
+ "SOLVER.LOSS: {} is set wrong. it should "
+ "include one of (softmax_loss, bce_loss, dice_loss) at least"
+ " example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss']"
+ .format(cfg.SOLVER.LOSS))
+
invalid_loss = [x for x in loss_type if x not in valid_loss]
if len(invalid_loss) > 0:
- print("Warning: the loss {} you set is invalid. it will not be included in loss computed.".format(invalid_loss))
+ print(
+ "Warning: the loss {} you set is invalid. it will not be included in loss computed."
+ .format(invalid_loss))
avg_loss = 0
for i in range(0, len(avg_loss_list)):
@@ -238,7 +252,11 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
logit = sigmoid_to_softmax(logit)
else:
logit = softmax(logit)
- logit = fluid.layers.resize_bilinear(logit, out_shape=origin_shape, align_corners=False, align_mode=0)
+ logit = fluid.layers.resize_bilinear(
+ logit,
+ out_shape=origin_shape,
+ align_corners=False,
+ align_mode=0)
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
logit = fluid.layers.argmax(logit, axis=3)
return origin_image, logit
diff --git a/pdseg/models/modeling/hrnet.py b/pdseg/models/modeling/hrnet.py
index 36ca5eaed70720943a666cb5690416eed8077bfd..741834e157105b233403772f2672ed60aafc488f 100644
--- a/pdseg/models/modeling/hrnet.py
+++ b/pdseg/models/modeling/hrnet.py
@@ -146,7 +146,7 @@ def layer1(input, name=None):
name=name + '_' + str(i + 1))
return conv
-def highResolutionNet(input, num_classes):
+def high_resolution_net(input, num_classes):
channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS
channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS
@@ -198,7 +198,7 @@ def highResolutionNet(input, num_classes):
def hrnet(input, num_classes):
- logit = highResolutionNet(input, num_classes)
+ logit = high_resolution_net(input, num_classes)
return logit
if __name__ == '__main__':
diff --git a/pretrained_model/download_model.py b/pretrained_model/download_model.py
index 6d7c265f6514ee3d16aa8e010ba9b071031ef07b..12b01472457bd25e22005141b21bb9d3014bf4fe 100644
--- a/pretrained_model/download_model.py
+++ b/pretrained_model/download_model.py
@@ -37,6 +37,20 @@ model_urls = {
"https://paddleseg.bj.bcebos.com/models/Xception41_pretrained.tgz",
"xception65_imagenet":
"https://paddleseg.bj.bcebos.com/models/Xception65_pretrained.tgz",
+ "hrnet_w18_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w18_imagenet.tar",
+ "hrnet_w30_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar",
+ "hrnet_w32_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar" ,
+ "hrnet_w40_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar",
+ "hrnet_w44_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w44_imagenet.tar",
+ "hrnet_w48_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w48_imagenet.tar",
+ "hrnet_w64_bn_imagenet":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w64_imagenet.tar",
# COCO pretrained
"deeplabv3p_mobilenetv2-1-0_bn_coco":
@@ -65,6 +79,8 @@ model_urls = {
"https://paddleseg.bj.bcebos.com/models/pspnet50_cityscapes.tgz",
"pspnet101_bn_cityscapes":
"https://paddleseg.bj.bcebos.com/models/pspnet101_cityscapes.tgz",
+ "hrnet_w18_bn_cityscapes":
+ "https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz",
}
if __name__ == "__main__":
diff --git a/turtorial/finetune_hrnet.md b/turtorial/finetune_hrnet.md
new file mode 100644
index 0000000000000000000000000000000000000000..f7feb9ddafd909fa829cf5f3e3d1c66c82505f57
--- /dev/null
+++ b/turtorial/finetune_hrnet.md
@@ -0,0 +1,131 @@
+# HRNet模型训练教程
+
+* 本教程旨在介绍如何通过使用PaddleSeg提供的 ***`HRNet`*** 预训练模型在自定义数据集上进行训练。
+
+* 在阅读本教程前,请确保您已经了解过PaddleSeg的[快速入门](../README.md#快速入门)和[基础功能](../README.md#基础功能)等章节,以便对PaddleSeg有一定的了解
+
+* 本教程的所有命令都基于PaddleSeg主目录进行执行
+
+## 一. 准备待训练数据
+
+我们提前准备好了一份数据集,通过以下代码进行下载
+
+```shell
+python dataset/download_pet.py
+```
+
+## 二. 下载预训练模型
+
+关于PaddleSeg支持的所有预训练模型的列表,我们可以从[模型组合](#模型组合)中查看我们所需模型的名字和配置
+
+接着下载对应的预训练模型
+
+```shell
+python pretrained_model/download_model.py hrnet_w18_bn_cityscapes
+```
+
+## 三. 准备配置
+
+接着我们需要确定相关配置,从本教程的角度,配置分为三部分:
+
+* 数据集
+ * 训练集主目录
+ * 训练集文件列表
+ * 测试集文件列表
+ * 评估集文件列表
+* 预训练模型
+ * 预训练模型名称
+ * 预训练模型各阶段通道数设置
+ * 预训练模型的Normalization类型
+ * 预训练模型路径
+* 其他
+ * 学习率
+ * Batch大小
+ * ...
+
+在三者中,预训练模型的配置尤为重要,如果模型配置错误,会导致预训练的参数没有加载,进而影响收敛速度。预训练模型相关的配置如第二步所展示。
+
+数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet`中
+
+其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为**configs/hrnet_w18_pet.yaml**
+
+```yaml
+# 数据集配置
+DATASET:
+ DATA_DIR: "./dataset/mini_pet/"
+ NUM_CLASSES: 3
+ TEST_FILE_LIST: "./dataset/mini_pet/file_list/test_list.txt"
+ TRAIN_FILE_LIST: "./dataset/mini_pet/file_list/train_list.txt"
+ VAL_FILE_LIST: "./dataset/mini_pet/file_list/val_list.txt"
+ VIS_FILE_LIST: "./dataset/mini_pet/file_list/test_list.txt"
+
+# 预训练模型配置
+MODEL:
+ MODEL_NAME: "hrnet"
+ DEFAULT_NORM_TYPE: "bn"
+ HRNET:
+ STAGE2:
+ NUM_CHANNELS: [18, 36]
+ STAGE3:
+ NUM_CHANNELS: [18, 36, 72]
+ STAGE4:
+ NUM_CHANNELS: [18, 36, 72, 144]
+
+# 其他配置
+TRAIN_CROP_SIZE: (512, 512)
+EVAL_CROP_SIZE: (512, 512)
+AUG:
+ AUG_METHOD: "unpadding"
+ FIX_RESIZE_SIZE: (512, 512)
+BATCH_SIZE: 4
+TRAIN:
+ PRETRAINED_MODEL_DIR: "./pretrained_model/hrnet_w18_bn_cityscapes/"
+ MODEL_SAVE_DIR: "./saved_model/hrnet_w18_bn_pet/"
+ SNAPSHOT_EPOCH: 10
+TEST:
+ TEST_MODEL: "./saved_model/hrnet_w18_bn_pet/final"
+SOLVER:
+ NUM_EPOCHS: 100
+ LR: 0.005
+ LR_POLICY: "poly"
+ OPTIMIZER: "sgd"
+```
+
+## 四. 配置/数据校验
+
+在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程
+
+```shell
+python pdseg/check.py --cfg ./configs/hrnet_w18_pet.yaml
+```
+
+
+## 五. 开始训练
+
+校验通过后,使用下述命令启动训练
+
+```shell
+python pdseg/train.py --use_gpu --cfg ./configs/hrnet_w18_pet.yaml
+```
+
+## 六. 进行评估
+
+模型训练完成,使用下述命令启动评估
+
+```shell
+python pdseg/eval.py --use_gpu --cfg ./configs/hrnet_w18_pet.yaml
+```
+
+## 模型组合
+
+|预训练模型名称|BackBone|Norm Type|数据集|配置|
+|-|-|-|-|-|
+|hrnet_w18_bn_cityscapes|-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [18, 36]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [18, 36, 72]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [18, 36, 72, 144]
MODEL.DEFAULT_NORM_TYPE: bn|
+| hrnet_w18_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [18, 36]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [18, 36, 72]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [18, 36, 72, 144]
MODEL.DEFAULT_NORM_TYPE: bn |
+| hrnet_w30_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [30, 60]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [30, 60, 120]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [30, 60, 120, 240]
MODEL.DEFAULT_NORM_TYPE: bn |
+| hrnet_w32_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [32, 64]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [32, 64, 128]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [32, 64, 128, 256]
MODEL.DEFAULT_NORM_TYPE: bn |
+| hrnet_w40_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [40, 80]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [40, 80, 160]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [40, 80, 160, 320]
MODEL.DEFAULT_NORM_TYPE: bn |
+| hrnet_w44_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [44, 88]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [44, 88, 176]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [44, 88, 176, 352]
MODEL.DEFAULT_NORM_TYPE: bn |
+| hrnet_w48_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [48, 96]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [48, 96, 192]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [48, 96, 192, 384]
MODEL.DEFAULT_NORM_TYPE: bn |
+| hrnet_w64_bn_imagenet |-|bn| ImageNet | MODEL.MODEL_NAME: hrnet
MODEL.HRNET.STAGE2.NUM_CHANNELS: [64, 128]
MODEL.HRNET.STAGE3.NUM_CHANNELS: [64, 128, 256]
MODEL.HRNET.STAGE4.NUM_CHANNELS: [64, 128, 256, 512]
MODEL.DEFAULT_NORM_TYPE: bn |
+