From d983fc34364cf5d231a94297f841724054f47bdf Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 28 Nov 2022 10:34:55 +0800 Subject: [PATCH] clear fluid api: warpctc, nce, identity_loss (#48142) * clear fluid api: warpctc, nce, identity_loss * fix test_layers.py __init__.py * fix loss.py * change __init__.py and api calling method * fix nce * fix nce * fix fluid.data * delete warpctc api document * fix loss.py * fix ctc_loss * fix test_warpctc_op.py * fix test_layers.py * fix some bug * fix conflict * fix ci bug * Empty Commit test=allcase * fix ci bug --- python/paddle/fluid/layers/loss.py | 468 ------------------ .../unittests/ipu/test_warpctc_op_ipu.py | 3 +- .../tests/unittests/test_dist_transpiler.py | 2 +- .../test_imperative_load_static_param.py | 4 +- .../fluid/tests/unittests/test_layers.py | 25 +- .../paddle/fluid/tests/unittests/test_nce.py | 18 +- .../fluid/tests/unittests/test_warpctc_op.py | 227 ++++----- .../unittests/xpu/test_warpctc_op_xpu.py | 64 ++- python/paddle/incubate/__init__.py | 2 +- python/paddle/incubate/nn/loss.py | 75 +++ python/paddle/nn/functional/loss.py | 67 ++- python/paddle/static/nn/__init__.py | 2 +- python/paddle/static/nn/loss.py | 259 ++++++++++ 13 files changed, 578 insertions(+), 638 deletions(-) create mode 100644 python/paddle/incubate/nn/loss.py create mode 100644 python/paddle/static/nn/loss.py diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index 5af7111e58..306437f754 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -37,8 +37,6 @@ from paddle import _C_ops, _legacy_C_ops __all__ = [ 'cross_entropy', 'square_error_cost', - 'warpctc', - 'nce', 'softmax_with_cross_entropy', 'sigmoid_cross_entropy_with_logits', ] @@ -182,416 +180,6 @@ def square_error_cost(input, label): return paddle.nn.functional.square_error_cost(input, label) -def warpctc( - input, - label, - blank=0, - norm_by_times=False, - input_length=None, - label_length=None, -): - """ - An operator integrating the open source Warp-CTC library - (https://github.com/baidu-research/warp-ctc) - to compute Connectionist Temporal Classification (CTC) loss. - It can be aliased as softmax with CTC, since a native softmax activation is - interated to the Warp-CTC library to normalize values for each row of the - input tensor. - - Args: - input (Variable): The unscaled probabilities of variable-length sequences, - which is a 2-D Tensor with LoD information, or a 3-D Tensor without Lod - information. When it is a 2-D LodTensor, its shape is - `[Lp, num_classes + 1]`, where `Lp` is the sum of all input - sequences' length and `num_classes` is the true number of classes. - (not including the blank label). When it is a 3-D Tensor, its shape - is `[max_logit_length, batch_size, num_classes + 1]`, - where `max_logit_length` is the longest length of - input logit sequence. The data type should be float32 or float64. - label (Variable): The ground truth of variable-length sequence, - which must be a 2-D Tensor with LoD information or a 3-D Tensor without - LoD information, needs to be consistent with the coressponding input. - When it is a 2-D LoDTensor, its shape is `[Lg, 1]`, where `Lg` is the sum - of all labels' length. When it is a 3-D Tensor, its shape is - `[batch_size, max_label_length]`, where `max_label_length` is the longest - length of label sequence. Data type must be int32. - blank (int, default 0): The blank label index of Connectionist - Temporal Classification (CTC) loss, which is in the - half-opened interval `[0, num_classes + 1)`. The data type must be int32. - norm_by_times(bool, default false): Whether to normalize the gradients - by the number of time-step, which is also the sequence's length. - There is no need to normalize the gradients if warpctc layer was - followed by a mean_op. - input_length(Variable): The length for each input sequence if it is - of Tensor type, it should have shape `[batch_size]` and dtype int64. - label_length(Variable): The length for each label sequence if it is - of Tensor type, it should have shape `[batch_size]` and dtype int64. - - Returns: - Variable: The Connectionist Temporal Classification (CTC) loss, - which is a 2-D Tensor with the shape `[batch_size, 1]`. - The date type is the same as input. - - Examples: - - .. code-block:: python - - # using LoDTensor - import paddle - import paddle.fluid as fluid - import numpy as np - - # lengths of logit sequences - seq_lens = [2,6] - # lengths of label sequences - label_lens = [2,3] - # class num - class_num = 5 - - paddle.enable_static() - logits = fluid.data(name='logits',shape=[None, class_num+1], - dtype='float32',lod_level=1) - label = fluid.data(name='label', shape=[None, 1], - dtype='int32', lod_level=1) - cost = fluid.layers.warpctc(input=logits, label=label) - place = fluid.CPUPlace() - x = fluid.create_lod_tensor( - np.random.rand(np.sum(seq_lens), class_num+1).astype("float32"), - [seq_lens], place) - y = fluid.create_lod_tensor( - np.random.randint(0, class_num, [np.sum(label_lens), 1]).astype("int32"), - [label_lens], place) - exe = fluid.Executor(place) - output= exe.run(fluid.default_main_program(), - feed={"logits": x,"label": y}, - fetch_list=[cost.name]) - print(output) - - .. code-block:: python - - # using Tensor - import paddle - import paddle.fluid as fluid - import numpy as np - - # length of the longest logit sequence - max_seq_length = 5 - #length of the longest label sequence - max_label_length = 3 - # number of logit sequences - batch_size = 16 - # class num - class_num = 5 - paddle.enable_static() - logits = fluid.data(name='logits', - shape=[max_seq_length, batch_size, class_num+1], - dtype='float32') - logits_length = fluid.data(name='logits_length', shape=[None], - dtype='int64') - label = fluid.data(name='label', shape=[batch_size, max_label_length], - dtype='int32') - label_length = fluid.data(name='labels_length', shape=[None], - dtype='int64') - cost = fluid.layers.warpctc(input=logits, label=label, - input_length=logits_length, - label_length=label_length) - place = fluid.CPUPlace() - x = np.random.rand(max_seq_length, batch_size, class_num+1).astype("float32") - y = np.random.randint(0, class_num, [batch_size, max_label_length]).astype("int32") - exe = fluid.Executor(place) - output= exe.run(fluid.default_main_program(), - feed={"logits": x, - "label": y, - "logits_length": np.array([max_seq_length]*batch_size).astype("int64"), - "labels_length": np.array([max_label_length]*batch_size).astype("int64")}, - fetch_list=[cost.name]) - print(output) - """ - if in_dygraph_mode(): - if input_length is None or label_length is None: - raise ValueError( - "input_length and label_length must not be None in dygraph mode!" - ) - loss_out = _C_ops.warpctc( - input, label, input_length, label_length, blank, norm_by_times - ) - return loss_out - if _non_static_mode(): - if input_length is None or label_length is None: - raise ValueError( - "input_length and label_length must not be None in dygraph mode!" - ) - grad, loss_out = _legacy_C_ops.warpctc( - input, - label, - input_length, - label_length, - 'blank', - blank, - 'norm_by_times', - norm_by_times, - ) - return loss_out - helper = LayerHelper('warpctc', **locals()) - check_variable_and_dtype(input, 'input', ['float32', 'float64'], "warpctc") - check_variable_and_dtype(label, 'label', ['int32'], "warpctc") - this_inputs = {'Logits': [input], 'Label': [label]} - if input_length is not None and label_length is not None: - check_variable_and_dtype( - input_length, 'LogitsLength', ['int64'], "warpctc" - ) - check_variable_and_dtype( - label_length, 'LabelLength', ['int64'], "warpctc" - ) - this_inputs['LogitsLength'] = [input_length] - this_inputs['LabelLength'] = [label_length] - - loss_out = helper.create_variable_for_type_inference(dtype=input.dtype) - grad_out = helper.create_variable_for_type_inference(dtype=input.dtype) - - helper.append_op( - type='warpctc', - inputs=this_inputs, - outputs={'WarpCTCGrad': [grad_out], 'Loss': [loss_out]}, - attrs={ - 'blank': blank, - 'norm_by_times': norm_by_times, - }, - ) - return loss_out - - -# FIXME(wuyi): let docstring_checker.py understand @autodoc. -# For now, the comments in c++ use types like Tensor, but in python side -# the type is often "Variable", and arguments may vary. -@static_only -@templatedoc(op_type="nce") -def nce( - input, - label, - num_total_classes, - sample_weight=None, - param_attr=None, - bias_attr=None, - num_neg_samples=None, - name=None, - sampler="uniform", - custom_dist=None, - seed=0, - is_sparse=False, -): - """ - :api_attr: Static Graph - - ${comment} - - Args: - input (Tensor): Input tensor, 2-D tensor with shape [batch_size, dim], - and data type is float32 or float64. - label (Tensor): Input label, 2-D tensor with shape [batch_size, num_true_class], - and data type is int64. - num_total_classes (int):${num_total_classes_comment}. - sample_weight (Tensor|None): A Tensor of shape [batch_size, 1] - storing a weight for each sample. The default weight for each - sample is 1.0. - param_attr (ParamAttr|None): To specify the weight parameter attribute. - Default: None, which means the default weight parameter property is - used. See usage for details in :ref:`api_fluid_ParamAttr` . - bias_attr (ParamAttr|None): To specify the bias parameter attribute. - Default: None, which means the default bias parameter property is - used. See usage for details in :ref:`api_fluid_ParamAttr` . - num_neg_samples (int): ${num_neg_samples_comment}. - name(str|None): For detailed information, please refer to - :ref:`api_guide_Name` . Usually name is no need to set and None by default. - sampler (str, optional): The sampler used to sample class from negative classes. - It can be 'uniform', 'log_uniform' or 'custom_dist'. - default: 'uniform'. - custom_dist (nd.array|None): A numpy ndarray with size=num_total_classes. - It is used when sampler is set to 'custom_dist'. - custom_dist[i] is the probability of i-th class to be sampled. - default: None. - seed (int, optional): The seed used in sampler. Default 0, means no random seed. - is_sparse(bool, optional): The flag indicating whether to use sparse update, - the weight@GRAD and bias@GRAD will be changed to SelectedRows. Default False. - - Returns: - Tensor: The output nce loss. - - Examples: - .. code-block:: python - - - import paddle - import numpy as np - - paddle.enable_static() - - window_size = 5 - words = [] - for i in range(window_size): - words.append(paddle.static.data( - name='word_{0}'.format(i), shape=[-1, 1], dtype='int64')) - - dict_size = 10000 - label_word = int(window_size / 2) + 1 - - embs = [] - for i in range(window_size): - if i == label_word: - continue - - emb = paddle.static.nn.embedding(input=words[i], size=[dict_size, 32], - param_attr='embed', is_sparse=True) - embs.append(emb) - - embs = paddle.concat(x=embs, axis=1) - loss = paddle.static.nn.nce(input=embs, label=words[label_word], - num_total_classes=dict_size, param_attr='nce.w_0', - bias_attr='nce.b_0') - - #or use custom distribution - dist = np.array([0.05,0.5,0.1,0.3,0.05]) - loss = paddle.static.nn.nce(input=embs, label=words[label_word], - num_total_classes=5, param_attr='nce.w_1', - bias_attr='nce.b_1', - num_neg_samples=3, - sampler="custom_dist", - custom_dist=dist) - """ - helper = LayerHelper('nce', **locals()) - check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nce') - check_variable_and_dtype(label, 'label', ['int64'], 'nce') - - dim = input.shape[1] - num_true_class = label.shape[1] - w = helper.create_parameter( - attr=helper.param_attr, - shape=[num_total_classes, dim], - is_bias=False, - dtype=input.dtype, - ) - inputs = {} - if helper.bias_attr: - b = helper.create_parameter( - attr=helper.bias_attr, - shape=[num_total_classes, 1], - is_bias=True, - dtype=input.dtype, - ) - inputs['Bias'] = b - cost = helper.create_variable_for_type_inference(dtype=input.dtype) - sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype) - sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype) - - inputs['Input'] = input - inputs['Label'] = label - inputs['Weight'] = w - inputs['SampleWeight'] = sample_weight if sample_weight is not None else [] - - if sampler == "uniform": - sampler = 0 - elif sampler == "log_uniform": - sampler = 1 - elif sampler == "custom_dist": - assert custom_dist is not None - - custom_dist_len = num_total_classes - alias_probs_ = [0] * custom_dist_len - alias_ = [0] * custom_dist_len - bigs = [] - littles = [] - for i in range(custom_dist_len): - normal_prob = custom_dist[i] * custom_dist_len - if normal_prob - 1.0 > 0: - bigs.append((i, normal_prob)) - elif 1.0 - normal_prob > 0: - littles.append((i, normal_prob)) - else: - alias_probs_[i] = normal_prob - alias_[i] = -1 - - while len(bigs) and len(littles): - big = bigs.pop(0) - little = littles.pop(0) - - big_idx = big[0] - big_prob = big[1] - - alias_probs_[little[0]] = little[1] - alias_[little[0]] = big_idx - big_left = big[1] + little[1] - 1 - if big_left - 1.0 > 0: - bigs.append((big_idx, big_left)) - elif 1.0 - big_left > 0: - littles.append((big_idx, big_left)) - else: - alias_probs_[big_idx] = big_left - alias_[big_idx] = -1 - - if len(bigs): - big = bigs.pop(0) - alias_probs_[big[0]] = 1.0 - alias_[big[0]] = -1 - if len(littles): - little = littles.pop(0) - alias_probs_[little[0]] = 1.0 - alias_[little[0]] = -1 - - def _init_by_numpy_array(numpy_array): - ret = helper.create_parameter( - attr=ParamAttr(), - shape=numpy_array.shape, - dtype=numpy_array.dtype, - default_initializer=NumpyArrayInitializer(numpy_array), - ) - ret.stop_gradient = True - return ret - - inputs['CustomDistProbs'] = _init_by_numpy_array( - np.array(custom_dist).astype('float32') - ) - inputs['CustomDistAlias'] = _init_by_numpy_array( - np.array(alias_).astype('int32') - ) - inputs['CustomDistAliasProbs'] = _init_by_numpy_array( - np.array(alias_probs_).astype('float32') - ) - sampler = 2 - else: - raise Exception("Unsupported sampler type.") - - if num_neg_samples is None: - num_neg_samples = 10 - else: - num_neg_samples = int(num_neg_samples) - - remote_prefetch = is_sparse - print( - "With sparse mode, if your models has only small parameter prefetch may cause speed down" - ) - - attrs = { - 'num_total_classes': int(num_total_classes), - 'num_neg_samples': num_neg_samples, - 'seed': seed, - 'sampler': sampler, - 'is_sparse': is_sparse, - 'remote_prefetch': remote_prefetch, - } - - helper.append_op( - type='nce', - inputs=inputs, - outputs={ - 'Cost': cost, - 'SampleLogits': sample_logits, - 'SampleLabels': sample_labels, - }, - attrs=attrs, - ) - return cost / (num_neg_samples + 1) - - def softmax_with_cross_entropy( logits, label, @@ -706,62 +294,6 @@ def softmax_with_cross_entropy( ) -def identity_loss(x, reduction="none"): - r"""Marks a tensor as being part of the loss calculation for IPU. - - This operator is used to handle on the (final) loss of a model so that - it is used as the start of backpropagation. - - When `reduction` is `none`, return raw `Out`. - - When `reduction` is `mean`, return - - .. math:: - Out = MEAN(Out) - - When `reduction` is `sum`, return - - .. math:: - Out = SUM(Out) - - Parameters: - x (Variable): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of - additional dimensions. It's data type should be float32, float64 on CPU and float16, float32 on IPU. - reduction(str|int, optional): Reduce the loss output. Supported string values are: 'sum', 'mean', 'none' - the corresponding int values are 0, 1, 2 respectively. The default value is "none". - - Returns: - Variable: The loss ``Tensor`` with the specified reduction applied. - - Examples: - - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - loss = fluid.data(name="loss", shape=[-1, 1], dtype="float32") - out = paddle.incubate.identity_loss(loss, reduction=1) - """ - if isinstance(reduction, str): - reduction = {"sum": 0, "mean": 1, "none": 2}.get(reduction.lower()) - if reduction is None: - raise Exception("Unsupported reduction type.") - - if _non_static_mode(): - return _legacy_C_ops.identity_loss(x, "reduction", reduction) - - check_variable_and_dtype(x, 'x', ['float32', 'float64'], "identity_loss") - attrs = {'reduction': reduction} - helper = LayerHelper('identity_loss', **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="identity_loss", inputs={"X": x}, outputs={"Out": out}, attrs=attrs - ) - return out - - @templatedoc() def sigmoid_cross_entropy_with_logits( x, label, ignore_index=kIgnoreIndex, name=None, normalize=False diff --git a/python/paddle/fluid/tests/unittests/ipu/test_warpctc_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_warpctc_op_ipu.py index a162f20a0a..8491f15ece 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_warpctc_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_warpctc_op_ipu.py @@ -98,11 +98,12 @@ class TestBase(IPUOpTest): label_length = paddle.static.data( name=self.feed_list[3], shape=self.feed_shape[3], dtype='int64' ) - out = paddle.fluid.layers.warpctc( + out = paddle.nn.functional.ctc_loss( logits, labels, input_length=input_length, label_length=label_length, + reduction='mean', **self.attrs ) loss = paddle.mean(out) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 8b563dbc33..2c53d27efa 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -1351,7 +1351,7 @@ class TestRemoteNce(TestDistLookupTableBase): ) ) - cost = fluid.layers.nce( + cost = paddle.static.nn.nce( input=input, label=label, num_total_classes=num_total_classes, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py index 7147e924a1..3ee24ec982 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py @@ -76,8 +76,8 @@ class TestDygraphLoadStatic(unittest.TestCase): nce_label = fluid.data( name="nce_label", shape=[None, 10], dtype='int64' ) - nce_out_1 = fluid.layers.nce(nce_in, nce_label, 10000) - nce_out_2 = fluid.layers.nce(nce_in, nce_label, 10000) + nce_out_1 = paddle.static.nn.nce(nce_in, nce_label, 10000) + nce_out_2 = paddle.static.nn.nce(nce_in, nce_label, 10000) prelu_in = fluid.data( name="prelu_in", shape=[None, 5, 10, 10], dtype='float32' diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 4a9c5f907a..22b4e22061 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1314,7 +1314,7 @@ class TestLayer(LayerTest): embs = layers.concat(input=embs, axis=1) wl = fluid.layers.unsqueeze(words[label_word], axes=[0]) - nce_loss = layers.nce( + nce_loss = paddle.static.nn.nce( input=embs, label=wl, num_total_classes=dict_size, @@ -3274,7 +3274,7 @@ class TestBook(LayerTest): embs.append(emb) embs = layers.concat(input=embs, axis=1) - loss = layers.nce( + loss = paddle.static.nn.nce( input=embs, label=words[label_word], num_total_classes=dict_size, @@ -4343,21 +4343,24 @@ class TestBook(LayerTest): def test_warpctc_with_padding(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): - input_length = layers.data( + input_length = paddle.static.data( name='logits_length', shape=[11], dtype='int64' ) - label_length = layers.data( + label_length = paddle.static.data( name='labels_length', shape=[12], dtype='int64' ) - label = layers.data(name='label', shape=[12, 1], dtype='int32') - predict = layers.data( + label = paddle.static.data( + name='label', shape=[12, 1], dtype='int32' + ) + predict = paddle.static.data( name='predict', shape=[4, 4, 8], dtype='float32' ) - output = layers.warpctc( - input=predict, - label=label, - input_length=input_length, - label_length=label_length, + output = paddle.nn.functional.ctc_loss( + log_probs=predict, + labels=label, + input_lengths=input_length, + label_lengths=label_length, + reduction='none', ) return output diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index 885d12c6fa..2691bf2c98 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -210,7 +210,7 @@ class TestNCECase1SelectedRows(unittest.TestCase): ) ) - cost = fluid.layers.nce( + cost = paddle.static.nn.nce( input=input, label=label, num_total_classes=num_total_classes, @@ -291,7 +291,9 @@ class TestNCE_OpError(unittest.TestCase): name='label1', shape=[-1, 4], dtype="int64" ) # the input(input) of nce layer must be Variable. - self.assertRaises(TypeError, fluid.layers.nce, input1, label1, 5) + self.assertRaises( + TypeError, paddle.static.nn.nce, input1, label1, 5 + ) input2 = fluid.layers.data( name='input2', shape=[-1, 4], dtype="float32" @@ -300,7 +302,9 @@ class TestNCE_OpError(unittest.TestCase): np.array([0.0, 3.0, 2.0, 4.0]), [[1, 1, 2]], fluid.CPUPlace() ) # the input(label) of nce layer must be Variable. - self.assertRaises(TypeError, fluid.layers.nce, input2, label2, 5) + self.assertRaises( + TypeError, paddle.static.nn.nce, input2, label2, 5 + ) input3 = fluid.layers.data( name='input3', shape=[-1, 4], dtype="float16" @@ -309,7 +313,9 @@ class TestNCE_OpError(unittest.TestCase): name='label3', shape=[-1, 1], dtype="int64" ) # the data type of input(input) must be float32 or float64. - self.assertRaises(TypeError, fluid.layers.nce, input3, label3, 5) + self.assertRaises( + TypeError, paddle.static.nn.nce, input3, label3, 5 + ) input4 = fluid.layers.data( name='input4', shape=[-1, 4], dtype="float32" @@ -318,7 +324,9 @@ class TestNCE_OpError(unittest.TestCase): name='label4', shape=[-1, 1], dtype="int32" ) # the data type of input(label) must be int64. - self.assertRaises(TypeError, fluid.layers.nce, input4, label4, 5) + self.assertRaises( + TypeError, paddle.static.nn.nce, input4, label4, 5 + ) class TestDygraphNCE_OpError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index b3febb9b40..d9c129505e 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -17,7 +17,7 @@ import unittest import numpy as np from op_test import OpTest from test_softmax_op import stable_softmax -import paddle.fluid as fluid +from paddle.fluid.framework import _test_eager_guard import paddle.fluid.core as core from paddle.fluid import Program, program_guard import paddle @@ -206,19 +206,6 @@ class CTCForward: return self.loss -def python_api( - logits, - label, - logits_length=None, - labels_length=None, - blank=0, - norm_by_times=False, -): - return paddle.fluid.layers.warpctc( - logits, label, blank, norm_by_times, logits_length, labels_length - ) - - class TestWarpCTCOp(OpTest): def config(self): self.batch_size = 4 @@ -317,7 +304,6 @@ class TestWarpCTCOpWithPadding(OpTest): def setUp(self): self.op_type = "warpctc" - self.python_api = python_api self.python_out_sig = ["Loss"] self.config() @@ -394,7 +380,7 @@ class TestWarpCTCOpWithPadding(OpTest): } def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): self.outputs['WarpCTCGrad'] = self.gradient @@ -439,7 +425,6 @@ class TestWarpCTCOpFp64(OpTest): def setUp(self): self.op_type = "warpctc" - self.python_api = python_api self.python_out_sig = ["Loss"] self.config() @@ -516,67 +501,74 @@ class TestWarpCTCOpFp64(OpTest): } def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): self.outputs['WarpCTCGrad'] = self.gradient - self.check_grad(["Logits"], "Loss", check_eager=True) + self.check_grad(["Logits"], "Loss", check_eager=False) class TestWarpCTCOpError(unittest.TestCase): def test_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): - logits = fluid.data( + logits = paddle.static.data( name='logits', shape=[5, 16, 6], dtype='float32' ) - logits_length = fluid.data( + logits_length = paddle.static.data( name='logits_length', shape=[None], dtype='int64' ) - label = fluid.data(name='label', shape=[16, 3], dtype='int32') - label_length = fluid.data( + label = paddle.static.data( + name='label', shape=[16, 3], dtype='int32' + ) + label_length = paddle.static.data( name='labels_length', shape=[None], dtype='int64' ) def test_logits_Variable(): logits_data = np.random.rand(5, 16, 6).astype(logits.dtype) - fluid.layers.warpctc( - input=logits_data, - label=label, - input_length=logits_length, - label_length=label_length, + paddle.nn.functional.ctc_loss( + log_probs=logits_data, + labels=label, + input_lengths=logits_length, + label_lengths=label_length, + reduction='none', ) self.assertRaises(TypeError, test_logits_Variable) def test_label_Variable(): label_data = np.random.randint(0, 5, [5, 1]).astype("int32") - fluid.layers.warpctc( - input=logits, - label=label_data, - input_length=logits_length, - label_length=label_length, + paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=label_data, + input_lengths=logits_length, + label_lengths=label_length, + reduction='none', ) self.assertRaises(TypeError, test_label_Variable) def test_logits_len_Variable(): logits_length_data = np.array([5] * 16).astype("int64") - fluid.layers.warpctc( - input=logits, - label=label, - input_length=logits_length_data, - label_length=label_length, + paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=label, + input_lengths=logits_length_data, + label_lengths=label_length, + reduction='none', ) self.assertRaises(TypeError, test_logits_len_Variable) def test_label_len_Variable(): label_length_data = np.array([3] * 16).astype("int64") - fluid.layers.warpctc( - input=logits, - label=label, - input_length=logits_length, + paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=label, + input_lengths=logits_length, label_length=label_length_data, + reduction='none', ) self.assertRaises(TypeError, test_label_len_Variable) @@ -590,7 +582,13 @@ class TestWarpCTCOpError(unittest.TestCase): softmax = paddle.to_tensor(logits) labels = paddle.to_tensor(labels) - fluid.layers.warpctc(input=softmax, label=labels) + paddle.nn.functional.ctc_loss( + log_probs=softmax, + labels=labels, + input_lengths=None, + label_lengths=None, + reduction='none', + ) paddle.disable_static() self.assertRaises(ValueError, test_dygraph_with_lod) @@ -598,74 +596,6 @@ class TestWarpCTCOpError(unittest.TestCase): class TestCTCLossAPICase(unittest.TestCase): - def test_functinal_api(self): - self.batch_size = 4 - self.num_classes = CUDA_BLOCK_SIZE + 2 - self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64) - self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64) - self.blank = self.num_classes - 1 - self.norm_by_times = False - - logits = np.random.uniform( - 0.1, - 1.0, - [max(self.logits_length), self.batch_size, self.num_classes], - ).astype("float32") - softmax = np.apply_along_axis(stable_softmax, -1, logits) - # labels should not be blank - labels = np.random.randint( - 0, - self.num_classes - 1, - [self.batch_size, max(self.labels_length)], - dtype="int32", - ) - - ctc = CTCForward( - softmax, - self.logits_length, - labels, - self.labels_length, - self.num_classes, - self.batch_size, - self.blank, - self.norm_by_times, - ) - loss_np = ctc.forward() - - paddle.disable_static() - softmax = paddle.to_tensor(logits) - labels = paddle.to_tensor(labels) - logits_length = paddle.to_tensor(self.logits_length) - labels_length = paddle.to_tensor(self.labels_length) - loss_pd_mean = F.ctc_loss( - softmax, - labels, - logits_length, - labels_length, - blank=self.blank, - reduction='mean', - ) - loss_pd_mean = loss_pd_mean.numpy() - - loss_pd_sum = F.ctc_loss( - softmax, - labels, - logits_length, - labels_length, - blank=self.blank, - reduction='sum', - ) - loss_pd_sum = loss_pd_sum.numpy() - paddle.enable_static() - loss_np = np.squeeze(loss_np, axis=-1) - loss_np_mean = (loss_np / labels_length.numpy()).mean() - loss_np_sum = loss_np.sum() - - np.testing.assert_allclose( - loss_pd_mean, loss_np_mean, rtol=1e-05, atol=1 - ) - np.testing.assert_allclose(loss_pd_sum, loss_np_sum, rtol=1e-05, atol=1) - def test_class_api(self): self.batch_size = 3 self.num_classes = 15 @@ -715,6 +645,81 @@ class TestCTCLossAPICase(unittest.TestCase): np.testing.assert_allclose(loss_pd, loss_np, rtol=1e-05, atol=1) + def test_eager_ctcloss(self): + def test_functinal_api(): + self.batch_size = 4 + self.num_classes = CUDA_BLOCK_SIZE + 2 + self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64) + self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64) + self.blank = self.num_classes - 1 + self.norm_by_times = False + + logits = np.random.uniform( + 0.1, + 1.0, + [max(self.logits_length), self.batch_size, self.num_classes], + ).astype("float32") + softmax = np.apply_along_axis(stable_softmax, -1, logits) + # labels should not be blank + labels = np.random.randint( + 0, + self.num_classes - 1, + [self.batch_size, max(self.labels_length)], + dtype="int32", + ) + + ctc = CTCForward( + softmax, + self.logits_length, + labels, + self.labels_length, + self.num_classes, + self.batch_size, + self.blank, + self.norm_by_times, + ) + loss_np = ctc.forward() + + paddle.disable_static() + softmax = paddle.to_tensor(logits) + labels = paddle.to_tensor(labels) + logits_length = paddle.to_tensor(self.logits_length) + labels_length = paddle.to_tensor(self.labels_length) + loss_pd_mean = F.ctc_loss( + softmax, + labels, + logits_length, + labels_length, + blank=self.blank, + reduction='mean', + ) + loss_pd_mean = loss_pd_mean.numpy() + + loss_pd_sum = F.ctc_loss( + softmax, + labels, + logits_length, + labels_length, + blank=self.blank, + reduction='sum', + ) + loss_pd_sum = loss_pd_sum.numpy() + paddle.enable_static() + loss_np = np.squeeze(loss_np, axis=-1) + loss_np_mean = (loss_np / labels_length.numpy()).mean() + loss_np_sum = loss_np.sum() + + np.testing.assert_allclose( + loss_pd_mean, loss_np_mean, rtol=1e-05, atol=1 + ) + np.testing.assert_allclose( + loss_pd_sum, loss_np_sum, rtol=1e-05, atol=1 + ) + + with _test_eager_guard(): + test_functinal_api() + test_functinal_api() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_warpctc_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_warpctc_op_xpu.py index 3dcefb0e1e..8807b7be35 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_warpctc_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_warpctc_op_xpu.py @@ -212,19 +212,6 @@ class CTCForward(object): return self.loss -def python_api( - logits, - label, - logits_length=None, - labels_length=None, - blank=0, - norm_by_times=False, -): - return paddle.fluid.layers.warpctc( - logits, label, blank, norm_by_times, logits_length, labels_length - ) - - class XPUTestWarpCTCOp(XPUOpTestWrapper): def __init__(self): self.op_name = 'warpctc' @@ -244,7 +231,6 @@ class XPUTestWarpCTCOp(XPUOpTestWrapper): self.op_type = "warpctc" self.dtype = self.in_type self.place = paddle.XPUPlace(0) - self.python_api = python_api self.python_out_sig = ["Loss"] self.config() @@ -325,7 +311,7 @@ class XPUTestWarpCTCOp(XPUOpTestWrapper): } def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): self.outputs['WarpCTCGrad'] = self.gradient @@ -367,44 +353,48 @@ class XPUTestWarpCTCOp(XPUOpTestWrapper): def test_logits_Variable(): logits_data = np.random.rand(5, 16, 6).astype(logits.dtype) - fluid.layers.warpctc( - input=logits_data, - label=label, - input_length=logits_length, - label_length=label_length, + paddle.nn.functional.ctc_loss( + log_probs=logits_data, + labels=label, + input_lengths=logits_length, + label_lengths=label_length, + reduction='none', ) self.assertRaises(TypeError, test_logits_Variable) def test_label_Variable(): label_data = np.random.randint(0, 5, [5, 1]).astype("int32") - fluid.layers.warpctc( - input=logits, - label=label_data, - input_length=logits_length, - label_length=label_length, + paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=label_data, + input_lengths=logits_length, + label_lengths=label_length, + reduction='none', ) self.assertRaises(TypeError, test_label_Variable) def test_logits_len_Variable(): logits_length_data = np.array([5] * 16).astype("int64") - fluid.layers.warpctc( - input=logits, - label=label, - input_length=logits_length_data, - label_length=label_length, + paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=label, + input_lengths=logits_length_data, + label_lengths=label_length, + reduction='none', ) self.assertRaises(TypeError, test_logits_len_Variable) def test_label_len_Variable(): label_length_data = np.array([3] * 16).astype("int64") - fluid.layers.warpctc( - input=logits, - label=label, - input_length=logits_length, - label_length=label_length_data, + paddle.nn.functional.ctc_loss( + log_probs=logits, + labels=label, + input_lengths=logits_length, + label_lengths=label_length_data, + reduction='none', ) self.assertRaises(TypeError, test_label_len_Variable) @@ -423,7 +413,9 @@ class XPUTestWarpCTCOp(XPUOpTestWrapper): softmax = paddle.to_tensor(logits) labels = paddle.to_tensor(labels) - fluid.layers.warpctc(input=softmax, label=labels) + paddle.nn.functional.ctc_loss( + log_probs=softmax, labels=labels, reduction='none' + ) paddle.disable_static() self.assertRaises(ValueError, test_dygraph_with_lod) diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index d5ff9454a8..97dd8353be 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -36,7 +36,7 @@ from . import nn # noqa: F401 from . import asp # noqa: F401 from . import multiprocessing # noqa: F401 -from ..fluid.layers.loss import identity_loss +from .nn.loss import identity_loss from ..fluid.incubate import fleet from . import xpu diff --git a/python/paddle/incubate/nn/loss.py b/python/paddle/incubate/nn/loss.py new file mode 100644 index 0000000000..7175834084 --- /dev/null +++ b/python/paddle/incubate/nn/loss.py @@ -0,0 +1,75 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.layer_helper import LayerHelper +from paddle.framework import ( + _non_static_mode, +) +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle import _legacy_C_ops + + +def identity_loss(x, reduction="none"): + r"""Marks a tensor as being part of the loss calculation for IPU. + + This operator is used to handle on the (final) loss of a model so that + it is used as the start of backpropagation. + + When `reduction` is `none`, return raw `Out`. + + When `reduction` is `mean`, return + + .. math:: + Out = MEAN(Out) + + When `reduction` is `sum`, return + + .. math:: + Out = SUM(Out) + + Parameters: + x (Variable): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of + additional dimensions. It's data type should be float32, float64 on CPU and float16, float32 on IPU. + reduction(str|int, optional): Reduce the loss output. Supported string values are: 'sum', 'mean', 'none' + the corresponding int values are 0, 1, 2 respectively. The default value is "none". + + Returns: + Variable: The loss ``Tensor`` with the specified reduction applied. + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + loss = paddle.static.data(name="loss", shape=[-1, 1], dtype="float32") + out = paddle.incubate.identity_loss(loss, reduction=1) + """ + if isinstance(reduction, str): + reduction = {"sum": 0, "mean": 1, "none": 2}.get(reduction.lower()) + if reduction is None: + raise Exception("Unsupported reduction type.") + + if _non_static_mode(): + return _legacy_C_ops.identity_loss(x, "reduction", reduction) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], "identity_loss") + attrs = {'reduction': reduction} + helper = LayerHelper('identity_loss', **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="identity_loss", inputs={"X": x}, outputs={"Out": out}, attrs=attrs + ) + return out diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e93e52b31a..8516765109 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -37,6 +37,8 @@ from ...fluid.framework import ( __all__ = [] +kIgnoreIndex = -100 + def dice_loss(input, label, epsilon=0.00001, name=None): r""" @@ -1818,7 +1820,70 @@ def ctc_loss( """ - loss_out = fluid.layers.warpctc( + def warpctc( + input, + label, + blank=0, + norm_by_times=False, + input_length=None, + label_length=None, + ): + if in_dygraph_mode(): + if input_length is None or label_length is None: + raise ValueError( + "input_length and label_length must not be None in dygraph mode!" + ) + loss_out = _C_ops.warpctc( + input, label, input_length, label_length, blank, norm_by_times + ) + return loss_out + if _non_static_mode(): + if input_length is None or label_length is None: + raise ValueError( + "input_length and label_length must not be None in dygraph mode!" + ) + grad, loss_out = _legacy_C_ops.warpctc( + input, + label, + input_length, + label_length, + 'blank', + blank, + 'norm_by_times', + norm_by_times, + ) + return loss_out + helper = LayerHelper('warpctc', **locals()) + check_variable_and_dtype( + input, 'input', ['float32', 'float64'], "warpctc" + ) + check_variable_and_dtype(label, 'label', ['int32'], "warpctc") + this_inputs = {'Logits': [input], 'Label': [label]} + if input_length is not None and label_length is not None: + check_variable_and_dtype( + input_length, 'LogitsLength', ['int64'], "warpctc" + ) + check_variable_and_dtype( + label_length, 'LabelLength', ['int64'], "warpctc" + ) + this_inputs['LogitsLength'] = [input_length] + this_inputs['LabelLength'] = [label_length] + + loss_out = helper.create_variable_for_type_inference(dtype=input.dtype) + grad_out = helper.create_variable_for_type_inference(dtype=input.dtype) + + helper.append_op( + type='warpctc', + inputs=this_inputs, + outputs={'WarpCTCGrad': [grad_out], 'Loss': [loss_out]}, + attrs={ + 'blank': blank, + 'norm_by_times': norm_by_times, + }, + ) + return loss_out + + loss_out = warpctc( log_probs, labels, blank, norm_by_times, input_lengths, label_lengths ) diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index ef966ecd98..6f27289efc 100755 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -30,7 +30,7 @@ from ...fluid.layers import group_norm # noqa: F401 from ...fluid.layers import instance_norm # noqa: F401 from ...fluid.layers import layer_norm # noqa: F401 from ...fluid.layers import multi_box_head # noqa: F401 -from ...fluid.layers import nce # noqa: F401 +from .loss import nce # noqa: F401 from ...fluid.layers import prelu # noqa: F401 from ...fluid.layers import py_func # noqa: F401 from ...fluid.layers import row_conv # noqa: F401 diff --git a/python/paddle/static/nn/loss.py b/python/paddle/static/nn/loss.py new file mode 100644 index 0000000000..1cba5dfe67 --- /dev/null +++ b/python/paddle/static/nn/loss.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -* +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.layers.layer_function_generator import templatedoc + +# TODO: define loss functions of neural network +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import NumpyArrayInitializer +import numpy as np +from paddle.fluid.framework import ( + static_only, +) + +__all__ = [] + + +# FIXME(wuyi): let docstring_checker.py understand @autodoc. +# For now, the comments in c++ use types like Tensor, but in python side +# the type is often "Variable", and arguments may vary. +@static_only +@templatedoc(op_type="nce") +def nce( + input, + label, + num_total_classes, + sample_weight=None, + param_attr=None, + bias_attr=None, + num_neg_samples=None, + name=None, + sampler="uniform", + custom_dist=None, + seed=0, + is_sparse=False, +): + """ + :api_attr: Static Graph + + ${comment} + + Args: + input (Tensor): Input tensor, 2-D tensor with shape [batch_size, dim], + and data type is float32 or float64. + label (Tensor): Input label, 2-D tensor with shape [batch_size, num_true_class], + and data type is int64. + num_total_classes (int):${num_total_classes_comment}. + sample_weight (Tensor|None): A Tensor of shape [batch_size, 1] + storing a weight for each sample. The default weight for each + sample is 1.0. + param_attr (ParamAttr|None): To specify the weight parameter attribute. + Default: None, which means the default weight parameter property is + used. See usage for details in :ref:`api_fluid_ParamAttr` . + bias_attr (ParamAttr|None): To specify the bias parameter attribute. + Default: None, which means the default bias parameter property is + used. See usage for details in :ref:`api_fluid_ParamAttr` . + num_neg_samples (int): ${num_neg_samples_comment}. + name(str|None): For detailed information, please refer to + :ref:`api_guide_Name` . Usually name is no need to set and None by default. + sampler (str, optional): The sampler used to sample class from negative classes. + It can be 'uniform', 'log_uniform' or 'custom_dist'. + default: 'uniform'. + custom_dist (nd.array|None): A numpy ndarray with size=num_total_classes. + It is used when sampler is set to 'custom_dist'. + custom_dist[i] is the probability of i-th class to be sampled. + default: None. + seed (int, optional): The seed used in sampler. Default 0, means no random seed. + is_sparse(bool, optional): The flag indicating whether to use sparse update, + the weight@GRAD and bias@GRAD will be changed to SelectedRows. Default False. + + Returns: + Tensor: The output nce loss. + + Examples: + .. code-block:: python + + + import paddle + import numpy as np + + paddle.enable_static() + + window_size = 5 + words = [] + for i in range(window_size): + words.append(paddle.static.data( + name='word_{0}'.format(i), shape=[-1, 1], dtype='int64')) + + dict_size = 10000 + label_word = int(window_size / 2) + 1 + + embs = [] + for i in range(window_size): + if i == label_word: + continue + + emb = paddle.static.nn.embedding(input=words[i], size=[dict_size, 32], + param_attr='embed', is_sparse=True) + embs.append(emb) + + embs = paddle.concat(x=embs, axis=1) + loss = paddle.static.nn.nce(input=embs, label=words[label_word], + num_total_classes=dict_size, param_attr='nce.w_0', + bias_attr='nce.b_0') + + #or use custom distribution + dist = np.array([0.05,0.5,0.1,0.3,0.05]) + loss = paddle.static.nn.nce(input=embs, label=words[label_word], + num_total_classes=5, param_attr='nce.w_1', + bias_attr='nce.b_1', + num_neg_samples=3, + sampler="custom_dist", + custom_dist=dist) + """ + helper = LayerHelper('nce', **locals()) + check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nce') + check_variable_and_dtype(label, 'label', ['int64'], 'nce') + + dim = input.shape[1] + num_true_class = label.shape[1] + w = helper.create_parameter( + attr=helper.param_attr, + shape=[num_total_classes, dim], + is_bias=False, + dtype=input.dtype, + ) + inputs = {} + if helper.bias_attr: + b = helper.create_parameter( + attr=helper.bias_attr, + shape=[num_total_classes, 1], + is_bias=True, + dtype=input.dtype, + ) + inputs['Bias'] = b + cost = helper.create_variable_for_type_inference(dtype=input.dtype) + sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype) + sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype) + + inputs['Input'] = input + inputs['Label'] = label + inputs['Weight'] = w + inputs['SampleWeight'] = sample_weight if sample_weight is not None else [] + + if sampler == "uniform": + sampler = 0 + elif sampler == "log_uniform": + sampler = 1 + elif sampler == "custom_dist": + assert custom_dist is not None + + custom_dist_len = num_total_classes + alias_probs_ = [0] * custom_dist_len + alias_ = [0] * custom_dist_len + bigs = [] + littles = [] + for i in range(custom_dist_len): + normal_prob = custom_dist[i] * custom_dist_len + if normal_prob - 1.0 > 0: + bigs.append((i, normal_prob)) + elif 1.0 - normal_prob > 0: + littles.append((i, normal_prob)) + else: + alias_probs_[i] = normal_prob + alias_[i] = -1 + + while len(bigs) and len(littles): + big = bigs.pop(0) + little = littles.pop(0) + + big_idx = big[0] + big_prob = big[1] + + alias_probs_[little[0]] = little[1] + alias_[little[0]] = big_idx + big_left = big[1] + little[1] - 1 + if big_left - 1.0 > 0: + bigs.append((big_idx, big_left)) + elif 1.0 - big_left > 0: + littles.append((big_idx, big_left)) + else: + alias_probs_[big_idx] = big_left + alias_[big_idx] = -1 + + if len(bigs): + big = bigs.pop(0) + alias_probs_[big[0]] = 1.0 + alias_[big[0]] = -1 + if len(littles): + little = littles.pop(0) + alias_probs_[little[0]] = 1.0 + alias_[little[0]] = -1 + + def _init_by_numpy_array(numpy_array): + ret = helper.create_parameter( + attr=ParamAttr(), + shape=numpy_array.shape, + dtype=numpy_array.dtype, + default_initializer=NumpyArrayInitializer(numpy_array), + ) + ret.stop_gradient = True + return ret + + inputs['CustomDistProbs'] = _init_by_numpy_array( + np.array(custom_dist).astype('float32') + ) + inputs['CustomDistAlias'] = _init_by_numpy_array( + np.array(alias_).astype('int32') + ) + inputs['CustomDistAliasProbs'] = _init_by_numpy_array( + np.array(alias_probs_).astype('float32') + ) + sampler = 2 + else: + raise Exception("Unsupported sampler type.") + + if num_neg_samples is None: + num_neg_samples = 10 + else: + num_neg_samples = int(num_neg_samples) + + remote_prefetch = is_sparse + print( + "With sparse mode, if your models has only small parameter prefetch may cause speed down" + ) + + attrs = { + 'num_total_classes': int(num_total_classes), + 'num_neg_samples': num_neg_samples, + 'seed': seed, + 'sampler': sampler, + 'is_sparse': is_sparse, + 'remote_prefetch': remote_prefetch, + } + + helper.append_op( + type='nce', + inputs=inputs, + outputs={ + 'Cost': cost, + 'SampleLogits': sample_logits, + 'SampleLabels': sample_labels, + }, + attrs=attrs, + ) + return cost / (num_neg_samples + 1) -- GitLab