diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 56765c19cba659001e714474124fef3e98409e97..34c6387a1643cb569e575654a253591f4c4e7f05 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -96,10 +96,6 @@ __all__ = [ 'clip_by_norm', 'mean', 'mul', - 'hash', - 'grid_sampler', - 'log_loss', - 'bilinear_tensor_product', 'merge_selected_rows', 'get_tensor_from_selected_rows', 'unfold', @@ -5223,292 +5219,6 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): return out -def hash(input, hash_size, num_hash=1, name=None): - """ - - This OP hash the input to an integer less than the hash_size. - The hash algorithm we used was xxHash - Extremely fast hash algorithm - (https://github.com/Cyan4973/xxHash/tree/v0.6.5) - - Args: - input(Variable): A **Two-Dimensional** LoDTensor with type int32, int64. - **Only support LoDTensor**. - num_hash(int, optional): The times of hash, default is 1. - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. - - Returns: - Variable: A LoDTensor with the same data type as input. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import numpy as np - import paddle - paddle.enable_static() - - place = fluid.core.CPUPlace() - - x = fluid.data(name="x", shape=[2,2], dtype="int32", lod_level=1) - res = fluid.layers.hash(name="res", input=x, hash_size=1000, num_hash=4) - - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - in1 = np.array([[1,2],[3,4]]).astype("int32") - print(in1) - x_i = fluid.create_lod_tensor(in1, [[0, 2]], place) - res = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res], return_numpy=False) - print(np.array(res[0])) - # [[[722] - # [407] - # [337] - # [395]] - # [[603] - # [590] - # [386] - # [901]]] - """ - check_variable_and_dtype(input, 'input', ['int32', 'int64'], 'hash') - check_type(hash_size, 'hash_size', int, 'hash') - check_type(num_hash, 'num_hash', int, 'hash') - helper = LayerHelper('hash', **locals()) - out = helper.create_variable_for_type_inference( - helper.input_dtype(), stop_gradient=True - ) - helper.append_op( - type='hash', - inputs={'X': input}, - outputs={'Out': out}, - attrs={'num_hash': num_hash, 'mod_by': hash_size}, - ) - return out - - -@templatedoc() -def grid_sampler(x, grid, name=None): - """ - - This operation samples input X by using bilinear interpolation based on - flow field grid, which is usually generated by :code:`affine_grid` . The grid of - shape [N, H, W, 2] is the concatenation of (x, y) coordinates - with shape [N, H, W] each, where x is indexing the 4th dimension - (in width dimension) of input data x and y is indexing the 3rd - dimension (in height dimension), finally results is the bilinear - interpolation value of 4 nearest corner points. The output tensor - shape will be [N, C, H, W]. - - .. code-block:: text - - Step 1: - Get (x, y) grid coordinates and scale to [0, H-1/W-1]. - - .. code-block:: text - - grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1) - grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1) - - Step 2: - Indices input data X with grid (x, y) in each [H, W] area, and bilinear - interpolate point value by 4 nearest points. - - wn ------- y_n ------- en - | | | - | d_n | - | | | - x_w --d_w-- grid--d_e-- x_e - | | | - | d_s | - | | | - ws ------- y_s ------- wn - - x_w = floor(x) // west side x coord - x_e = x_w + 1 // east side x coord - y_n = floor(y) // north side y coord - y_s = y_s + 1 // south side y coord - - d_w = grid_x - x_w // distance to west side - d_e = x_e - grid_x // distance to east side - d_n = grid_y - y_n // distance to north side - d_s = y_s - grid_y // distance to south side - - wn = X[:, :, y_n, x_w] // north-west point value - en = X[:, :, y_n, x_e] // north-east point value - ws = X[:, :, y_s, x_w] // south-east point value - es = X[:, :, y_s, x_w] // north-east point value - - output = wn * d_e * d_s + en * d_w * d_s - + ws * d_e * d_n + es * d_w * d_n - - Args: - x(Variable): The input tensor, which is a 4-D tensor with shape - [N, C, H, W], N is the batch size, C is the channel - number, H and W is the feature height and width. - The data type is float32 or float64. - grid(Variable): Input grid tensor of shape [N, H, W, 2]. The - data type is float32 or float64. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - Variable: Output of shape [N, C, H, W] data samples input X - using bilnear interpolation based on input grid. - The data type is same as input tensor. - - Examples: - - .. code-block:: python - - import paddle.fluid as fluid - import paddle.fluid as fluid - import paddle - - paddle.enable_static() - # use with affine_grid - x = fluid.data(name='x', shape=[None, 10, 32, 32], dtype='float32') - theta = fluid.layers.data(name='theta', shape=[2, 3], dtype='float32') - grid = fluid.layers.affine_grid(theta=theta, out_shape=[3, 10, 32, 32]) - out = fluid.layers.grid_sampler(x=x, grid=grid) - - """ - helper = LayerHelper("grid_sampler", **locals()) - - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sampler') - check_variable_and_dtype( - grid, 'grid', ['float32', 'float64'], 'grid_sampler' - ) - if not isinstance(x, Variable): - return ValueError("The x should be a Variable") - - if not isinstance(grid, Variable): - return ValueError("The grid should be a Variable") - - out = helper.create_variable_for_type_inference(x.dtype) - ipts = {'X': x, 'Grid': grid} - - attrs = {'use_cudnn': False} if core.is_compiled_with_rocm() else {} - - helper.append_op( - type='grid_sampler', inputs=ipts, outputs={'Output': out}, attrs=attrs - ) - return out - - -def log_loss(input, label, epsilon=1e-4, name=None): - r""" - - **Negative Log Loss Layer** - - This layer accepts input predictions and target label and returns the - negative log loss. - - .. math:: - - Out = -label * \log{(input + \epsilon)} - - (1 - label) * \log{(1 - input + \epsilon)} - - Args: - input (Tensor|list): A 2-D tensor with shape [N x 1], where N is the - batch size. This input is a probability computed - by the previous operator. Data type float32. - label (Tensor|list): The ground truth which is a 2-D tensor with - shape [N x 1], where N is the batch size. - Data type float32. - epsilon (float, optional): A small number for numerical stability. Default 1e-4. - name(str|None): For detailed information, please refer to - :ref:`api_guide_Name` . Usually name is no need to set and None by default. - - Returns: - Tensor, which shape is [N x 1], data type is float32. - - Examples: - .. code-block:: python - - import paddle - import paddle.nn.functional as F - - label = paddle.randn((10,1)) - prob = paddle.randn((10,1)) - cost = F.log_loss(input=prob, label=label) - """ - return paddle.nn.functional.log_loss(input, label, epsilon, name) - - -def bilinear_tensor_product( - x, y, size, act=None, name=None, param_attr=None, bias_attr=None -): - r""" - :api_attr: Static Graph - - **Bilinear Tensor Product Layer** - - This layer performs bilinear tensor product on two inputs. - For example: - - .. math:: - out_{i} = x * W_{i} * {y^\mathrm{T}}, i=0,1,...,size-1 - - In this formula: - - :math:`x`: the first input contains M elements, shape is [batch_size, M]. - - :math:`y`: the second input contains N elements, shape is [batch_size, N]. - - :math:`W_{i}`: the i-th learned weight, shape is [M, N]. - - :math:`out_{i}`: the i-th element of out, shape is [batch_size, size]. - - :math:`y^\mathrm{T}`: the transpose of :math:`y_{2}`. - - Args: - x (Variable): 2-D input tensor with shape [batch_size, M]. Data type - is float32 or float64. - y (Variable): 2-D input tensor with shape [batch_size, N]. Data type - should be same as **x**. - size (int): The dimension of this layer. - act (str|None): Activation to be applied to the output of this layer. Default None. - name(str|None): For detailed information, please refer to - :ref:`api_guide_Name` . Usually name is no need to set and None by default. - param_attr (ParamAttr|None): To specify the weight parameter attribute. - Default: None, which means the default weight parameter property is - used. See usage for details in :ref:`api_fluid_ParamAttr` . - bias_attr (ParamAttr|None): To specify the bias parameter attribute. - Default: None, which means the default bias parameter property is - used. See usage for details in :ref:`api_fluid_ParamAttr` . - Returns: - Variable: A 2-D Tensor of shape [batch_size, size]. Data type is the same as input **x**. - - Examples: - .. code-block:: python - - import paddle - paddle.enable_static() - layer1 = paddle.static.data("t1", shape=[-1, 5], dtype="float32") - layer2 = paddle.static.data("t2", shape=[-1, 4], dtype="float32") - tensor = paddle.static.nn.bilinear_tensor_product(x=layer1, y=layer2, size=1000) - """ - helper = LayerHelper('bilinear_tensor_product', **locals()) - dtype = helper.input_dtype('x') - - param_shape = [size, x.shape[1], y.shape[1]] - - w = helper.create_parameter( - attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=False - ) - out = helper.create_variable_for_type_inference(dtype=dtype) - - inputs = {"X": x, "Y": y, "Weight": w} - if helper.bias_attr: - bias_size = [1, size] - bias = helper.create_parameter( - attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True - ) - inputs["Bias"] = bias - helper.append_op( - type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out} - ) - - # add activation - return helper.append_activation(out) - - @templatedoc() def get_tensor_from_selected_rows(x, name=None): """ diff --git a/python/paddle/fluid/tests/unittests/npu/test_log_loss_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_log_loss_op_npu.py index 87cd872e8cc9135bd8660e9c289d01f177d8921c..c47b42ee125bed04d047f1630cce2d9fe167dc21 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_log_loss_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_log_loss_op_npu.py @@ -76,37 +76,5 @@ class TestLogLossOp(OpTest): self.check_grad_with_place(self.place, ['Predicted'], 'Loss') -@unittest.skipIf( - not paddle.is_compiled_with_npu(), "core is not compiled with NPU" -) -class TestLogLossOpError(unittest.TestCase): - def test_errors(self): - with fluid.program_guard(fluid.Program()): - - def test_x_type(): - input_data = np.random.random(100, 1).astype("float32") - fluid.layers.log_loss(input_data) - - self.assertRaises(TypeError, test_x_type) - - def test_x_dtype(): - x2 = fluid.layers.data(name='x2', shape=[100, 1], dtype='int32') - fluid.layers.log_loss(x2) - - self.assertRaises(TypeError, test_x_dtype) - - def test_label_type(): - input_data = np.random.random(100, 1).astype("float32") - fluid.layers.log_loss(input_data) - - self.assertRaises(TypeError, test_label_type) - - def test_label_dtype(): - x2 = fluid.layers.data(name='x2', shape=[100, 1], dtype='int32') - fluid.layers.log_loss(x2) - - self.assertRaises(TypeError, test_label_dtype) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet.py b/python/paddle/fluid/tests/unittests/test_fleet.py index d0445c2c5e09b864e7520dffa1090d22b48e52c9..6092710a798c0ad3a08181e64a5049f510cda57d 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet.py +++ b/python/paddle/fluid/tests/unittests/test_fleet.py @@ -79,7 +79,7 @@ class TestFleet1(unittest.TestCase): append_batch_size=False, ) label_cast = fluid.layers.cast(label, dtype='float32') - cost = fluid.layers.log_loss(fc, label_cast) + cost = paddle.nn.functional.log_loss(fc, label_cast) try: adam = fluid.optimizer.Adam(learning_rate=0.000005) adam = fleet.distributed_optimizer( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py b/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py index 577652037e53862fdbe74081c3b16f2f3ae66242..f5975ae990d7029f6bba659aec7132552b61206b 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py @@ -16,6 +16,8 @@ import os import unittest +import paddle + class TestFleet1(unittest.TestCase): """ @@ -73,7 +75,7 @@ class TestFleet1(unittest.TestCase): append_batch_size=False, ) label_cast = fluid.layers.cast(label, dtype='float32') - cost = fluid.layers.log_loss(fc, label_cast) + cost = paddle.nn.functional.log_loss(fc, label_cast) try: adam = fluid.optimizer.Adam(learning_rate=0.000005) adam = fleet.distributed_optimizer( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py index f64d8cb1692b2034a015e21923b777ef5ad09ddc..daee01f38f742c8c966a113c1e6970c35ca18ab2 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py @@ -16,6 +16,7 @@ import os import unittest +import paddle import paddle.fluid.incubate.fleet.base.role_maker as role_maker @@ -97,7 +98,7 @@ class TestCloudRoleMaker(unittest.TestCase): append_batch_size=False, ) label_cast = fluid.layers.cast(label, dtype='float32') - cost = fluid.layers.log_loss(fc, label_cast) + cost = paddle.nn.functional.log_loss(fc, label_cast) try: adam = fluid.optimizer.Adam(learning_rate=0.000005) adam = fleet.distributed_optimizer(adam) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py index a657d3deb51a02549ea0a472e2e13ac591675ce2..7a6ba4248352a4ef83c737604d75c82b113db9a0 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_2.py @@ -79,7 +79,7 @@ class TestCloudRoleMaker2(unittest.TestCase): append_batch_size=False, ) label_cast = fluid.layers.cast(label, dtype='float32') - cost = fluid.layers.log_loss(fc, label_cast) + cost = paddle.nn.functional.log_loss(fc, label_cast) try: adam = fluid.optimizer.Adam(learning_rate=0.000005) adam = fleet.distributed_optimizer(adam) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_3.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_3.py index 79b5e136f189a95f28e611c12bfb8e4ab18b0742..c3df410610ba9688e3bc44e164bb59cc2764033a 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_3.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_3.py @@ -16,6 +16,8 @@ import os import unittest +import paddle + class TestCloudRoleMaker(unittest.TestCase): """ @@ -70,7 +72,7 @@ class TestCloudRoleMaker(unittest.TestCase): append_batch_size=False, ) label_cast = fluid.layers.cast(label, dtype='float32') - cost = fluid.layers.log_loss(fc, label_cast) + cost = paddle.nn.functional.log_loss(fc, label_cast) try: adam = fluid.optimizer.Adam(learning_rate=0.000005) adam = fleet.distributed_optimizer(adam) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_unitaccessor.py b/python/paddle/fluid/tests/unittests/test_fleet_unitaccessor.py index 9c7736a39384f1e6af26c8c8f4583e65bcf82f18..78c4a4541e3c06e9989cf4cdb64795cea1edbf09 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_unitaccessor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_unitaccessor.py @@ -16,6 +16,8 @@ import os import unittest +import paddle + class TestFleet1(unittest.TestCase): """ @@ -73,7 +75,7 @@ class TestFleet1(unittest.TestCase): append_batch_size=False, ) label_cast = fluid.layers.cast(label, dtype='float32') - cost = fluid.layers.log_loss(fc, label_cast) + cost = paddle.nn.functional.log_loss(fc, label_cast) strategy = {} strategy["embedding"] = {} diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py index 53b1551c7b8446273d332842a539b16ece3e8430..75ddd7bb89c8c9c7129e8f2ed8ed0ab1a3b29e8e 100644 --- a/python/paddle/fluid/tests/unittests/test_hash_op.py +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -17,8 +17,6 @@ import unittest import numpy as np from op_test import OpTest -import paddle.fluid as fluid - class TestHashOp(OpTest): def setUp(self): @@ -120,44 +118,5 @@ class TestHashOp3(TestHashOp): self.check_output() -class TestHashOpError(unittest.TestCase): - def test_errors(self): - with fluid.program_guard(fluid.Program(), fluid.Program()): - input_data = np.random.randint(0, 10, (8, 1)).astype("int32") - - def test_Variable(): - # the input type must be Variable - fluid.layers.hash(input=input_data, hash_size=2**32) - - self.assertRaises(TypeError, test_Variable) - - def test_type(): - # dtype must be int32, int64. - x2 = fluid.layers.data( - name='x2', shape=[1], dtype="float32", lod_level=1 - ) - fluid.layers.hash(input=x2, hash_size=2**32) - - self.assertRaises(TypeError, test_type) - - def test_hash_size_type(): - # hash_size dtype must be int32, int64. - x3 = fluid.layers.data( - name='x3', shape=[1], dtype="int32", lod_level=1 - ) - fluid.layers.hash(input=x3, hash_size=1024.5) - - self.assertRaises(TypeError, test_hash_size_type) - - def test_num_hash_type(): - # num_hash dtype must be int32, int64. - x4 = fluid.layers.data( - name='x4', shape=[1], dtype="int32", lod_level=1 - ) - fluid.layers.hash(input=x4, hash_size=2**32, num_hash=2.5) - - self.assertRaises(TypeError, test_num_hash_type) - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py index c4e280ea46fd0b819989e66ffc898e4cbf95081b..ecb4600163c4b8aaa716a6c457ca9a8dc1d67a0c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py @@ -271,7 +271,7 @@ class TestDygraphDeepCF(unittest.TestCase): deepcf = DeepCF(num_users, num_items, matrix) prediction = deepcf(users, items) - loss = paddle.sum(fluid.layers.log_loss(prediction, labels)) + loss = paddle.sum(paddle.nn.functional.log_loss(prediction, labels)) adam = fluid.optimizer.AdamOptimizer(0.01) adam.minimize(loss) @@ -325,7 +325,7 @@ class TestDygraphDeepCF(unittest.TestCase): to_variable(items_np[slice : slice + self.batch_size]), ) loss = paddle.sum( - fluid.layers.log_loss( + paddle.nn.functional.log_loss( prediction, to_variable( labels_np[slice : slice + self.batch_size] @@ -359,7 +359,7 @@ class TestDygraphDeepCF(unittest.TestCase): to_variable(items_np[slice : slice + self.batch_size]), ) loss2 = paddle.sum( - fluid.layers.log_loss( + paddle.nn.functional.log_loss( prediction2, to_variable( labels_np[slice : slice + self.batch_size] @@ -402,7 +402,7 @@ class TestDygraphDeepCF(unittest.TestCase): ), ) loss = paddle.sum( - fluid.layers.log_loss( + paddle.nn.functional.log_loss( prediction, to_variable( labels_np[slice : slice + self.batch_size] diff --git a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py index 169269cc03e31e9390c6aaa1ea9398c61855b5df..a98d9b994b33a9f228f6d0bd4bb69aa55f5538b4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py @@ -86,11 +86,15 @@ class TestDygraphLoadStatic(unittest.TestCase): "t2", shape=[None, 4], dtype="float32" ) - bilinear_tensor_pro_out_1 = fluid.layers.bilinear_tensor_product( - x=bilinear_tensor_pro_x, y=bilinear_tensor_pro_y, size=1000 - ) - bilinear_tensor_pro_out_2 = fluid.layers.bilinear_tensor_product( - x=bilinear_tensor_pro_x, y=bilinear_tensor_pro_y, size=1000 + bilinear_tensor_pro_out_1 = ( + paddle.static.nn.common.bilinear_tensor_product( + x=bilinear_tensor_pro_x, y=bilinear_tensor_pro_y, size=1000 + ) + ) + bilinear_tensor_pro_out_2 = ( + paddle.static.nn.common.bilinear_tensor_product( + x=bilinear_tensor_pro_x, y=bilinear_tensor_pro_y, size=1000 + ) ) conv2d_trans_in = fluid.data( diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 62def4247037f5078830011a2c4c53e886c5b070..3f7edb6022a859965ca962e1eecb2d49ed0396f0 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -750,7 +750,7 @@ class TestLayer(LayerTest): data_y = layers.data( name='y', shape=[1, 3], dtype="float32", append_batch_size=False ) - out = layers.bilinear_tensor_product( + out = paddle.static.nn.common.bilinear_tensor_product( data_x, data_y, 6, @@ -825,7 +825,7 @@ class TestLayer(LayerTest): data_y2 = layers.data( name='y', shape=[1, 3], dtype="float32", append_batch_size=False ) - out2 = layers.bilinear_tensor_product( + out2 = paddle.static.nn.common.bilinear_tensor_product( data_x2, data_y2, 6, act='sigmoid' ) @@ -3418,15 +3418,6 @@ class TestBook(LayerTest): out = layers.iou_similarity(x, y, name='iou_similarity') return out - def make_grid_sampler(self): - with program_guard( - fluid.default_main_program(), fluid.default_startup_program() - ): - x = self._get_data(name='x', shape=[3, 5, 7], dtype='float32') - grid = self._get_data(name='grid', shape=[5, 7, 2], dtype='float32') - out = layers.grid_sampler(x, grid) - return out - def make_bilinear_tensor_product_layer(self): with program_guard( fluid.default_main_program(), fluid.default_startup_program() @@ -3434,7 +3425,9 @@ class TestBook(LayerTest): data = self._get_data(name='data', shape=[4], dtype="float32") theta = self._get_data(name="theta", shape=[5], dtype="float32") - out = layers.bilinear_tensor_product(data, theta, 6) + out = paddle.static.nn.common.bilinear_tensor_product( + data, theta, 6 + ) return out def make_batch_norm(self): diff --git a/python/paddle/fluid/tests/unittests/test_log_loss_op.py b/python/paddle/fluid/tests/unittests/test_log_loss_op.py index 908f4bf94e510f7dc71704b0758fa9e24d07d2b8..25bede0af214b8bea9b65cd3b4e59b1a4f2b0f9c 100644 --- a/python/paddle/fluid/tests/unittests/test_log_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_log_loss_op.py @@ -17,8 +17,6 @@ import unittest import numpy as np from op_test import OpTest -import paddle.fluid as fluid - def sigmoid_array(x): return 1 / (1 + np.exp(-x)) @@ -51,34 +49,5 @@ class TestLogLossOp(OpTest): self.check_grad(['Predicted'], 'Loss', max_relative_error=0.03) -class TestLogLossOpError(unittest.TestCase): - def test_errors(self): - with fluid.program_guard(fluid.Program()): - - def test_x_type(): - input_data = np.random.random(100, 1).astype("float32") - fluid.layers.log_loss(input_data) - - self.assertRaises(TypeError, test_x_type) - - def test_x_dtype(): - x2 = fluid.layers.data(name='x2', shape=[100, 1], dtype='int32') - fluid.layers.log_loss(x2) - - self.assertRaises(TypeError, test_x_dtype) - - def test_label_type(): - input_data = np.random.random(100, 1).astype("float32") - fluid.layers.log_loss(input_data) - - self.assertRaises(TypeError, test_label_type) - - def test_label_dtype(): - x2 = fluid.layers.data(name='x2', shape=[100, 1], dtype='int32') - fluid.layers.log_loss(x2) - - self.assertRaises(TypeError, test_label_dtype) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index 1849cfd395a55388452265a6d4a463be45c36ebe..9635811f6a818c875077be0d85aefea94525c2e8 100755 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -20,11 +20,11 @@ from .common import deform_conv2d # noqa: F401 from .common import conv3d # noqa: F401 from .common import conv2d_transpose # noqa: F401 from .common import conv3d_transpose # noqa: F401 +from .common import bilinear_tensor_product # noqa: F401 from .common import py_func # noqa: F401 from ...tensor.creation import create_parameter # noqa: F401 from ...fluid.layers import batch_norm # noqa: F401 -from ...fluid.layers import bilinear_tensor_product # noqa: F401 from ...fluid.layers import case # noqa: F401 from ...fluid.layers import cond # noqa: F401 from ...fluid.layers import conv2d # noqa: F401 @@ -61,8 +61,8 @@ from ...fluid.layers.sequence_lod import sequence_reverse # noqa: F401 __all__ = [ # noqa 'fc', 'batch_norm', - 'embedding', 'bilinear_tensor_product', + 'embedding', 'case', 'cond', 'conv2d', diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index a8dec018ff14ab78af4743817cb3c1f225faf6cf..420a00ddbdc51e89a1b7e6660d8d3cd902f85b0e 100755 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2088,6 +2088,184 @@ def deform_conv2d( ) +def bilinear_tensor_product( + x, y, size, act=None, name=None, param_attr=None, bias_attr=None +): + r""" + This layer performs bilinear tensor product on two inputs. + + .. math:: + + out_{i} = x * W_{i} * {y^\mathrm{T}}, i=0,1,...,size-1 + + In this formula: + - :math:`x`: the first input contains M elements, shape is [batch_size, M]. + - :math:`y`: the second input contains N elements, shape is [batch_size, N]. + - :math:`W_{i}`: the i-th learned weight, shape is [M, N]. + - :math:`out_{i}`: the i-th element of out, shape is [batch_size, size]. + - :math:`y^\mathrm{T}`: the transpose of :math:`y_{2}`. + + Args: + x (Variable): 2-D input tensor with shape [batch_size, M]. Data type + is float32 or float64. + y (Variable): 2-D input tensor with shape [batch_size, N]. Data type + should be same as **x**. + size (int): The dimension of this layer. + act (str|None): Activation to be applied to the output of this layer. Default None. + name(str|None): For detailed information, please refer to + :ref:`api_guide_Name` . Usually name is no need to set and None by default. + param_attr (ParamAttr|None): To specify the weight parameter attribute. + Default: None, which means the default weight parameter property is + used. See usage for details in :ref:`api_fluid_ParamAttr` . + bias_attr (ParamAttr|None): To specify the bias parameter attribute. + Default: None, which means the default bias parameter property is + used. See usage for details in :ref:`api_fluid_ParamAttr` . + + Returns: + Tensor, A 2-D Tensor of shape [batch_size, size]. Data type is the same as input **x**. + + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.static.data("t1", shape=[-1, 5], dtype="float32") + y = paddle.static.data("t2", shape=[-1, 4], dtype="float32") + tensor = paddle.static.nn.bilinear_tensor_product(x, y, size=1000) + + """ + helper = LayerHelper('bilinear_tensor_product', **locals()) + dtype = helper.input_dtype('x') + + param_shape = [size, x.shape[1], y.shape[1]] + + w = helper.create_parameter( + attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=False + ) + out = helper.create_variable_for_type_inference(dtype=dtype) + + inputs = {"X": x, "Y": y, "Weight": w} + if helper.bias_attr: + bias_size = [1, size] + bias = helper.create_parameter( + attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True + ) + inputs["Bias"] = bias + helper.append_op( + type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out} + ) + + # add activation + return helper.append_activation(out) + + +@static_only +def prelu(x, mode, param_attr=None, data_format="NCHW", name=None): + r""" + + prelu activation. + + .. math:: + prelu(x) = max(0, x) + \alpha * min(0, x) + + There are three modes for the activation: + + .. code-block:: text + + all: All elements share same alpha. + channel: Elements in same channel share same alpha. + element: All elements do not share alpha. Each element has its own alpha. + + Parameters: + x (Tensor): The input Tensor or LoDTensor with data type float32. + mode (str): The mode for weight sharing. + param_attr (ParamAttr|None, optional): The parameter attribute for the learnable \ + weight (alpha), it can be create by ParamAttr. None by default. \ + For detailed information, please refer to :ref:`api_paddle_ParamAttr`. + data_format(str, optional): Data format that specifies the layout of input. + It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". + name (str, optional): Name for the operation (optional, default is None). \ + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A tensor with the same shape and data type as x. + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.static.data(name="x", shape=[None,5,10,10], dtype="float32") + mode = 'channel' + output = paddle.static.nn.prelu( + x,mode,param_attr=paddle.ParamAttr(name='alpha')) + + """ + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu') + + helper = LayerHelper('prelu', **locals()) + if mode not in ['all', 'channel', 'element']: + raise ValueError('mode should be one of all, channel, element.') + + alpha_shape = [1] + if mode == 'channel': + + true_data_format = [ + 'NC', + 'NCL', + 'NCHW', + 'NCDHW', + 'NLC', + 'NHWC', + 'NDHWC', + ] + if data_format not in true_data_format: + raise ValueError( + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', " + "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format) + ) + + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' + + assert ( + len(x.shape) >= 2 + ), "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'" + # NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]). + # To be consistent with Prelu, it is simplified. + # NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version. + # NOTE(GuoxiaWang): support NHWC data format + if data_format == 'NHWC': + alpha_shape = [1, 1, 1, x.shape[-1]] + else: + alpha_shape = [1, x.shape[1], 1, 1] + + elif mode == 'element': + assert ( + len(x.shape) >= 1 + ), "The size of input shape should be equal or larger than 1 in prelu() when mode is 'element'" + alpha_shape = [1] + list(x.shape)[1:] + dtype = helper.input_dtype(input_param_name='x') + alpha = helper.create_parameter( + attr=helper.param_attr, + shape=alpha_shape, + dtype=dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0.25), + ) + + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="prelu", + inputs={"X": x, 'Alpha': alpha}, + attrs={"mode": mode, "data_format": data_format}, + outputs={"Out": out}, + ) + return out + + class PyFuncRegistry: _register_funcs = [] @@ -2106,12 +2284,10 @@ class PyFuncRegistry: self._id = core._append_python_callable_object_and_return_id(self) ''' Why record self here? - 1. For debug usage. Users can call :code:`py_func.registered_func(idx)` method to find the registered function corresponding to :code:`idx`. - 2. For increasing reference count of self. It seems that to release Python object whose reference count is 1 would cause @@ -2169,25 +2345,20 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): This is used to register customized Python OP to Paddle. The design principe of py_func is that Tensor and numpy array can be converted to each other easily. So you can use Python and numpy API to register a python OP. - The forward function of the registered OP is ``func`` and the backward function of that is ``backward_func``. Paddle will call ``func`` at forward runtime and call ``backward_func`` at backward runtime(if ``backward_func`` is not None). ``x`` is the input of ``func``, whose type must be Tensor; ``out`` is the output of ``func``, whose type can be either Tensor or numpy array. - The input of the backward function ``backward_func`` is ``x``, ``out`` and the gradient of ``out``. If ``out`` have no gradient, the relevant input of ``backward_func`` is None. If ``x`` do not have a gradient, the user should return None in ``backward_func``. - The data type and shape of ``out`` should also be set correctly before this API is called, and the data type and shape of the gradient of ``out`` and ``x`` will be inferred automatically. - This API can also be used to debug the neural network by setting the ``func`` as a function that only print variables. - Args: func (callable): The forward function of the registered OP. When the network is running, the forward output ``out`` will be calculated according to this @@ -2211,61 +2382,47 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): that no tensors need to be removed from ``x`` and ``out``. If it is not None, these tensors will not be the input of ``backward_func``. This parameter is only useful when ``backward_func`` is not None. - Returns: Tensor|tuple(Tensor)|list[Tensor]: The output ``out`` of the forward function ``func``. - Examples: .. code-block:: python - # example 1: import paddle import numpy as np - paddle.enable_static() - # Creates a forward function, Tensor can be input directly without # being converted into numpy array. def tanh(x): return np.tanh(x) - # Skip x in backward function and return the gradient of x # Tensor must be actively converted to numpy array, otherwise, # operations such as +/- can't be used. def tanh_grad(y, dy): return np.array(dy) * (1 - np.square(np.array(y))) - # Creates a forward function for debugging running networks(print value) def debug_func(x): print(x) - def create_tmp_var(name, dtype, shape): return paddle.static.default_main_program().current_block().create_var( name=name, dtype=dtype, shape=shape) - def simple_net(img, label): hidden = img for idx in range(4): hidden = paddle.static.nn.fc(hidden, size=200) new_hidden = create_tmp_var(name='hidden_{}'.format(idx), dtype=hidden.dtype, shape=hidden.shape) - # User-defined forward and backward hidden = paddle.static.py_func(func=tanh, x=hidden, out=new_hidden, backward_func=tanh_grad, skip_vars_in_backward_input=hidden) - # User-defined debug functions that print out the input Tensor paddle.static.py_func(func=debug_func, x=hidden, out=None) - prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax') ce_loss = paddle.nn.loss.CrossEntropyLoss() return ce_loss(prediction, label) - x = paddle.static.data(name='x', shape=[1,4], dtype='float32') y = paddle.static.data(name='y', shape=[1], dtype='int64') res = simple_net(x, y) - exe = paddle.static.Executor(paddle.CPUPlace()) exe.run(paddle.static.default_startup_program()) input1 = np.random.random(size=[1,4]).astype('float32') @@ -2274,54 +2431,40 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): feed={'x':input1, 'y':input2}, fetch_list=[res.name]) print(out) - .. code-block:: python - # example 2: # This example shows how to turn Tensor into numpy array and # use numpy API to register an Python OP import paddle import numpy as np - paddle.enable_static() - def element_wise_add(x, y): # Tensor must be actively converted to numpy array, otherwise, # numpy.shape can't be used. x = np.array(x) y = np.array(y) - if x.shape != y.shape: raise AssertionError("the shape of inputs must be the same!") - result = np.zeros(x.shape, dtype='int32') for i in range(len(x)): for j in range(len(x[0])): result[i][j] = x[i][j] + y[i][j] - return result - def create_tmp_var(name, dtype, shape): return paddle.static.default_main_program().current_block().create_var( name=name, dtype=dtype, shape=shape) - def py_func_demo(): start_program = paddle.static.default_startup_program() main_program = paddle.static.default_main_program() - # Input of the forward function x = paddle.static.data(name='x', shape=[2,3], dtype='int32') y = paddle.static.data(name='y', shape=[2,3], dtype='int32') - # Output of the forward function, name/dtype/shape must be specified output = create_tmp_var('output','int32', [3,1]) - # Multiple Variable should be passed in the form of tuple(Variale) or list[Variale] paddle.static.py_func(func=element_wise_add, x=[x,y], out=output) - exe=paddle.static.Executor(paddle.CPUPlace()) exe.run(start_program) - # Feed numpy array to main_program input1 = np.random.randint(1, 10, size=[2,3], dtype='int32') input2 = np.random.randint(1, 10, size=[2,3], dtype='int32') @@ -2329,9 +2472,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): feed={'x':input1, 'y':input2}, fetch_list=[output.name]) print("{0} + {1} = {2}".format(input1, input2, out)) - py_func_demo() - # Reference output: # [[5, 9, 9] + [[7, 8, 4] = [array([[12, 17, 13] # [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)] @@ -2405,109 +2546,3 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): # For debug usage py_func.registered_func = PyFuncRegistry.registered_func py_func.registered_func_num = PyFuncRegistry.registered_func_num - - -@static_only -def prelu(x, mode, param_attr=None, data_format="NCHW", name=None): - r""" - - prelu activation. - - .. math:: - prelu(x) = max(0, x) + \alpha * min(0, x) - - There are three modes for the activation: - - .. code-block:: text - - all: All elements share same alpha. - channel: Elements in same channel share same alpha. - element: All elements do not share alpha. Each element has its own alpha. - - Parameters: - x (Tensor): The input Tensor or LoDTensor with data type float32. - mode (str): The mode for weight sharing. - param_attr (ParamAttr|None, optional): The parameter attribute for the learnable \ - weight (alpha), it can be create by ParamAttr. None by default. \ - For detailed information, please refer to :ref:`api_paddle_ParamAttr`. - data_format(str, optional): Data format that specifies the layout of input. - It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". - name (str, optional): Name for the operation (optional, default is None). \ - For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor: A tensor with the same shape and data type as x. - - Examples: - - .. code-block:: python - - import paddle - paddle.enable_static() - - x = paddle.static.data(name="x", shape=[None,5,10,10], dtype="float32") - mode = 'channel' - output = paddle.static.nn.prelu( - x,mode,param_attr=paddle.ParamAttr(name='alpha')) - - """ - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu') - - helper = LayerHelper('prelu', **locals()) - if mode not in ['all', 'channel', 'element']: - raise ValueError('mode should be one of all, channel, element.') - - alpha_shape = [1] - if mode == 'channel': - - true_data_format = [ - 'NC', - 'NCL', - 'NCHW', - 'NCDHW', - 'NLC', - 'NHWC', - 'NDHWC', - ] - if data_format not in true_data_format: - raise ValueError( - "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', " - "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format) - ) - - data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' - - assert ( - len(x.shape) >= 2 - ), "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'" - # NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]). - # To be consistent with Prelu, it is simplified. - # NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version. - # NOTE(GuoxiaWang): support NHWC data format - if data_format == 'NHWC': - alpha_shape = [1, 1, 1, x.shape[-1]] - else: - alpha_shape = [1, x.shape[1], 1, 1] - - elif mode == 'element': - assert ( - len(x.shape) >= 1 - ), "The size of input shape should be equal or larger than 1 in prelu() when mode is 'element'" - alpha_shape = [1] + list(x.shape)[1:] - dtype = helper.input_dtype(input_param_name='x') - alpha = helper.create_parameter( - attr=helper.param_attr, - shape=alpha_shape, - dtype=dtype, - is_bias=False, - default_initializer=paddle.nn.initializer.Constant(0.25), - ) - - out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="prelu", - inputs={"X": x, 'Alpha': alpha}, - attrs={"mode": mode, "data_format": data_format}, - outputs={"Out": out}, - ) - return out