提交 af25e256 编写于 作者: W weishengyu

modify format

上级 9e975699
...@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter ...@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
def classification_eval(evaler, epoch_id=0): def classification_eval(engine, epoch_id=0):
output_info = dict() output_info = dict()
time_info = { time_info = {
"batch_cost": AverageMeter( "batch_cost": AverageMeter(
...@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0): ...@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0):
"reader_cost": AverageMeter( "reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"), "reader_cost", ".5f", postfix=" s,"),
} }
print_batch_step = evaler.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
metric_key = None metric_key = None
tic = time.time() tic = time.time()
eval_dataloader = evaler.eval_dataloader if evaler.use_dali else evaler.eval_dataloader( max_iter = len(engine.eval_dataloader) - 1 if platform.system(
) ) == "Windows" else len(engine.eval_dataloader)
max_iter = len(evaler.eval_dataloader) - 1 if platform.system( for iter_id, batch in enumerate(engine.eval_dataloader):
) == "Windows" else len(evaler.eval_dataloader)
for iter_id, batch in enumerate(eval_dataloader):
if iter_id >= max_iter: if iter_id >= max_iter:
break break
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
if evaler.use_dali: if engine.use_dali:
batch = [ batch = [
paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
...@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0): ...@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0):
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
out = evaler.model(batch[0]) out = engine.model(batch[0])
# calc loss # calc loss
if evaler.eval_loss_func is not None: if engine.eval_loss_func is not None:
loss_dict = evaler.eval_loss_func(out, batch[1]) loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict: for key in loss_dict:
if key not in output_info: if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size) output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# calc metric # calc metric
if evaler.eval_metric_func is not None: if engine.eval_metric_func is not None:
metric_dict = evaler.eval_metric_func(out, batch[1]) metric_dict = engine.eval_metric_func(out, batch[1])
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
for key in metric_dict: for key in metric_dict:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
...@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0): ...@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0):
]) ])
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, epoch_id, iter_id,
len(evaler.eval_dataloader), metric_msg, time_msg, ips_msg)) len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
tic = time.time() tic = time.time()
if evaler.use_dali: if engine.use_dali:
evaler.eval_dataloader.reset() engine.eval_dataloader.reset()
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) for key in output_info "{}: {:.5f}".format(key, output_info[key].avg) for key in output_info
]) ])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model # do not try to save best eval.model
if evaler.eval_metric_func is None: if engine.eval_metric_func is None:
return -1 return -1
# return 1st metric in the dict # return 1st metric in the dict
return output_info[metric_key].avg return output_info[metric_key].avg
...@@ -20,21 +20,21 @@ import paddle ...@@ -20,21 +20,21 @@ import paddle
from ppcls.utils import logger from ppcls.utils import logger
def retrieval_eval(evaler, epoch_id=0): def retrieval_eval(engine, epoch_id=0):
evaler.model.eval() engine.model.eval()
# step1. build gallery # step1. build gallery
if evaler.gallery_query_dataloader is not None: if engine.gallery_query_dataloader is not None:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery_query') engine, name='gallery_query')
query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id
else: else:
gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = cal_feature(
evaler, name='gallery') engine, name='gallery')
query_feas, query_img_id, query_query_id = cal_feature( query_feas, query_img_id, query_query_id = cal_feature(
evaler, name='query') engine, name='query')
# step2. do evaluation # step2. do evaluation
sim_block_size = evaler.config["Global"].get("sim_block_size", 64) sim_block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size) sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if len(query_feas) % sim_block_size: if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size) sections.append(len(query_feas) % sim_block_size)
...@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0): ...@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0):
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections) image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
metric_key = None metric_key = None
if evaler.eval_loss_func is None: if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.} metric_dict = {metric_key: 0.}
else: else:
metric_dict = dict() metric_dict = dict()
...@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0): ...@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0):
else: else:
keep_mask = None keep_mask = None
metric_tmp = evaler.eval_metric_func(similarity_matrix, metric_tmp = engine.eval_metric_func(similarity_matrix,
image_id_blocks[block_idx], image_id_blocks[block_idx],
gallery_img_id, keep_mask) gallery_img_id, keep_mask)
...@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0): ...@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0):
return metric_dict[metric_key] return metric_dict[metric_key]
def cal_feature(evaler, name='gallery'): def cal_feature(engine, name='gallery'):
all_feas = None all_feas = None
all_image_id = None all_image_id = None
all_unique_id = None all_unique_id = None
has_unique_id = False has_unique_id = False
if name == 'gallery': if name == 'gallery':
dataloader = evaler.gallery_dataloader dataloader = engine.gallery_dataloader
elif name == 'query': elif name == 'query':
dataloader = evaler.query_dataloader dataloader = engine.query_dataloader
elif name == 'gallery_query': elif name == 'gallery_query':
dataloader = evaler.gallery_query_dataloader dataloader = engine.gallery_query_dataloader
else: else:
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
dataloader) dataloader)
dataloader_tmp = dataloader if evaler.use_dali else dataloader() for idx, batch in enumerate(dataloader): # load is very time-consuming
for idx, batch in enumerate(dataloader_tmp): # load is very time-consuming
if idx >= max_iter: if idx >= max_iter:
break break
if idx % evaler.config["Global"]["print_batch_step"] == 0: if idx % engine.config["Global"]["print_batch_step"] == 0:
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
) )
if evaler.use_dali: if engine.use_dali:
batch = [ batch = [
paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
...@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'): ...@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'):
if len(batch) == 3: if len(batch) == 3:
has_unique_id = True has_unique_id = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = evaler.model(batch[0], batch[1]) out = engine.model(batch[0], batch[1])
batch_feas = out["features"] batch_feas = out["features"]
# do norm # do norm
if evaler.config["Global"].get("feature_normalize", True): if engine.config["Global"].get("feature_normalize", True):
feas_norm = paddle.sqrt( feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm) batch_feas = paddle.divide(batch_feas, feas_norm)
# do binarize # do binarize
if evaler.config["Global"].get("feature_binarize") == "round": if engine.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
if evaler.config["Global"].get("feature_binarize") == "sign": if engine.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32") batch_feas = paddle.sign(batch_feas).astype("float32")
if all_feas is None: if all_feas is None:
...@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'): ...@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'):
if has_unique_id: if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]]) all_unique_id = paddle.concat([all_unique_id, batch[2]])
if evaler.use_dali: if engine.use_dali:
dataloader_tmp.reset() dataloader.reset()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
feat_list = [] feat_list = []
......
...@@ -18,19 +18,16 @@ import paddle ...@@ -18,19 +18,16 @@ import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info from ppcls.engine.train.utils import update_loss, update_metric, log_info
def train_epoch(trainer, epoch_id, print_batch_step): def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
for iter_id, batch in enumerate(engine.train_dataloader):
train_dataloader = trainer.train_dataloader if trainer.use_dali else trainer.train_dataloader( if iter_id >= engine.max_iter:
)
for iter_id, batch in enumerate(train_dataloader):
if iter_id >= trainer.max_iter:
break break
if iter_id == 5: if iter_id == 5:
for key in trainer.time_info: for key in engine.time_info:
trainer.time_info[key].reset() engine.time_info[key].reset()
trainer.time_info["reader_cost"].update(time.time() - tic) engine.time_info["reader_cost"].update(time.time() - tic)
if trainer.use_dali: if engine.use_dali:
batch = [ batch = [
paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
...@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step): ...@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step):
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
trainer.global_step += 1 engine.global_step += 1
# image input # image input
if trainer.amp: if engine.amp:
with paddle.amp.auto_cast(custom_black_list={ with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}): }):
out = forward(trainer, batch) out = forward(engine, batch)
loss_dict = trainer.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
else: else:
out = forward(trainer, batch) out = forward(engine, batch)
# calc loss # calc loss
if trainer.config["DataLoader"]["Train"]["dataset"].get( if engine.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None): "batch_transform_ops", None):
loss_dict = trainer.train_loss_func(out, batch[1:]) loss_dict = engine.train_loss_func(out, batch[1:])
else: else:
loss_dict = trainer.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr # step opt and lr
if trainer.amp: if engine.amp:
scaled = trainer.scaler.scale(loss_dict["loss"]) scaled = engine.scaler.scale(loss_dict["loss"])
scaled.backward() scaled.backward()
trainer.scaler.minimize(trainer.optimizer, scaled) engine.scaler.minimize(engine.optimizer, scaled)
else: else:
loss_dict["loss"].backward() loss_dict["loss"].backward()
trainer.optimizer.step() engine.optimizer.step()
trainer.optimizer.clear_grad() engine.optimizer.clear_grad()
trainer.lr_sch.step() engine.lr_sch.step()
# below code just for logging # below code just for logging
# update metric_for_logger # update metric_for_logger
update_metric(trainer, out, batch, batch_size) update_metric(engine, out, batch, batch_size)
# update_loss_for_logger # update_loss_for_logger
update_loss(trainer, loss_dict, batch_size) update_loss(engine, loss_dict, batch_size)
trainer.time_info["batch_cost"].update(time.time() - tic) engine.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0: if iter_id % print_batch_step == 0:
log_info(trainer, batch_size, epoch_id, iter_id) log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time() tic = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册