diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 9bc603c0ecc2c9da9eaf34cf0791fe2767d52a9a..ee6e541c9e6c6eaa4a57c2d687466354975371e7 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -49,6 +49,8 @@ std::map> op_ins_map = { {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, {"distribute_fpn_proposals", {"FpnRois", "RoisNum"}}, {"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}}, + {"hierarchical_sigmoid", + {"X", "W", "Label", "PathTable", "PathCode", "Bias"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/tests/unittests/test_directory_migration.py b/python/paddle/fluid/tests/unittests/test_directory_migration.py index 7d48f2c419085ff7ea66076bfc561aa7655d41c6..fd014f3b4ecaf481c52738d3926718a67c329adc 100644 --- a/python/paddle/fluid/tests/unittests/test_directory_migration.py +++ b/python/paddle/fluid/tests/unittests/test_directory_migration.py @@ -64,11 +64,10 @@ class TestDirectory(unittest.TestCase): 'paddle.static.nn.create_parameter', 'paddle.static.nn.crf_decoding', 'paddle.static.nn.data_norm', 'paddle.static.nn.deformable_conv', 'paddle.static.nn.group_norm', - 'paddle.static.nn.hsigmoid', 'paddle.static.nn.instance_norm', - 'paddle.static.nn.layer_norm', 'paddle.static.nn.multi_box_head', - 'paddle.static.nn.nce', 'paddle.static.nn.prelu', - 'paddle.static.nn.row_conv', 'paddle.static.nn.spectral_norm', - 'paddle.static.nn.embedding' + 'paddle.static.nn.instance_norm', 'paddle.static.nn.layer_norm', + 'paddle.static.nn.multi_box_head', 'paddle.static.nn.nce', + 'paddle.static.nn.prelu', 'paddle.static.nn.row_conv', + 'paddle.static.nn.spectral_norm', 'paddle.static.nn.embedding' ] import_file = 'run_import_modules.py' diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid.py b/python/paddle/fluid/tests/unittests/test_hsigmoid.py deleted file mode 100644 index 80937640c2d2fde6d2ba30b7bccc396566801e63..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle import fluid, nn -import paddle.fluid.dygraph as dg -import paddle.nn.functional as F -import paddle.fluid.initializer as I -import numpy as np -import unittest - - -class HSigmoidTestCase(unittest.TestCase): - def __init__(self, - methodName="runTest", - batch_size=4, - feature_size=6, - num_classes=8, - labels=None, - path_code=None, - path_table=None, - is_sparse=False, - dtype="float32"): - super(HSigmoidTestCase, self).__init__() - self.batch_size = batch_size - self.feature_size = feature_size - self.num_classes = num_classes - self.dtype = dtype - self.is_sparse = is_sparse - - self.labels = labels - self.path_code = path_code - self.path_table = path_table - self.is_custom = path_code is not None and path_table is not None - - def setUp(self): - input_shape = (self.batch_size, self.feature_size) - self.input = np.random.uniform( - -1, 1, size=input_shape).astype(self.dtype) - if self.labels is None: - self.labels = np.random.randint( - 0, self.num_classes, size=(self.batch_size, 1)).astype(np.int64) - C = self.num_classes if self.is_custom else self.num_classes - 1 - self.weight_shape = (C, self.feature_size) - self.weight = np.random.randn(*self.weight_shape).astype(self.dtype) - self.bias_shape = (C, 1) - self.bias = np.random.randn(*self.bias_shape).astype(self.dtype) - - def fluid_layer(self, place): - main = fluid.Program() - start = fluid.Program() - with fluid.unique_name.guard(): - with fluid.program_guard(main, start): - x = fluid.data( - "input", [-1, self.feature_size], dtype=self.dtype) - label = fluid.data("labels", [-1, 1], dtype="int64") - if self.is_custom: - path_table = fluid.data( - "path_table", [-1, -1], dtype="int64") - path_code = fluid.data("path_code", [-1, -1], dtype="int64") - else: - path_table = path_code = None - y = fluid.layers.hsigmoid( - x, - label, - self.num_classes, - param_attr=I.NumpyArrayInitializer(self.weight), - bias_attr=I.NumpyArrayInitializer(self.bias), - path_table=path_table, - path_code=path_code, - is_custom=self.is_custom, - is_sparse=self.is_sparse, ) - exe = fluid.Executor(place) - exe.run(start) - feed_dict = {"input": self.input, "labels": self.labels} - if self.is_custom: - feed_dict["path_code"] = self.path_code - feed_dict["path_table"] = self.path_table - y_np, = exe.run(main, feed=feed_dict, fetch_list=[y]) - return y_np - - def functional(self, place): - main = fluid.Program() - start = fluid.Program() - with fluid.unique_name.guard(): - with fluid.program_guard(main, start): - x = fluid.data( - "input", [-1, self.feature_size], dtype=self.dtype) - label = fluid.data("labels", [-1, 1], dtype="int64") - if self.is_custom: - path_table = fluid.data( - "path_table", [-1, -1], dtype="int64") - path_code = fluid.data("path_code", [-1, -1], dtype="int64") - else: - path_table = path_code = None - w = fluid.data("weight", self.weight_shape, dtype=self.dtype) - b = fluid.data("bias", self.bias_shape, dtype=self.dtype) - y = F.hsigmoid( - x, - label, - w, - b, - self.num_classes, - is_sparse=self.is_sparse, - path_table=path_table, - path_code=path_code) - - exe = fluid.Executor(place) - exe.run(start) - feed_dict = { - "input": self.input, - "labels": self.labels, - "weight": self.weight, - "bias": self.bias - } - if self.is_custom: - feed_dict["path_code"] = self.path_code - feed_dict["path_table"] = self.path_table - y_np, = exe.run(main, feed=feed_dict, fetch_list=[y]) - return y_np - - def nn_layer(self, place): - with dg.guard(place): - x_var = dg.to_variable(self.input) - label_var = dg.to_variable(self.labels) - if self.is_custom: - path_code_var = dg.to_variable(self.path_code) - path_table_var = dg.to_variable(self.path_table) - else: - path_code_var = path_table_var = None - hierarchical_softmax = nn.HSigmoid( - self.feature_size, - self.num_classes, - is_custom=self.is_custom, - is_sparse=self.is_sparse, - param_attr=I.NumpyArrayInitializer(self.weight), - bias_attr=I.NumpyArrayInitializer(self.bias), - dtype=self.dtype) - y_var = hierarchical_softmax( - x_var, - label_var, - path_table=path_table_var, - path_code=path_code_var) - y_np = y_var.numpy() - return y_np - - def _test_equivalence(self, place): - result1 = self.fluid_layer(place) - result2 = self.functional(place) - result3 = self.nn_layer(place) - np.testing.assert_array_almost_equal(result1, result2) - np.testing.assert_array_almost_equal(result2, result3) - - def runTest(self): - place = fluid.CPUPlace() - self._test_equivalence(place) - - -class HSigmoidTestErrorCase(HSigmoidTestCase): - def runTest(self): - place = fluid.CPUPlace() - with dg.guard(place): - with self.assertRaises(ValueError): - self.nn_layer() - - def nn_layer(self): - x_var = dg.to_variable(self.input) - label_var = dg.to_variable(self.labels) - if self.is_custom: - path_code_var = dg.to_variable(self.path_code) - path_table_var = dg.to_variable(self.path_table) - else: - path_code_var = path_table_var = None - hierarchical_softmax = nn.HSigmoid( - self.feature_size, - self.num_classes, - is_custom=self.is_custom, - param_attr=I.NumpyArrayInitializer(self.weight), - bias_attr=I.NumpyArrayInitializer(self.bias), - dtype=self.dtype) - y_var = hierarchical_softmax( - x_var, - label_var, - path_table=path_table_var, - path_code=path_code_var) - y_np = y_var.numpy() - return y_np - - -def load_tests(loader, standard_tests, pattern): - suite = unittest.TestSuite() - suite.addTest(HSigmoidTestCase(methodName="runTest")) - suite.addTest( - HSigmoidTestCase( - methodName="runTest", - batch_size=4, - feature_size=6, - num_classes=8, - labels=np.array([0, 1, 4, 5]).astype(np.int64), - path_table=np.array([(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), ( - 0, 1, 4, -1, -1), (0, 2, -1, -1, -1)]).astype(np.int64), - path_code=np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( - 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]).astype(np.int64))) - suite.addTest(HSigmoidTestErrorCase(methodName="runTest", num_classes=1)) - return suite - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 5c9867e681524f519e267fb744fc4090c836036a..3f8eed08adf68edc8d3db24f6a6593a8a3a8b0fc 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -19,10 +19,13 @@ import numpy as np import paddle import paddle.fluid.core as core import paddle.fluid as fluid +import paddle.nn.functional as F from paddle.fluid import Program, program_guard +import paddle.fluid.initializer as I import math from op_test import OpTest, skip_check_grad_ci +paddle.enable_static() np.random.seed(100) @@ -56,7 +59,6 @@ class CodeTableWithCustomTree(object): def get_length(self): length = 0 for ele in self.ptable_[self.index_]: # find the first -1 to stop trace - if ele >= 0: length = length + 1 else: @@ -388,8 +390,192 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label')) -class TestHSigmoidOpError(unittest.TestCase): +class TestHSigmoidLossAPI(unittest.TestCase): + # test paddle.nn.functional.hsigmoid_loss, paddle.nn.HSigmoidLoss + def setUp(self): + self.dtype = 'float32' + self.batch_size = 4 + self.feature_size = 6 + self.num_classes = 8 + self.is_custom = False + self.place = paddle.CPUPlace() + + paddle.set_default_dtype(self.dtype) + + self.x_np = np.random.uniform( + -1, 1, [self.batch_size, self.feature_size]).astype(self.dtype) + self.labels_np = np.random.randint( + self.num_classes, size=(self.batch_size, 1), dtype='int64') + self.weight_np = np.random.uniform( + -1, 1, [self.num_classes - 1, self.feature_size]).astype(self.dtype) + self.bias_np = np.random.uniform(-1, 1, ( + self.num_classes - 1, )).astype(self.dtype) + self.path_table_np = None + self.path_code_np = None + _, self.out_np = hsigmoid(self.x_np, self.weight_np, self.labels_np, + self.bias_np, self.num_classes) + self.set_attrs() + + if self.is_custom: + _, self.out_np = hsigmoidWithCustomTree( + self.x_np, self.weight_np, self.path_table_np, + self.path_code_np, self.labels_np, + self.bias_np.reshape(-1, 1), self.num_classes) + + def set_attrs(self): + pass + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + labels = paddle.to_tensor(self.labels_np) + weight = paddle.to_tensor(self.weight_np) + bias = paddle.to_tensor(self.bias_np) + path_table = None + path_code = None + if self.is_custom: + path_table = paddle.to_tensor(self.path_table_np) + path_code = paddle.to_tensor(self.path_code_np) + out1 = F.hsigmoid_loss(x, labels, self.num_classes, weight, bias, + path_table, path_code) + + weight_attr = I.NumpyArrayInitializer(self.weight_np) + bias_attr = I.NumpyArrayInitializer(self.bias_np) + m = paddle.nn.HSigmoidLoss(self.feature_size, self.num_classes, + weight_attr, bias_attr, self.is_custom) + out2 = m(x, labels, path_table, path_code) + + for out in [out1, out2]: + self.assertTrue(np.allclose(self.out_np, out.numpy())) + paddle.enable_static() + + def test_static_api(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): + x = paddle.static.data('x', [-1, self.feature_size]) + labels = paddle.static.data('labels', [-1, 1], 'int64') + weight = paddle.static.data('weight', [-1, self.feature_size]) + bias = paddle.static.data('bias', [-1, ]) + path_table = None + path_code = None + if self.is_custom: + path_table = paddle.static.data('path_table', [-1, -1], 'int64') + path_code = paddle.static.data('path_code', [-1, -1], 'int64') + out1 = F.hsigmoid_loss(x, labels, self.num_classes, weight, bias, + path_table, path_code) + + weight_attr = paddle.framework.ParamAttr( + initializer=I.NumpyArrayInitializer(self.weight_np)) + bias_attr = paddle.framework.ParamAttr( + initializer=I.NumpyArrayInitializer(self.bias_np)) + m = paddle.nn.HSigmoidLoss(self.feature_size, self.num_classes, + weight_attr, bias_attr, self.is_custom) + out2 = m(x, labels, path_table, path_code) + + exe = paddle.static.Executor(self.place) + exe.run(startup_program) + feed_dict = { + 'x': self.x_np, + 'labels': self.labels_np, + 'weight': self.weight_np, + 'bias': self.bias_np + } + if self.is_custom: + feed_dict["path_code"] = self.path_code_np + feed_dict["path_table"] = self.path_table_np + ret1, ret2 = exe.run(train_program, + feed=feed_dict, + fetch_list=[out1, out2]) + + for ret in [ret1, ret2]: + self.assertTrue(np.allclose(self.out_np, ret)) + + def test_fluid_api(self): + train_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + x = fluid.data('x', [-1, self.feature_size]) + labels = fluid.data('labels', [-1, 1], 'int64') + path_table = None + path_code = None + if self.is_custom: + path_table = fluid.data('path_table', [-1, -1], 'int64') + path_code = fluid.data('path_code', [-1, -1], 'int64') + weight_attr = I.NumpyArrayInitializer(self.weight_np) + bias_attr = I.NumpyArrayInitializer(self.bias_np) + out = fluid.layers.hsigmoid(x, labels, self.num_classes, + weight_attr, bias_attr, 'out', + path_table, path_code, self.is_custom) + + exe = fluid.Executor(self.place) + exe.run(startup_program) + feed_dict = {'x': self.x_np, 'labels': self.labels_np} + if self.is_custom: + feed_dict["path_code"] = self.path_code_np + feed_dict["path_table"] = self.path_table_np + ret, = exe.run(train_program, feed=feed_dict, fetch_list=[out]) + + self.assertTrue(np.allclose(ret, self.out_np)) + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + # test paddle.nn.HSigmoidLoss + self.assertRaises(ValueError, paddle.nn.HSigmoidLoss, 6, 1) + + # test paddle.nn.functional.hsigmoid_loss + x = paddle.static.data('x', [4, 6]) + label = paddle.static.data('label', [4, 1], 'int64') + weight = paddle.static.data('weight', [7, 6]) + bias = paddle.static.data('bias', [7]) + + x_int32 = paddle.static.data('x_int32', [4, 6], 'int32') + self.assertRaises(TypeError, F.hsigmoid_loss, x_int32, label, 8, + weight) + + label_float32 = paddle.static.data('label_float32', [4, 1], + 'float32') + self.assertRaises(TypeError, F.hsigmoid_loss, x, label_float32, 8, + weight) + + weight_int32 = paddle.static.data('weight_int32', [7, 6], 'int32') + self.assertRaises(TypeError, F.hsigmoid_loss, x, label, 8, + weight_int32) + + bias_int32 = paddle.static.data('bias_int32', [7], 'int32') + self.assertRaises( + TypeError, + F.hsigmoid_loss, + x, + label, + 8, + weight, + bias=bias_int32) + + path_table_int32 = paddle.static.data('path_table_int32', [7], + 'int32') + self.assertRaises( + TypeError, + F.hsigmoid_loss, + x, + label, + 8, + weight, + path_table=path_table_int32) + + path_code_int32 = paddle.static.data('path_code_int32', [7], + 'int32') + self.assertRaises( + TypeError, + F.hsigmoid_loss, + x, + label, + 8, + weight, + path_code=path_code_int32) + + # test paddle.fluid.layers.hsigmoid with program_guard(Program()): label = fluid.data('label', [4, 1], 'int64') # The input type must be Variable. @@ -410,5 +596,17 @@ class TestHSigmoidOpError(unittest.TestCase): label_int32, 2) +class TestHSigmoidLossAPICustom(TestHSigmoidLossAPI): + def set_attrs(self): + self.is_custom = True + self.path_table_np = np.array([(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), ( + 0, 1, 4, -1, -1), (0, 2, -1, -1, -1)]).astype(np.int64) + self.path_code_np = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]).astype(np.int64) + + def test_errors(self): + pass + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index b16e95b7130f9fea68ce275339fc2940c698d6ea..1dddef0cace1d6ac4585d7d8c31c8f47c7a003ea 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -73,7 +73,6 @@ from .layer.activation import Swish #DEFINE_ALIAS from .layer.activation import Tanhshrink #DEFINE_ALIAS from .layer.activation import ThresholdedReLU #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS -from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.activation import Maxout #DEFINE_ALIAS from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS @@ -133,6 +132,7 @@ from .layer.common import Linear # from .layer.loss import NCELoss #DEFINE_ALIAS from .layer.loss import BCEWithLogitsLoss #DEFINE_ALIAS from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS +from .layer.loss import HSigmoidLoss #DEFINE_ALIAS from .layer.loss import MSELoss #DEFINE_ALIAS from .layer.loss import L1Loss #DEFINE_ALIAS from .layer.loss import NLLLoss #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index d2e1832c6b63776f8e7ba5e6b33042b74b8d500a..30eefb2c3912b58a3470ae1502dfd67bdbcf4e6f 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -36,7 +36,6 @@ from .activation import hardshrink #DEFINE_ALIAS from .activation import hardtanh #DEFINE_ALIAS from .activation import hardsigmoid #DEFINE_ALIAS from .activation import hardswish #DEFINE_ALIAS -from .activation import hsigmoid #DEFINE_ALIAS from .activation import leaky_relu #DEFINE_ALIAS from .activation import log_sigmoid #DEFINE_ALIAS from .activation import maxout #DEFINE_ALIAS @@ -140,6 +139,7 @@ from .loss import center_loss #DEFINE_ALIAS from .loss import cross_entropy #DEFINE_ALIAS from .loss import dice_loss #DEFINE_ALIAS from .loss import edit_distance #DEFINE_ALIAS +from .loss import hsigmoid_loss #DEFINE_ALIAS from .loss import iou_similarity #DEFINE_ALIAS from .loss import kl_div #DEFINE_ALIAS from .loss import l1_loss #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 6e09e25b1ab0526395cc942f3fdc34e3602dfa40..33ecd29162c12c4ed64061a558641869e1f85312 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -26,7 +26,6 @@ __all__ = [ 'hardtanh', 'hardsigmoid', 'hardswish', - 'hsigmoid', 'leaky_relu', 'log_sigmoid', 'maxout', @@ -361,128 +360,6 @@ def hardswish(x, name=None): return out -def hsigmoid(input, - label, - weight, - bias, - num_classes, - path_table=None, - path_code=None, - is_sparse=False): - """ - :alias_main: paddle.nn.functional.hsigmoid - :alias: paddle.nn.functional.hsigmoid,paddle.nn.functional.activation.hsigmoid - - The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity - and speed up the model training, especially the training of language model. - Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier. - For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on - the path, and sum them to get a total cost. - Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N` - represents the number of classes or the size of word dict. - - The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural - Network Language Model `_. For the custom - tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example): - - 1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict. - 2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table. - 3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code. - Code means the label of each binary classifier, 1 indicate true, 0 indicate false. - 4. Now, each word should has its path and code along the path, you can pass a batch of path and code related - to the same batch of inputs. - - Parameters: - input (Variable): A tensor with the shape [N, D], where N is the size of mini-batch, - and D is the feature size. Its data type supports float32 and float64. - label (Variable): A tensor contains the labels of training data. Its shape is [N, 1] - and data type is int64. - weight (Variable): A tensor with shape (num_classes - 1, D) if not using custom tree(path_code and path_table is None), or (num_classes, D) if using custom tree. - bias (Variable): A tensor with shape (num_classes - 1, 1) if not using custom tree(path_code and path_table is None), or (num_classes, 1) if using custom tree. - num_classes (int): The number of classes or the size of word dict, must be greater than 2. - If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes` - should not be None. If the custom tree is used (:attr:`is_custom` is set to True), - :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of - classes using by the binary classifier. - path_table (Variable, optional): A tensor that stores each batch of samples' path from leaf to root - node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i, - path_table[i] is a np.array like structure and each element in this array is the indexes in parent - nodes' weight matrix. Default: None. - path_code (Variable, optional): A tensor that stores each batch of samples' code of path from leaf - to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`. - Each code of path is consisted with the code of nodes from leaf to root node. Default: None. - is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the - gradient of W and input will be sparse. Default: False. - - Returns: - Variable: A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as :attr:`input`. - - Examples: - .. code-block:: python - - from paddle import fluid, nn - import paddle.fluid.dygraph as dg - import paddle.nn.functional as F - import numpy as np - - main = fluid.Program() - start = fluid.Program() - feature_size = 6 - num_classes = 8 - with fluid.unique_name.guard(): - with fluid.program_guard(main, start): - x = fluid.data("input", [-1, feature_size], - dtype="float32") - label = fluid.data("labels", [-1, 1], dtype="int64") - w = fluid.data("weight", (num_classes -1, feature_size), dtype="float32") - b = fluid.data("bias", (num_classes -1, ), dtype="float32") - y = F.hsigmoid(x, label, w, b, num_classes) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(start) - feed_dict = { - "input": np.random.randn(4, feature_size).astype(np.float32), - "labels": np.random.randint(0, num_classes, (4, 1)).astype(np.int64), - "weight": np.random.randn(num_classes - 1, feature_size).astype(np.float32), - "bias": np.random.randn(num_classes - 1, ).astype(np.float32), - } - y_np, = exe.run(main, feed=feed_dict, fetch_list=[y]) - print(y_np.shape) - - # (4, 1) - """ - - attrs = { - "num_classes": num_classes, - "is_sparse": is_sparse, - "remote_prefetch": is_sparse - } - - inputs = { - "X": input, - "W": weight, - "Bias": bias, - "PathTable": path_table, - "PathCode": path_code, - "Label": label - } - - helper = LayerHelper('hierarchical_sigmoid', **locals()) - dtype = helper.input_dtype() - - out = helper.create_variable_for_type_inference(dtype) - pre_out = helper.create_variable_for_type_inference(dtype) - outputs = {"Out": out, "PreOut": pre_out, "W_Out": weight} - - helper.append_op( - type="hierarchical_sigmoid", - inputs=inputs, - outputs=outputs, - attrs=attrs) - return out - - def leaky_relu(x, negative_slope=0.01, name=None): """ leaky_relu activation diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c4b5606dddcf1aa4abbed430b0d83633bf5bccf5..d085213dffc234a67632670e5754b0aac7974f4d 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -54,6 +54,7 @@ __all__ = [ 'cross_entropy', 'dice_loss', 'edit_distance', + 'hsigmoid_loss', 'iou_similarity', 'kl_div', 'l1_loss', @@ -343,6 +344,138 @@ def binary_cross_entropy_with_logits(logit, return out +def hsigmoid_loss(input, + label, + num_classes, + weight, + bias=None, + path_table=None, + path_code=None, + is_sparse=False, + name=None): + """ + The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity + and speed up the model training, especially the training of language model. + Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier. + For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on + the path, and sum them to get a total cost. + Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N` + represents the number of classes or the size of word dict. + + The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural + Network Language Model `_. For the custom + tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example): + + 1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict. + 2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table. + 3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code. + Code means the label of each binary classifier, 1 indicate true, 0 indicate false. + 4. Now, each word should has its path and code along the path, you can pass a batch of path and code related + to the same batch of inputs. + + Parameters: + input (Tensor): A tensor with the shape [N, D], where N is the size of mini-batch, + and D is the feature size. Its data type supports float32 or float64. + label (Tensor): A tensor contains the labels of training data. Its shape is [N, 1] + and data type is int64. + num_classes (int): The number of classes or the size of word dict, must be greater than 2. + If the default tree is used (path_code and path_table is None are None), `num_classes` + should not be None. If the custom tree is used (path_code and path_table is None are not None), + `num_classes` should be the number of non-leaf nodes, which indicates the num of + classes using by the binary classifier. + weight (Tensor): A tensor with shape (num_classes - 1, D), with the same data type as `input`. + bias (Tensor, optional): A tensor with shape (num_classes - 1, 1), with the same data type as `input`. + If `bias` is None, no bias will be add. Default is None. + path_table (Tensor, optional): A tensor that stores each batch of samples' path from leaf to root + node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i, + path_table[i] is a np.array like structure and each element in this array is the indexes in parent + nodes' weight matrix. If `path_table` and `path_code` are None, the default tree will be used. + Default is None. + path_code (Tensor, optional): A tensor that stores each batch of samples' code of path from leaf + to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`. + Each code of path is consisted with the code of nodes from leaf to root node. If `path_table` and + `path_code` are None, the default tree will be used. Default is None. + is_sparse (bool, optional): Whether use sparse updating instead of dense updating. If `is_sparse` is True, + the gradient of `weight` and `input` will be sparse. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as `input`. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + paddle.set_device('cpu') + + input = paddle.uniform([2, 3]) + # [[-0.8018668 0.8736385 -0.9064771 ] # random + # [-0.10228515 -0.87188244 -0.8783718 ]] # random + label = paddle.to_tensor([0, 1, 4, 5]) + num_classes = 5 + weight=paddle.uniform([num_classes-1, 3]) + # [[-0.24148715 0.8449961 -0.7399121 ] # random + # [-0.9800559 0.43509364 0.9091208 ] # random + # [ 0.60194826 0.10430074 -0.4521166 ] # random + # [-0.4469818 -0.01536179 -0.604454 ]] # random + + out=F.hsigmoid_loss(input, label, num_classes, weight) + # [[3.0159328] + # [2.2407534]] + """ + + if in_dygraph_mode(): + out, _, _ = core.ops.hierarchical_sigmoid( + input, weight, label, path_table, path_code, bias, 'num_classes', + num_classes, 'is_sparse', is_sparse, 'remote_prefetch', is_sparse) + return out + + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'hsigmoid_loss') + check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid_loss') + check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], + 'hsigmoid_loss') + if bias is not None: + check_variable_and_dtype(bias, 'bias', ['float32', 'float64'], + 'hsigmoid_loss') + if path_table is not None: + check_variable_and_dtype(path_table, 'path_table', ['int64'], + 'hsigmoid_loss') + if path_code is not None: + check_variable_and_dtype(path_code, 'path_code', ['int64'], + 'hsigmoid_loss') + + attrs = { + "num_classes": num_classes, + "is_sparse": is_sparse, + "remote_prefetch": is_sparse + } + + inputs = { + "X": input, + "W": weight, + "Bias": bias, + "PathTable": path_table, + "PathCode": path_code, + "Label": label + } + + helper = LayerHelper('hsigmoid_loss', **locals()) + out = helper.create_variable_for_type_inference(input.dtype) + pre_out = helper.create_variable_for_type_inference(input.dtype) + outputs = {"Out": out, "PreOut": pre_out, "W_Out": weight} + + helper.append_op( + type="hierarchical_sigmoid", + inputs=inputs, + outputs=outputs, + attrs=attrs) + return out + + def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None): """ This operator calculates smooth_l1_loss. Creates a criterion that uses a squared diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 760af09f1f2f5af066058572f681ec21f9a93180..3a5bcaa21fe5b69b81d89c7eaf8a2815951a5d96 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -41,7 +41,6 @@ from .activation import LeakyReLU #DEFINE_ALIAS from .activation import Sigmoid #DEFINE_ALIAS # from .activation import Softmax #DEFINE_ALIAS from .activation import LogSoftmax #DEFINE_ALIAS -from .activation import HSigmoid #DEFINE_ALIAS from .common import BilinearTensorProduct #DEFINE_ALIAS from .common import Bilinear #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index cd17f26e09e37546dee753c591d02ef9327661cd..dbb9d00f365cfa6c80ec95d43d00e77ffe5874ee 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -38,7 +38,6 @@ __all__ = [ 'LogSigmoid', 'LogSoftmax', 'Maxout', - 'HSigmoid', ] from ...fluid.dygraph import layers @@ -319,142 +318,6 @@ class Hardtanh(layers.Layer): return F.hardtanh(x, self._min, self._max, self._name) -class HSigmoid(layers.Layer): - """ - :alias_main: paddle.nn.HSigmoid - :alias: paddle.nn.HSigmoid,paddle.nn.layer.HSigmoid,paddle.nn.layer.activation.HSigmoid - - Hierarchical Sigmoid Layer. - - The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity - and speed up the model training, especially the training of language model. - Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier. - For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on - the path, and sum them to get a total cost. - Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N` - represents the number of classes or the size of word dict. - - The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural - Network Language Model _`. For the custom - tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example): - - 1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict. - 2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table. - 3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code. - Code means the label of each binary classifier, 1 indicate true, 0 indicate false. - 4. Now, each word should has its path and code along the path, you can pass a batch of path and code related - to the same batch of inputs. - - Parameters: - feature_size (int): The feature size. - num_classes (int): The number of classes or the size of word dict, must be greater than 2. - If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes` - should not be None. If the custom tree is used (:attr:`is_custom` is set to True), - :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of - classes using by the binary classifier. - param_attr (ParamAttr, optional): The parameter attribute for the learnable parameters/weights - of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a - ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is - initialized with Xavier. Default: None. - bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it - is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr, - hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not - set, the bias is initialized zero. Default: None. - is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and - `path_code` should be passed to its forward method, otherwise `path_table` and `path_code` - should not be passed to its forward method. Default: False. - is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the - gradient of W and input will be sparse. Default: False. - - Returns: - None - - Examples: - .. code-block:: python - - from paddle import fluid, nn - import paddle.fluid.dygraph as dg - import paddle.nn.functional as F - import numpy as np - - main = fluid.Program() - start = fluid.Program() - feature_size = 6 - num_classes = 8 - with fluid.unique_name.guard(): - with fluid.program_guard(main, start): - x = fluid.data("input", [-1, feature_size], - dtype="float32") - label = fluid.data("labels", [-1, 1], dtype="int64") - hsm = nn.HSigmoid(feature_size, num_classes) - y = hsm(x, label) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(start) - feed_dict = { - "input": np.random.randn(4, feature_size).astype(np.float32), - "labels": np.random.randint(0, num_classes, (4, 1)).astype(np.int64), - } - y_np, = exe.run(main, feed=feed_dict, fetch_list=[y]) - print(y_np.shape) - - # (4, 1) - """ - - def __init__(self, - feature_size, - num_classes, - param_attr=None, - bias_attr=None, - is_custom=False, - is_sparse=False, - dtype="float32"): - super(HSigmoid, self).__init__() - if (num_classes < 2) and (not is_custom): - raise ValueError( - "num_classes must not be less than 2 with default tree") - - if (not is_custom) and (is_sparse): - print("Sparse mode should not be used without custom tree") - is_sparse = False - - self._feature_size = feature_size - self._num_classes = num_classes - self._is_custom = is_custom - self._is_sparse = is_sparse - - self._param_attr = param_attr - self._bias_attr = bias_attr - - self._dtype = dtype - - remote_prefetch = is_sparse - print("With sparse mode, if your models has only" - " small parameter prefetch may cause speed down") - - C = self._num_classes if is_custom else self._num_classes - 1 - self.weight = self.create_parameter( - [C, self._feature_size], - attr=self._param_attr, - is_bias=False, - dtype=self._dtype) - self.bias = self.create_parameter( - [C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype) - - def forward(self, input, label, path_table=None, path_code=None): - out = F.hsigmoid( - input, - label, - self.weight, - self.bias, - self._num_classes, - path_table=path_table, - path_code=path_code, - is_sparse=self._is_sparse) - return out - - class PReLU(layers.Layer): """ PReLU Activation. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 98048bb7e64cf6944460f666e93702351e69fd78..5ce4baca55749a35718b12bee0b875bf226160ba 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -23,6 +23,7 @@ from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator __all__ = [ 'BCEWithLogitsLoss', 'CrossEntropyLoss', + 'HSigmoidLoss', 'MSELoss', 'L1Loss', 'NLLLoss', @@ -251,6 +252,128 @@ class CrossEntropyLoss(fluid.dygraph.Layer): reduction=self.reduction) +class HSigmoidLoss(fluid.dygraph.Layer): + """ + Hierarchical Sigmoid Layer. + + The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity + and speed up the model training, especially the training of language model. + Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier. + For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on + the path, and sum them to get a total cost. + Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N` + represents the number of classes or the size of word dict. + + The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural + Network Language Model _`. For the custom + tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example): + + 1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict. + 2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table. + 3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code. + Code means the label of each binary classifier, 1 indicate true, 0 indicate false. + 4. Now, each word should has its path and code along the path, you can pass a batch of path and code related + to the same batch of inputs. + + Parameters: + feature_size (int): The number of features. + num_classes (int): The number of classes or the size of word dict, must be greater than 2. + If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes` + should not be None. If the custom tree is used (:attr:`is_custom` is set to True), + :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of + classes using by the binary classifier. + weight_attr (ParamAttr, optional): The parameter attribute for the learnable weights + of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a + ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is + initialized with Xavier. Default is None. + bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it + is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr, + hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not + set, the bias is initialized zero. Default is None. + is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and + `path_code` should be passed to its forward method, otherwise `path_table` and `path_code` + should not be passed to its forward method. Default is False. + is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, + the gradient of weight and input will be sparse. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + input (Tensor): The input tensor. The shapes is [N, D], where N is batch size and D is feature size. It's data type should be float32, float64. + label (Tensor): It's shapes is [N, 1]. It's data type should be int64. + output (Tensor): The HSigmoid Loss of ``input`` and ``label``. Shape is [N, 1] + + Examples: + .. code-block:: python + + import paddle + paddle.set_device('cpu') + + input = paddle.uniform([2, 3]) + # [[-0.2820413 0.9528898 -0.81638825] # random + # [-0.6733154 -0.33866507 0.25770962]] # random + label = paddle.to_tensor([0, 1, 4, 5]) + m = paddle.nn.HSigmoidLoss(3, 5) + out = m(input, label) + # [[2.4543471] + # [1.9359267]] + """ + + def __init__(self, + feature_size, + num_classes, + weight_attr=None, + bias_attr=None, + is_custom=False, + is_sparse=False, + name=None): + super(HSigmoidLoss, self).__init__() + if (num_classes < 2) and (not is_custom): + raise ValueError( + "num_classes must not be less than 2 with default tree") + + if (not is_custom) and (is_sparse): + print("Sparse mode should not be used without custom tree") + is_sparse = False + + self._feature_size = feature_size + self._num_classes = num_classes + self._is_custom = is_custom + self._is_sparse = is_sparse + + self._weight_attr = weight_attr + self._bias_attr = bias_attr + + self._name = name + self._dtype = paddle.get_default_dtype() + + remote_prefetch = is_sparse + print("With sparse mode, if your models has only" + " small parameter prefetch may cause speed down") + + C = self._num_classes if is_custom else self._num_classes - 1 + self.weight = self.create_parameter( + [C, self._feature_size], + attr=self._weight_attr, + is_bias=False, + dtype=self._dtype) + self.bias = self.create_parameter( + [C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype) + + def forward(self, input, label, path_table=None, path_code=None): + out = F.hsigmoid_loss( + input, + label, + self._num_classes, + self.weight, + self.bias, + path_table=path_table, + path_code=path_code, + is_sparse=self._is_sparse, + name=self._name) + return out + + class MSELoss(fluid.dygraph.layers.Layer): """ **Mean Square Error Loss** diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index d50bb33f240012259f9ee58cb029ed525e3e42c8..cd089432b1ca37dcff3c70f6b48834487f285926 100644 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -27,7 +27,6 @@ __all__ = [ 'data_norm', 'deformable_conv', 'group_norm', - 'hsigmoid', 'instance_norm', 'layer_norm', 'multi_box_head', @@ -53,7 +52,6 @@ from ...fluid.layers import crf_decoding #DEFINE_ALIAS from ...fluid.layers import data_norm #DEFINE_ALIAS from ...fluid.layers import deformable_conv #DEFINE_ALIAS from ...fluid.layers import group_norm #DEFINE_ALIAS -from ...fluid.layers import hsigmoid #DEFINE_ALIAS from ...fluid.layers import instance_norm #DEFINE_ALIAS from ...fluid.layers import layer_norm #DEFINE_ALIAS from ...fluid.layers import multi_box_head #DEFINE_ALIAS