From 053cc43d82aec3343745b8871672bce9504feaef Mon Sep 17 00:00:00 2001 From: MissPenguin Date: Tue, 22 Jun 2021 04:23:27 +0000 Subject: [PATCH] refine --- ppocr/modeling/heads/rec_srn_head.py | 3 ++- tools/program.py | 16 ++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py index d2c7fc02..8d59e471 100644 --- a/ppocr/modeling/heads/rec_srn_head.py +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -250,7 +250,8 @@ class SRNHead(nn.Layer): self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 - def forward(self, inputs, others): + def forward(self, inputs, targets=None): + others = targets[-4:] encoder_word_pos = others[0] gsrm_word_pos = others[1] gsrm_slf_attn_bias1 = others[2] diff --git a/tools/program.py b/tools/program.py index 0cc60c38..bd17db4a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -209,14 +209,8 @@ def train(config, lr = optimizer.get_lr() images = batch[0] if use_srn: - others = batch[-4:] - preds = model(images, others) model_average = True - elif model_type == "table": - others = batch[1:] - preds = model(images, others) - else: - preds = model(images) + preds = model(images, data=batch[1:]) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -358,13 +352,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - - if use_srn: - others = batch[-4:] - preds = model(images, others) - else: - preds = model(images) - + preds = model(images, data=batch[1:]) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start -- GitLab