From 1c68b63a66635e840ef4b0f9d0d4cf402317fdb1 Mon Sep 17 00:00:00 2001 From: zengshao0622 Date: Wed, 15 Feb 2023 03:39:02 +0000 Subject: [PATCH] fix some nots --- ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml | 2 +- ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml | 2 +- ppcls/engine/evaluation/retrieval.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml b/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml index 72e39816..a2382817 100644 --- a/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml +++ b/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml @@ -162,7 +162,7 @@ DataLoader: contrast: 0.4 saturation: 0.4 hue: 0.1 - p: 1.0 + p: 1.0 # refer to official settings - RandomGrayscale: p: 0.2 - NormalizeImage: diff --git a/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml b/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml index d7f640bb..79667edc 100644 --- a/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml +++ b/ppcls/configs/ssl/CCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml @@ -162,7 +162,7 @@ DataLoader: contrast: 0.4 saturation: 0.4 hue: 0.1 - p: 1.0 + p: 1.0 # refer to official settings - RandomGrayscale: p: 0.2 - NormalizeImage: diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 4d823809..875a01c3 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -15,7 +15,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform from collections import defaultdict import numpy as np @@ -257,7 +256,7 @@ def compute_re_ranking_dist(query_feat: paddle.Tensor, original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) V = np.zeros_like(original_dist).astype(np.float16) - initial_rank = np.argpartition(original_dist, range(1, k1 + 1)) # 22.2s + initial_rank = np.argpartition(original_dist, range(1, k1 + 1)) logger.info("Start re-ranking...") for p in range(num_all): -- GitLab