提交 d8f049ae 编写于 作者: Z zh-hike 提交者: Walter

增加代码规范,删除无用空格

上级 c5a29dba
......@@ -91,9 +91,7 @@ class RecModel(TheseusLayer):
out = dict()
x = self.backbone(x)
out["backbone"] = x
if self.neck is not None:
feat = self.neck(x)
out["neck"] = feat
......
......@@ -143,7 +143,6 @@ class Wide_ResNet(TheseusLayer):
# if use the output of projection head for classification
self.proj_after = proj_after
self.low_dim = low_dim
channels = [
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
]
......@@ -184,7 +183,6 @@ class Wide_ResNet(TheseusLayer):
else:
self.fc = nn.Linear(channels[3], num_classes)
self.channels = channels[3]
# projection head
if self.proj:
self.l2norm = Normalize(2)
......@@ -202,7 +200,6 @@ class Wide_ResNet(TheseusLayer):
feat = self.relu(self.bn1(feat))
feat = F.adaptive_avg_pool2d(feat, 1)
feat = paddle.reshape(feat, [-1, self.channels])
if self.proj:
pfeat = self.fc1(feat)
pfeat = self.relu_mlp(pfeat)
......
......@@ -60,7 +60,7 @@ Loss:
UnLabelLoss:
Train:
- CCSSLCeLoss:
- CCSSLCELoss:
weight: 1.
- SoftSupConLoss:
weight: 1.0
......
......@@ -53,14 +53,12 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
inputs_x, targets_x = label_data_batch
inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch[:3]
batch_size_label = inputs_x.shape[0]
inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0)
loss_dict, logits_label = get_loss(engine, inputs, batch_size_label,
temperture, threshold, targets_x,
)
loss = loss_dict['loss']
loss.backward()
......@@ -76,13 +74,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
if engine.ema:
engine.model_ema.update(engine.model)
update_metric(engine, logits_label, label_data_batch, batch_size)
update_loss(engine, loss_dict, batch_size)
engine.time_info['batch_cost'].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(engine, batch_size, epoch_id, iter_id)
......@@ -101,6 +95,7 @@ def get_loss(engine,
**kwargs
):
out = engine.model(inputs)
logits, feats = out['logits'], out['features']
feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3)
feat_x = feats[:batch_size_label]
......@@ -118,9 +113,7 @@ def get_loss(engine,
'mask': mask,
'max_probs': max_probs,
}
unlabel_loss = engine.unlabel_train_loss_func(feats, batch)
loss_dict = {}
for k, v in loss_dict_label.items():
if k != 'loss':
......
......@@ -17,7 +17,7 @@ from .triplet import TripletLoss, TripletLossV2
from .tripletangularmarginloss import TripletAngularMarginLoss, TripletAngularMarginLoss_XBM
from .supconloss import SupConLoss
from .softsuploss import SoftSupConLoss
from .ccssl_loss import CCSSLCeLoss
from .ccssl_loss import CCSSLCELoss
from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
......
......@@ -4,9 +4,9 @@ import copy
import paddle.nn as nn
class CCSSLCeLoss(nn.Layer):
class CCSSLCELoss(nn.Layer):
def __init__(self, **kwargs):
super(CCSSLCeLoss, self).__init__()
super(CCSSLCELoss, self).__init__()
self.celoss = nn.CrossEntropyLoss(reduction='none')
def forward(self, inputs, batch, **kwargs):
......@@ -16,4 +16,4 @@ class CCSSLCeLoss(nn.Layer):
loss_u = self.celoss(logits_s1, p_targets_u_w) * mask
loss_u = loss_u.mean()
return {'CCSSLCeLoss': loss_u}
return {'CCSSLCELoss': loss_u}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -47,9 +47,7 @@ class SoftSupConLoss(nn.Layer):
labels = labels.reshape((-1, 1))
mask = paddle.equal(labels, labels.T).astype('float32')
max_probs = max_probs.reshape((-1, 1))
score_mask = paddle.matmul(max_probs, max_probs.T)
mask = paddle.multiply(mask, score_mask)
contrast_count = feat.shape[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册