npair_loss_cn.rst 2.1 KB
Newer Older
H
Hao Wang 已提交
1 2 3 4 5 6 7 8 9 10 11
.. _cn_api_fluid_layers_npair_loss:

npair_loss
-------------------------------

.. py:function:: paddle.fluid.layers.npair_loss(anchor, positive, labels, l2_reg=0.002)

**Npair Loss Layer**

参考阅读 `Improved Deep Metric Learning with Multi class N pair Loss Objective <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf>`_

L
lvmengsi 已提交
12
NPair损失需要成对的数据。NPair损失分为两部分:第一部分是对嵌入向量进行L2正则化;第二部分是每一对数据的相似性矩阵的每一行和映射到ont-hot之后的标签的交叉熵损失的和。
H
Hao Wang 已提交
13 14

参数:
L
lvmengsi 已提交
15 16 17 18
    - **anchor** (Variable) -  锚点图像的嵌入Tensor,形状为[batch_size, embedding_dims]的2-D Tensor。数据类型:float32和float64。
    - **positive** (Variable) -  正例图像的嵌入Tensor,形状为[batch_size, embedding_dims]的2-D Tensor。数据类型:float32和float64。
    - **labels** (Variable) - 标签向量,形状为[batch_size]的1-DTensor。数据类型:float32、float64和int64。
    - **l2_reg** (float) - 嵌入向量的L2正则化系数,默认:0.002。
H
Hao Wang 已提交
19

L
lvmengsi 已提交
20
返回: Tensor。经过npair loss计算之后的结果,是一个值。
H
Hao Wang 已提交
21

L
lvmengsi 已提交
22
返回类型:Variable
H
Hao Wang 已提交
23 24 25 26 27 28

**代码示例**:

.. code-block:: python

    import paddle.fluid as fluid
L
lvmengsi 已提交
29
    import numpy as np
H
Hao Wang 已提交
30 31 32 33 34 35 36
    anchor = fluid.layers.data(
              name = 'anchor', shape = [18, 6], dtype = 'float32', append_batch_size=False)
    positive = fluid.layers.data(
              name = 'positive', shape = [18, 6], dtype = 'float32', append_batch_size=False)
    labels = fluid.layers.data(
              name = 'labels', shape = [18], dtype = 'float32', append_batch_size=False)

L
lvmengsi 已提交
37 38 39 40 41 42 43 44 45
    res = fluid.layers.npair_loss(anchor, positive, labels, l2_reg = 0.002)
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    a = np.random.rand(18, 6).astype("float32")
    p = np.random.rand(18, 6).astype("float32")
    l = np.random.rand(18).astype("float32")
    output = exe.run(feed={"anchor": a, "positive": p, "labels": l}, fetch_list=[res])
    print(output)
H
Hao Wang 已提交
46 47 48 49 50 51