未验证 提交 36f9a0d2 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #6478 from tink2123/rm_fluid

replace fluid to paddle
......@@ -27,12 +27,12 @@ class CosineEmbeddingLoss(nn.Layer):
self.epsilon = 1e-12
def forward(self, x1, x2, target):
similarity = paddle.fluid.layers.reduce_sum(
similarity = paddle.sum(
x1 * x2, dim=-1) / (paddle.norm(
x1, axis=-1) * paddle.norm(
x2, axis=-1) + self.epsilon)
one_list = paddle.full_like(target, fill_value=1)
out = paddle.fluid.layers.reduce_mean(
out = paddle.mean(
paddle.where(
paddle.equal(target, one_list), 1. - similarity,
paddle.maximum(
......
......@@ -19,7 +19,6 @@ from __future__ import print_function
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import fluid
class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
......@@ -36,13 +35,13 @@ class TableAttentionLoss(nn.Layer):
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:return: loss
'''
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
# overlap
inters = iw * ih
......@@ -55,12 +54,12 @@ class TableAttentionLoss(nn.Layer):
# ious
ious = inters / uni
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
# enclose erea
enclose = ew * eh + eps
......
......@@ -175,7 +175,7 @@ class Kie_backbone(nn.Layer):
img, relations, texts, gt_bboxes, tag, img_size)
x = self.img_feat(img)
boxes, rois_num = self.bbox2roi(gt_bboxes)
feats = paddle.fluid.layers.roi_align(
feats = paddle.vision.ops.roi_align(
x,
boxes,
spatial_scale=1.0,
......
......@@ -18,7 +18,6 @@ from __future__ import print_function
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import paddle.fluid as fluid
import paddle
import numpy as np
......
......@@ -20,13 +20,11 @@ import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import paddle.fluid as fluid
import numpy as np
from .self_attention import WrapEncoderForFeature
from .self_attention import WrapEncoder
from paddle.static import Program
from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
import paddle.fluid.framework as framework
from collections import OrderedDict
gradient_clip = 10
......
......@@ -22,7 +22,6 @@ import paddle
from paddle import ParamAttr, nn
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import paddle.fluid as fluid
import numpy as np
gradient_clip = 10
......@@ -288,10 +287,10 @@ class PrePostProcessLayer(nn.Layer):
"layer_norm_%d" % len(self.sublayers()),
paddle.nn.LayerNorm(
normalized_shape=d_model,
weight_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))))
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(1.)),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.)))))
elif cmd == "d": # add dropout
self.functors.append(lambda x: F.dropout(
x, p=dropout_rate, mode="downscale_in_infer")
......@@ -324,7 +323,7 @@ class PrepareEncoder(nn.Layer):
def forward(self, src_word, src_pos):
src_word_emb = src_word
src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
src_word_emb = paddle.cast(src_word_emb, 'float32')
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
src_pos = paddle.squeeze(src_pos, axis=-1)
src_pos_enc = self.emb(src_pos)
......@@ -367,7 +366,7 @@ class PrepareDecoder(nn.Layer):
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
src_word = fluid.layers.cast(src_word, 'int64')
src_word = paddle.cast(src_word, 'int64')
src_word = paddle.squeeze(src_word, axis=-1)
src_word_emb = self.emb0(src_word)
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册