提交 229acda8 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix ips info and reduce interval of metric calc

上级 cffb0316
...@@ -31,7 +31,8 @@ class CTCLoss(nn.Layer): ...@@ -31,7 +31,8 @@ class CTCLoss(nn.Layer):
predicts = predicts[-1] predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2)) predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64') preds_lengths = paddle.to_tensor(
[N] * B, dtype='int64', place=paddle.CPUPlace())
labels = batch[1].astype("int32") labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64') label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
......
...@@ -146,6 +146,7 @@ def train(config, ...@@ -146,6 +146,7 @@ def train(config,
scaler=None): scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
...@@ -244,6 +245,16 @@ def train(config, ...@@ -244,6 +245,16 @@ def train(config,
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
train_batch_time = time.time() - reader_start train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time train_batch_cost += train_batch_time
eta_meter.update(train_batch_time) eta_meter.update(train_batch_time)
...@@ -258,16 +269,6 @@ def train(config, ...@@ -258,16 +269,6 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
if vdl_writer is not None and dist.get_rank() == 0: if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items(): for k, v in train_stats.get().items():
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册