From 2b5d4f7d404d728bf6439bc65f2d6a9006c28d14 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 13 Oct 2017 13:53:49 +0800 Subject: [PATCH] Add three trained image classification models on ImageNet. --- image_classification/README.md | 5 + image_classification/index.html | 292 +++++++++++++++++++++++++ image_classification/model_download.sh | 55 +++++ 3 files changed, 352 insertions(+) create mode 100644 image_classification/index.html create mode 100644 image_classification/model_download.sh diff --git a/image_classification/README.md b/image_classification/README.md index 94a0a1b7..2414321c 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -221,3 +221,8 @@ for file_name, result in zip(file_list, lab): ``` 首先从文件中加载训练好的模型(代码里以第10轮迭代的结果为例),然后读取`image_list_file`中的图像。`image_list_file`是一个文本文件,每一行为一个图像路径。代码使用`paddle.infer`判断`image_list_file`中每个图像的类别,并进行输出。 + +## 使用预训练模型 +为方便进行测试和fine-tuning,我们提供了一些对应于示例中模型配置的预训练模型,目前包括ResNet50、ResNet101和Vgg16这几种模型,并提供脚本`model_download.sh`进行模型下载,如下载ResNet50可执行"`sh model_download.sh ResNet50`",完成后`Paddle_ResNet50.tar.gz`即是相应模型,可参照示例代码进行加载。 + +需要注意,模型压缩包中所含各文件名对应了模型中的各参数名,这是模型参数加载的依据,所以需要保证网络配置中的参数名能够正确对应到相应的文件。这里提供的模型均使用了示例代码中的配置,如需修改网络配置并使用提供的模型请多加注意。 diff --git a/image_classification/index.html b/image_classification/index.html new file mode 100644 index 00000000..2a9dea2e --- /dev/null +++ b/image_classification/index.html @@ -0,0 +1,292 @@ + + + + + + + + + + + + + + + + + +
+
+ + + + + + + diff --git a/image_classification/model_download.sh b/image_classification/model_download.sh new file mode 100644 index 00000000..2a73c56e --- /dev/null +++ b/image_classification/model_download.sh @@ -0,0 +1,55 @@ +#! /usr/bin/env bash + +function download() { + URL=$1 + MD5=$2 + TARGET=$3 + + if [ -e $TARGET ]; then + md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'` + if [ $MD5 == $md5_result ]; then + echo "$TARGET already exists, download skipped." + return 0 + fi + fi + + wget -c $URL -O "$TARGET" + if [ $? -ne 0 ]; then + return 1 + fi + + md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'` + if [ ! $MD5 == $md5_result ]; then + return 1 + fi +} + +case "$1" in + "ResNet50") + URL="http://cloud.dlnel.org/filepub/?uuid=f63f237a-698e-4a22-9782-baf5bb183019" + MD5="eb4d7b5962c9954340207788af0d6967" + ;; + "ResNet101") + URL="http://cloud.dlnel.org/filepub/?uuid=3d5fb996-83d0-4745-8adc-13ee960fc55c" + MD5="7e71f24998aa8e434fa164a7c4fc9c02" + ;; + "Vgg16") + URL="http://cloud.dlnel.org/filepub/?uuid=aa0e397e-474a-4cc1-bd8f-65a214039c2e" + MD5="e73dc42507e6acd3a8b8087f66a9f395" + ;; + *) + echo "The "$1" model is not provided currently." + exit 1 + ;; +esac +TARGET="Paddle_"$1".tar.gz" + +echo "Download "$1" model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the model!" + exit 1 +fi + + +exit 0 -- GitLab