提交 d8f049ae 编写于 作者: Z zh-hike 提交者: Walter

增加代码规范,删除无用空格

上级 c5a29dba
...@@ -91,9 +91,7 @@ class RecModel(TheseusLayer): ...@@ -91,9 +91,7 @@ class RecModel(TheseusLayer):
out = dict() out = dict()
x = self.backbone(x) x = self.backbone(x)
out["backbone"] = x out["backbone"] = x
if self.neck is not None: if self.neck is not None:
feat = self.neck(x) feat = self.neck(x)
out["neck"] = feat out["neck"] = feat
......
...@@ -143,7 +143,6 @@ class Wide_ResNet(TheseusLayer): ...@@ -143,7 +143,6 @@ class Wide_ResNet(TheseusLayer):
# if use the output of projection head for classification # if use the output of projection head for classification
self.proj_after = proj_after self.proj_after = proj_after
self.low_dim = low_dim self.low_dim = low_dim
channels = [ channels = [
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor 16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
] ]
...@@ -184,7 +183,6 @@ class Wide_ResNet(TheseusLayer): ...@@ -184,7 +183,6 @@ class Wide_ResNet(TheseusLayer):
else: else:
self.fc = nn.Linear(channels[3], num_classes) self.fc = nn.Linear(channels[3], num_classes)
self.channels = channels[3] self.channels = channels[3]
# projection head # projection head
if self.proj: if self.proj:
self.l2norm = Normalize(2) self.l2norm = Normalize(2)
...@@ -202,7 +200,6 @@ class Wide_ResNet(TheseusLayer): ...@@ -202,7 +200,6 @@ class Wide_ResNet(TheseusLayer):
feat = self.relu(self.bn1(feat)) feat = self.relu(self.bn1(feat))
feat = F.adaptive_avg_pool2d(feat, 1) feat = F.adaptive_avg_pool2d(feat, 1)
feat = paddle.reshape(feat, [-1, self.channels]) feat = paddle.reshape(feat, [-1, self.channels])
if self.proj: if self.proj:
pfeat = self.fc1(feat) pfeat = self.fc1(feat)
pfeat = self.relu_mlp(pfeat) pfeat = self.relu_mlp(pfeat)
......
...@@ -60,7 +60,7 @@ Loss: ...@@ -60,7 +60,7 @@ Loss:
UnLabelLoss: UnLabelLoss:
Train: Train:
- CCSSLCeLoss: - CCSSLCELoss:
weight: 1. weight: 1.
- SoftSupConLoss: - SoftSupConLoss:
weight: 1.0 weight: 1.0
......
...@@ -53,14 +53,12 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): ...@@ -53,14 +53,12 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
inputs_x, targets_x = label_data_batch inputs_x, targets_x = label_data_batch
inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch[:3] inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch[:3]
batch_size_label = inputs_x.shape[0] batch_size_label = inputs_x.shape[0]
inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0) inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0)
loss_dict, logits_label = get_loss(engine, inputs, batch_size_label, loss_dict, logits_label = get_loss(engine, inputs, batch_size_label,
temperture, threshold, targets_x, temperture, threshold, targets_x,
) )
loss = loss_dict['loss'] loss = loss_dict['loss']
loss.backward() loss.backward()
...@@ -76,13 +74,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): ...@@ -76,13 +74,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
if engine.ema: if engine.ema:
engine.model_ema.update(engine.model) engine.model_ema.update(engine.model)
update_metric(engine, logits_label, label_data_batch, batch_size) update_metric(engine, logits_label, label_data_batch, batch_size)
update_loss(engine, loss_dict, batch_size) update_loss(engine, loss_dict, batch_size)
engine.time_info['batch_cost'].update(time.time() - tic) engine.time_info['batch_cost'].update(time.time() - tic)
if iter_id % print_batch_step == 0: if iter_id % print_batch_step == 0:
log_info(engine, batch_size, epoch_id, iter_id) log_info(engine, batch_size, epoch_id, iter_id)
...@@ -101,6 +95,7 @@ def get_loss(engine, ...@@ -101,6 +95,7 @@ def get_loss(engine,
**kwargs **kwargs
): ):
out = engine.model(inputs) out = engine.model(inputs)
logits, feats = out['logits'], out['features'] logits, feats = out['logits'], out['features']
feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3) feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3)
feat_x = feats[:batch_size_label] feat_x = feats[:batch_size_label]
...@@ -118,9 +113,7 @@ def get_loss(engine, ...@@ -118,9 +113,7 @@ def get_loss(engine,
'mask': mask, 'mask': mask,
'max_probs': max_probs, 'max_probs': max_probs,
} }
unlabel_loss = engine.unlabel_train_loss_func(feats, batch) unlabel_loss = engine.unlabel_train_loss_func(feats, batch)
loss_dict = {} loss_dict = {}
for k, v in loss_dict_label.items(): for k, v in loss_dict_label.items():
if k != 'loss': if k != 'loss':
......
...@@ -17,7 +17,7 @@ from .triplet import TripletLoss, TripletLossV2 ...@@ -17,7 +17,7 @@ from .triplet import TripletLoss, TripletLossV2
from .tripletangularmarginloss import TripletAngularMarginLoss, TripletAngularMarginLoss_XBM from .tripletangularmarginloss import TripletAngularMarginLoss, TripletAngularMarginLoss_XBM
from .supconloss import SupConLoss from .supconloss import SupConLoss
from .softsuploss import SoftSupConLoss from .softsuploss import SoftSupConLoss
from .ccssl_loss import CCSSLCeLoss from .ccssl_loss import CCSSLCELoss
from .pairwisecosface import PairwiseCosface from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss from .distanceloss import DistanceLoss
......
...@@ -4,9 +4,9 @@ import copy ...@@ -4,9 +4,9 @@ import copy
import paddle.nn as nn import paddle.nn as nn
class CCSSLCeLoss(nn.Layer): class CCSSLCELoss(nn.Layer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(CCSSLCeLoss, self).__init__() super(CCSSLCELoss, self).__init__()
self.celoss = nn.CrossEntropyLoss(reduction='none') self.celoss = nn.CrossEntropyLoss(reduction='none')
def forward(self, inputs, batch, **kwargs): def forward(self, inputs, batch, **kwargs):
...@@ -16,4 +16,4 @@ class CCSSLCeLoss(nn.Layer): ...@@ -16,4 +16,4 @@ class CCSSLCeLoss(nn.Layer):
loss_u = self.celoss(logits_s1, p_targets_u_w) * mask loss_u = self.celoss(logits_s1, p_targets_u_w) * mask
loss_u = loss_u.mean() loss_u = loss_u.mean()
return {'CCSSLCeLoss': loss_u} return {'CCSSLCELoss': loss_u}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -47,9 +47,7 @@ class SoftSupConLoss(nn.Layer): ...@@ -47,9 +47,7 @@ class SoftSupConLoss(nn.Layer):
labels = labels.reshape((-1, 1)) labels = labels.reshape((-1, 1))
mask = paddle.equal(labels, labels.T).astype('float32') mask = paddle.equal(labels, labels.T).astype('float32')
max_probs = max_probs.reshape((-1, 1)) max_probs = max_probs.reshape((-1, 1))
score_mask = paddle.matmul(max_probs, max_probs.T) score_mask = paddle.matmul(max_probs, max_probs.T)
mask = paddle.multiply(mask, score_mask) mask = paddle.multiply(mask, score_mask)
contrast_count = feat.shape[1] contrast_count = feat.shape[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册