提交 adeb8a17 编写于 作者: A andyjpaddle

fix amp vqa

上级 91600fcc
...@@ -255,6 +255,8 @@ def train(config, ...@@ -255,6 +255,8 @@ def train(config,
with paddle.amp.auto_cast(): with paddle.amp.auto_cast():
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else: else:
preds = model(images) preds = model(images)
else: else:
...@@ -307,7 +309,8 @@ def train(config, ...@@ -307,7 +309,8 @@ def train(config,
train_stats.update(stats) train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0: if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) log_writer.log_metrics(
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and ( if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
...@@ -354,7 +357,8 @@ def train(config, ...@@ -354,7 +357,8 @@ def train(config,
# logger metric # logger metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) log_writer.log_metrics(
metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
...@@ -377,11 +381,18 @@ def train(config, ...@@ -377,11 +381,18 @@ def train(config,
logger.info(best_str) logger.info(best_str)
# logger best metric # logger best metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics={ log_writer.log_metrics(
"best_{}".format(main_indicator): best_model_dict[main_indicator] metrics={
}, prefix="EVAL", step=global_step) "best_{}".format(main_indicator):
best_model_dict[main_indicator]
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) },
prefix="EVAL",
step=global_step)
log_writer.log_model(
is_best=True,
prefix="best_accuracy",
metadata=best_model_dict)
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
...@@ -413,7 +424,8 @@ def train(config, ...@@ -413,7 +424,8 @@ def train(config,
epoch=epoch, epoch=epoch,
global_step=global_step) global_step=global_step)
if log_writer is not None: if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) log_writer.log_model(
is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
...@@ -585,7 +597,8 @@ def preprocess(is_train=False): ...@@ -585,7 +597,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir) log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer) loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir'] save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir) wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config: if "wandb" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册