提交 053cc43d 编写于 作者: M MissPenguin

refine

上级 c0492e02
...@@ -250,7 +250,8 @@ class SRNHead(nn.Layer): ...@@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
def forward(self, inputs, others): def forward(self, inputs, targets=None):
others = targets[-4:]
encoder_word_pos = others[0] encoder_word_pos = others[0]
gsrm_word_pos = others[1] gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2] gsrm_slf_attn_bias1 = others[2]
......
...@@ -209,14 +209,8 @@ def train(config, ...@@ -209,14 +209,8 @@ def train(config,
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
if use_srn: if use_srn:
others = batch[-4:]
preds = model(images, others)
model_average = True model_average = True
elif model_type == "table": preds = model(images, data=batch[1:])
others = batch[1:]
preds = model(images, others)
else:
preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward() avg_loss.backward()
...@@ -358,13 +352,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, ...@@ -358,13 +352,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
preds = model(images, data=batch[1:])
if use_srn:
others = batch[-4:]
preds = model(images, others)
else:
preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册