From 2e80aab477be38425620a2116a54efc741d7335c Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 18 May 2022 19:03:18 +0800 Subject: [PATCH] add support for svtr static training (#6328) --- ppocr/modeling/architectures/__init__.py | 26 ++++++++++++++++++++---- ppocr/modeling/heads/rec_sar_head.py | 20 +++++++++--------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 3f47f64a..1c955ef3 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -40,11 +40,29 @@ def apply_to_static(model, config, logger): return model assert "image_shape" in config[ "Global"], "image_shape must be assigned for static training mode..." - supported_list = ["DB"] - assert config["Architecture"][ - "algorithm"] in supported_list, f"algorithms that supports static training must in in {supported_list} but got {config['Architecture']['algorithm']}" + supported_list = ["DB", "SVTR"] + if config["Architecture"]["algorithm"] in ["Distillation"]: + algo = list(config["Architecture"]["Models"].values())[0]["algorithm"] + else: + algo = config["Architecture"]["algorithm"] + assert algo in supported_list, f"algorithms that supports static training must in in {supported_list} but got {algo}" + + specs = [ + InputSpec( + [None] + config["Global"]["image_shape"], dtype='float32') + ] + + if algo == "SVTR": + specs.append([ + InputSpec( + [None, config["Global"]["max_text_length"]], + dtype='int64'), InputSpec( + [None, config["Global"]["max_text_length"]], dtype='int64'), + InputSpec( + [None], dtype='int64'), InputSpec( + [None], dtype='float64') + ]) - specs = [InputSpec([None] + config["Global"]["image_shape"])] model = to_static(model, input_spec=specs) logger.info("Successfully to apply @to_static with specs: {}".format(specs)) return model diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py index 0e6b3440..5e64cae8 100644 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -83,7 +83,7 @@ class SAREncoder(nn.Layer): def forward(self, feat, img_metas=None): if img_metas is not None: - assert len(img_metas[0]) == feat.shape[0] + assert len(img_metas[0]) == paddle.shape(feat)[0] valid_ratios = None if img_metas is not None and self.mask: @@ -98,9 +98,10 @@ class SAREncoder(nn.Layer): if valid_ratios is not None: valid_hf = [] - T = holistic_feat.shape[1] - for i in range(len(valid_ratios)): - valid_step = min(T, math.ceil(T * valid_ratios[i])) - 1 + T = paddle.shape(holistic_feat)[1] + for i in range(paddle.shape(valid_ratios)[0]): + valid_step = paddle.minimum( + T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1 valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf = paddle.stack(valid_hf, axis=0) else: @@ -247,13 +248,14 @@ class ParallelSARDecoder(BaseDecoder): # bsz * (seq_len + 1) * h * w * attn_size attn_weight = self.conv1x1_2(attn_weight) # bsz * (seq_len + 1) * h * w * 1 - bsz, T, h, w, c = attn_weight.shape + bsz, T, h, w, c = paddle.shape(attn_weight) assert c == 1 if valid_ratios is not None: # cal mask of attention weight - for i in range(len(valid_ratios)): - valid_width = min(w, math.ceil(w * valid_ratios[i])) + for i in range(paddle.shape(valid_ratios)[0]): + valid_width = paddle.minimum( + w, paddle.ceil(valid_ratios[i] * w).astype("int32")) if valid_width < w: attn_weight[i, :, :, valid_width:, :] = float('-inf') @@ -288,7 +290,7 @@ class ParallelSARDecoder(BaseDecoder): img_metas: [label, valid_ratio] ''' if img_metas is not None: - assert len(img_metas[0]) == feat.shape[0] + assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0] valid_ratios = None if img_metas is not None and self.mask: @@ -302,7 +304,6 @@ class ParallelSARDecoder(BaseDecoder): # bsz * (seq_len + 1) * C out_dec = self._2d_attention( in_dec, feat, out_enc, valid_ratios=valid_ratios) - # bsz * (seq_len + 1) * num_classes return out_dec[:, 1:, :] # bsz * seq_len * num_classes @@ -395,7 +396,6 @@ class SARHead(nn.Layer): if self.training: label = targets[0] # label - label = paddle.to_tensor(label, dtype='int64') final_out = self.decoder( feat, holistic_feat, label, img_metas=targets) else: -- GitLab