提交 7bcabe0f 编写于 作者: M MissPenguin

refine

上级 6e1cfb05
...@@ -210,7 +210,10 @@ def train(config, ...@@ -210,7 +210,10 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
preds = model(images, data=batch[1:]) if use_srn or model_type == 'table':
preds = model(images, data=batch[1:])
else:
preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward() avg_loss.backward()
...@@ -356,7 +359,10 @@ def eval(model, ...@@ -356,7 +359,10 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
preds = model(images, data=batch[1:]) if use_srn or model_type == 'table':
preds = model(images, data=batch[1:])
else:
preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册