未验证 提交 55a785bb 编写于 作者: J jerrywgz 提交者: GitHub

Merge pull request #15949 from ceci3/npair_loss0

add npair loss op
......@@ -144,7 +144,7 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon'
paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', 'c317aa595deb31649083c8faa91cdb97'))
paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '12c5bbb8b38c42e623fbc47611d766e1'))
paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '1ba0508d573f65feecf3564dce22aa1d'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', 'b3ecb819454832885c1f0f3ab9a5b938'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '7a1966d7c3a48f1fc0881cdaf5d83b0b'))
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', '06211aefc50c5a3e940d7204d859cdf7'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', 'e4fb4ed511b2293b8f04f7e872afbfd7'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '735fa9758a6d7ff3b47d7b827f961c1d'))
......@@ -221,6 +221,7 @@ paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels'
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99'))
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7'))
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e'))
......
......@@ -187,6 +187,7 @@ __all__ = [
'teacher_student_sigmoid_loss',
'huber_loss',
'tree_conv',
'npair_loss',
]
kIgnoreIndex = -100
......@@ -6977,7 +6978,6 @@ def image_resize(input,
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
......@@ -10652,3 +10652,60 @@ def tree_conv(nodes_vector,
else:
pre_activation = out
return helper.append_activation(pre_activation)
from .ops import square
from .control_flow import equal
def npair_loss(anchor, positive, labels, l2_reg=0.002):
'''
**Npair Loss Layer**
Read `Improved Deep Metric Learning with Multi class N pair Loss Objective <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf>`_ .
Npair loss requires paired data. Npair loss has two parts: the first part is L2
regularizer on the embedding vector; the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
Args:
anchor(Variable): embedding vector for the anchor image. shape=[batch_size, embedding_dims]
positive(Variable): embedding vector for the positive image. shape=[batch_size, embedding_dims]
labels(Variable): 1-D tensor. shape=[batch_size]
l2_reg(float32): L2 regularization term on embedding vector, default: 0.002
Returns:
npair loss(Variable): return npair loss, shape=[1]
Examples:
.. code-block:: python
anchor = fluid.layers.data(
name = 'anchor', shape = [18, 6], dtype = 'float32', append_batch_size=False)
positive = fluid.layers.data(
name = 'positive', shape = [18, 6], dtype = 'float32', append_batch_size=False)
labels = fluid.layers.data(
name = 'labels', shape = [18], dtype = 'float32', append_batch_size=False)
npair_loss = fluid.layers.npair_loss(anchor, positive, labels, l2_reg = 0.002)
'''
Beta = 0.25
batch_size = labels.shape[0]
labels = reshape(labels, shape=[batch_size, 1], inplace=True)
labels = expand(labels, expand_times=[1, batch_size])
labels = equal(labels, transpose(labels, perm=[1, 0])).astype('float32')
labels = labels / reduce_sum(labels, dim=1, keep_dim=True)
l2loss = reduce_mean(reduce_sum(square(anchor), 1)) \
+ reduce_mean(reduce_sum(square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = matmul(
anchor, positive, transpose_x=False, transpose_y=True)
softmax_value = softmax(similarity_matrix)
cross_entropy = -1 * reduce_sum(labels * log(softmax_value), 0)
celoss = reduce_mean(cross_entropy)
return l2loss + celoss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
def npairloss(anchor, positive, labels, l2_reg=0.002):
def softmax_cross_entropy_with_logits(logits, labels):
logits = np.exp(logits)
logits = logits / np.sum(logits, axis=1).reshape(-1, 1)
return np.mean(
-np.sum(labels * np.log(logits), axis=1), dtype=np.float32)
batch_size = labels.shape[0]
labels = np.reshape(labels, (batch_size, 1))
labels = np.equal(labels, labels.transpose()).astype(float)
labels = labels / np.sum(labels, axis=1, keepdims=True)
l2loss = np.mean(np.sum(np.power(anchor, 2), 1)) + np.mean(
np.sum(np.power(positive, 2), 1))
l2loss = (l2loss * 0.25 * l2_reg).astype(np.float32)
similarity_matrix = np.matmul(anchor, positive.transpose())
celoss = np.mean(
softmax_cross_entropy_with_logits(similarity_matrix, labels))
return l2loss + celoss
class TestNpairLossOp(unittest.TestCase):
def setUp(self):
self.dtype = np.float32
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_npair_loss(self):
reg_lambda = 0.002
num_data, feat_dim, num_classes = 18, 6, 3
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
embeddings_anchor = np.random.rand(num_data,
feat_dim).astype(np.float32)
embeddings_positive = np.random.rand(num_data,
feat_dim).astype(np.float32)
row_labels = np.random.randint(
0, num_classes, size=(num_data)).astype(np.float32)
out_loss = npairloss(
embeddings_anchor,
embeddings_positive,
row_labels,
l2_reg=reg_lambda)
anc = fluid.layers.create_tensor(
dtype='float32', persistable=True, name='anc')
pos = fluid.layers.create_tensor(
dtype='float32', persistable=True, name='pos')
lab = fluid.layers.create_tensor(
dtype='float32', persistable=True, name='lab')
fluid.layers.assign(input=embeddings_anchor, output=anc)
fluid.layers.assign(input=embeddings_positive, output=pos)
fluid.layers.assign(input=row_labels, output=lab)
npair_loss_op = fluid.layers.npair_loss(
anchor=anc, positive=pos, labels=lab, l2_reg=reg_lambda)
out_tensor = exe.run(feed={'anc': anc,
'pos': pos,
'lab': lab},
fetch_list=[npair_loss_op.name])
self.__assert_close(
out_tensor,
out_loss,
"inference output are different at " + str(place) + ", " +
str(np.dtype('float32')) + str(np.array(out_tensor)) +
str(out_loss),
atol=1e-3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册