提交 4db13244 编写于 作者: Z zh-hike 提交者: Walter

增加RecModel配合WideResNet代码以及参数转换成RecModel

上级 7823f340
......@@ -71,7 +71,7 @@ cifar10数据在训练过程中会自动下载到默认缓存路径 `~/.cache/pa
单卡训练执行以下命令
```
python tools/train.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml
python tools/train.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000.yaml
```
4卡训练执行以下操作
......@@ -86,7 +86,7 @@ python -m paddle.distributed.launch --gpus='0,1,2,3' tools/train.py -c ppcls/con
准备用于评估的 `*.pdparams` 模型参数文件,可以使用训练好的模型,也可以使用 *4. 模型训练* 中保存的模型。
* 以训练过程中保存的 `best_model_ema.ema.pdparams`为例,执行如下命令即可进行评估。
```
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml -o Global.pretrained_model="./output/WideResNetCCSSL/best_model_ema.ema"
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml -o Global.pretrained_model="./output/WideResNet/best_model_ema.ema"
```
* 以训练好的模型为例,下载提供的已经训练好的模型,到 `PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。
......
......@@ -86,9 +86,17 @@ class RecModel(TheseusLayer):
else:
self.head = None
if "Decoup" in config:
self.decoup = build_gear(config['Decoup'])
else:
self.decoup = None
def forward(self, x, label=None):
out = dict()
x = self.backbone(x)
if self.decoup is not None:
return self.decoup(x)
out["backbone"] = x
if self.neck is not None:
x = self.neck(x)
......
......@@ -201,8 +201,6 @@ class Wide_ResNet(nn.Layer):
feat = self.relu(self.bn1(feat))
feat = F.adaptive_avg_pool2d(feat, 1)
feat = paddle.reshape(feat, [-1, self.channels])
if not self.training:
return self.fc(feat)
if self.proj:
pfeat = self.fc1(feat)
......
......@@ -20,6 +20,7 @@ from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh
from .bnneck import BNNeck
from .adamargin import AdaMargin
from .decoup import Decoup
__all__ = ['build_gear']
......@@ -27,7 +28,7 @@ __all__ = ['build_gear']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
'BNNeck', 'AdaMargin'
'BNNeck', 'AdaMargin', 'FRFBNeck', 'Decoup'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
import paddle
import paddle.nn as nn
class Decoup(nn.Layer):
def __init__(self, logits_index, features_index, **kwargs):
super(Decoup, self).__init__()
self.logits_index = logits_index
self.features_index = features_index
def forward(self, out, **kwargs):
assert isinstance(out, (list, tuple)), 'out must be list or tuple'
out = {'logits': out[self.logits_index], 'features':out[self.features_index]}
return out
......@@ -23,6 +23,10 @@ EMA:
decay: 0.999
Arch:
name: RecModel
infer_output_key: logits
infer_add_softmax: false
Backbone:
name: WideResNet
widen_factor: 2
depth: 28
......@@ -32,6 +36,11 @@ Arch:
proj: true
proj_after: false
Decoup:
name: Decoup
logits_index: 0
features_index: 1
use_sync_bn: true
Loss:
......
......@@ -10,6 +10,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info
from ppcls.utils import profiler
from paddle.nn import functional as F
import numpy as np
import paddle
# from reprod_log import ReprodLogger
......@@ -20,6 +21,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
# epoch = 0
##############################################################
paddle.save(engine.model.state_dict(), '../recmodel.pdparams')
assert 1==0
tic = time.time()
if not hasattr(engine, 'train_dataloader_iter'):
engine.train_dataloader_iter = iter(engine.train_dataloader)
......@@ -81,7 +85,7 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch
batch_size_label = inputs_x.shape[0]
inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2])
inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0)
loss_dict, logits_label = get_loss(engine, inputs, batch_size_label,
temperture, threshold, targets_x,
......@@ -134,8 +138,9 @@ def get_loss(engine,
targets_x,
**kwargs
):
logits, feats = engine.model(inputs)
out = engine.model(inputs)
logits, feats = out['logits'], out['features']
# logits, feats = engine.model(inputs)
feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3)
feat_x = feats[:batch_size_label]
logits_x = logits[:batch_size_label]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册