diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/README.md b/modules/thirdparty/image/classification/SpinalNet_Gemstones/README.md new file mode 100644 index 0000000000000000000000000000000000000000..432711cb9adfd959fb42d2c7ec01211683e10810 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/README.md @@ -0,0 +1,118 @@ +# PaddleHub SpinalNet + +本示例将展示如何使用PaddleHub的SpinalNet预训练模型进行宝石识别或finetune并完成宝石的预测任务。 + +## 1. 首先要安装PaddleHub2.0版 + +```shell +$pip install -U paddlehub==2.0.0 +``` + +## 2. 在本地加载封装的模型 + +```Python +import paddlehub as hub +``` +### 加载spinalnet_res50_gemstone +```Python +spinal_res50 = hub.Module(name="spinalnet_res50_gemstone") +``` +### 加载spinalnet_vgg16_gemstone +```Python +spinal_vgg16 = hub.Module(name="spinalnet_vgg16_gemstone") +``` +### 加载spinalnet_res101_gemstone +```Python +spinal_res101 = hub.Module(name="spinalnet_res101_gemstone") +``` +## 3. 预测 + +### 使用spinalnet_res50_gemstone预测 +```Python +result_res50 = spinal_res50.predict(['/PATH/TO/IMAGE']) +print(result_res50) +``` +### 使用spinalnet_vgg16_gemstone预测 +```Python +result_vgg16 = spinal_vgg16.predict(['/PATH/TO/IMAGE']) +print(result_vgg16) +``` +### 使用spinalnet_res101_gemstone预测 +```Python +sresult_res101 = spinal_res101.predict(['/PATH/TO/IMAGE']) +print(result_res101) +``` +## 4. 命令行预测 + +```shell +$ hub run spinalnet_res50_gemstone --input_path "/PATH/TO/IMAGE" --top_k 5 +``` + +## 5. 对PaddleHub模型进行训练微调 + +## 如何开始Fine-tune + +在完成安装PaddlePaddle与PaddleHub后,即可对Spinalnet模型进行针对宝石数据集的Fine-tune。 + +## 代码步骤 + +使用PaddleHub Fine-tune API进行Fine-tune可以分为5个步骤。 + +### Step1: 加载必要的库 +```python +from paddlehub.finetune.trainer import Trainer +from gem_dataset import GemStones +from paddlehub.vision import transforms as T +import paddle +``` + + +### Step2: 定义数据预处理方式 +```python + +train_transforms = T.Compose([T.Resize((256, 256)), T.CenterCrop(224), T.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])], to_rgb=True) +eval_transforms = T.Compose([T.Resize((256, 256)), T.CenterCrop(224), T.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])], to_rgb=True) +``` + +`transforms` 数据增强模块定义了丰富的数据预处理方式,用户可按照需求替换自己需要的数据预处理方式。 + +### Step3: 定义数据集 +```python +gem_train = GemStones(transforms=train_transforms, mode='train') +gem_validate = GemStones(transforms=eval_transforms, mode='eval') +``` + + +数据集的准备代码可以参考 [gem_dataset.py](PaddleHub/modules/thirdparty/image/classification/SpinanlNet_Gemstones/gem_dataset.py)。 + + +### Step4: 开始训练微调 + +```python +optimizer = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=spinal_res50.parameters()) +trainer = Trainer(spinal_res50, optimizer, use_gpu=True, checkpoint_dir='fine_tuned_model') +trainer.train(gem_train, epochs=5, batch_size=128, eval_dataset=gem_validate, save_interval=1, log_interval=10) +``` + +### Step5: 微调后再预测 + +```python +spinal_res50 = hub.Module(name="spinalnet_res50_gemstone") +result_res50 = spinal_res50.predict(['/PATH/TO/IMAGE']) +print(result_res50) +``` + + +### 查看代码 + +https://github.com/PaddleHub/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py + +https://github.com/PaddleHub/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py + +https://github.com/PaddleHub/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py + +### 依赖 + +paddlepaddle >= 2.0.0 + +paddlehub >= 2.0.0 diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/gem_dataset.py b/modules/thirdparty/image/classification/SpinalNet_Gemstones/gem_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5e77a899d2473fd05e9ddc0396ec0dbd50fdbf5a --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/gem_dataset.py @@ -0,0 +1,53 @@ +import paddle +import numpy as np +from typing import Callable +from code.config import config_parameters + +class GemStones(paddle.io.Dataset): + """ + step 1:paddle.io.Dataset + """ + def __init__(self, transforms: Callable, mode: str ='train'): + """ + step 2:create reader + """ + super(GemStones, self).__init__() + + self.mode = mode + self.transforms = transforms + + train_image_dir = config_parameters['train_image_dir'] + eval_image_dir = config_parameters['eval_image_dir'] + test_image_dir = config_parameters['test_image_dir'] + + train_data_folder = paddle.vision.DatasetFolder(train_image_dir) + eval_data_folder = paddle.vision.DatasetFolder(eval_image_dir) + test_data_folder = paddle.vision.DatasetFolder(test_image_dir) + + config_parameters['label_dict'] = train_data_folder.class_to_idx + + if self.mode == 'train': + self.data = train_data_folder + elif self.mode == 'eval': + self.data = eval_data_folder + elif self.mode == 'test': + self.data = test_data_folder + + + def __getitem__(self, index): + """ + step 3:implement __getitem__ + """ + data = np.array(self.data[index][0]).astype('float32') + + data = self.transforms(data) + + label = np.array(self.data[index][1]).astype('int64') + + return data, label + + def __len__(self): + """ + step 4:implement __len__ + """ + return len(self.data) \ No newline at end of file diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/README.md b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/README.md new file mode 100755 index 0000000000000000000000000000000000000000..55fdee68aeeea4b4d194d8d5c59a1a08f88455f0 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/README.md @@ -0,0 +1,21 @@ +## 概述 +* [SpinalNet](https://arxiv.org/abs/2007.03347)的网络结构如下图, + +[网络结构图](https://ai-studio-static-online.cdn.bcebos.com/0c58fff63018401089f92085a2aea5d46921351012e64ac4b7d5a8e1370c463f) + +该模型为SpinalNet在宝石数据集上的预训练模型,可以安装PaddleHub后完成一键预测及微调。 + +## 预训练模型 + +预训练模型位于https://aistudio.baidu.com/asistudio/datasetdetail/69923 + +## API +加载该模型后,使用PadduleHub2.0的默认图像分类API +``` +def Predict(images, batch_size, top_k): +``` + +**参数** +* images (list[str: 图片路径]) : 输入图像数据列表 +* batch_size: 默认值为1 +* top_k: 每张图片的前k个预测类别 diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/label_list.txt b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/label_list.txt new file mode 100755 index 0000000000000000000000000000000000000000..d64659423aca69b2788a40754855ace9010d735d --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/label_list.txt @@ -0,0 +1,87 @@ +Alexandrite +Almandine +Amazonite +Amber +Amethyst +Ametrine +Andalusite +Andradite +Aquamarine +Aventurine Green +Aventurine Yellow +Benitoite +Beryl Golden +Bixbite +Bloodstone +Blue Lace Agate +Carnelian +Cats Eye +Chalcedony +Chalcedony Blue +Chrome Diopside +Chrysoberyl +Chrysocolla +Chrysoprase +Citrine +Coral +Danburite +Diamond +Diaspore +Dumortierite +Emerald +Fluorite +Garnet Red +Goshenite +Grossular +Hessonite +Hiddenite +Iolite +Jade +Jasper +Kunzite +Kyanite +Labradorite +Lapis Lazuli +Larimar +Malachite +Moonstone +Morganite +Onyx Black +Onyx Green +Onyx Red +Opal +Pearl +Peridot +Prehnite +Pyrite +Pyrope +Quartz Beer +Quartz Lemon +Quartz Rose +Quartz Rutilated +Quartz Smoky +Rhodochrosite +Rhodolite +Rhodonite +Ruby +Sapphire Blue +Sapphire Pink +Sapphire Purple +Sapphire Yellow +Scapolite +Serpentine +Sodalite +Spessartite +Sphene +Spinel +Spodumene +Sunstone +Tanzanite +Tigers Eye +Topaz +Tourmaline +Tsavorite +Turquoise +Variscite +Zircon +Zoisite diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py new file mode 100755 index 0000000000000000000000000000000000000000..e4a67b7f8db92e0490dadfb16b6fd81050205130 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res101_gemstone/module.py @@ -0,0 +1,255 @@ +# copyright (c) 2021 nanting03. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddlehub.vision.transforms as T +from paddlehub.module.module import moduleinfo +from paddlehub.module.cv_module import ImageClassifierModule + + +class BottleneckBlock(nn.Layer): + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BottleneckBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) + self.bn1 = norm_layer(width) + + self.conv2 = nn.Conv2D(width, + width, + 3, + padding=dilation, + stride=stride, + groups=groups, + dilation=dilation, + bias_attr=False) + self.bn2 = norm_layer(width) + + self.conv3 = nn.Conv2D(width, + planes * self.expansion, + 1, + bias_attr=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + def __init__(self, block=BottleneckBlock, depth=101, with_pool=True): + super(ResNet, self).__init__() + layer_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3] + } + layers = layer_cfg[depth] + self.with_pool = with_pool + self._norm_layer = nn.BatchNorm2D + + self.inplanes = 64 + self.dilation = 1 + + self.conv1 = nn.Conv2D(3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D(self.inplanes, + planes * block.expansion, + 1, + stride=stride, + bias_attr=False), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, 1, 64, + previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.with_pool: + x = self.avgpool(x) + + return x + + +@moduleinfo(name="spinalnet_res101_gemstone", + type="CV/classification", + author="nanting03", + author_email="975348977@qq.com", + summary="spinalnet_res101_gemstone is a classification model, " + "this module is trained with Gemstone dataset.", + version="1.0.0", + meta=ImageClassifierModule) +class SpinalNet_ResNet101(nn.Layer): + def __init__(self, label_list: list = None, load_checkpoint: str = None): + super(SpinalNet_ResNet101, self).__init__() + + if label_list is not None: + self.labels = label_list + class_dim = len(self.labels) + else: + label_list = [] + label_file = os.path.join(self.directory, 'label_list.txt') + files = open(label_file) + for line in files.readlines(): + line = line.strip('\n') + label_list.append(line) + self.labels = label_list + class_dim = len(self.labels) + + self.backbone = ResNet() + + half_in_size = round(2048 / 2) + layer_width = 20 + + self.half_in_size = half_in_size + + self.fc_spinal_layer1 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size, layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_spinal_layer2 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size + layer_width, + layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_spinal_layer3 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size + layer_width, + layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_spinal_layer4 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size + layer_width, + layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_out = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(layer_width * 4, class_dim), + ) + + if load_checkpoint is not None: + self.model_dict = paddle.load(load_checkpoint)[0] + self.set_dict(self.model_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, + 'spinalnet_res101.pdparams') + self.model_dict = paddle.load(checkpoint) + self.set_dict(self.model_dict) + print("load pretrained checkpoint success") + + def transforms(self, images: Union[str, np.ndarray]): + transforms = T.Compose([ + T.Resize((256, 256)), + T.CenterCrop(224), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ], + to_rgb=True) + return transforms(images) + + def forward(self, inputs: paddle.Tensor): + y = self.backbone(inputs) + feature = y + y = paddle.flatten(y, 1) + + y1 = self.fc_spinal_layer1(y[:, 0:self.half_in_size]) + y2 = self.fc_spinal_layer2( + paddle.concat([y[:, self.half_in_size:2 * self.half_in_size], y1], + axis=1)) + y3 = self.fc_spinal_layer3( + paddle.concat([y[:, 0:self.half_in_size], y2], axis=1)) + y4 = self.fc_spinal_layer4( + paddle.concat([y[:, self.half_in_size:2 * self.half_in_size], y3], + axis=1)) + + y = paddle.concat([y1, y2, y3, y4], axis=1) + + y = self.fc_out(y) + return y, feature diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/README.md b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/README.md new file mode 100755 index 0000000000000000000000000000000000000000..55fdee68aeeea4b4d194d8d5c59a1a08f88455f0 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/README.md @@ -0,0 +1,21 @@ +## 概述 +* [SpinalNet](https://arxiv.org/abs/2007.03347)的网络结构如下图, + +[网络结构图](https://ai-studio-static-online.cdn.bcebos.com/0c58fff63018401089f92085a2aea5d46921351012e64ac4b7d5a8e1370c463f) + +该模型为SpinalNet在宝石数据集上的预训练模型,可以安装PaddleHub后完成一键预测及微调。 + +## 预训练模型 + +预训练模型位于https://aistudio.baidu.com/asistudio/datasetdetail/69923 + +## API +加载该模型后,使用PadduleHub2.0的默认图像分类API +``` +def Predict(images, batch_size, top_k): +``` + +**参数** +* images (list[str: 图片路径]) : 输入图像数据列表 +* batch_size: 默认值为1 +* top_k: 每张图片的前k个预测类别 diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/label_list.txt b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/label_list.txt new file mode 100755 index 0000000000000000000000000000000000000000..d64659423aca69b2788a40754855ace9010d735d --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/label_list.txt @@ -0,0 +1,87 @@ +Alexandrite +Almandine +Amazonite +Amber +Amethyst +Ametrine +Andalusite +Andradite +Aquamarine +Aventurine Green +Aventurine Yellow +Benitoite +Beryl Golden +Bixbite +Bloodstone +Blue Lace Agate +Carnelian +Cats Eye +Chalcedony +Chalcedony Blue +Chrome Diopside +Chrysoberyl +Chrysocolla +Chrysoprase +Citrine +Coral +Danburite +Diamond +Diaspore +Dumortierite +Emerald +Fluorite +Garnet Red +Goshenite +Grossular +Hessonite +Hiddenite +Iolite +Jade +Jasper +Kunzite +Kyanite +Labradorite +Lapis Lazuli +Larimar +Malachite +Moonstone +Morganite +Onyx Black +Onyx Green +Onyx Red +Opal +Pearl +Peridot +Prehnite +Pyrite +Pyrope +Quartz Beer +Quartz Lemon +Quartz Rose +Quartz Rutilated +Quartz Smoky +Rhodochrosite +Rhodolite +Rhodonite +Ruby +Sapphire Blue +Sapphire Pink +Sapphire Purple +Sapphire Yellow +Scapolite +Serpentine +Sodalite +Spessartite +Sphene +Spinel +Spodumene +Sunstone +Tanzanite +Tigers Eye +Topaz +Tourmaline +Tsavorite +Turquoise +Variscite +Zircon +Zoisite diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py new file mode 100755 index 0000000000000000000000000000000000000000..a3ad0e3d2e16e30a18c29c2806e7e56e78070cd0 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_res50_gemstone/module.py @@ -0,0 +1,255 @@ +# copyright (c) 2021 nanting03. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddlehub.vision.transforms as T +from paddlehub.module.module import moduleinfo +from paddlehub.module.cv_module import ImageClassifierModule + + +class BottleneckBlock(nn.Layer): + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BottleneckBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) + self.bn1 = norm_layer(width) + + self.conv2 = nn.Conv2D(width, + width, + 3, + padding=dilation, + stride=stride, + groups=groups, + dilation=dilation, + bias_attr=False) + self.bn2 = norm_layer(width) + + self.conv3 = nn.Conv2D(width, + planes * self.expansion, + 1, + bias_attr=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + def __init__(self, block=BottleneckBlock, depth=50, with_pool=True): + super(ResNet, self).__init__() + layer_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3] + } + layers = layer_cfg[depth] + self.with_pool = with_pool + self._norm_layer = nn.BatchNorm2D + + self.inplanes = 64 + self.dilation = 1 + + self.conv1 = nn.Conv2D(3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D(self.inplanes, + planes * block.expansion, + 1, + stride=stride, + bias_attr=False), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, 1, 64, + previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.with_pool: + x = self.avgpool(x) + + return x + + +@moduleinfo(name="spinalnet_res50_gemstone", + type="CV/classification", + author="nanting03", + author_email="975348977@qq.com", + summary="spinalnet_res50_gemstone is a classification model, " + "this module is trained with Gemstone dataset.", + version="1.0.0", + meta=ImageClassifierModule) +class SpinalNet_ResNet50(nn.Layer): + def __init__(self, label_list: list = None, load_checkpoint: str = None): + super(SpinalNet_ResNet50, self).__init__() + + if label_list is not None: + self.labels = label_list + class_dim = len(self.labels) + else: + label_list = [] + label_file = os.path.join(self.directory, 'label_list.txt') + files = open(label_file) + for line in files.readlines(): + line = line.strip('\n') + label_list.append(line) + self.labels = label_list + class_dim = len(self.labels) + + self.backbone = ResNet() + + half_in_size = round(2048 / 2) + layer_width = 20 + + self.half_in_size = half_in_size + + self.fc_spinal_layer1 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size, layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_spinal_layer2 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size + layer_width, + layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_spinal_layer3 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size + layer_width, + layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_spinal_layer4 = nn.Sequential( + nn.Dropout(p=0.5), nn.Linear(half_in_size + layer_width, + layer_width), + nn.BatchNorm1D(layer_width), nn.ReLU()) + self.fc_out = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(layer_width * 4, class_dim), + ) + + if load_checkpoint is not None: + self.model_dict = paddle.load(load_checkpoint)[0] + self.set_dict(self.model_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, + 'spinalnet_res50.pdparams') + self.model_dict = paddle.load(checkpoint) + self.set_dict(self.model_dict) + print("load pretrained checkpoint success") + + def transforms(self, images: Union[str, np.ndarray]): + transforms = T.Compose([ + T.Resize((256, 256)), + T.CenterCrop(224), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ], + to_rgb=True) + return transforms(images) + + def forward(self, inputs: paddle.Tensor): + y = self.backbone(inputs) + feature = y + y = paddle.flatten(y, 1) + + y1 = self.fc_spinal_layer1(y[:, 0:self.half_in_size]) + y2 = self.fc_spinal_layer2( + paddle.concat([y[:, self.half_in_size:2 * self.half_in_size], y1], + axis=1)) + y3 = self.fc_spinal_layer3( + paddle.concat([y[:, 0:self.half_in_size], y2], axis=1)) + y4 = self.fc_spinal_layer4( + paddle.concat([y[:, self.half_in_size:2 * self.half_in_size], y3], + axis=1)) + + y = paddle.concat([y1, y2, y3, y4], axis=1) + + y = self.fc_out(y) + return y, feature diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/README.md b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/README.md new file mode 100755 index 0000000000000000000000000000000000000000..55fdee68aeeea4b4d194d8d5c59a1a08f88455f0 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/README.md @@ -0,0 +1,21 @@ +## 概述 +* [SpinalNet](https://arxiv.org/abs/2007.03347)的网络结构如下图, + +[网络结构图](https://ai-studio-static-online.cdn.bcebos.com/0c58fff63018401089f92085a2aea5d46921351012e64ac4b7d5a8e1370c463f) + +该模型为SpinalNet在宝石数据集上的预训练模型,可以安装PaddleHub后完成一键预测及微调。 + +## 预训练模型 + +预训练模型位于https://aistudio.baidu.com/asistudio/datasetdetail/69923 + +## API +加载该模型后,使用PadduleHub2.0的默认图像分类API +``` +def Predict(images, batch_size, top_k): +``` + +**参数** +* images (list[str: 图片路径]) : 输入图像数据列表 +* batch_size: 默认值为1 +* top_k: 每张图片的前k个预测类别 diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/label_list.txt b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/label_list.txt new file mode 100755 index 0000000000000000000000000000000000000000..d64659423aca69b2788a40754855ace9010d735d --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/label_list.txt @@ -0,0 +1,87 @@ +Alexandrite +Almandine +Amazonite +Amber +Amethyst +Ametrine +Andalusite +Andradite +Aquamarine +Aventurine Green +Aventurine Yellow +Benitoite +Beryl Golden +Bixbite +Bloodstone +Blue Lace Agate +Carnelian +Cats Eye +Chalcedony +Chalcedony Blue +Chrome Diopside +Chrysoberyl +Chrysocolla +Chrysoprase +Citrine +Coral +Danburite +Diamond +Diaspore +Dumortierite +Emerald +Fluorite +Garnet Red +Goshenite +Grossular +Hessonite +Hiddenite +Iolite +Jade +Jasper +Kunzite +Kyanite +Labradorite +Lapis Lazuli +Larimar +Malachite +Moonstone +Morganite +Onyx Black +Onyx Green +Onyx Red +Opal +Pearl +Peridot +Prehnite +Pyrite +Pyrope +Quartz Beer +Quartz Lemon +Quartz Rose +Quartz Rutilated +Quartz Smoky +Rhodochrosite +Rhodolite +Rhodonite +Ruby +Sapphire Blue +Sapphire Pink +Sapphire Purple +Sapphire Yellow +Scapolite +Serpentine +Sodalite +Spessartite +Sphene +Spinel +Spodumene +Sunstone +Tanzanite +Tigers Eye +Topaz +Tourmaline +Tsavorite +Turquoise +Variscite +Zircon +Zoisite diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py new file mode 100755 index 0000000000000000000000000000000000000000..a0ae5b774035d7c48d43541692ff786e287c1ac3 --- /dev/null +++ b/modules/thirdparty/image/classification/SpinalNet_Gemstones/spinalnet_vgg16_gemstone/module.py @@ -0,0 +1,188 @@ +# copyright (c) 2021 nanting03. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddlehub.vision.transforms as T +from paddlehub.module.module import moduleinfo +from paddlehub.module.cv_module import ImageClassifierModule + +import paddle +from paddle import nn + + +class VGG(nn.Layer): + def __init__(self, features, with_pool=True): + super(VGG, self).__init__() + self.features = features + self.with_pool = with_pool + + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) + + def forward(self, x): + x = self.features(x) + + if self.with_pool: + x = self.avgpool(x) + + return x + + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2D(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2D(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2D(v), nn.ReLU()] + else: + layers += [conv2d, nn.ReLU()] + in_channels = v + return nn.Sequential(*layers) + + +cfgs = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': + [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [ + 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, + 512, 512, 'M' + ], + 'E': [ + 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, + 'M', 512, 512, 512, 512, 'M' + ], +} + + +def _vgg(arch, cfg, batch_norm, **kwargs): + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) + return model + + +def vgg16(batch_norm=False, **kwargs): + model_name = 'vgg16' + if batch_norm: + model_name += ('_bn') + return _vgg(model_name, 'D', batch_norm, **kwargs) + + +@moduleinfo(name="spinalnet_vgg16_gemstone", + type="CV/classification", + author="nanting03", + author_email="975348977@qq.com", + summary="spinalnet_vgg16_gemstone is a classification model, " + "this module is trained with Gemstone dataset.", + version="1.0.0", + meta=ImageClassifierModule) +class SpinalNet_VGG16(nn.Layer): + def __init__(self, label_list: list = None, load_checkpoint: str = None): + super(SpinalNet_VGG16, self).__init__() + + if label_list is not None: + self.labels = label_list + class_dim = len(self.labels) + else: + label_list = [] + label_file = os.path.join(self.directory, 'label_list.txt') + files = open(label_file) + for line in files.readlines(): + line = line.strip('\n') + label_list.append(line) + self.labels = label_list + class_dim = len(self.labels) + + self.backbone = vgg16() + + half_in_size = round(512 * 7 * 7 / 2) + layer_width = 4096 + + self.half_in_size = half_in_size + + self.fc_spinal_layer1 = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(half_in_size, layer_width), + nn.BatchNorm1D(layer_width), + nn.ReLU(), + ) + self.fc_spinal_layer2 = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(half_in_size + layer_width, layer_width), + nn.BatchNorm1D(layer_width), + nn.ReLU(), + ) + self.fc_spinal_layer3 = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(half_in_size + layer_width, layer_width), + nn.BatchNorm1D(layer_width), + nn.ReLU(), + ) + self.fc_spinal_layer4 = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(half_in_size + layer_width, layer_width), + nn.BatchNorm1D(layer_width), + nn.ReLU(), + ) + self.fc_out = nn.Sequential(nn.Dropout(p=0.5), + nn.Linear(layer_width * 4, class_dim)) + + if load_checkpoint is not None: + self.model_dict = paddle.load(load_checkpoint)[0] + self.set_dict(self.model_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, + 'spinalnet_vgg16.pdparams') + self.model_dict = paddle.load(checkpoint) + self.set_dict(self.model_dict) + print("load pretrained checkpoint success") + + def transforms(self, images: Union[str, np.ndarray]): + transforms = T.Compose([ + T.Resize((256, 256)), + T.CenterCrop(224), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ], + to_rgb=True) + return transforms(images) + + def forward(self, inputs: paddle.Tensor): + + y = self.backbone(inputs) + feature = y + y = paddle.flatten(y, 1) + + y1 = self.fc_spinal_layer1(y[:, 0:self.half_in_size]) + y2 = self.fc_spinal_layer2( + paddle.concat([y[:, self.half_in_size:2 * self.half_in_size], y1], + axis=1)) + y3 = self.fc_spinal_layer3( + paddle.concat([y[:, 0:self.half_in_size], y2], axis=1)) + y4 = self.fc_spinal_layer4( + paddle.concat([y[:, self.half_in_size:2 * self.half_in_size], y3], + axis=1)) + + y = paddle.concat([y1, y2, y3, y4], axis=1) + + y = self.fc_out(y) + return y, feature diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Cats Eye/cats_eye_3.jpg b/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Cats Eye/cats_eye_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b32a8aa53994a18d34f927976f946f2b09853b2 Binary files /dev/null and b/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Cats Eye/cats_eye_3.jpg differ diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Fluorite/fluorite_18.jpg b/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Fluorite/fluorite_18.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c6234cf9c889f002bb589fdd2ff754b62f11c84d Binary files /dev/null and b/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Fluorite/fluorite_18.jpg differ diff --git a/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Kunzite/kunzite_28.jpg b/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Kunzite/kunzite_28.jpg new file mode 100644 index 0000000000000000000000000000000000000000..70149bc7ad37f9632200e28826b1c07b21ac2142 Binary files /dev/null and b/modules/thirdparty/image/classification/SpinalNet_Gemstones/testImages/Kunzite/kunzite_28.jpg differ