未验证 提交 d2d1483c 编写于 作者: B Bin Lu 提交者: GitHub

adjust some code format and wrongly spelling (#4319)

* Update ace_loss.py

* Update center_loss.py

* Update enhanced_ctc_loss.md
上级 98046bb1
......@@ -16,7 +16,7 @@ Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最
从上图可以看到, 当γ> 0时,调整系数(1-y’)^γ 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子γ用于调节简单样本权重降低的速率,当γ为0时即为交叉熵损失函数,当γ增加时,调整因子的影响也会随之增大。实验发现γ为2是最优。平衡因子α用来平衡正负样本本身的比例不均,文中α取0.25。
对于经典的CTC算法,假设某个特征序列(f<sub>1</sub>, f<sub>2</sub>, ......f<sub>t</sub>), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现 CTCLoss值和y’有如下关系:
对于经典的CTC算法,假设某个特征序列(f<sub>1</sub>, f<sub>2</sub>, ......f<sub>t</sub>), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现, CTCLoss值和y’有如下关系:
<div align="center">
<img src="./equation_ctcloss.png" width = "250" />
</div>
......@@ -38,7 +38,7 @@ A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggrega
<img src="./rec_algo_compare.png" width = "1000" />
</div>
虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。
虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。
A_CTC Loss定义如下:
<div align="center">
<img src="./equation_a_ctc.png" width = "300" />
......@@ -47,7 +47,7 @@ A_CTC Loss定义如下:
实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py)
## 3. C-CTC Loss
C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。
通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下:
<div align="center">
......
......@@ -32,6 +32,7 @@ class ACELoss(nn.Layer):
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
......@@ -42,9 +43,7 @@ class ACELoss(nn.Layer):
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss}
......@@ -27,7 +27,6 @@ class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def __init__(self,
num_classes=6625,
feat_dim=96,
......@@ -37,8 +36,7 @@ class CenterLoss(nn.Layer):
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
shape=[self.num_classes, self.feat_dim]).astype("float64")
if init_center:
assert os.path.exists(
......@@ -60,22 +58,23 @@ class CenterLoss(nn.Layer):
batch_size = feats_reshape.shape[0]
#calc feat * feat
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#calc l2 distance between feats and centers
square_feat = paddle.sum(paddle.square(feats_reshape),
axis=1,
keepdim=True)
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
square_center = paddle.sum(paddle.square(self.centers),
axis=1,
keepdim=True)
square_center = paddle.expand(
square_center, [self.num_classes, batch_size]).astype("float64")
square_center = paddle.transpose(square_center, [1, 0])
#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
distmat = paddle.add(square_feat, square_center)
feat_dot_center = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * feat_dot_center
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
......@@ -83,7 +82,8 @@ class CenterLoss(nn.Layer):
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64") #get mask
label).astype("float64")
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册