提交 4b7bf06e 编写于 作者: C ceci3

test=develop

上级 454f4f21
......@@ -220,6 +220,7 @@ paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels',
paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0))
paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.tree_conv ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None))
paddle.fluid.layers.npair_loss ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......
......@@ -186,6 +186,7 @@ __all__ = [
'teacher_student_sigmoid_loss',
'huber_loss',
'tree_conv',
'npair_loss',
]
kIgnoreIndex = -100
......@@ -10560,3 +10561,52 @@ def tree_conv(nodes_vector,
else:
pre_activation = out
return helper.append_activation(pre_activation)
def npair_loss(anchor, positive, labels, l2_reg=0.002):
'''
**Npair Loss Layer**
see http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
Npair loss requires paired data. Npair loss has two parts, the first part is L2
regularizer on the embedding vector, the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
Args:
anchor(Variable): embedding vector for the anchor image. shape=[batch_size, embedding_dims]
positive(Variable): embedding vector for the positive image. shape=[batch_size, embedding_dims]
labels(Varieble): 1-D tensor. shape=[batch_size]
l2_res(float32): L2 regularization term on embedding vector, default: 0.02
Returns:
npair loss(Variable): return npair loss, shape=[1]
Examples:
.. code-block:: python
npair_loss = fluid.layers.npair_loss(anchor, positive, labels, l2_reg)
'''
Beta = 0.25
batch_size = labels.shape[0]
labels = reshape(labels, shape=[batch_size, 1], inplace=True)
labels = expand(labels, expand_times=[1, batch_size])
from .control_flow import equal
from .ops import square
labels = equal(labels, transpose(labels, perm=[1, 0])).astype('float32')
labels = labels / reduce_sum(labels, dim=1, keep_dim=True)
l2loss = reduce_mean(reduce_sum(square(anchor), 1)) \
+ reduce_mean(reduce_sum(square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = matmul(
anchor, positive, transpose_x=False, transpose_y=True)
softmax_value = softmax(similarity_matrix)
cross_entropy = -1 * reduce_sum(labels * log(softmax_value), 0)
celoss = reduce_mean(cross_entropy)
return l2loss + celoss
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
def npairloss(anchor, positive, labels, l2_reg=0.002):
def softmax_cross_entropy_with_logits(logits, labels):
logits = np.exp(logits)
logits = logits / np.sum(logits, axis=1).reshape(-1, 1)
return np.mean(
-np.sum(labels * np.log(logits), axis=1), dtype=np.float32)
batch_size = labels.shape[0]
labels = np.reshape(labels, (batch_size, 1))
labels = np.equal(labels, labels.transpose()).astype(float)
labels = labels / np.sum(labels, axis=1, keepdims=True)
l2loss = np.mean(np.sum(np.power(anchor, 2), 1)) + np.mean(
np.sum(np.power(positive, 2), 1))
l2loss = (l2loss * 0.25 * l2_reg).astype(np.float32)
similarity_matrix = np.matmul(anchor, positive.transpose())
celoss = np.mean(
softmax_cross_entropy_with_logits(similarity_matrix, labels))
return l2loss + celoss
def create_or_get_tensor(scope, var_name, var, place):
tensor = scope.var(var_name).get_tensor()
if var is not None:
assert isinstance(var, np.ndarray)
tensor.set_recursive_sequence_lengths([])
tensor.set(var, place)
return tensor
class TestNpairLossOp(unittest.TestCase):
def setUp(self):
self.dtype = np.float32
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(self, place, dtype, shape):
reg_lambda = 0.002
num_data, feat_dim, num_classes = shape[0], shape[1], shape[2]
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
embeddings_anchor = np.random.rand(num_data,
feat_dim).astype(np.float32)
embeddings_positive = np.random.rand(num_data,
feat_dim).astype(np.float32)
labels = np.random.randint(
0, num_classes, size=(num_data)).astype(np.float32)
out_loss = npairloss(
embeddings_anchor, embeddings_positive, labels, l2_reg=reg_lambda)
anchor_tensor = fluid.layers.data(
name='anchor',
shape=[num_data, feat_dim],
dtype=self.dtype,
append_batch_size=False)
positive_tensor = fluid.layers.data(
name='positive',
shape=[num_data, feat_dim],
dtype=self.dtype,
append_batch_size=False)
labels_tensor = fluid.layers.data(
name='labels',
shape=[num_data],
dtype=self.dtype,
append_batch_size=False)
npair_loss_op = fluid.layers.npair_loss(
anchor=anchor_tensor,
positive=positive_tensor,
labels=labels_tensor,
l2_reg=reg_lambda)
out_tensor = exe.run(feed={
'anchor': embeddings_anchor,
'positive': embeddings_positive,
'labels': labels
},
fetch_list=[npair_loss_op.name])
self.__assert_close(
out_tensor,
out_loss,
"inference output are different at " + str(place) + ", " +
str(np.dtype(dtype)) + str(np.array(out_tensor)) + str(out_loss),
atol=1e-3)
def test_check_output(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.ops_support_gpu("npair_loss"):
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place, self.dtype, [18, 6, 3])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册