From 11acfe5124129ba88d529849ef79701beaf03b08 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 2 Jun 2022 09:53:01 +0000 Subject: [PATCH] replace fluid to paddle --- ppocr/losses/rec_aster_loss.py | 4 ++-- ppocr/losses/table_att_loss.py | 25 +++++++++++----------- ppocr/modeling/backbones/kie_unet_sdmgr.py | 2 +- ppocr/modeling/backbones/rec_resnet_fpn.py | 1 - ppocr/modeling/heads/rec_srn_head.py | 2 -- ppocr/modeling/heads/self_attention.py | 13 ++++++----- 6 files changed, 21 insertions(+), 26 deletions(-) diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py index fbb99d29..52605e46 100644 --- a/ppocr/losses/rec_aster_loss.py +++ b/ppocr/losses/rec_aster_loss.py @@ -27,12 +27,12 @@ class CosineEmbeddingLoss(nn.Layer): self.epsilon = 1e-12 def forward(self, x1, x2, target): - similarity = paddle.fluid.layers.reduce_sum( + similarity = paddle.sum( x1 * x2, dim=-1) / (paddle.norm( x1, axis=-1) * paddle.norm( x2, axis=-1) + self.epsilon) one_list = paddle.full_like(target, fill_value=1) - out = paddle.fluid.layers.reduce_mean( + out = paddle.mean( paddle.where( paddle.equal(target, one_list), 1. - similarity, paddle.maximum( diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index d7fd99e6..51377efa 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -19,7 +19,6 @@ from __future__ import print_function import paddle from paddle import nn from paddle.nn import functional as F -from paddle import fluid class TableAttentionLoss(nn.Layer): def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): @@ -36,13 +35,13 @@ class TableAttentionLoss(nn.Layer): :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] :return: loss ''' - ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0]) - iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1]) - ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2]) - iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3]) + ix1 = paddle.maximum(preds[:, 0], bbox[:, 0]) + iy1 = paddle.maximum(preds[:, 1], bbox[:, 1]) + ix2 = paddle.minimum(preds[:, 2], bbox[:, 2]) + iy2 = paddle.minimum(preds[:, 3], bbox[:, 3]) - iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10) - ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10) + iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10) + ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10) # overlap inters = iw * ih @@ -55,12 +54,12 @@ class TableAttentionLoss(nn.Layer): # ious ious = inters / uni - ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0]) - ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1]) - ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2]) - ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3]) - ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10) - eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10) + ex1 = paddle.minimum(preds[:, 0], bbox[:, 0]) + ey1 = paddle.minimum(preds[:, 1], bbox[:, 1]) + ex2 = paddle.maximum(preds[:, 2], bbox[:, 2]) + ey2 = paddle.maximum(preds[:, 3], bbox[:, 3]) + ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10) + eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10) # enclose erea enclose = ew * eh + eps diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py index 545e4e75..793c68c6 100644 --- a/ppocr/modeling/backbones/kie_unet_sdmgr.py +++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -175,7 +175,7 @@ class Kie_backbone(nn.Layer): img, relations, texts, gt_bboxes, tag, img_size) x = self.img_feat(img) boxes, rois_num = self.bbox2roi(gt_bboxes) - feats = paddle.fluid.layers.roi_align( + feats = paddle.vision.ops.roi_align( x, boxes, spatial_scale=1.0, diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py index a7e876a2..79efd6e4 100644 --- a/ppocr/modeling/backbones/rec_resnet_fpn.py +++ b/ppocr/modeling/backbones/rec_resnet_fpn.py @@ -18,7 +18,6 @@ from __future__ import print_function from paddle import nn, ParamAttr from paddle.nn import functional as F -import paddle.fluid as fluid import paddle import numpy as np diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py index 8d59e471..1070d8cd 100644 --- a/ppocr/modeling/heads/rec_srn_head.py +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -20,13 +20,11 @@ import math import paddle from paddle import nn, ParamAttr from paddle.nn import functional as F -import paddle.fluid as fluid import numpy as np from .self_attention import WrapEncoderForFeature from .self_attention import WrapEncoder from paddle.static import Program from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN -import paddle.fluid.framework as framework from collections import OrderedDict gradient_clip = 10 diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py index 6c27fdbe..6e4c65e3 100644 --- a/ppocr/modeling/heads/self_attention.py +++ b/ppocr/modeling/heads/self_attention.py @@ -22,7 +22,6 @@ import paddle from paddle import ParamAttr, nn from paddle import nn, ParamAttr from paddle.nn import functional as F -import paddle.fluid as fluid import numpy as np gradient_clip = 10 @@ -288,10 +287,10 @@ class PrePostProcessLayer(nn.Layer): "layer_norm_%d" % len(self.sublayers()), paddle.nn.LayerNorm( normalized_shape=d_model, - weight_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(1.)), - bias_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(0.))))) + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(1.)), + bias_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.))))) elif cmd == "d": # add dropout self.functors.append(lambda x: F.dropout( x, p=dropout_rate, mode="downscale_in_infer") @@ -324,7 +323,7 @@ class PrepareEncoder(nn.Layer): def forward(self, src_word, src_pos): src_word_emb = src_word - src_word_emb = fluid.layers.cast(src_word_emb, 'float32') + src_word_emb = paddle.cast(src_word_emb, 'float32') src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) src_pos = paddle.squeeze(src_pos, axis=-1) src_pos_enc = self.emb(src_pos) @@ -367,7 +366,7 @@ class PrepareDecoder(nn.Layer): self.dropout_rate = dropout_rate def forward(self, src_word, src_pos): - src_word = fluid.layers.cast(src_word, 'int64') + src_word = paddle.cast(src_word, 'int64') src_word = paddle.squeeze(src_word, axis=-1) src_word_emb = self.emb0(src_word) src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) -- GitLab