未验证 提交 57f01253 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #4965 from LDOUBLEV/sdmgr

fix train
...@@ -227,10 +227,6 @@ def train(config, ...@@ -227,10 +227,6 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
train_start = time.time() train_start = time.time()
# use amp # use amp
...@@ -243,6 +239,8 @@ def train(config, ...@@ -243,6 +239,8 @@ def train(config,
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type == "kie":
preds = model(batch)
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册