diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py index d2c7fc028d28c79057708d4e6f306c417ba6306a..8d59e4711a043afd9234f430a62c9876c0a8f6f4 100644 --- a/ppocr/modeling/heads/rec_srn_head.py +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -250,7 +250,8 @@ class SRNHead(nn.Layer): self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 - def forward(self, inputs, others): + def forward(self, inputs, targets=None): + others = targets[-4:] encoder_word_pos = others[0] gsrm_word_pos = others[1] gsrm_slf_attn_bias1 = others[2] diff --git a/tools/program.py b/tools/program.py index 0cc60c3877fa4fd3b9eb2ec12f3a0c5d378db15b..bd17db4afec459468b6428611308cd8c41920ca5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -209,14 +209,8 @@ def train(config, lr = optimizer.get_lr() images = batch[0] if use_srn: - others = batch[-4:] - preds = model(images, others) model_average = True - elif model_type == "table": - others = batch[1:] - preds = model(images, others) - else: - preds = model(images) + preds = model(images, data=batch[1:]) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -358,13 +352,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - - if use_srn: - others = batch[-4:] - preds = model(images, others) - else: - preds = model(images) - + preds = model(images, data=batch[1:]) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start