Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
f8b6ded0
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f8b6ded0
编写于
3月 19, 2021
作者:
E
Edward Yang
提交者:
GitHub
3月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Spinalnet to Release/v2.0 from v2.0rc. (#1326)
* Added spinalnet gemstone classification module
上级
f1d51561
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
1193 addition
and
0 deletion
+1193
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/README.md
...dparty/image/classification/SpinalNet_Gemstones/README.md
+118
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/gem_dataset.py
...y/image/classification/SpinalNet_Gemstones/gem_dataset.py
+53
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/README.md
...n/SpinalNet_Gemstones/spinalnet_res101_gemstone/README.md
+21
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/label_list.txt
...nalNet_Gemstones/spinalnet_res101_gemstone/label_list.txt
+87
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py
...n/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py
+255
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/README.md
...on/SpinalNet_Gemstones/spinalnet_res50_gemstone/README.md
+21
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/label_list.txt
...inalNet_Gemstones/spinalnet_res50_gemstone/label_list.txt
+87
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py
...on/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py
+255
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/README.md
...on/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/README.md
+21
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/label_list.txt
...inalNet_Gemstones/spinalnet_vgg16_gemstone/label_list.txt
+87
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py
...on/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py
+188
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Cats Eye/cats_eye_3.jpg
...on/SpinalNet_Gemstones/testImages/Cats Eye/cats_eye_3.jpg
+0
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Fluorite/fluorite_18.jpg
...n/SpinalNet_Gemstones/testImages/Fluorite/fluorite_18.jpg
+0
-0
modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Kunzite/kunzite_28.jpg
...ion/SpinalNet_Gemstones/testImages/Kunzite/kunzite_28.jpg
+0
-0
未找到文件。
modules/thirdparty/image/classification/SpinalNet_Gemstones/README.md
0 → 100644
浏览文件 @
f8b6ded0
# PaddleHub SpinalNet
本示例将展示如何使用PaddleHub的SpinalNet预训练模型进行宝石识别或finetune并完成宝石的预测任务。
## 1. 首先要安装PaddleHub2.0版
```
shell
$pip
install
-U
paddlehub
==
2.0.0
```
## 2. 在本地加载封装的模型
```
Python
import paddlehub as hub
```
### 加载spinalnet_res50_gemstone
```
Python
spinal_res50 = hub.Module(name="spinalnet_res50_gemstone")
```
### 加载spinalnet_vgg16_gemstone
```
Python
spinal_vgg16 = hub.Module(name="spinalnet_vgg16_gemstone")
```
### 加载spinalnet_res101_gemstone
```
Python
spinal_res101 = hub.Module(name="spinalnet_res101_gemstone")
```
## 3. 预测
### 使用spinalnet_res50_gemstone预测
```
Python
result_res50 = spinal_res50.predict(['/PATH/TO/IMAGE'])
print(result_res50)
```
### 使用spinalnet_vgg16_gemstone预测
```
Python
result_vgg16 = spinal_vgg16.predict(['/PATH/TO/IMAGE'])
print(result_vgg16)
```
### 使用spinalnet_res101_gemstone预测
```
Python
sresult_res101 = spinal_res101.predict(['/PATH/TO/IMAGE'])
print(result_res101)
```
## 4. 命令行预测
```
shell
$
hub run spinalnet_res50_gemstone
--input_path
"/PATH/TO/IMAGE"
--top_k
5
```
## 5. 对PaddleHub模型进行训练微调
## 如何开始Fine-tune
在完成安装PaddlePaddle与PaddleHub后,即可对Spinalnet模型进行针对宝石数据集的Fine-tune。
## 代码步骤
使用PaddleHub Fine-tune API进行Fine-tune可以分为5个步骤。
### Step1: 加载必要的库
```
python
from
paddlehub.finetune.trainer
import
Trainer
from
gem_dataset
import
GemStones
from
paddlehub.vision
import
transforms
as
T
import
paddle
```
### Step2: 定义数据预处理方式
```
python
train_transforms
=
T
.
Compose
([
T
.
Resize
((
256
,
256
)),
T
.
CenterCrop
(
224
),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])],
to_rgb
=
True
)
eval_transforms
=
T
.
Compose
([
T
.
Resize
((
256
,
256
)),
T
.
CenterCrop
(
224
),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])],
to_rgb
=
True
)
```
`transforms`
数据增强模块定义了丰富的数据预处理方式,用户可按照需求替换自己需要的数据预处理方式。
### Step3: 定义数据集
```
python
gem_train
=
GemStones
(
transforms
=
train_transforms
,
mode
=
'train'
)
gem_validate
=
GemStones
(
transforms
=
eval_transforms
,
mode
=
'eval'
)
```
数据集的准备代码可以参考
[
gem_dataset.py
](
PaddleHub/modules/thirdparty/image/classification/SpinanlNet_Gemstones/gem_dataset.py
)
。
### Step4: 开始训练微调
```
python
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.001
,
momentum
=
0.9
,
parameters
=
spinal_res50
.
parameters
())
trainer
=
Trainer
(
spinal_res50
,
optimizer
,
use_gpu
=
True
,
checkpoint_dir
=
'fine_tuned_model'
)
trainer
.
train
(
gem_train
,
epochs
=
5
,
batch_size
=
128
,
eval_dataset
=
gem_validate
,
save_interval
=
1
,
log_interval
=
10
)
```
### Step5: 微调后再预测
```
python
spinal_res50
=
hub
.
Module
(
name
=
"spinalnet_res50_gemstone"
)
result_res50
=
spinal_res50
.
predict
([
'/PATH/TO/IMAGE'
])
print
(
result_res50
)
```
### 查看代码
https://github.com/PaddleHub/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py
https://github.com/PaddleHub/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py
https://github.com/PaddleHub/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py
### 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
modules/thirdparty/image/classification/SpinalNet_Gemstones/gem_dataset.py
0 → 100644
浏览文件 @
f8b6ded0
import
paddle
import
numpy
as
np
from
typing
import
Callable
from
code.config
import
config_parameters
class
GemStones
(
paddle
.
io
.
Dataset
):
"""
step 1:paddle.io.Dataset
"""
def
__init__
(
self
,
transforms
:
Callable
,
mode
:
str
=
'train'
):
"""
step 2:create reader
"""
super
(
GemStones
,
self
).
__init__
()
self
.
mode
=
mode
self
.
transforms
=
transforms
train_image_dir
=
config_parameters
[
'train_image_dir'
]
eval_image_dir
=
config_parameters
[
'eval_image_dir'
]
test_image_dir
=
config_parameters
[
'test_image_dir'
]
train_data_folder
=
paddle
.
vision
.
DatasetFolder
(
train_image_dir
)
eval_data_folder
=
paddle
.
vision
.
DatasetFolder
(
eval_image_dir
)
test_data_folder
=
paddle
.
vision
.
DatasetFolder
(
test_image_dir
)
config_parameters
[
'label_dict'
]
=
train_data_folder
.
class_to_idx
if
self
.
mode
==
'train'
:
self
.
data
=
train_data_folder
elif
self
.
mode
==
'eval'
:
self
.
data
=
eval_data_folder
elif
self
.
mode
==
'test'
:
self
.
data
=
test_data_folder
def
__getitem__
(
self
,
index
):
"""
step 3:implement __getitem__
"""
data
=
np
.
array
(
self
.
data
[
index
][
0
]).
astype
(
'float32'
)
data
=
self
.
transforms
(
data
)
label
=
np
.
array
(
self
.
data
[
index
][
1
]).
astype
(
'int64'
)
return
data
,
label
def
__len__
(
self
):
"""
step 4:implement __len__
"""
return
len
(
self
.
data
)
\ No newline at end of file
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/README.md
0 → 100755
浏览文件 @
f8b6ded0
## 概述
*
[
SpinalNet
](
https://arxiv.org/abs/2007.03347
)
的网络结构如下图,
[
网络结构图
](
https://ai-studio-static-online.cdn.bcebos.com/0c58fff63018401089f92085a2aea5d46921351012e64ac4b7d5a8e1370c463f
)
该模型为SpinalNet在宝石数据集上的预训练模型,可以安装PaddleHub后完成一键预测及微调。
## 预训练模型
预训练模型位于https://aistudio.baidu.com/asistudio/datasetdetail/69923
## API
加载该模型后,使用PadduleHub2.0的默认图像分类API
```
def Predict(images, batch_size, top_k):
```
**参数**
*
images (list[str: 图片路径]) : 输入图像数据列表
*
batch_size: 默认值为1
*
top_k: 每张图片的前k个预测类别
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/label_list.txt
0 → 100755
浏览文件 @
f8b6ded0
Alexandrite
Almandine
Amazonite
Amber
Amethyst
Ametrine
Andalusite
Andradite
Aquamarine
Aventurine Green
Aventurine Yellow
Benitoite
Beryl Golden
Bixbite
Bloodstone
Blue Lace Agate
Carnelian
Cats Eye
Chalcedony
Chalcedony Blue
Chrome Diopside
Chrysoberyl
Chrysocolla
Chrysoprase
Citrine
Coral
Danburite
Diamond
Diaspore
Dumortierite
Emerald
Fluorite
Garnet Red
Goshenite
Grossular
Hessonite
Hiddenite
Iolite
Jade
Jasper
Kunzite
Kyanite
Labradorite
Lapis Lazuli
Larimar
Malachite
Moonstone
Morganite
Onyx Black
Onyx Green
Onyx Red
Opal
Pearl
Peridot
Prehnite
Pyrite
Pyrope
Quartz Beer
Quartz Lemon
Quartz Rose
Quartz Rutilated
Quartz Smoky
Rhodochrosite
Rhodolite
Rhodonite
Ruby
Sapphire Blue
Sapphire Pink
Sapphire Purple
Sapphire Yellow
Scapolite
Serpentine
Sodalite
Spessartite
Sphene
Spinel
Spodumene
Sunstone
Tanzanite
Tigers Eye
Topaz
Tourmaline
Tsavorite
Turquoise
Variscite
Zircon
Zoisite
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py
0 → 100755
浏览文件 @
f8b6ded0
# copyright (c) 2021 nanting03. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Union
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddlehub.vision.transforms
as
T
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.cv_module
import
ImageClassifierModule
class
BottleneckBlock
(
nn
.
Layer
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
groups
=
1
,
base_width
=
64
,
dilation
=
1
,
norm_layer
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2D
width
=
int
(
planes
*
(
base_width
/
64.
))
*
groups
self
.
conv1
=
nn
.
Conv2D
(
inplanes
,
width
,
1
,
bias_attr
=
False
)
self
.
bn1
=
norm_layer
(
width
)
self
.
conv2
=
nn
.
Conv2D
(
width
,
width
,
3
,
padding
=
dilation
,
stride
=
stride
,
groups
=
groups
,
dilation
=
dilation
,
bias_attr
=
False
)
self
.
bn2
=
norm_layer
(
width
)
self
.
conv3
=
nn
.
Conv2D
(
width
,
planes
*
self
.
expansion
,
1
,
bias_attr
=
False
)
self
.
bn3
=
norm_layer
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
()
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
block
=
BottleneckBlock
,
depth
=
101
,
with_pool
=
True
):
super
(
ResNet
,
self
).
__init__
()
layer_cfg
=
{
18
:
[
2
,
2
,
2
,
2
],
34
:
[
3
,
4
,
6
,
3
],
50
:
[
3
,
4
,
6
,
3
],
101
:
[
3
,
4
,
23
,
3
],
152
:
[
3
,
8
,
36
,
3
]
}
layers
=
layer_cfg
[
depth
]
self
.
with_pool
=
with_pool
self
.
_norm_layer
=
nn
.
BatchNorm2D
self
.
inplanes
=
64
self
.
dilation
=
1
self
.
conv1
=
nn
.
Conv2D
(
3
,
self
.
inplanes
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias_attr
=
False
)
self
.
bn1
=
self
.
_norm_layer
(
self
.
inplanes
)
self
.
relu
=
nn
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
)
if
with_pool
:
self
.
avgpool
=
nn
.
AdaptiveAvgPool2D
((
1
,
1
))
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
,
dilate
=
False
):
norm_layer
=
self
.
_norm_layer
downsample
=
None
previous_dilation
=
self
.
dilation
if
dilate
:
self
.
dilation
*=
stride
stride
=
1
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
planes
*
block
.
expansion
,
1
,
stride
=
stride
,
bias_attr
=
False
),
norm_layer
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
,
1
,
64
,
previous_dilation
,
norm_layer
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
norm_layer
=
norm_layer
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
if
self
.
with_pool
:
x
=
self
.
avgpool
(
x
)
return
x
@
moduleinfo
(
name
=
"spinalnet_res101_gemstone"
,
type
=
"CV/classification"
,
author
=
"nanting03"
,
author_email
=
"975348977@qq.com"
,
summary
=
"spinalnet_res101_gemstone is a classification model, "
"this module is trained with Gemstone dataset."
,
version
=
"1.0.0"
,
meta
=
ImageClassifierModule
)
class
SpinalNet_ResNet101
(
nn
.
Layer
):
def
__init__
(
self
,
label_list
:
list
=
None
,
load_checkpoint
:
str
=
None
):
super
(
SpinalNet_ResNet101
,
self
).
__init__
()
if
label_list
is
not
None
:
self
.
labels
=
label_list
class_dim
=
len
(
self
.
labels
)
else
:
label_list
=
[]
label_file
=
os
.
path
.
join
(
self
.
directory
,
'label_list.txt'
)
files
=
open
(
label_file
)
for
line
in
files
.
readlines
():
line
=
line
.
strip
(
'
\n
'
)
label_list
.
append
(
line
)
self
.
labels
=
label_list
class_dim
=
len
(
self
.
labels
)
self
.
backbone
=
ResNet
()
half_in_size
=
round
(
2048
/
2
)
layer_width
=
20
self
.
half_in_size
=
half_in_size
self
.
fc_spinal_layer1
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_spinal_layer2
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_spinal_layer3
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_spinal_layer4
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_out
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
layer_width
*
4
,
class_dim
),
)
if
load_checkpoint
is
not
None
:
self
.
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
self
.
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'spinalnet_res101.pdparams'
)
self
.
model_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_dict
(
self
.
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transforms
(
self
,
images
:
Union
[
str
,
np
.
ndarray
]):
transforms
=
T
.
Compose
([
T
.
Resize
((
256
,
256
)),
T
.
CenterCrop
(
224
),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
],
to_rgb
=
True
)
return
transforms
(
images
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
backbone
(
inputs
)
feature
=
y
y
=
paddle
.
flatten
(
y
,
1
)
y1
=
self
.
fc_spinal_layer1
(
y
[:,
0
:
self
.
half_in_size
])
y2
=
self
.
fc_spinal_layer2
(
paddle
.
concat
([
y
[:,
self
.
half_in_size
:
2
*
self
.
half_in_size
],
y1
],
axis
=
1
))
y3
=
self
.
fc_spinal_layer3
(
paddle
.
concat
([
y
[:,
0
:
self
.
half_in_size
],
y2
],
axis
=
1
))
y4
=
self
.
fc_spinal_layer4
(
paddle
.
concat
([
y
[:,
self
.
half_in_size
:
2
*
self
.
half_in_size
],
y3
],
axis
=
1
))
y
=
paddle
.
concat
([
y1
,
y2
,
y3
,
y4
],
axis
=
1
)
y
=
self
.
fc_out
(
y
)
return
y
,
feature
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/README.md
0 → 100755
浏览文件 @
f8b6ded0
## 概述
*
[
SpinalNet
](
https://arxiv.org/abs/2007.03347
)
的网络结构如下图,
[
网络结构图
](
https://ai-studio-static-online.cdn.bcebos.com/0c58fff63018401089f92085a2aea5d46921351012e64ac4b7d5a8e1370c463f
)
该模型为SpinalNet在宝石数据集上的预训练模型,可以安装PaddleHub后完成一键预测及微调。
## 预训练模型
预训练模型位于https://aistudio.baidu.com/asistudio/datasetdetail/69923
## API
加载该模型后,使用PadduleHub2.0的默认图像分类API
```
def Predict(images, batch_size, top_k):
```
**参数**
*
images (list[str: 图片路径]) : 输入图像数据列表
*
batch_size: 默认值为1
*
top_k: 每张图片的前k个预测类别
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/label_list.txt
0 → 100755
浏览文件 @
f8b6ded0
Alexandrite
Almandine
Amazonite
Amber
Amethyst
Ametrine
Andalusite
Andradite
Aquamarine
Aventurine Green
Aventurine Yellow
Benitoite
Beryl Golden
Bixbite
Bloodstone
Blue Lace Agate
Carnelian
Cats Eye
Chalcedony
Chalcedony Blue
Chrome Diopside
Chrysoberyl
Chrysocolla
Chrysoprase
Citrine
Coral
Danburite
Diamond
Diaspore
Dumortierite
Emerald
Fluorite
Garnet Red
Goshenite
Grossular
Hessonite
Hiddenite
Iolite
Jade
Jasper
Kunzite
Kyanite
Labradorite
Lapis Lazuli
Larimar
Malachite
Moonstone
Morganite
Onyx Black
Onyx Green
Onyx Red
Opal
Pearl
Peridot
Prehnite
Pyrite
Pyrope
Quartz Beer
Quartz Lemon
Quartz Rose
Quartz Rutilated
Quartz Smoky
Rhodochrosite
Rhodolite
Rhodonite
Ruby
Sapphire Blue
Sapphire Pink
Sapphire Purple
Sapphire Yellow
Scapolite
Serpentine
Sodalite
Spessartite
Sphene
Spinel
Spodumene
Sunstone
Tanzanite
Tigers Eye
Topaz
Tourmaline
Tsavorite
Turquoise
Variscite
Zircon
Zoisite
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py
0 → 100755
浏览文件 @
f8b6ded0
# copyright (c) 2021 nanting03. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Union
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddlehub.vision.transforms
as
T
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.cv_module
import
ImageClassifierModule
class
BottleneckBlock
(
nn
.
Layer
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
groups
=
1
,
base_width
=
64
,
dilation
=
1
,
norm_layer
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2D
width
=
int
(
planes
*
(
base_width
/
64.
))
*
groups
self
.
conv1
=
nn
.
Conv2D
(
inplanes
,
width
,
1
,
bias_attr
=
False
)
self
.
bn1
=
norm_layer
(
width
)
self
.
conv2
=
nn
.
Conv2D
(
width
,
width
,
3
,
padding
=
dilation
,
stride
=
stride
,
groups
=
groups
,
dilation
=
dilation
,
bias_attr
=
False
)
self
.
bn2
=
norm_layer
(
width
)
self
.
conv3
=
nn
.
Conv2D
(
width
,
planes
*
self
.
expansion
,
1
,
bias_attr
=
False
)
self
.
bn3
=
norm_layer
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
()
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
block
=
BottleneckBlock
,
depth
=
50
,
with_pool
=
True
):
super
(
ResNet
,
self
).
__init__
()
layer_cfg
=
{
18
:
[
2
,
2
,
2
,
2
],
34
:
[
3
,
4
,
6
,
3
],
50
:
[
3
,
4
,
6
,
3
],
101
:
[
3
,
4
,
23
,
3
],
152
:
[
3
,
8
,
36
,
3
]
}
layers
=
layer_cfg
[
depth
]
self
.
with_pool
=
with_pool
self
.
_norm_layer
=
nn
.
BatchNorm2D
self
.
inplanes
=
64
self
.
dilation
=
1
self
.
conv1
=
nn
.
Conv2D
(
3
,
self
.
inplanes
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias_attr
=
False
)
self
.
bn1
=
self
.
_norm_layer
(
self
.
inplanes
)
self
.
relu
=
nn
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
)
if
with_pool
:
self
.
avgpool
=
nn
.
AdaptiveAvgPool2D
((
1
,
1
))
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
,
dilate
=
False
):
norm_layer
=
self
.
_norm_layer
downsample
=
None
previous_dilation
=
self
.
dilation
if
dilate
:
self
.
dilation
*=
stride
stride
=
1
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
planes
*
block
.
expansion
,
1
,
stride
=
stride
,
bias_attr
=
False
),
norm_layer
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
,
1
,
64
,
previous_dilation
,
norm_layer
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
norm_layer
=
norm_layer
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
if
self
.
with_pool
:
x
=
self
.
avgpool
(
x
)
return
x
@
moduleinfo
(
name
=
"spinalnet_res50_gemstone"
,
type
=
"CV/classification"
,
author
=
"nanting03"
,
author_email
=
"975348977@qq.com"
,
summary
=
"spinalnet_res50_gemstone is a classification model, "
"this module is trained with Gemstone dataset."
,
version
=
"1.0.0"
,
meta
=
ImageClassifierModule
)
class
SpinalNet_ResNet50
(
nn
.
Layer
):
def
__init__
(
self
,
label_list
:
list
=
None
,
load_checkpoint
:
str
=
None
):
super
(
SpinalNet_ResNet50
,
self
).
__init__
()
if
label_list
is
not
None
:
self
.
labels
=
label_list
class_dim
=
len
(
self
.
labels
)
else
:
label_list
=
[]
label_file
=
os
.
path
.
join
(
self
.
directory
,
'label_list.txt'
)
files
=
open
(
label_file
)
for
line
in
files
.
readlines
():
line
=
line
.
strip
(
'
\n
'
)
label_list
.
append
(
line
)
self
.
labels
=
label_list
class_dim
=
len
(
self
.
labels
)
self
.
backbone
=
ResNet
()
half_in_size
=
round
(
2048
/
2
)
layer_width
=
20
self
.
half_in_size
=
half_in_size
self
.
fc_spinal_layer1
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_spinal_layer2
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_spinal_layer3
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_spinal_layer4
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
())
self
.
fc_out
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
layer_width
*
4
,
class_dim
),
)
if
load_checkpoint
is
not
None
:
self
.
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
self
.
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'spinalnet_res50.pdparams'
)
self
.
model_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_dict
(
self
.
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transforms
(
self
,
images
:
Union
[
str
,
np
.
ndarray
]):
transforms
=
T
.
Compose
([
T
.
Resize
((
256
,
256
)),
T
.
CenterCrop
(
224
),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
],
to_rgb
=
True
)
return
transforms
(
images
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
backbone
(
inputs
)
feature
=
y
y
=
paddle
.
flatten
(
y
,
1
)
y1
=
self
.
fc_spinal_layer1
(
y
[:,
0
:
self
.
half_in_size
])
y2
=
self
.
fc_spinal_layer2
(
paddle
.
concat
([
y
[:,
self
.
half_in_size
:
2
*
self
.
half_in_size
],
y1
],
axis
=
1
))
y3
=
self
.
fc_spinal_layer3
(
paddle
.
concat
([
y
[:,
0
:
self
.
half_in_size
],
y2
],
axis
=
1
))
y4
=
self
.
fc_spinal_layer4
(
paddle
.
concat
([
y
[:,
self
.
half_in_size
:
2
*
self
.
half_in_size
],
y3
],
axis
=
1
))
y
=
paddle
.
concat
([
y1
,
y2
,
y3
,
y4
],
axis
=
1
)
y
=
self
.
fc_out
(
y
)
return
y
,
feature
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/README.md
0 → 100755
浏览文件 @
f8b6ded0
## 概述
*
[
SpinalNet
](
https://arxiv.org/abs/2007.03347
)
的网络结构如下图,
[
网络结构图
](
https://ai-studio-static-online.cdn.bcebos.com/0c58fff63018401089f92085a2aea5d46921351012e64ac4b7d5a8e1370c463f
)
该模型为SpinalNet在宝石数据集上的预训练模型,可以安装PaddleHub后完成一键预测及微调。
## 预训练模型
预训练模型位于https://aistudio.baidu.com/asistudio/datasetdetail/69923
## API
加载该模型后,使用PadduleHub2.0的默认图像分类API
```
def Predict(images, batch_size, top_k):
```
**参数**
*
images (list[str: 图片路径]) : 输入图像数据列表
*
batch_size: 默认值为1
*
top_k: 每张图片的前k个预测类别
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/label_list.txt
0 → 100755
浏览文件 @
f8b6ded0
Alexandrite
Almandine
Amazonite
Amber
Amethyst
Ametrine
Andalusite
Andradite
Aquamarine
Aventurine Green
Aventurine Yellow
Benitoite
Beryl Golden
Bixbite
Bloodstone
Blue Lace Agate
Carnelian
Cats Eye
Chalcedony
Chalcedony Blue
Chrome Diopside
Chrysoberyl
Chrysocolla
Chrysoprase
Citrine
Coral
Danburite
Diamond
Diaspore
Dumortierite
Emerald
Fluorite
Garnet Red
Goshenite
Grossular
Hessonite
Hiddenite
Iolite
Jade
Jasper
Kunzite
Kyanite
Labradorite
Lapis Lazuli
Larimar
Malachite
Moonstone
Morganite
Onyx Black
Onyx Green
Onyx Red
Opal
Pearl
Peridot
Prehnite
Pyrite
Pyrope
Quartz Beer
Quartz Lemon
Quartz Rose
Quartz Rutilated
Quartz Smoky
Rhodochrosite
Rhodolite
Rhodonite
Ruby
Sapphire Blue
Sapphire Pink
Sapphire Purple
Sapphire Yellow
Scapolite
Serpentine
Sodalite
Spessartite
Sphene
Spinel
Spodumene
Sunstone
Tanzanite
Tigers Eye
Topaz
Tourmaline
Tsavorite
Turquoise
Variscite
Zircon
Zoisite
modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py
0 → 100755
浏览文件 @
f8b6ded0
# copyright (c) 2021 nanting03. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Union
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddlehub.vision.transforms
as
T
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.cv_module
import
ImageClassifierModule
import
paddle
from
paddle
import
nn
class
VGG
(
nn
.
Layer
):
def
__init__
(
self
,
features
,
with_pool
=
True
):
super
(
VGG
,
self
).
__init__
()
self
.
features
=
features
self
.
with_pool
=
with_pool
if
with_pool
:
self
.
avgpool
=
nn
.
AdaptiveAvgPool2D
((
7
,
7
))
def
forward
(
self
,
x
):
x
=
self
.
features
(
x
)
if
self
.
with_pool
:
x
=
self
.
avgpool
(
x
)
return
x
def
make_layers
(
cfg
,
batch_norm
=
False
):
layers
=
[]
in_channels
=
3
for
v
in
cfg
:
if
v
==
'M'
:
layers
+=
[
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
)]
else
:
conv2d
=
nn
.
Conv2D
(
in_channels
,
v
,
kernel_size
=
3
,
padding
=
1
)
if
batch_norm
:
layers
+=
[
conv2d
,
nn
.
BatchNorm2D
(
v
),
nn
.
ReLU
()]
else
:
layers
+=
[
conv2d
,
nn
.
ReLU
()]
in_channels
=
v
return
nn
.
Sequential
(
*
layers
)
cfgs
=
{
'A'
:
[
64
,
'M'
,
128
,
'M'
,
256
,
256
,
'M'
,
512
,
512
,
'M'
,
512
,
512
,
'M'
],
'B'
:
[
64
,
64
,
'M'
,
128
,
128
,
'M'
,
256
,
256
,
'M'
,
512
,
512
,
'M'
,
512
,
512
,
'M'
],
'D'
:
[
64
,
64
,
'M'
,
128
,
128
,
'M'
,
256
,
256
,
256
,
'M'
,
512
,
512
,
512
,
'M'
,
512
,
512
,
512
,
'M'
],
'E'
:
[
64
,
64
,
'M'
,
128
,
128
,
'M'
,
256
,
256
,
256
,
256
,
'M'
,
512
,
512
,
512
,
512
,
'M'
,
512
,
512
,
512
,
512
,
'M'
],
}
def
_vgg
(
arch
,
cfg
,
batch_norm
,
**
kwargs
):
model
=
VGG
(
make_layers
(
cfgs
[
cfg
],
batch_norm
=
batch_norm
),
**
kwargs
)
return
model
def
vgg16
(
batch_norm
=
False
,
**
kwargs
):
model_name
=
'vgg16'
if
batch_norm
:
model_name
+=
(
'_bn'
)
return
_vgg
(
model_name
,
'D'
,
batch_norm
,
**
kwargs
)
@
moduleinfo
(
name
=
"spinalnet_vgg16_gemstone"
,
type
=
"CV/classification"
,
author
=
"nanting03"
,
author_email
=
"975348977@qq.com"
,
summary
=
"spinalnet_vgg16_gemstone is a classification model, "
"this module is trained with Gemstone dataset."
,
version
=
"1.0.0"
,
meta
=
ImageClassifierModule
)
class
SpinalNet_VGG16
(
nn
.
Layer
):
def
__init__
(
self
,
label_list
:
list
=
None
,
load_checkpoint
:
str
=
None
):
super
(
SpinalNet_VGG16
,
self
).
__init__
()
if
label_list
is
not
None
:
self
.
labels
=
label_list
class_dim
=
len
(
self
.
labels
)
else
:
label_list
=
[]
label_file
=
os
.
path
.
join
(
self
.
directory
,
'label_list.txt'
)
files
=
open
(
label_file
)
for
line
in
files
.
readlines
():
line
=
line
.
strip
(
'
\n
'
)
label_list
.
append
(
line
)
self
.
labels
=
label_list
class_dim
=
len
(
self
.
labels
)
self
.
backbone
=
vgg16
()
half_in_size
=
round
(
512
*
7
*
7
/
2
)
layer_width
=
4096
self
.
half_in_size
=
half_in_size
self
.
fc_spinal_layer1
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
(),
)
self
.
fc_spinal_layer2
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
(),
)
self
.
fc_spinal_layer3
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
(),
)
self
.
fc_spinal_layer4
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
half_in_size
+
layer_width
,
layer_width
),
nn
.
BatchNorm1D
(
layer_width
),
nn
.
ReLU
(),
)
self
.
fc_out
=
nn
.
Sequential
(
nn
.
Dropout
(
p
=
0.5
),
nn
.
Linear
(
layer_width
*
4
,
class_dim
))
if
load_checkpoint
is
not
None
:
self
.
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
self
.
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'spinalnet_vgg16.pdparams'
)
self
.
model_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_dict
(
self
.
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transforms
(
self
,
images
:
Union
[
str
,
np
.
ndarray
]):
transforms
=
T
.
Compose
([
T
.
Resize
((
256
,
256
)),
T
.
CenterCrop
(
224
),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
],
to_rgb
=
True
)
return
transforms
(
images
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
backbone
(
inputs
)
feature
=
y
y
=
paddle
.
flatten
(
y
,
1
)
y1
=
self
.
fc_spinal_layer1
(
y
[:,
0
:
self
.
half_in_size
])
y2
=
self
.
fc_spinal_layer2
(
paddle
.
concat
([
y
[:,
self
.
half_in_size
:
2
*
self
.
half_in_size
],
y1
],
axis
=
1
))
y3
=
self
.
fc_spinal_layer3
(
paddle
.
concat
([
y
[:,
0
:
self
.
half_in_size
],
y2
],
axis
=
1
))
y4
=
self
.
fc_spinal_layer4
(
paddle
.
concat
([
y
[:,
self
.
half_in_size
:
2
*
self
.
half_in_size
],
y3
],
axis
=
1
))
y
=
paddle
.
concat
([
y1
,
y2
,
y3
,
y4
],
axis
=
1
)
y
=
self
.
fc_out
(
y
)
return
y
,
feature
modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Cats Eye/cats_eye_3.jpg
0 → 100644
浏览文件 @
f8b6ded0
2.9 KB
modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Fluorite/fluorite_18.jpg
0 → 100644
浏览文件 @
f8b6ded0
64.3 KB
modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Kunzite/kunzite_28.jpg
0 → 100644
浏览文件 @
f8b6ded0
37.7 KB
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录