提交 2f8a95fe 编写于 作者: B breezedeus

refactor for onnx exportation

上级 f9f74d81
......@@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.
# Credits: adapted from https://github.com/mindee/doctr
from collections import OrderedDict
from typing import Tuple, Dict, Any, Optional, List, Union
from copy import deepcopy
......@@ -177,7 +177,7 @@ class OcrModel(nn.Module):
input_lengths: torch.Tensor,
target: Optional[List[str]] = None,
candidates: Optional[Union[str, List[str]]] = None,
return_logits: bool = False,
return_logits: bool = True,
return_preds: bool = False,
) -> Dict[str, Any]:
"""
......@@ -191,7 +191,9 @@ class OcrModel(nn.Module):
:return: 预测结果
"""
features = self.encoder(x)
input_lengths = input_lengths // self.encoder.compress_ratio
input_lengths = torch.div(
input_lengths, self.encoder.compress_ratio, rounding_mode='floor'
)
# B x C x H x W --> B x C*H x W --> B x W x C*H
c, h, w = features.shape[1], features.shape[2], features.shape[3]
features_seq = torch.reshape(features, shape=(-1, h * c, w))
......@@ -200,20 +202,24 @@ class OcrModel(nn.Module):
logits = self._decode(features_seq, input_lengths)
logits = self.linear(logits)
logits = self._mask_by_candidates(logits, candidates)
logits = self.mask_by_candidates(
logits, candidates, self.vocab, self.letter2id
)
out: Dict[str, Any] = {}
out: OrderedDict[str, Any] = {}
if return_logits:
out["logits"] = logits
out['output_lengths'] = input_lengths
if target is None or return_preds:
# Post-process boxes
out["preds"] = self.postprocessor(logits, input_lengths)
if self.postprocessor is not None:
out["preds"] = self.postprocessor(logits, input_lengths)
if target is not None:
out['loss'] = self._compute_loss(logits, target, input_lengths)
return out
return dict(out)
def _decode(self, features_seq, input_lengths):
if not isinstance(self.decoder, (nn.LSTM, nn.GRU)):
......@@ -232,18 +238,23 @@ class OcrModel(nn.Module):
)
return logits
def _mask_by_candidates(
self, logits: torch.Tensor, candidates: Optional[Union[str, List[str]]]
@classmethod
def mask_by_candidates(
cls,
logits: torch.Tensor,
candidates: Optional[Union[str, List[str]]],
vocab: List[str],
letter2id: Dict[str, int],
):
if candidates is None:
return logits
_candidates = [self.letter2id[word] for word in candidates]
_candidates = [letter2id[word] for word in candidates]
_candidates.sort()
_candidates = torch.tensor(_candidates, dtype=torch.int64)
candidates = torch.zeros(
(len(self.vocab) + 1,), dtype=torch.bool, device=logits.device
(len(vocab) + 1,), dtype=torch.bool, device=logits.device
)
candidates[_candidates] = True
candidates[-1] = True # 间隔符号/填充符号,必须为真
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册