diff --git a/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md b/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md index 2548754004570f05821df2332c671512bcd57b2f..03bd87b18e3b00b5332d148432f53aa0b4342543 100644 --- a/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md +++ b/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md @@ -71,7 +71,7 @@ cifar10数据在训练过程中会自动下载到默认缓存路径 `~/.cache/pa 单卡训练执行以下命令 ``` -python tools/train.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml +python tools/train.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000.yaml ``` 4卡训练执行以下操作 @@ -86,7 +86,7 @@ python -m paddle.distributed.launch --gpus='0,1,2,3' tools/train.py -c ppcls/con 准备用于评估的 `*.pdparams` 模型参数文件,可以使用训练好的模型,也可以使用 *4. 模型训练* 中保存的模型。 * 以训练过程中保存的 `best_model_ema.ema.pdparams`为例,执行如下命令即可进行评估。 ``` -python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml -o Global.pretrained_model="./output/WideResNetCCSSL/best_model_ema.ema" +python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml -o Global.pretrained_model="./output/WideResNet/best_model_ema.ema" ``` * 以训练好的模型为例,下载提供的已经训练好的模型,到 `PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。 diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 0ce446ca669adefaa170506abb5a75d3db95d648..1792b6306607dd4ed154b7ac1fb711d8b181f456 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -85,10 +85,18 @@ class RecModel(TheseusLayer): self.head = build_gear(config["Head"]) else: self.head = None + + if "Decoup" in config: + self.decoup = build_gear(config['Decoup']) + else: + self.decoup = None def forward(self, x, label=None): + out = dict() x = self.backbone(x) + if self.decoup is not None: + return self.decoup(x) out["backbone"] = x if self.neck is not None: x = self.neck(x) diff --git a/ppcls/arch/backbone/model_zoo/wideresnet.py b/ppcls/arch/backbone/model_zoo/wideresnet.py index f33a3f264452873c41fe7f03dcce97e8905ccadc..12b8f77dafce7c31600a17066169d15175ce37b5 100644 --- a/ppcls/arch/backbone/model_zoo/wideresnet.py +++ b/ppcls/arch/backbone/model_zoo/wideresnet.py @@ -201,8 +201,6 @@ class Wide_ResNet(nn.Layer): feat = self.relu(self.bn1(feat)) feat = F.adaptive_avg_pool2d(feat, 1) feat = paddle.reshape(feat, [-1, self.channels]) - if not self.training: - return self.fc(feat) if self.proj: pfeat = self.fc1(feat) diff --git a/ppcls/arch/gears/__init__.py b/ppcls/arch/gears/__init__.py index 871967804e21c362935915942aa3f621207b934e..f2eed40e62d499fce67cff7e7a99b5ae7e1b3c5c 100644 --- a/ppcls/arch/gears/__init__.py +++ b/ppcls/arch/gears/__init__.py @@ -20,6 +20,7 @@ from .vehicle_neck import VehicleNeck from paddle.nn import Tanh from .bnneck import BNNeck from .adamargin import AdaMargin +from .decoup import Decoup __all__ = ['build_gear'] @@ -27,7 +28,7 @@ __all__ = ['build_gear'] def build_gear(config): support_dict = [ 'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh', - 'BNNeck', 'AdaMargin' + 'BNNeck', 'AdaMargin', 'FRFBNeck', 'Decoup' ] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppcls/arch/gears/decoup.py b/ppcls/arch/gears/decoup.py new file mode 100644 index 0000000000000000000000000000000000000000..6520e6978cd32c57909378371d06a25a8ec477c9 --- /dev/null +++ b/ppcls/arch/gears/decoup.py @@ -0,0 +1,16 @@ +import paddle +import paddle.nn as nn + + +class Decoup(nn.Layer): + def __init__(self, logits_index, features_index, **kwargs): + super(Decoup, self).__init__() + self.logits_index = logits_index + self.features_index = features_index + + + def forward(self, out, **kwargs): + assert isinstance(out, (list, tuple)), 'out must be list or tuple' + out = {'logits': out[self.logits_index], 'features':out[self.features_index]} + return out + diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml index 3ff233ac1c054d1ab089fb700278390e0409f30f..75a57915cee110eb17f45fbee965cfafc9ad2fe2 100644 --- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml +++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml @@ -23,14 +23,23 @@ EMA: decay: 0.999 Arch: - name: WideResNet - widen_factor: 2 - depth: 28 - dropout: 0 # CCSSL为 drop_rate - num_classes: &sign_num_classes 10 - low_dim: 64 - proj: true - proj_after: false + name: RecModel + infer_output_key: logits + infer_add_softmax: false + Backbone: + name: WideResNet + widen_factor: 2 + depth: 28 + dropout: 0 # CCSSL为 drop_rate + num_classes: &sign_num_classes 10 + low_dim: 64 + proj: true + proj_after: false + + Decoup: + name: Decoup + logits_index: 0 + features_index: 1 use_sync_bn: true diff --git a/ppcls/engine/train/train_fixmatch_ccssl.py b/ppcls/engine/train/train_fixmatch_ccssl.py index f74e658881feb45f6378ce26d5381e8e708c3fec..1d16f1206de4d29b5d23d8d0e1592253dd57777c 100644 --- a/ppcls/engine/train/train_fixmatch_ccssl.py +++ b/ppcls/engine/train/train_fixmatch_ccssl.py @@ -10,6 +10,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info from ppcls.utils import profiler from paddle.nn import functional as F import numpy as np +import paddle # from reprod_log import ReprodLogger @@ -20,6 +21,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): # epoch = 0 ############################################################## + paddle.save(engine.model.state_dict(), '../recmodel.pdparams') + + assert 1==0 tic = time.time() if not hasattr(engine, 'train_dataloader_iter'): engine.train_dataloader_iter = iter(engine.train_dataloader) @@ -81,7 +85,7 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch batch_size_label = inputs_x.shape[0] - inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2]) + inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0) loss_dict, logits_label = get_loss(engine, inputs, batch_size_label, temperture, threshold, targets_x, @@ -134,8 +138,9 @@ def get_loss(engine, targets_x, **kwargs ): - - logits, feats = engine.model(inputs) + out = engine.model(inputs) + logits, feats = out['logits'], out['features'] + # logits, feats = engine.model(inputs) feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3) feat_x = feats[:batch_size_label] logits_x = logits[:batch_size_label]