提交 4db13244 编写于 作者: Z zh-hike 提交者: Walter

增加RecModel配合WideResNet代码以及参数转换成RecModel

上级 7823f340
...@@ -71,7 +71,7 @@ cifar10数据在训练过程中会自动下载到默认缓存路径 `~/.cache/pa ...@@ -71,7 +71,7 @@ cifar10数据在训练过程中会自动下载到默认缓存路径 `~/.cache/pa
单卡训练执行以下命令 单卡训练执行以下命令
``` ```
python tools/train.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml python tools/train.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000.yaml
``` ```
4卡训练执行以下操作 4卡训练执行以下操作
...@@ -86,7 +86,7 @@ python -m paddle.distributed.launch --gpus='0,1,2,3' tools/train.py -c ppcls/con ...@@ -86,7 +86,7 @@ python -m paddle.distributed.launch --gpus='0,1,2,3' tools/train.py -c ppcls/con
准备用于评估的 `*.pdparams` 模型参数文件,可以使用训练好的模型,也可以使用 *4. 模型训练* 中保存的模型。 准备用于评估的 `*.pdparams` 模型参数文件,可以使用训练好的模型,也可以使用 *4. 模型训练* 中保存的模型。
* 以训练过程中保存的 `best_model_ema.ema.pdparams`为例,执行如下命令即可进行评估。 * 以训练过程中保存的 `best_model_ema.ema.pdparams`为例,执行如下命令即可进行评估。
``` ```
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml -o Global.pretrained_model="./output/WideResNetCCSSL/best_model_ema.ema" python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml -o Global.pretrained_model="./output/WideResNet/best_model_ema.ema"
``` ```
* 以训练好的模型为例,下载提供的已经训练好的模型,到 `PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。 * 以训练好的模型为例,下载提供的已经训练好的模型,到 `PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。
......
...@@ -86,9 +86,17 @@ class RecModel(TheseusLayer): ...@@ -86,9 +86,17 @@ class RecModel(TheseusLayer):
else: else:
self.head = None self.head = None
if "Decoup" in config:
self.decoup = build_gear(config['Decoup'])
else:
self.decoup = None
def forward(self, x, label=None): def forward(self, x, label=None):
out = dict() out = dict()
x = self.backbone(x) x = self.backbone(x)
if self.decoup is not None:
return self.decoup(x)
out["backbone"] = x out["backbone"] = x
if self.neck is not None: if self.neck is not None:
x = self.neck(x) x = self.neck(x)
......
...@@ -201,8 +201,6 @@ class Wide_ResNet(nn.Layer): ...@@ -201,8 +201,6 @@ class Wide_ResNet(nn.Layer):
feat = self.relu(self.bn1(feat)) feat = self.relu(self.bn1(feat))
feat = F.adaptive_avg_pool2d(feat, 1) feat = F.adaptive_avg_pool2d(feat, 1)
feat = paddle.reshape(feat, [-1, self.channels]) feat = paddle.reshape(feat, [-1, self.channels])
if not self.training:
return self.fc(feat)
if self.proj: if self.proj:
pfeat = self.fc1(feat) pfeat = self.fc1(feat)
......
...@@ -20,6 +20,7 @@ from .vehicle_neck import VehicleNeck ...@@ -20,6 +20,7 @@ from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh from paddle.nn import Tanh
from .bnneck import BNNeck from .bnneck import BNNeck
from .adamargin import AdaMargin from .adamargin import AdaMargin
from .decoup import Decoup
__all__ = ['build_gear'] __all__ = ['build_gear']
...@@ -27,7 +28,7 @@ __all__ = ['build_gear'] ...@@ -27,7 +28,7 @@ __all__ = ['build_gear']
def build_gear(config): def build_gear(config):
support_dict = [ support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh', 'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
'BNNeck', 'AdaMargin' 'BNNeck', 'AdaMargin', 'FRFBNeck', 'Decoup'
] ]
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
import paddle
import paddle.nn as nn
class Decoup(nn.Layer):
def __init__(self, logits_index, features_index, **kwargs):
super(Decoup, self).__init__()
self.logits_index = logits_index
self.features_index = features_index
def forward(self, out, **kwargs):
assert isinstance(out, (list, tuple)), 'out must be list or tuple'
out = {'logits': out[self.logits_index], 'features':out[self.features_index]}
return out
...@@ -23,6 +23,10 @@ EMA: ...@@ -23,6 +23,10 @@ EMA:
decay: 0.999 decay: 0.999
Arch: Arch:
name: RecModel
infer_output_key: logits
infer_add_softmax: false
Backbone:
name: WideResNet name: WideResNet
widen_factor: 2 widen_factor: 2
depth: 28 depth: 28
...@@ -32,6 +36,11 @@ Arch: ...@@ -32,6 +36,11 @@ Arch:
proj: true proj: true
proj_after: false proj_after: false
Decoup:
name: Decoup
logits_index: 0
features_index: 1
use_sync_bn: true use_sync_bn: true
Loss: Loss:
......
...@@ -10,6 +10,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info ...@@ -10,6 +10,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info
from ppcls.utils import profiler from ppcls.utils import profiler
from paddle.nn import functional as F from paddle.nn import functional as F
import numpy as np import numpy as np
import paddle
# from reprod_log import ReprodLogger # from reprod_log import ReprodLogger
...@@ -20,6 +21,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): ...@@ -20,6 +21,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
# epoch = 0 # epoch = 0
############################################################## ##############################################################
paddle.save(engine.model.state_dict(), '../recmodel.pdparams')
assert 1==0
tic = time.time() tic = time.time()
if not hasattr(engine, 'train_dataloader_iter'): if not hasattr(engine, 'train_dataloader_iter'):
engine.train_dataloader_iter = iter(engine.train_dataloader) engine.train_dataloader_iter = iter(engine.train_dataloader)
...@@ -81,7 +85,7 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): ...@@ -81,7 +85,7 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch
batch_size_label = inputs_x.shape[0] batch_size_label = inputs_x.shape[0]
inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2]) inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0)
loss_dict, logits_label = get_loss(engine, inputs, batch_size_label, loss_dict, logits_label = get_loss(engine, inputs, batch_size_label,
temperture, threshold, targets_x, temperture, threshold, targets_x,
...@@ -134,8 +138,9 @@ def get_loss(engine, ...@@ -134,8 +138,9 @@ def get_loss(engine,
targets_x, targets_x,
**kwargs **kwargs
): ):
out = engine.model(inputs)
logits, feats = engine.model(inputs) logits, feats = out['logits'], out['features']
# logits, feats = engine.model(inputs)
feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3) feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3)
feat_x = feats[:batch_size_label] feat_x = feats[:batch_size_label]
logits_x = logits[:batch_size_label] logits_x = logits[:batch_size_label]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册