From 92b3a717050ced0fac88e329dc905518909892f3 Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Tue, 13 Oct 2020 12:04:06 +0800 Subject: [PATCH] Update api 2.0 for some ops * 1. remove paddle.unique_with_counts api, which counts as unique api 2. add paddle.math.increment(x, value=1.0, name=None) api 3. replace paddle.sums with paddle.add_n api 4. update paddle.metric.accuracy api (add name parameter) --- python/paddle/__init__.py | 4 +- python/paddle/fluid/layers/nn.py | 2 +- .../fluid/tests/unittests/test_accuracy_op.py | 17 ++++ .../fluid/tests/unittests/test_increment.py | 44 +++++++++ .../unittests/test_math_op_patch_var_base.py | 9 +- .../fluid/tests/unittests/test_sum_op.py | 14 ++- python/paddle/metric/__init__.py | 3 +- python/paddle/metric/metrics.py | 73 ++++++++++++++- python/paddle/tensor/__init__.py | 4 +- python/paddle/tensor/manipulation.py | 2 - python/paddle/tensor/math.py | 91 +++++++++++-------- 11 files changed, 207 insertions(+), 56 deletions(-) create mode 100755 python/paddle/fluid/tests/unittests/test_increment.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e1d9450cd59..ad7632dc138 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -130,7 +130,6 @@ from .tensor.manipulation import stack #DEFINE_ALIAS from .tensor.manipulation import strided_slice #DEFINE_ALIAS from .tensor.manipulation import transpose #DEFINE_ALIAS from .tensor.manipulation import unique #DEFINE_ALIAS -from .tensor.manipulation import unique_with_counts #DEFINE_ALIAS from .tensor.manipulation import unsqueeze #DEFINE_ALIAS from .tensor.manipulation import unstack #DEFINE_ALIAS from .tensor.manipulation import flip #DEFINE_ALIAS @@ -172,9 +171,8 @@ from .tensor.math import sqrt #DEFINE_ALIAS from .tensor.math import square #DEFINE_ALIAS from .tensor.math import stanh #DEFINE_ALIAS from .tensor.math import sum #DEFINE_ALIAS -from .tensor.math import sums #DEFINE_ALIAS from .tensor.math import tanh #DEFINE_ALIAS -from .tensor.math import elementwise_sum #DEFINE_ALIAS +from .tensor.math import add_n #DEFINE_ALIAS from .tensor.math import max #DEFINE_ALIAS from .tensor.math import maximum #DEFINE_ALIAS from .tensor.math import min #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 19b431dce0c..dcfead697b9 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10852,7 +10852,7 @@ def sum(x): # and '__int64' on Windows. They both represent 64-bit integer variables. """ - return paddle.elementwise_sum(x) + return paddle.add_n(x) @templatedoc() diff --git a/python/paddle/fluid/tests/unittests/test_accuracy_op.py b/python/paddle/fluid/tests/unittests/test_accuracy_op.py index e4412b1b24e..00cf7d5e987 100755 --- a/python/paddle/fluid/tests/unittests/test_accuracy_op.py +++ b/python/paddle/fluid/tests/unittests/test_accuracy_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard @@ -67,11 +68,27 @@ class TestAccuracyOpError(unittest.TestCase): label = fluid.layers.data( name='label', shape=[-1, 1], dtype="int32") self.assertRaises(TypeError, fluid.layers.accuracy, x1, label) + self.assertRaises(TypeError, paddle.metric.accuracy, x1, label) # The input dtype of accuracy_op must be float32 or float64. x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") self.assertRaises(TypeError, fluid.layers.accuracy, x2, label) + self.assertRaises(TypeError, paddle.metric.accuracy, x2, label) x3 = fluid.layers.data(name='input', shape=[-1, 2], dtype="float16") fluid.layers.accuracy(input=x3, label=label) + paddle.metric.accuracy(input=x3, label=label) + + +class TestAccuracyAPI(unittest.TestCase): + def test_api(self): + with fluid.dygraph.guard(): + predictions = paddle.to_tensor( + [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], + dtype='float32') + label = paddle.to_tensor([[2], [0]], dtype="int64") + result = paddle.metric.accuracy(input=predictions, label=label, k=1) + expect_value = np.array([0.5], dtype='float32') + + self.assertEqual((result.numpy() == expect_value).all(), True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_increment.py b/python/paddle/fluid/tests/unittests/test_increment.py new file mode 100755 index 00000000000..e8cc7c8cf18 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_increment.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy as np +import paddle +import paddle.fluid as fluid + + +class TestIncrement(unittest.TestCase): + def test_api(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.layers.fill_constant( + shape=[1], dtype='int64', value=5) + expected_result = np.array([8], dtype='int64') + + output = paddle.tensor.math.increment(input, value=3) + exe = fluid.Executor(fluid.CPUPlace()) + result = exe.run(fetch_list=[output]) + self.assertEqual((result == expected_result).all(), True) + + with fluid.dygraph.guard(): + input = paddle.ones(shape=[1], dtype='int64') + expected_result = np.array([2], dtype='int64') + output = paddle.tensor.math.increment(input, value=1) + self.assertEqual((output.numpy() == expected_result).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index d85521f7662..37bea9deae7 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -397,8 +397,10 @@ class TestMathOpPatchesVarBase(unittest.TestCase): self.assertTrue( np.array_equal(m.unique()[0].numpy(), paddle.unique(m)[0].numpy())) self.assertTrue( - np.array_equal(m.unique_with_counts()[2], - paddle.unique_with_counts(m)[2])) + np.array_equal( + m.unique(return_counts=True)[1], + paddle.unique( + m, return_counts=True)[1])) self.assertTrue(np.array_equal(x.flip([0]), paddle.flip(x, [0]))) self.assertTrue(np.array_equal(x.unbind(0), paddle.unbind(x, 0))) self.assertTrue(np.array_equal(x.roll(1), paddle.roll(x, 1))) @@ -513,8 +515,7 @@ class TestMathOpPatchesVarBase(unittest.TestCase): self.assertTrue(inspect.ismethod(a.reduce_sum)) self.assertTrue(inspect.ismethod(a.scale)) self.assertTrue(inspect.ismethod(a.stanh)) - self.assertTrue(inspect.ismethod(a.sums)) - self.assertTrue(inspect.ismethod(a.elementwise_sum)) + self.assertTrue(inspect.ismethod(a.add_n)) self.assertTrue(inspect.ismethod(a.max)) self.assertTrue(inspect.ismethod(a.maximum)) self.assertTrue(inspect.ismethod(a.min)) diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index c0cd88a0a6a..35dc92ffb08 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -225,7 +225,7 @@ def create_test_sum_fp16_class(parent): globals()[cls_name] = TestSumFp16Case -class API_Test_Elementwise_Sum(unittest.TestCase): +class API_Test_Add_n(unittest.TestCase): def test_api(self): with fluid.program_guard(fluid.Program(), fluid.Program()): input0 = fluid.layers.fill_constant( @@ -234,11 +234,19 @@ class API_Test_Elementwise_Sum(unittest.TestCase): shape=[2, 3], dtype='int64', value=3) expected_result = np.empty((2, 3)) expected_result.fill(8) - sum_value = paddle.elementwise_sum([input0, input1]) + sum_value = paddle.add_n([input0, input1]) exe = fluid.Executor(fluid.CPUPlace()) result = exe.run(fetch_list=[sum_value]) - self.assertEqual((result == expected_result).all(), True) + self.assertEqual((result == expected_result).all(), True) + + with fluid.dygraph.guard(): + input0 = paddle.ones(shape=[2, 3], dtype='float32') + expected_result = np.empty((2, 3)) + expected_result.fill(2) + sum_value = paddle.add_n([input0, input0]) + + self.assertEqual((sum_value.numpy() == expected_result).all(), True) class TestRaiseSumError(unittest.TestCase): diff --git a/python/paddle/metric/__init__.py b/python/paddle/metric/__init__.py index fba45523889..d62fc137432 100644 --- a/python/paddle/metric/__init__.py +++ b/python/paddle/metric/__init__.py @@ -15,11 +15,10 @@ from .metrics import * from . import metrics -from ..fluid.layers.metric_op import accuracy, auc +from ..fluid.layers.metric_op import auc from ..fluid.layers.nn import chunk_eval, mean_iou __all__ = metrics.__all__ + [ - 'accuracy', 'auc', 'chunk_eval', 'mean_iou', diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index f4a9b8c01d0..fed659562cb 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -20,9 +20,13 @@ import six import abc import numpy as np +from ..fluid.data_feeder import check_variable_and_dtype +from ..fluid.layer_helper import LayerHelper +from ..fluid.layers.nn import topk +from ..fluid.framework import core, _varbase_creator, in_dygraph_mode import paddle -__all__ = ['Metric', 'Accuracy', 'Precision', 'Recall', 'Auc'] +__all__ = ['Metric', 'Accuracy', 'Precision', 'Recall', 'Auc', 'accuracy'] def _is_numpy_(var): @@ -733,3 +737,70 @@ class Auc(Metric): Returns metric name """ return self._name + + +def accuracy(input, label, k=1, correct=None, total=None, name=None): + """ + accuracy layer. + Refer to the https://en.wikipedia.org/wiki/Precision_and_recall + + This function computes the accuracy using the input and label. + If the correct label occurs in top k predictions, then correct will increment by one. + Note: the dtype of accuracy is determined by input. the input and label dtype can be different. + + Args: + input(Tensor): The input of accuracy layer, which is the predictions of network. A Tensor with type float32,float64. + The shape is ``[sample_number, class_dim]`` . + label(Tensor): The label of dataset. Tensor with type int32,int64. The shape is ``[sample_number, 1]`` . + k(int, optional): The top k predictions for each class will be checked. Data type is int64 or int32. + correct(Tensor, optional): The correct predictions count. A Tensor with type int64 or int32. + total(Tensor, optional): The total entries count. A tensor with type int64 or int32. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor, the correct rate. A Tensor with type float32. + + Examples: + .. code-block:: python + + import paddle + + predictions = paddle.to_tensor([[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], dtype='float32') + label = paddle.to_tensor([[2], [0]], dtype="int64") + result = paddle.metric.accuracy(input=predictions, label=label, k=1) + # [0.5] + """ + if in_dygraph_mode(): + if correct is None: + correct = _varbase_creator(dtype="int32") + if total is None: + total = _varbase_creator(dtype="int32") + + topk_out, topk_indices = topk(input, k=k) + _acc, _, _ = core.ops.accuracy(topk_out, topk_indices, label, correct, + total) + return _acc + + helper = LayerHelper("accuracy", **locals()) + check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'], + 'accuracy') + topk_out, topk_indices = topk(input, k=k) + acc_out = helper.create_variable_for_type_inference(dtype="float32") + if correct is None: + correct = helper.create_variable_for_type_inference(dtype="int32") + if total is None: + total = helper.create_variable_for_type_inference(dtype="int32") + helper.append_op( + type="accuracy", + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) + return acc_out diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 940bd1a4674..cfbaa961ddf 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -94,7 +94,6 @@ from .manipulation import stack #DEFINE_ALIAS from .manipulation import strided_slice #DEFINE_ALIAS from .manipulation import transpose #DEFINE_ALIAS from .manipulation import unique #DEFINE_ALIAS -from .manipulation import unique_with_counts #DEFINE_ALIAS from .manipulation import unsqueeze #DEFINE_ALIAS from .manipulation import unstack #DEFINE_ALIAS from .manipulation import flip #DEFINE_ALIAS @@ -137,9 +136,8 @@ from .math import sqrt #DEFINE_ALIAS from .math import square #DEFINE_ALIAS from .math import stanh #DEFINE_ALIAS from .math import sum #DEFINE_ALIAS -from .math import sums #DEFINE_ALIAS from .math import tanh #DEFINE_ALIAS -from .math import elementwise_sum #DEFINE_ALIAS +from .math import add_n #DEFINE_ALIAS from .math import max #DEFINE_ALIAS from .math import maximum #DEFINE_ALIAS from .math import min #DEFINE_ALIAS diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 531629c573f..73a37253828 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -30,7 +30,6 @@ from ..fluid.layers import unstack #DEFINE_ALIAS from ..fluid.layers import scatter_nd #DEFINE_ALIAS from ..fluid.layers import shard_index #DEFINE_ALIAS -from ..fluid.layers import unique_with_counts #DEFINE_ALIAS from ..fluid import layers import paddle @@ -57,7 +56,6 @@ __all__ = [ 'strided_slice', 'transpose', 'unique', - 'unique_with_counts', 'unsqueeze', 'unstack', 'flip', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2e71ed26a89..c0cb846042d 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -60,9 +60,7 @@ from ..fluid.layers import erf #DEFINE_ALIAS from ..fluid.layers import sqrt #DEFINE_ALIAS from ..fluid.layers import sin #DEFINE_ALIAS -from ..fluid.layers import increment #DEFINE_ALIAS from ..fluid.layers import multiplex #DEFINE_ALIAS -from ..fluid.layers import sums #DEFINE_ALIAS from ..fluid import layers @@ -105,9 +103,8 @@ __all__ = [ 'square', 'stanh', 'sum', - 'sums', 'tanh', - 'elementwise_sum', + 'add_n', 'max', 'maximum', 'min', @@ -728,11 +725,8 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): @templatedoc(op_type="sum") -def elementwise_sum(inputs, name=None): +def add_n(inputs, name=None): """ - :alias_main: paddle.elementwise_sum - :alias: paddle.elementwise_sum,paddle.tensor.elementwise_sum,paddle.tensor.math.elementwise_sum - ${comment} Case 1: @@ -766,53 +760,40 @@ def elementwise_sum(inputs, name=None): [14, 16, 18]] Args: - inputs (Variable|list(Variable)): A Varaible list. The shape and data type of the list elementsshould be consistent. - Variable can be multi-dimensional Tensoror LoDTensor, and data types can be: float32, float64, int32, int64. + inputs (Tensor|list(Tensor)): A Tensor list. The shape and data type of the list elements should be consistent. + Input can be multi-dimensional Tensor, and data types can be: float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` Returns: - Variable: the sum of input :math:`inputs` . its shape and data types are consistent with :math:`inputs` . + Tensor, the sum of input :math:`inputs` , its shape and data types are consistent with :math:`inputs`. Examples: .. code-block:: python import paddle - import paddle.fluid as fluid - input0 = fluid.layers.fill_constant(shape=[2, 3], dtype='int64', value=5) - input1 = fluid.layers.fill_constant(shape=[2, 3], dtype='int64', value=3) - sum = paddle.elementwise_sum([input0, input1]) - - # You can print out 'sum' via executor. - out = fluid.layers.Print(sum, message="the sum of input0 and input1: ") - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_main_program()) - - # The printed result is: - # 1570701754 the sum of input0 and input1: The place is:CPUPlace - # Tensor[elementwise_sum_0.tmp_0] - # shape: [2,3,] - # dtype: l - # data: 8,8,8,8,8,8, - - # the sum of input0 and input1 is 2-D Tensor with shape [2,3]. - # dtype is the corresponding C++ data type, which may vary in different environments. - # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, - # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, - # and '__int64' on Windows. They both represent 64-bit integer variables. + input0 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32') + input1 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]], dtype='float32') + output = paddle.add_n([input0, input1]) + # [[8., 10., 12.], + # [14., 16., 18.]] """ + if in_dygraph_mode(): + if isinstance(inputs, Variable): + inputs = [inputs] + return core.ops.sum(inputs, 'use_mkldnn', False) - helper = LayerHelper('elementwise_sum', **locals()) - check_type(inputs, 'inputs', (Variable, tuple, list), 'elementwise_sum') + helper = LayerHelper('add_n', **locals()) + check_type(inputs, 'inputs', (Variable, tuple, list), 'add_n') if isinstance(inputs, list) or isinstance(inputs, tuple): if len(inputs) > 0: for input in inputs: check_variable_and_dtype(input, "inputs", \ - ['float32', 'float64', 'int32', 'int64'], 'elementwise_sum') + ['float32', 'float64', 'int32', 'int64'], 'add_n') else: check_variable_and_dtype(inputs, "inputs", \ - ['float32', 'float64', 'int32', 'int64'], 'elementwise_sum') + ['float32', 'float64', 'int32', 'int64'], 'add_n') out = helper.create_variable_for_type_inference( @@ -1924,3 +1905,39 @@ def tanh(x, name=None): out = helper.create_variable_for_type_inference(x.dtype) helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out}) return out + +def increment(x, value=1.0, name=None): + """ + The OP is usually used for control flow to increment the data of :attr:`x` by an amount :attr:`value`. + Notice that the number of elements in :attr:`x` must be equal to 1. + + Args: + x (Tensor): A tensor that must always contain only one element, its data type supports float32, float64, int32 and int64. + value(float, optional): The amount to increment the data of :attr:`x`. Default: 1.0. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the elementwise-incremented tensor with the same shape and data type as :attr:`x`. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.zeros(shape=[1], dtype='float32') + counter = paddle.increment(data) + # [1.] + + """ + if in_dygraph_mode(): + return core.ops.increment(x, 'step', value) + + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + 'increment') + helper = LayerHelper("increment", **locals()) + helper.append_op( + type='increment', + inputs={'X': [x]}, + outputs={'Out': [x]}, + attrs={'step': float(value)}) + return x -- GitLab