diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.h b/paddle/fluid/operators/elementwise/elementwise_pow_op.h old mode 100644 new mode 100755 index ff55d2f2040a17c32720df08c1ac0b00cc1d7a02..a910c326196bc61758c3be7db3b8ac5d85b0095c --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.h @@ -22,15 +22,20 @@ namespace operators { template struct PowFunctor { inline HOSTDEVICE T operator()(T a, T b) const { -#ifdef __CUDA_ARCH__ - // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and - // it will return a float number like 2.99... , which floor to 2 - // when cast to int by default and it is wrong. - // Use llrint to cast it to the nearest integer, which is 3. + // TODO(wujionghao): A potential speed improvement is supporting different + // types in C++. + // #ifdef __CUDA_ARCH__ + // // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and + // // it will return a float number like 2.99... , which floor to 2 + // // when cast to int by default and it is wrong. + // // Use llrint to cast it to the nearest integer, which is 3. + // if (std::is_integral::value) { + // return std::llrint(std::pow(a, b)); + // } + // #endif if (std::is_integral::value) { return std::llrint(std::pow(a, b)); } -#endif return std::pow(a, b); } }; diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py old mode 100644 new mode 100755 diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py old mode 100644 new mode 100755 diff --git a/python/paddle/fluid/tests/unittests/test_pow.py b/python/paddle/fluid/tests/unittests/test_pow.py new file mode 100755 index 0000000000000000000000000000000000000000..0764cb580e40d115d8703278380a9ced12e24201 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pow.py @@ -0,0 +1,239 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import paddle.tensor as tensor +import paddle.fluid as fluid +from paddle.static import Program, program_guard +import numpy as np +import unittest + +DYNAMIC = 1 +STATIC = 2 + + +def _run_power(mode, x, y): + # dynamic mode + if mode == DYNAMIC: + paddle.disable_static() + # y is scalar + if isinstance(y, (int, float)): + x_ = paddle.to_tensor(x) + y_ = y + res = paddle.pow(x_, y_) + return res.numpy() + # y is tensor + else: + x_ = paddle.to_tensor(x) + y_ = paddle.to_tensor(y) + res = paddle.pow(x_, y_) + return res.numpy() + # static mode + elif mode == STATIC: + paddle.enable_static() + # y is scalar + if isinstance(y, (int, float)): + with program_guard(Program(), Program()): + x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype) + y_ = y + res = paddle.pow(x_, y_) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + outs = exe.run(feed={'x': x}, fetch_list=[res]) + return outs[0] + # y is tensor + else: + with program_guard(Program(), Program()): + x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype) + y_ = paddle.static.data(name="y", shape=y.shape, dtype=y.dtype) + res = paddle.pow(x_, y_) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + outs = exe.run(feed={'x': x, 'y': y}, fetch_list=[res]) + return outs[0] + + +class TestPowerAPI(unittest.TestCase): + """TestPowerAPI.""" + + def test_power(self): + """test_power.""" + np.random.seed(7) + # test 1-d float tensor ** float scalar + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = np.random.rand() * 10 + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d float tensor ** int scalar + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = int(np.random.rand() * 10) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = int(np.random.rand() * 10) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d float tensor ** 1-d float tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.float64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d float tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d float tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = (np.random.rand(*dims) * 10).astype(np.float64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int32) + y = (np.random.rand(*dims) * 10).astype(np.int32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = (np.random.rand(*dims) * 10).astype(np.int32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int32) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float32) + y = (np.random.rand(*dims) * 10).astype(np.float32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.float32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.int32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float32) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test broadcast + dims = (np.random.randint(1, 10), np.random.randint(5, 10), + np.random.randint(5, 10)) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(dims[-1]) * 10).astype(np.float64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + +class TestPowerError(unittest.TestCase): + """TestPowerError.""" + + def test_errors(self): + """test_errors.""" + np.random.seed(7) + + # test dynamic computation graph: inputs must be broadcastable + dims = (np.random.randint(1, 10), np.random.randint(5, 10), + np.random.randint(5, 10)) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(dims[-1] + 1) * 10).astype(np.float64) + self.assertRaises(fluid.core.EnforceNotMet, _run_power, DYNAMIC, x, y) + self.assertRaises(fluid.core.EnforceNotMet, _run_power, STATIC, x, y) + + # test dynamic computation graph: inputs must be broadcastable + dims = (np.random.randint(1, 10), np.random.randint(5, 10), + np.random.randint(5, 10)) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(dims[-1] + 1) * 10).astype(np.int8) + self.assertRaises(TypeError, paddle.pow, x, y) + + # test 1-d float tensor ** int string + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = int(np.random.rand() * 10) + self.assertRaises(TypeError, paddle.pow, x, str(y)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py old mode 100644 new mode 100755 diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9dfb31a5ac25b2afc9fe52bfc8bab5ad277d80b8..e0317f4faceedd379a1d84dd2f3e31ee71e46469 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -17,6 +17,8 @@ math functions from __future__ import print_function from paddle.common_ops_import import * +from paddle.tensor import cast +import paddle from ..fluid import layers from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable from ..fluid.layer_helper import LayerHelper @@ -64,6 +66,7 @@ from ..fluid.layers import sums #DEFINE_ALIAS from ..fluid import layers import paddle + __all__ = [ 'abs', 'acos', @@ -86,8 +89,8 @@ __all__ = [ 'logsumexp', 'mul', 'multiplex', - 'prod', 'pow', + 'prod', 'reciprocal', 'reduce_max', 'reduce_min', @@ -147,64 +150,87 @@ _supported_float_dtype_ = [ VarDesc.VarType.FP64, ] -@templatedoc() -def pow(input, exponent, name=None): +def pow(x, y, name=None): """ - :alias_main: paddle.pow - :alias: paddle.pow,paddle.tensor.pow,paddle.tensor.math.pow + Compute the power of tensor elements. The equation is: - This is Pow Activation Operator. + .. math:: + out = x^{y} - :math:`out = input^{exponent}` + **Note**: + ``paddle.pow`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . - Args: - input(Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float32`` or ``float64``. - exponent(float32|Variable): A scalar with type ``float32`` or a ``Tensor`` with shape [1] and type ``float32``. - name(str, optional): The default value is None. Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name` . + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64. + y (Tensor): An N-D Tensor with type float32, float64, int32 or int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Returns: - Variable: A ``Tensor`` or ``LoDTensor``. The data type is same as ``input``. + N-D Tensor. A location into which the result is stored. Its dimension equals with $x$. Examples: - .. code-block:: python + .. code-block:: python import paddle - import paddle.fluid as fluid - - x = fluid.data(name="x", shape=[32,32], dtype="float32") + import numpy as np - # example 1: argument exponent is float - y_1 = paddle.pow(x, 2.0) - # y_1 is x^{2.0} + paddle.disable_static() + + # example 1: y is a float + x_data = np.array([1, 2, 3]) + y = 2 + x = paddle.to_tensor(x_data) + res = paddle.pow(x, y) + print(res.numpy()) # [1 4 9] + + # example 2: y is a Tensor + y = paddle.fill_constant(shape=[1], value=2, dtype='float32') + res = paddle.pow(x, y) + print(res.numpy()) # [1 4 9] - # example 2: argument exponent is Variable - exponent_tensor = fluid.layers.fill_constant([1], "float32", 3.0) - y_2 = paddle.pow(x, exponent_tensor) - # y_2 is x^{3.0} """ + # in dynamic graph mode if in_dygraph_mode(): - return core.ops.pow(input, "exponent", exponent) - - helper = LayerHelper('pow', **locals()) - inputs = {'X': input} - attrs = {} - if isinstance(exponent, Variable): - exponent.stop_gradient = True - inputs['FactorTensor'] = exponent + if isinstance(y, (int, float)): + return core.ops.pow(x, 'factor', y) + elif isinstance(y, (paddle.Tensor, Variable)): + + if x.dtype != y.dtype: + y = cast(y, dtype='float64') + x = cast(x, dtype='float64') + out_dygraph = _elementwise_op_in_dygraph( + x, y, axis=-1, act=None, op_name='elementwise_pow') + return out_dygraph + + return _elementwise_op_in_dygraph( + x, y, axis=-1, act=None, op_name='elementwise_pow') + else: + raise TypeError('y must be scalar or tensor type, but received: %s '% (y.dtype)) + # in static graph mode else: - attrs['factor'] = exponent - - out = helper.create_variable_for_type_inference(dtype=input.dtype) - check_dtype( - out.dtype, out.name, - convert_dtype(input.dtype), 'pow', - '(The out data type in pow must be the same with input data type.)') + if isinstance(y, (int, float)): + helper = LayerHelper('pow', **locals()) + inputs = {'X': x} + attrs = {'factor': y} + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='pow', inputs=inputs, outputs={'Out': out}, attrs=attrs) + return out + elif isinstance(y, (paddle.Tensor, Variable)): + # TODO A potential speed improvement is supporting different types in C++ and removing the cast ops here + helper = LayerHelper('elementwise_pow', **locals()) + if x.dtype != y.dtype: + y = cast(y, dtype='float64') + x = cast(x, dtype='float64') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + return _elementwise_op(LayerHelper('elementwise_pow', **locals())) + else: + raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y))) - helper.append_op( - type='pow', inputs=inputs, outputs={'Out': out}, attrs=attrs) - return out @dygraph_only @@ -227,6 +253,8 @@ def _elementwise_op(helper): x = helper.kwargs.get('x', None) y = helper.kwargs.get('y', None) + out = helper.kwargs.get('out', None) + assert x is not None, 'x cannot be None in {}'.format(original_op_type) assert y is not None, 'y cannot be None in {}'.format(original_op_type) check_variable_and_dtype( @@ -239,11 +267,12 @@ def _elementwise_op(helper): axis = helper.kwargs.get('axis', -1) use_mkldnn = helper.kwargs.get('use_mkldnn', False) name = helper.kwargs.get('name', None) - if name is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - else: - out = helper.create_variable( - name=name, dtype=x.dtype, persistable=False) + + if out is None: + if name is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable(name=name, dtype=x.dtype, persistable=False) helper.append_op( type=op_type,