提交 09200a31 编写于 作者: H HydrogenSulfate

remove redundant code, fix bugs in lr.step, merge GoodsDataset into Vehicle

上级 30cbb183
...@@ -88,7 +88,7 @@ Optimizer: ...@@ -88,7 +88,7 @@ Optimizer:
DataLoader: DataLoader:
Train: Train:
dataset: dataset:
name: GoodsDataset name: VeriWild
image_root: ./dataset/SOP image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/train_list.txt cls_label_path: ./dataset/SOP/train_list.txt
backend: pil backend: pil
...@@ -117,7 +117,7 @@ DataLoader: ...@@ -117,7 +117,7 @@ DataLoader:
Eval: Eval:
Gallery: Gallery:
dataset: dataset:
name: GoodsDataset name: VeriWild
image_root: ./dataset/SOP image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/test_list.txt cls_label_path: ./dataset/SOP/test_list.txt
backend: pil backend: pil
...@@ -141,7 +141,7 @@ DataLoader: ...@@ -141,7 +141,7 @@ DataLoader:
Query: Query:
dataset: dataset:
name: GoodsDataset name: VeriWild
image_root: ./dataset/SOP image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/test_list.txt cls_label_path: ./dataset/SOP/test_list.txt
backend: pil backend: pil
......
...@@ -25,7 +25,6 @@ from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset ...@@ -25,7 +25,6 @@ from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
from ppcls.data.dataloader.common_dataset import create_operators from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.goods_dataset import GoodsDataset
from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset from ppcls.data.dataloader.mix_dataset import MixDataset
......
...@@ -82,26 +82,6 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler): ...@@ -82,26 +82,6 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
avai_pids = copy.deepcopy(self.pids) avai_pids = copy.deepcopy(self.pids)
return batch_idxs_dict, avai_pids, count return batch_idxs_dict, avai_pids, count
def __iter__(self):
batch_idxs_dict, avai_pids, count = self._prepare_batch()
for _ in range(self.max_iters):
final_idxs = []
if len(avai_pids) < self.num_pids_per_batch:
batch_idxs_dict, avai_pids, count = self._prepare_batch()
selected_pids = np.random.choice(
avai_pids, self.num_pids_per_batch, False, count / count.sum())
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
pid_idx = avai_pids.index(pid)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.pop(pid_idx)
count = np.delete(count, pid_idx)
else:
count[pid_idx] = len(batch_idxs_dict[pid])
yield final_idxs
def __iter__(self): def __iter__(self):
# prepare # prepare
batch_idxs_dict, avai_pids, count = self._prepare_batch() batch_idxs_dict, avai_pids, count = self._prepare_batch()
......
from __future__ import print_function
import os
from typing import Callable, List
import numpy as np
import paddle
from paddle.io import Dataset
from PIL import Image
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
class GoodsDataset(Dataset):
"""Dataset for Goods, such as SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def __init__(self,
image_root: str,
cls_label_path: str,
transform_ops: List[Callable]=None,
backend="cv2",
relabel=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel)
def _load_anno(self, seed=None, relabel=False):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = []
self.labels = []
self.cameras = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if relabel:
label_set = set()
for line in lines:
line = line.strip().split()
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for line in lines:
line = line.strip().split()
self.images.append(os.path.join(self._img_root, line[0]))
if relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
self.cameras.append(np.int64(line[2]))
assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist."
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert("RGB")
if self.backend == "cv2":
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
...@@ -19,8 +19,7 @@ import paddle ...@@ -19,8 +19,7 @@ import paddle
from paddle.io import Dataset from paddle.io import Dataset
import os import os
import cv2 import cv2
from PIL import Image
from ppcls.data import preprocess
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
from ppcls.utils import logger from ppcls.utils import logger
from .common_dataset import create_operators from .common_dataset import create_operators
...@@ -89,15 +88,30 @@ class CompCars(Dataset): ...@@ -89,15 +88,30 @@ class CompCars(Dataset):
class VeriWild(Dataset): class VeriWild(Dataset):
def __init__(self, image_root, cls_label_path, transform_ops=None): """Dataset for Vehicle and other similar data structure, such as VeRI-Wild, SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2",
relabel=False):
self._img_root = image_root self._img_root = image_root
self._cls_path = cls_label_path self._cls_path = cls_label_path
if transform_ops: if transform_ops:
self._transform_ops = create_operators(transform_ops) self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype() self._dtype = paddle.get_default_dtype()
self._load_anno() self._load_anno(relabel)
def _load_anno(self): def _load_anno(self, relabel):
assert os.path.exists( assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist." self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists( assert os.path.exists(
...@@ -107,22 +121,40 @@ class VeriWild(Dataset): ...@@ -107,22 +121,40 @@ class VeriWild(Dataset):
self.cameras = [] self.cameras = []
with open(self._cls_path) as fd: with open(self._cls_path) as fd:
lines = fd.readlines() lines = fd.readlines()
if relabel:
label_set = set()
for line in lines:
line = line.strip().split()
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
for line in lines: for line in lines:
line = line.strip().split() line = line.strip().split()
self.images.append(os.path.join(self._img_root, line[0])) self.images.append(os.path.join(self._img_root, line[0]))
self.labels.append(np.int64(line[1])) if relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
if len(line) >= 3: if len(line) >= 3:
self.cameras.append(np.int64(line[2])) self.cameras.append(np.int64(line[2]))
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1]), \
f"path {self.images[-1]} does not exist."
self.has_camera = len(self.cameras) > 0 self.has_camera = len(self.cameras) > 0
def __getitem__(self, idx): def __getitem__(self, idx):
try: try:
with open(self.images[idx], 'rb') as f: if self.backend == "cv2":
img = f.read() with open(self.images[idx], 'rb') as f:
img = f.read()
else:
img = Image.open(self.images[idx]).convert("RGB")
if self._transform_ops: if self._transform_ops:
img = transform(img, self._transform_ops) img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1)) if self.backend == "cv2":
img = img.transpose((2, 0, 1))
if self.has_camera: if self.has_camera:
return (img, self.labels[idx], self.cameras[idx]) return (img, self.labels[idx], self.cameras[idx])
else: else:
......
...@@ -42,6 +42,7 @@ from ppcls.data.utils.get_image_list import get_image_list ...@@ -42,6 +42,7 @@ from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators from ppcls.data import create_operators
from ppcls.engine.train import train_epoch from ppcls.engine.train import train_epoch
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
...@@ -377,7 +378,7 @@ class Engine(object): ...@@ -377,7 +378,7 @@ class Engine(object):
# step lr (by epoch) according to given metric, such as acc # step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)): for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \ if getattr(self.lr_sch[i], "by_epoch", False) and \
self.lr_sch[i].__class__.__name__ == "ReduceOnPlateau": type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc) self.lr_sch[i].step(acc)
if acc > best_metric["metric"]: if acc > best_metric["metric"]:
......
...@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function ...@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
import time import time
import paddle import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler from ppcls.utils import profiler
...@@ -98,7 +98,8 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -98,7 +98,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr(by epoch) # step lr(by epoch)
for i in range(len(engine.lr_sch)): for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False): if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step() engine.lr_sch[i].step()
......
...@@ -53,14 +53,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id): ...@@ -53,14 +53,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
ips_msg = "ips: {:.5f} samples/s".format( ips_msg = "ips: {:.5f} samples/s".format(
batch_size / trainer.time_info["batch_cost"].avg) batch_size / trainer.time_info["batch_cost"].avg)
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 eta_sec = (
) * trainer.max_iter - iter_id (trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.time_info["batch_cost"].avg ) * trainer.max_iter - iter_id) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
epoch_id, trainer.config["Global"]["epochs"], iter_id, epoch_id, trainer.config["Global"]["epochs"], iter_id,
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg, trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
eta_msg))
for i, lr in enumerate(trainer.lr_sch): for i, lr in enumerate(trainer.lr_sch):
logger.scaler( logger.scaler(
...@@ -74,3 +73,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id): ...@@ -74,3 +73,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
value=trainer.output_info[key].avg, value=trainer.output_info[key].avg,
step=trainer.global_step, step=trainer.global_step,
writer=trainer.vdl_writer) writer=trainer.vdl_writer)
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册