diff --git a/paddle/fluid/operators/size_op.cc b/paddle/fluid/operators/size_op.cc index 06eaca0216b36a50028fd7cfd3c0866a5b7c1de0..b45fa7c791ff22be422ce12a8348a071c60ddd0f 100644 --- a/paddle/fluid/operators/size_op.cc +++ b/paddle/fluid/operators/size_op.cc @@ -54,5 +54,6 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel, ops::SizeKernel, + ops::SizeKernel, ops::SizeKernel, ops::SizeKernel, ops::SizeKernel); diff --git a/paddle/fluid/operators/size_op.cu b/paddle/fluid/operators/size_op.cu index 4e5846660e62543638b669d586a92fc36b0c8e87..3ea3032693236d5618ff6f0c858cbd85e34633ab 100644 --- a/paddle/fluid/operators/size_op.cu +++ b/paddle/fluid/operators/size_op.cu @@ -14,8 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/size_op.h" -REGISTER_OP_CUDA_KERNEL(size, paddle::operators::SizeKernel, - paddle::operators::SizeKernel, - paddle::operators::SizeKernel, - paddle::operators::SizeKernel, - paddle::operators::SizeKernel); +REGISTER_OP_CUDA_KERNEL( + size, paddle::operators::SizeKernel, + paddle::operators::SizeKernel, + paddle::operators::SizeKernel, + paddle::operators::SizeKernel, paddle::operators::SizeKernel, + paddle::operators::SizeKernel); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 94de3fa0adb42f7b358688d4c1af78e822e64613..8404e82c544955732fd44040a9f57fd6eeb398bd 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -237,6 +237,7 @@ from .tensor.stat import reduce_mean #DEFINE_ALIAS from .tensor.stat import std #DEFINE_ALIAS from .tensor.stat import var #DEFINE_ALIAS from .fluid.data import data +from .tensor.stat import numel #DEFINE_ALIAS from .device import get_cudnn_version from .device import set_device from .device import get_device diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 985db80a22297a41436205d14a91188d560a9c41..2115f3f8de0b687aab765f5f5d6c15cf4f245fba 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11182,6 +11182,7 @@ def rank(input): return out +@deprecated(since="2.0.0", update_to="paddle.numel") def size(input): """ **Size Layer** @@ -11189,11 +11190,14 @@ def size(input): Returns the number of elements for a tensor, which is a int64 Tensor with shape [1]. Args: - input (Variable): The input variable. + input (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64. Returns: - Variable: The number of elements for the input variable. + Tensor: The number of elements for the input Tensor. + Raises: + TypeError: ``input`` must be a Tensor and the data type of ``input`` must be one of bool, float16, float32, float64, int32, int64. + Examples: .. code-block:: python @@ -11204,6 +11208,11 @@ def size(input): rank = layers.size(input) # 300 """ + if in_dygraph_mode(): + return core.ops.size(x) + check_variable_and_dtype( + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + "size") helper = LayerHelper('size', **locals()) out = helper.create_variable_for_type_inference(dtype='int64') helper.append_op(type='size', inputs={'Input': input}, outputs={'Out': out}) diff --git a/python/paddle/fluid/tests/unittests/test_numel_op.py b/python/paddle/fluid/tests/unittests/test_numel_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8512bc99e7451c73e5513b834fb6aa448717c646 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_numel_op.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import functools +import paddle + + +class TestNumelOp(OpTest): + def setUp(self): + self.op_type = "size" + self.init() + x = np.random.random((self.shape)).astype("float64") + self.inputs = {'Input': x, } + self.outputs = {'Out': np.array([np.size(x)])} + + def test_check_output(self): + self.check_output() + + def init(self): + self.shape = (6, 56, 8, 55) + + +class TestNumelOp1(TestNumelOp): + def init(self): + self.shape = (11, 66) + + +class TestNumelOp2(TestNumelOp): + def init(self): + self.shape = (0, ) + + +class TestNumelOoAPI(unittest.TestCase): + def test_numel_static(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + shape1 = [2, 1, 4, 5] + shape2 = [1, 4, 5] + x_1 = paddle.data(shape=shape1, dtype='int32', name='x_1') + x_2 = paddle.data(shape=shape2, dtype='int32', name='x_2') + input_1 = np.random.random(shape1).astype("int32") + input_2 = np.random.random(shape2).astype("int32") + out_1 = paddle.numel(x_1) + out_2 = paddle.numel(x_2) + exe = paddle.static.Executor(place=paddle.CPUPlace()) + res_1, res_2 = exe.run(feed={ + "x_1": input_1, + "x_2": input_2, + }, + fetch_list=[out_1, out_2]) + assert (np.array_equal( + res_1, np.array([np.size(input_1)]).astype("int64"))) + assert (np.array_equal( + res_2, np.array([np.size(input_2)]).astype("int64"))) + + def test_numel_imperative(self): + paddle.disable_static(paddle.CPUPlace()) + input_1 = np.random.random([2, 1, 4, 5]).astype("int32") + input_2 = np.random.random([1, 4, 5]).astype("int32") + x_1 = paddle.to_variable(input_1) + x_2 = paddle.to_variable(input_2) + out_1 = paddle.numel(x_1) + out_2 = paddle.numel(x_2) + assert (np.array_equal(out_1.numpy().item(0), np.size(input_1))) + assert (np.array_equal(out_2.numpy().item(0), np.size(input_2))) + paddle.enable_static() + + def test_error(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + + def test_x_type(): + shape = [1, 4, 5] + input_1 = np.random.random(shape).astype("int32") + out_1 = paddle.numel(input_1) + + self.assertRaises(TypeError, test_x_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ba108beb0bd93f537065ae71fa76a41c9da23853..e97779ce23f7a97c1288136e0cf816408dc45f21 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -181,5 +181,7 @@ from .stat import mean #DEFINE_ALIAS from .stat import reduce_mean #DEFINE_ALIAS from .stat import std #DEFINE_ALIAS from .stat import var #DEFINE_ALIAS +from .stat import numel #DEFINE_ALIAS +# from .tensor import Tensor #DEFINE_ALIAS # from .tensor import LoDTensor #DEFINE_ALIAS # from .tensor import LoDTensorArray #DEFINE_ALIAS diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 7d22a0be5b0a9a2088f22535c6e2e56f7dc1f959..9ae9f5025257d72b7c214df482c7f24e56e3cb30 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -15,9 +15,10 @@ # TODO: define statistical functions of a tensor from ..fluid.layers import reduce_mean #DEFINE_ALIAS -__all__ = ['mean', 'reduce_mean', 'std', 'var'] +__all__ = ['mean', 'reduce_mean', 'std', 'var', 'numel'] import numpy as np +from ..fluid.framework import Variable from ..fluid.layer_helper import LayerHelper from ..fluid.framework import core, in_dygraph_mode from ..fluid import layers @@ -244,3 +245,41 @@ def std(input, axis=None, keepdim=False, unbiased=True, out=None, name=None): return out else: return tmp + + +def numel(x, name=None): + """ + Returns the number of elements for a tensor, which is a int64 Tensor with shape [1] in static mode + or a scalar value in imperative mode + + Args: + x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64. + + Returns: + Tensor: The number of elements for the input Tensor. + + Raises: + TypeError: ``x`` must be a Tensor and the data type of ``x`` must be one of bool, float16, float32, float64, int32, int64. + + + Examples: + .. code-block:: python + + import paddle + + paddle.disable_static() + x = paddle.full(shape=[4, 5, 7], fill_value=0, dtype='int32') + numel = paddle.numel(x) # 140 + + + """ + if in_dygraph_mode(): + return core.ops.size(x) + + if not isinstance(x, Variable): + raise TypeError("x must be a Tensor in numel") + helper = LayerHelper('numel', **locals()) + out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64) + helper.append_op(type='size', inputs={'Input': x}, outputs={'Out': out}) + return out