未验证 提交 c4f38454 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1205 from weisy11/develop

add MixDataset, MixSampler and PKSampler
...@@ -54,7 +54,7 @@ Optimizer: ...@@ -54,7 +54,7 @@ Optimizer:
momentum: 0.9 momentum: 0.9
lr: lr:
name: Cosine name: Cosine
learning_rate: 0.01 learning_rate: 0.04
regularizer: regularizer:
name: 'L2' name: 'L2'
coeff: 0.0001 coeff: 0.0001
...@@ -84,10 +84,10 @@ DataLoader: ...@@ -84,10 +84,10 @@ DataLoader:
- RandomErasing: - RandomErasing:
EPSILON: 0.5 EPSILON: 0.5
sampler: sampler:
name: DistributedRandomIdentitySampler name: PKSampler
batch_size: 128 batch_size: 128
num_instances: 2 sample_per_id: 2
drop_last: False drop_last: True
loader: loader:
num_workers: 6 num_workers: 6
...@@ -97,7 +97,7 @@ DataLoader: ...@@ -97,7 +97,7 @@ DataLoader:
dataset: dataset:
name: LogoDataset name: LogoDataset
image_root: "dataset/LogoDet-3K-crop/val/" image_root: "dataset/LogoDet-3K-crop/val/"
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+query.txt" cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+val.txt"
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -122,7 +122,7 @@ DataLoader: ...@@ -122,7 +122,7 @@ DataLoader:
dataset: dataset:
name: LogoDataset name: LogoDataset
image_root: "dataset/LogoDet-3K-crop/train/" image_root: "dataset/LogoDet-3K-crop/train/"
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+gallery.txt" cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+train.txt"
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
......
...@@ -54,7 +54,7 @@ Optimizer: ...@@ -54,7 +54,7 @@ Optimizer:
momentum: 0.9 momentum: 0.9
lr: lr:
name: MultiStepDecay name: MultiStepDecay
learning_rate: 0.01 learning_rate: 0.04
milestones: [30, 60, 70, 80, 90, 100] milestones: [30, 60, 70, 80, 90, 100]
gamma: 0.5 gamma: 0.5
verbose: False verbose: False
...@@ -90,10 +90,10 @@ DataLoader: ...@@ -90,10 +90,10 @@ DataLoader:
r1: 0.3 r1: 0.3
mean: [0., 0., 0.] mean: [0., 0., 0.]
sampler: sampler:
name: DistributedRandomIdentitySampler name: PKSampler
batch_size: 64 batch_size: 64
num_instances: 2 sample_per_id: 2
drop_last: False drop_last: True
shuffle: True shuffle: True
loader: loader:
num_workers: 4 num_workers: 4
......
...@@ -53,7 +53,7 @@ Optimizer: ...@@ -53,7 +53,7 @@ Optimizer:
momentum: 0.9 momentum: 0.9
lr: lr:
name: Cosine name: Cosine
learning_rate: 0.01 learning_rate: 0.04
regularizer: regularizer:
name: 'L2' name: 'L2'
coeff: 0.0005 coeff: 0.0005
...@@ -88,10 +88,10 @@ DataLoader: ...@@ -88,10 +88,10 @@ DataLoader:
mean: [0., 0., 0.] mean: [0., 0., 0.]
sampler: sampler:
name: DistributedRandomIdentitySampler name: PKSampler
batch_size: 128 batch_size: 128
num_instances: 2 sample_per_id: 2
drop_last: False drop_last: True
shuffle: True shuffle: True
loader: loader:
num_workers: 6 num_workers: 6
......
...@@ -26,9 +26,12 @@ from ppcls.data.dataloader.common_dataset import create_operators ...@@ -26,9 +26,12 @@ from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
# sampler # sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data import preprocess from ppcls.data import preprocess
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
......
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import os
from paddle.io import Dataset
from .. import dataloader
class MixDataset(Dataset):
def __init__(self, datasets_config):
super().__init__()
self.dataset_list = []
start_idx = 0
end_idx = 0
for config_i in datasets_config:
dataset_name = config_i.pop('name')
dataset = getattr(dataloader, dataset_name)(**config_i)
end_idx += len(dataset)
self.dataset_list.append([end_idx, start_idx, dataset])
start_idx = end_idx
self.length = end_idx
def __getitem__(self, idx):
for dataset_i in self.dataset_list:
if dataset_i[0] > idx:
dataset_i_idx = idx - dataset_i[1]
return dataset_i[2][dataset_i_idx]
def __len__(self):
return self.length
def get_dataset_list(self):
return self.dataset_list
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from paddle.io import DistributedBatchSampler, Sampler
from ppcls.utils import logger
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data import dataloader
class MixSampler(DistributedBatchSampler):
def __init__(self, dataset, batch_size, sample_configs, iter_per_epoch):
super().__init__(dataset, batch_size)
assert isinstance(dataset,
MixDataset), "MixSampler only support MixDataset"
self.sampler_list = []
self.batch_size = batch_size
self.start_list = []
self.length = iter_per_epoch
dataset_list = dataset.get_dataset_list()
batch_size_left = self.batch_size
self.iter_list = []
for i, config_i in enumerate(sample_configs):
self.start_list.append(dataset_list[i][1])
sample_method = config_i.pop("name")
ratio_i = config_i.pop("ratio")
if i < len(sample_configs) - 1:
batch_size_i = int(self.batch_size * ratio_i)
batch_size_left -= batch_size_i
else:
batch_size_i = batch_size_left
assert batch_size_i <= len(dataset_list[i][2])
config_i["batch_size"] = batch_size_i
if sample_method == "DistributedBatchSampler":
sampler_i = DistributedBatchSampler(dataset_list[i][2],
**config_i)
else:
sampler_i = getattr(dataloader, sample_method)(
dataset_list[i][2], **config_i)
self.sampler_list.append(sampler_i)
self.iter_list.append(iter(sampler_i))
self.length += len(dataset_list[i][2]) * ratio_i
self.iter_counter = 0
def __iter__(self):
while self.iter_counter < self.length:
batch = []
for i, iter_i in enumerate(self.iter_list):
batch_i = next(iter_i, None)
if batch_i is None:
iter_i = iter(self.sampler_list[i])
self.iter_list[i] = iter_i
batch_i = next(iter_i, None)
assert batch_i is not None, "dataset {} return None".format(
i)
batch += [idx + self.start_list[i] for idx in batch_i]
if len(batch) == self.batch_size:
self.iter_counter += 1
yield batch
else:
logger.info("Some dataset reaches end")
self.iter_counter = 0
def __len__(self):
return self.length
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from collections import defaultdict
import numpy as np
import random
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger
class PKSampler(DistributedBatchSampler):
"""
First, randomly sample P identities.
Then for each identity randomly sample K instances.
Therefore batch size is P*K, and the sampler called PKSampler.
Args:
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id).
sample_per_id(int): number of instances per identity in a batch.
batch_size (int): number of examples in a batch.
shuffle(bool): whether to shuffle indices order before generating
batch indices. Default False.
"""
def __init__(self,
dataset,
batch_size,
sample_per_id,
shuffle=True,
drop_last=True,
sample_method="sample_avg_prob"):
super().__init__(
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \
"PKSampler configs error, Sample_per_id must be a divisor of batch_size."
assert hasattr(self.dataset,
"labels"), "Dataset must have labels attribute."
self.sample_per_label = sample_per_id
self.label_dict = defaultdict(list)
self.sample_method = sample_method
for idx, label in enumerate(self.dataset.labels):
self.label_dict[label].append(idx)
self.label_list = list(self.label_dict)
assert len(self.label_list) * self.sample_per_label > self.batch_size, \
"batch size should be smaller than "
if self.sample_method == "id_avg_prob":
self.prob_list = np.array([1 / len(self.label_list)] *
len(self.label_list))
elif self.sample_method == "sample_avg_prob":
counter = []
for label_i in self.label_list:
counter.append(len(self.label_dict[label_i]))
self.prob_list = np.array(counter) / sum(counter)
else:
logger.error(
"PKSampler only support id_avg_prob and sample_avg_prob sample method, "
"but receive {}.".format(self.sample_method))
if sum(np.abs(self.prob_list - 1) > 0.00000001):
self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
logger.error("PKSampler prob list error")
else:
logger.info(
"PKSampler: sum of prob list not equal to 1, change the last prob"
)
def __iter__(self):
label_per_batch = self.batch_size // self.sample_per_label
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(self.label_list)
for i in range(len(self)):
batch_index = []
batch_label_list = np.random.choice(
self.label_list,
size=label_per_batch,
replace=False,
p=self.prob_list)
for label_i in batch_label_list:
label_i_indexes = self.label_dict[label_i]
if self.sample_per_label <= len(label_i_indexes):
batch_index.extend(
np.random.choice(
label_i_indexes,
size=self.sample_per_label,
replace=False))
else:
batch_index.extend(
np.random.choice(
label_i_indexes,
size=self.sample_per_label,
replace=True))
if not self.drop_last or len(batch_index) == self.batch_size:
yield batch_index
...@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter ...@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
def classification_eval(evaler, epoch_id=0): def classification_eval(engine, epoch_id=0):
output_info = dict() output_info = dict()
time_info = { time_info = {
"batch_cost": AverageMeter( "batch_cost": AverageMeter(
...@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0): ...@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0):
"reader_cost": AverageMeter( "reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"), "reader_cost", ".5f", postfix=" s,"),
} }
print_batch_step = evaler.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
metric_key = None metric_key = None
tic = time.time() tic = time.time()
eval_dataloader = evaler.eval_dataloader if evaler.use_dali else evaler.eval_dataloader( max_iter = len(engine.eval_dataloader) - 1 if platform.system(
) ) == "Windows" else len(engine.eval_dataloader)
max_iter = len(evaler.eval_dataloader) - 1 if platform.system( for iter_id, batch in enumerate(engine.eval_dataloader):
) == "Windows" else len(evaler.eval_dataloader)
for iter_id, batch in enumerate(eval_dataloader):
if iter_id >= max_iter: if iter_id >= max_iter:
break break
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
if evaler.use_dali: if engine.use_dali:
batch = [ batch = [
paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
...@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0): ...@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0):
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
out = evaler.model(batch[0]) out = engine.model(batch[0])
# calc loss # calc loss
if evaler.eval_loss_func is not None: if engine.eval_loss_func is not None:
loss_dict = evaler.eval_loss_func(out, batch[1]) loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict: for key in loss_dict:
if key not in output_info: if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size) output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# calc metric # calc metric
if evaler.eval_metric_func is not None: if engine.eval_metric_func is not None:
metric_dict = evaler.eval_metric_func(out, batch[1]) metric_dict = engine.eval_metric_func(out, batch[1])
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
for key in metric_dict: for key in metric_dict:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
...@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0): ...@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0):
]) ])
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, epoch_id, iter_id,
len(evaler.eval_dataloader), metric_msg, time_msg, ips_msg)) len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
tic = time.time() tic = time.time()
if evaler.use_dali: if engine.use_dali:
evaler.eval_dataloader.reset() engine.eval_dataloader.reset()
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) for key in output_info "{}: {:.5f}".format(key, output_info[key].avg) for key in output_info
]) ])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model # do not try to save best eval.model
if evaler.eval_metric_func is None: if engine.eval_metric_func is None:
return -1 return -1
# return 1st metric in the dict # return 1st metric in the dict
return output_info[metric_key].avg return output_info[metric_key].avg
...@@ -20,21 +20,21 @@ import paddle ...@@ -20,21 +20,21 @@ import paddle
from ppcls.utils import logger from ppcls.utils import logger
def retrieval_eval(evaler, epoch_id=0): def retrieval_eval(engine, epoch_id=0):
evaler.model.eval() engine.model.eval()
# step1. build gallery # step1. build gallery
if evaler.gallery_query_dataloader is not None: if engine.gallery_query_dataloader is not None:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery_query') engine, name='gallery_query')
query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id
else: else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery') engine, name='gallery')
query_feas, query_img_id, query_query_id = cal_feature( query_feas, query_img_id, query_query_id = cal_feature(
evaler, name='query') engine, name='query')
# step2. do evaluation # step2. do evaluation
sim_block_size = evaler.config["Global"].get("sim_block_size", 64) sim_block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size) sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if len(query_feas) % sim_block_size: if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size) sections.append(len(query_feas) % sim_block_size)
...@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0): ...@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0):
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections) image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
metric_key = None metric_key = None
if evaler.eval_loss_func is None: if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.} metric_dict = {metric_key: 0.}
else: else:
metric_dict = dict() metric_dict = dict()
...@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0): ...@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0):
else: else:
keep_mask = None keep_mask = None
metric_tmp = evaler.eval_metric_func(similarity_matrix, metric_tmp = engine.eval_metric_func(similarity_matrix,
image_id_blocks[block_idx], image_id_blocks[block_idx],
gallery_img_id, keep_mask) gallery_img_id, keep_mask)
...@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0): ...@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0):
return metric_dict[metric_key] return metric_dict[metric_key]
def cal_feature(evaler, name='gallery'): def cal_feature(engine, name='gallery'):
all_feas = None all_feas = None
all_image_id = None all_image_id = None
all_unique_id = None all_unique_id = None
has_unique_id = False has_unique_id = False
if name == 'gallery': if name == 'gallery':
dataloader = evaler.gallery_dataloader dataloader = engine.gallery_dataloader
elif name == 'query': elif name == 'query':
dataloader = evaler.query_dataloader dataloader = engine.query_dataloader
elif name == 'gallery_query': elif name == 'gallery_query':
dataloader = evaler.gallery_query_dataloader dataloader = engine.gallery_query_dataloader
else: else:
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
dataloader) dataloader)
dataloader_tmp = dataloader if evaler.use_dali else dataloader() for idx, batch in enumerate(dataloader): # load is very time-consuming
for idx, batch in enumerate(dataloader_tmp): # load is very time-consuming
if idx >= max_iter: if idx >= max_iter:
break break
if idx % evaler.config["Global"]["print_batch_step"] == 0: if idx % engine.config["Global"]["print_batch_step"] == 0:
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
) )
if evaler.use_dali: if engine.use_dali:
batch = [ batch = [
paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
...@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'): ...@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'):
if len(batch) == 3: if len(batch) == 3:
has_unique_id = True has_unique_id = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = evaler.model(batch[0], batch[1]) out = engine.model(batch[0], batch[1])
batch_feas = out["features"] batch_feas = out["features"]
# do norm # do norm
if evaler.config["Global"].get("feature_normalize", True): if engine.config["Global"].get("feature_normalize", True):
feas_norm = paddle.sqrt( feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm) batch_feas = paddle.divide(batch_feas, feas_norm)
# do binarize # do binarize
if evaler.config["Global"].get("feature_binarize") == "round": if engine.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
if evaler.config["Global"].get("feature_binarize") == "sign": if engine.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32") batch_feas = paddle.sign(batch_feas).astype("float32")
if all_feas is None: if all_feas is None:
...@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'): ...@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'):
if has_unique_id: if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]]) all_unique_id = paddle.concat([all_unique_id, batch[2]])
if evaler.use_dali: if engine.use_dali:
dataloader_tmp.reset() dataloader.reset()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
feat_list = [] feat_list = []
......
...@@ -18,19 +18,16 @@ import paddle ...@@ -18,19 +18,16 @@ import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info from ppcls.engine.train.utils import update_loss, update_metric, log_info
def train_epoch(trainer, epoch_id, print_batch_step): def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
for iter_id, batch in enumerate(engine.train_dataloader):
train_dataloader = trainer.train_dataloader if trainer.use_dali else trainer.train_dataloader( if iter_id >= engine.max_iter:
)
for iter_id, batch in enumerate(train_dataloader):
if iter_id >= trainer.max_iter:
break break
if iter_id == 5: if iter_id == 5:
for key in trainer.time_info: for key in engine.time_info:
trainer.time_info[key].reset() engine.time_info[key].reset()
trainer.time_info["reader_cost"].update(time.time() - tic) engine.time_info["reader_cost"].update(time.time() - tic)
if trainer.use_dali: if engine.use_dali:
batch = [ batch = [
paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
...@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step): ...@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step):
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
trainer.global_step += 1 engine.global_step += 1
# image input # image input
if trainer.amp: if engine.amp:
with paddle.amp.auto_cast(custom_black_list={ with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}): }):
out = forward(trainer, batch) out = forward(engine, batch)
loss_dict = trainer.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
else: else:
out = forward(trainer, batch) out = forward(engine, batch)
# calc loss # calc loss
if trainer.config["DataLoader"]["Train"]["dataset"].get( if engine.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None): "batch_transform_ops", None):
loss_dict = trainer.train_loss_func(out, batch[1:]) loss_dict = engine.train_loss_func(out, batch[1:])
else: else:
loss_dict = trainer.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr # step opt and lr
if trainer.amp: if engine.amp:
scaled = trainer.scaler.scale(loss_dict["loss"]) scaled = engine.scaler.scale(loss_dict["loss"])
scaled.backward() scaled.backward()
trainer.scaler.minimize(trainer.optimizer, scaled) engine.scaler.minimize(engine.optimizer, scaled)
else: else:
loss_dict["loss"].backward() loss_dict["loss"].backward()
trainer.optimizer.step() engine.optimizer.step()
trainer.optimizer.clear_grad() engine.optimizer.clear_grad()
trainer.lr_sch.step() engine.lr_sch.step()
# below code just for logging # below code just for logging
# update metric_for_logger # update metric_for_logger
update_metric(trainer, out, batch, batch_size) update_metric(engine, out, batch, batch_size)
# update_loss_for_logger # update_loss_for_logger
update_loss(trainer, loss_dict, batch_size) update_loss(engine, loss_dict, batch_size)
trainer.time_info["batch_cost"].update(time.time() - tic) engine.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0: if iter_id % print_batch_step == 0:
log_info(trainer, batch_size, epoch_id, iter_id) log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time() tic = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册