From 4086f48ea1845f7bd88047a8b3757e00015a9714 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 30 Oct 2020 19:00:51 +0800 Subject: [PATCH] Check and fix tensor and scalar type promotion (#28299) * check and fix tensor and scalar type promotion * fix else branch error * fix scalar method error * fix test_math_op_path unittest * add future division for unittest * rm useless bin file --- python/paddle/fluid/dygraph/math_op_patch.py | 46 ++- python/paddle/fluid/layers/math_op_patch.py | 43 +- .../tests/unittests/test_math_op_patch.py | 10 +- ...st_tensor_scalar_type_promotion_dynamic.py | 318 +++++++++++++++ ...est_tensor_scalar_type_promotion_static.py | 369 ++++++++++++++++++ 5 files changed, 758 insertions(+), 28 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_static.py diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index d1781fdb010..203a5e0f86a 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -149,28 +149,46 @@ def monkey_patch_math_varbase(): reverse=False, scalar_method=None): def __impl__(self, other_var): - # tensor and ComplexVariable opetator + # 0. check tensor and ComplexVariable opetator if isinstance(other_var, ComplexVariable): # need import paddle in closure import paddle math_op = getattr(paddle.incubate.complex.tensor, op_type) return math_op(self, other_var) - # FIXME(zjl): elementwise_div between integers cannot be converted to scale, - # which may lose accuracy. This is a hot fix for release 1.6. - if scalar_method is not None and not ( - op_type == 'elementwise_div' and - self.dtype in _supported_int_dtype_): - if isinstance(other_var, float): - if self.dtype in _supported_int_dtype_: - assert other_var == int(other_var), \ - "float value {} cannot convert to integer".format(other_var) + # 1. scalar exists cases + # we need combine the tensor.dtype and scalar.dtype, cast correct object + if isinstance(other_var, float): + # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float + if self.dtype in _supported_int_dtype_: + self = astype(self, 'float32') + # here use `scale` replace `elementwise` to get better performance + # but only +, -, *, / can use this method + if scalar_method is not None: return scalar_method(self, other_var) - elif isinstance(other_var, int): - return scalar_method(self, float(other_var)) + elif isinstance(other_var, int): + # in all cases(+, -, *, /, **, //, %), we can cast it to float + # because the output tensor.dtype depend on the type of input tensor + other_var = float(other_var) + # division is a special case + # NOTE(chenweihang): because we cast tensor to float32 instead float64, + # the division result can only guarantee the numerical accuracy of 6 digits + # after the decimal point. The result of numpy calculation is of float64 type, + # so the calculation result here and the calculation result of numpy are + # different after 6 decimal point. If necessary, we can also use float64 here. + # torch's behavior here is consistent with ours + if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: + self = astype(self, 'float32') + # here use `scale` replace `elementwise` to get better performance + # but only +, -, *, / can use this method + if scalar_method is not None: + return scalar_method(self, other_var) + else: + # do nothing + pass + # 2. create varbase for scalar lhs_dtype = self.dtype - if not isinstance(other_var, core.VarBase): if reverse: other_var = create_tensor( @@ -179,6 +197,7 @@ def monkey_patch_math_varbase(): # add fill_op other_var = create_scalar(value=other_var, dtype=lhs_dtype) + # 3. unify right var type to left var rhs_dtype = other_var.dtype if lhs_dtype != rhs_dtype: other_var = astype(other_var, lhs_dtype) @@ -187,6 +206,7 @@ def monkey_patch_math_varbase(): self = other_var other_var = tmp + # 4. calculation axis = -1 math_op = getattr(core.ops, op_type) return math_op(self, other_var, 'axis', axis) diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 92b58a7e2ee..8f5fdf52d95 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -215,21 +215,39 @@ def monkey_patch_variable(): reverse=False, scalar_method=None): def __impl__(self, other_var): - # FIXME(zjl): elementwise_div between integers cannot be converted to scale, - # which may lose accuracy. This is a hot fix for release 1.6. - if scalar_method is not None and not ( - op_type == 'elementwise_div' and - self.dtype in _supported_int_dtype_): - if isinstance(other_var, float): - if self.dtype in _supported_int_dtype_: - assert other_var == int(other_var), \ - "float value {} cannot convert to integer".format(other_var) + # 1. scalar exists cases + # we need combine the tensor.dtype and scalar.dtype, cast correct object + if isinstance(other_var, float): + # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float + if self.dtype in _supported_int_dtype_: + self = astype(self, 'float32') + # here use `scale` replace `elementwise` to get better performance + # but only +, -, *, / can use this method + if scalar_method is not None: return scalar_method(self, other_var) - elif isinstance(other_var, int): - return scalar_method(self, float(other_var)) + elif isinstance(other_var, int): + # in all cases(+, -, *, /, **, //, %), we can cast it to float + # because the output tensor.dtype depend on the type of input tensor + other_var = float(other_var) + # division is a special case + # NOTE(chenweihang): because we cast tensor to float32 instead float64, + # the division result can only guarantee the numerical accuracy of 6 digits + # after the decimal point. The result of numpy calculation is of float64 type, + # so the calculation result here and the calculation result of numpy are + # different after 6 decimal point. If necessary, we can also use float64 here. + # torch's behavior here is consistent with ours + if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: + self = astype(self, 'float32') + # here use `scale` replace `elementwise` to get better performance + # but only +, -, *, / can use this method + if scalar_method is not None: + return scalar_method(self, other_var) + else: + # do nothing + pass + # 2. create variable for scalar lhs_dtype = safe_get_dtype(self) - if not isinstance(other_var, Variable): if reverse: has_batch_size = False @@ -251,6 +269,7 @@ def monkey_patch_variable(): other_var = create_scalar( current_block(self), value=other_var, dtype=lhs_dtype) + # 3. unify right var type to left var rhs_dtype = safe_get_dtype(other_var) if lhs_dtype != rhs_dtype: other_var = astype(other_var, lhs_dtype) diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index f6eff22d6ce..76e371b2167 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function +from __future__ import print_function, division import unittest from decorator_helper import prog_scope +import paddle import paddle.fluid as fluid import numpy class TestMathOpPatches(unittest.TestCase): + def setUp(self): + paddle.enable_static() + @prog_scope() def test_add_scalar(self): a = fluid.layers.data(name="a", shape=[1]) @@ -197,8 +201,8 @@ class TestMathOpPatches(unittest.TestCase): feed={"a": a_np}, fetch_list=[b]) - b_np_actual = (a_np / 7).astype('int64') - self.assertTrue(numpy.array_equal(b_np, b_np_actual)) + b_np_actual = (a_np / 7).astype('float32') + self.assertTrue(numpy.allclose(b_np, b_np_actual)) @prog_scope() def test_equal(self): diff --git a/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py b/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py new file mode 100644 index 00000000000..5f2dfbdd99e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_dynamic.py @@ -0,0 +1,318 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import unittest +import numpy as np + +import paddle + +# Support types are ref from `paddle.tensor.math` +# - Related paddle dtypes: +# - int type: int64, (no test here: uint8, int8, int16, int32) +# - float type: float32, (no test here: float64) +# - Python scalar dtypes: +# - int(64) +# - float(64) + + +class TestTensorScalarTypePromotionDynamic(unittest.TestCase): + def check_operation(self, a, b, c, op): + if op == '+': + c_rlt = a + b + elif op == '-': + c_rlt = a - b + elif op == '*': + c_rlt = a * b + elif op == '/': + c_rlt = a / b + elif op == '**': + c_rlt = a**b + elif op == '//': + c_rlt = a // b + elif op == '%': + c_rlt = a % b + else: + raise ValueError("Unsupported operation.") + + self.assertEqual(c_rlt.dtype, c.dtype) + self.assertTrue(np.array_equal(c_rlt.numpy(), c.numpy())) + + def test_tensor_add_scalar(self): + # tensor(int64) + scalar(int) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1 + c = paddle.full([2, 2, 2], 2, dtype="int64") + self.check_operation(a, b, c, '+') + + # tensor(float32) + scalar(int) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '+') + + # tensor(int64) + scalar(float, .0) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.0 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '+') + + # tensor(int64) + scalar(float, .5) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.5 + c = paddle.full([2, 2, 2], 2.5, dtype="float32") + self.check_operation(a, b, c, '+') + + # tensor(float32) + scalar(float) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1.5 + c = paddle.full([2, 2, 2], 2.5, dtype="float32") + self.check_operation(a, b, c, '+') + + def test_tensor_sub_scalar(self): + # tensor(int64) - scalar(int) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1 + c = paddle.zeros([2, 2, 2], dtype="int64") + self.check_operation(a, b, c, '-') + + # tensor(float32) - scalar(int) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1 + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # tensor(int64) - scalar(float, .0) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.0 + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # tensor(int64) - scalar(float, .5) + a = paddle.full([2, 2, 2], 2, dtype='int64') + b = 1.5 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + # tensor(float32) - scalar(float) + a = paddle.full([2, 2, 2], 2, dtype='float32') + b = 1.5 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + def test_scalar_sub_tensor(self): + # scalar(int) - tensor(int64) + a = 1 + b = paddle.ones([2, 2, 2], dtype='int64') + c = paddle.zeros([2, 2, 2], dtype="int64") + self.check_operation(a, b, c, '-') + + # scalar(int) - tensor(float32) + a = 1 + b = paddle.ones([2, 2, 2], dtype='float32') + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # scalar(float, .0) - tensor(int64) + a = 1.0 + b = paddle.ones([2, 2, 2], dtype='int64') + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # scalar(float, .5) - tensor(int64) + a = 1.5 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], -0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + # scalar(float) - tensor(float32) + a = 1.5 + b = paddle.full([2, 2, 2], 2, dtype='float32') + c = paddle.full([2, 2, 2], -0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + def test_tensor_mul_tensor(self): + # tensor(int64) * scalar(int) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1 + c = paddle.ones([2, 2, 2], dtype="int64") + self.check_operation(a, b, c, '*') + + # tensor(float32) * scalar(int) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1 + c = paddle.ones([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '*') + + # tensor(int64) * scalar(float, .0) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.0 + c = paddle.ones([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '*') + + # tensor(int64) * scalar(float, .5) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.5 + c = paddle.full([2, 2, 2], 1.5, dtype="float32") + self.check_operation(a, b, c, '*') + + # tensor(float32) * scalar(float) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1.5 + c = paddle.full([2, 2, 2], 1.5, dtype="float32") + self.check_operation(a, b, c, '*') + + def test_tensor_div_scalar(self): + # tensor(int64) / scalar(int) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 2 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(float32) / scalar(int) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 2 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(int64) / scalar(float, .0) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 2.0 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(int64) / scalar(float, .5) + a = paddle.ones([2, 2, 2], dtype='int64') + b = 0.5 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(float32) / scalar(float) + a = paddle.ones([2, 2, 2], dtype='float32') + b = 0.5 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + def test_scalar_div_tensor(self): + # scalar(int) / tensor(int64) + a = 1 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # scalar(int) / tensor(float32) + a = 1 + b = paddle.full([2, 2, 2], 0.5, dtype='float32') + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + # scalar(float) / tensor(int64) + a = 1.0 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # scalar(float) / tensor(float32) + a = 1.0 + b = paddle.full([2, 2, 2], 0.5, dtype='float32') + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + def test_tensor_pow_scalar(self): + # tensor(int64) ** scalar(int) + a = paddle.full([2, 2, 2], 2, dtype='int64') + b = 3 + c = paddle.full([2, 2, 2], 8, dtype="int64") + self.check_operation(a, b, c, '**') + + # tensor(int64) ** scalar(float) + a = paddle.full([2, 2, 2], 2, dtype='int64') + b = 3.0 + c = paddle.full([2, 2, 2], 8, dtype="float32") + self.check_operation(a, b, c, '**') + + # tensor(float32) ** scalar(int) + a = paddle.full([2, 2, 2], 2, dtype='float32') + b = 3 + c = paddle.full([2, 2, 2], 8, dtype="float32") + self.check_operation(a, b, c, '**') + + # tensor(float32) ** scalar(float) + a = paddle.full([2, 2, 2], 2, dtype='float32') + b = 3.0 + c = paddle.full([2, 2, 2], 8, dtype="float32") + self.check_operation(a, b, c, '**') + + def test_scalar_pow_tensor(self): + # scalar(int) ** tensor(int64) + a = 3 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 9, dtype="int64") + self.check_operation(a, b, c, '**') + + # scalar(float) ** tensor(int64) + a = 3.0 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 9, dtype="float32") + self.check_operation(a, b, c, '**') + + # scalar(int) ** tensor(float32) + a = 3 + b = paddle.full([2, 2, 2], 2, dtype='float32') + c = paddle.full([2, 2, 2], 9, dtype="float32") + self.check_operation(a, b, c, '**') + + # tensor(float32) ** scalar(float) + a = 3.0 + b = paddle.full([2, 2, 2], 2, dtype='float32') + c = paddle.full([2, 2, 2], 9, dtype="float32") + self.check_operation(a, b, c, '**') + + ## TODO: floordiv op kernel doesn't support float + def test_tensor_floordiv_scalar(self): + # tensor(int64) // scalar(int) + a = paddle.full([2, 2, 2], 3, dtype='int64') + b = 2 + c = paddle.full([2, 2, 2], 1, dtype="int64") + self.check_operation(a, b, c, '//') + + def test_tensor_mod_scalar(self): + # tensor(int64) % scalar(int) + a = paddle.full([2, 2, 2], 3, dtype='int64') + b = 2 + c = paddle.full([2, 2, 2], 1, dtype="int64") + self.check_operation(a, b, c, '%') + + # tensor(int64) % scalar(float) + a = paddle.full([2, 2, 2], 3, dtype='int64') + b = 2.0 + c = paddle.full([2, 2, 2], 1, dtype="float32") + self.check_operation(a, b, c, '%') + + # tensor(float32) % scalar(int) + a = paddle.full([2, 2, 2], 3, dtype='float32') + b = 2 + c = paddle.full([2, 2, 2], 1, dtype="float32") + self.check_operation(a, b, c, '%') + + # tensor(float32) % scalar(float) + a = paddle.full([2, 2, 2], 3, dtype='float32') + b = 2.0 + c = paddle.full([2, 2, 2], 1, dtype="float32") + self.check_operation(a, b, c, '%') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_static.py b/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_static.py new file mode 100644 index 00000000000..d697666e12d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_scalar_type_promotion_static.py @@ -0,0 +1,369 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import unittest +import numpy as np + +import paddle +from paddle.static import program_guard +from paddle.static import Program + +# Support types are ref from `paddle.tensor.math` +# - Related paddle dtypes: +# - int type: int64, (no test here: uint8, int8, int16, int32) +# - float type: float32, (no test here: float64) +# - Python scalar dtypes: +# - int(64) +# - float(64) + + +class TestTensorScalarTypePromotionStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def check_operation(self, a, b, c, op): + exe = paddle.static.Executor() + + if op == '+': + c_rlt = a + b + elif op == '-': + c_rlt = a - b + elif op == '*': + c_rlt = a * b + elif op == '/': + c_rlt = a / b + elif op == '**': + c_rlt = a**b + elif op == '//': + c_rlt = a // b + elif op == '%': + c_rlt = a % b + else: + raise ValueError("Unsupported operation.") + + rlt = exe.run(fetch_list=[c_rlt.name, c.name]) + + self.assertEqual(rlt[0].dtype, rlt[1].dtype) + self.assertTrue(np.array_equal(rlt[0], rlt[1])) + + def test_tensor_add_scalar(self): + # tensor(int64) + scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1 + c = paddle.full([2, 2, 2], 2, dtype="int64") + self.check_operation(a, b, c, '+') + + # tensor(float32) + scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '+') + + # tensor(int64) + scalar(float, .0) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.0 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '+') + + # tensor(int64) + scalar(float, .5) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.5 + c = paddle.full([2, 2, 2], 2.5, dtype="float32") + self.check_operation(a, b, c, '+') + + # tensor(float32) + scalar(float) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1.5 + c = paddle.full([2, 2, 2], 2.5, dtype="float32") + self.check_operation(a, b, c, '+') + + def test_tensor_sub_scalar(self): + # tensor(int64) - scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1 + c = paddle.zeros([2, 2, 2], dtype="int64") + self.check_operation(a, b, c, '-') + + # tensor(float32) - scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1 + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # tensor(int64) - scalar(float, .0) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.0 + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # tensor(int64) - scalar(float, .5) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 2, dtype='int64') + b = 1.5 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + # tensor(float32) - scalar(float) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 2, dtype='float32') + b = 1.5 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + def test_scalar_sub_tensor(self): + # scalar(int) - tensor(int64) + with program_guard(Program()): + a = 1 + b = paddle.ones([2, 2, 2], dtype='int64') + c = paddle.zeros([2, 2, 2], dtype="int64") + self.check_operation(a, b, c, '-') + + # scalar(int) - tensor(float32) + with program_guard(Program()): + a = 1 + b = paddle.ones([2, 2, 2], dtype='float32') + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # scalar(float, .0) - tensor(int64) + with program_guard(Program()): + a = 1.0 + b = paddle.ones([2, 2, 2], dtype='int64') + c = paddle.zeros([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '-') + + # scalar(float, .5) - tensor(int64) + with program_guard(Program()): + a = 1.5 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], -0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + # scalar(float) - tensor(float32) + with program_guard(Program()): + a = 1.5 + b = paddle.full([2, 2, 2], 2, dtype='float32') + c = paddle.full([2, 2, 2], -0.5, dtype="float32") + self.check_operation(a, b, c, '-') + + def test_tensor_mul_tensor(self): + # tensor(int64) * scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1 + c = paddle.ones([2, 2, 2], dtype="int64") + self.check_operation(a, b, c, '*') + + # tensor(float32) * scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1 + c = paddle.ones([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '*') + + # tensor(int64) * scalar(float, .0) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.0 + c = paddle.ones([2, 2, 2], dtype="float32") + self.check_operation(a, b, c, '*') + + # tensor(int64) * scalar(float, .5) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 1.5 + c = paddle.full([2, 2, 2], 1.5, dtype="float32") + self.check_operation(a, b, c, '*') + + # tensor(float32) * scalar(float) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 1.5 + c = paddle.full([2, 2, 2], 1.5, dtype="float32") + self.check_operation(a, b, c, '*') + + def test_tensor_div_scalar(self): + # tensor(int64) / scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 2 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(float32) / scalar(int) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 2 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(int64) / scalar(float, .0) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 2.0 + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(int64) / scalar(float, .5) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='int64') + b = 0.5 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + # tensor(float32) / scalar(float) + with program_guard(Program()): + a = paddle.ones([2, 2, 2], dtype='float32') + b = 0.5 + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + def test_scalar_div_tensor(self): + # scalar(int) / tensor(int64) + with program_guard(Program()): + a = 1 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # scalar(int) / tensor(float32) + with program_guard(Program()): + a = 1 + b = paddle.full([2, 2, 2], 0.5, dtype='float32') + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + # scalar(float) / tensor(int64) + with program_guard(Program()): + a = 1.0 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 0.5, dtype="float32") + self.check_operation(a, b, c, '/') + + # scalar(float) / tensor(float32) + with program_guard(Program()): + a = 1.0 + b = paddle.full([2, 2, 2], 0.5, dtype='float32') + c = paddle.full([2, 2, 2], 2, dtype="float32") + self.check_operation(a, b, c, '/') + + def test_tensor_pow_scalar(self): + # tensor(int64) ** scalar(int) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 2, dtype='int64') + b = 3 + c = paddle.full([2, 2, 2], 8, dtype="int64") + self.check_operation(a, b, c, '**') + + # tensor(int64) ** scalar(float) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 2, dtype='int64') + b = 3.0 + c = paddle.full([2, 2, 2], 8, dtype="float32") + self.check_operation(a, b, c, '**') + + # tensor(float32) ** scalar(int) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 2, dtype='float32') + b = 3 + c = paddle.full([2, 2, 2], 8, dtype="float32") + self.check_operation(a, b, c, '**') + + # tensor(float32) ** scalar(float) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 2, dtype='float32') + b = 3.0 + c = paddle.full([2, 2, 2], 8, dtype="float32") + self.check_operation(a, b, c, '**') + + def test_scalar_pow_tensor(self): + # scalar(int) ** tensor(int64) + with program_guard(Program()): + a = 3 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 9, dtype="int64") + self.check_operation(a, b, c, '**') + + # scalar(float) ** tensor(int64) + with program_guard(Program()): + a = 3.0 + b = paddle.full([2, 2, 2], 2, dtype='int64') + c = paddle.full([2, 2, 2], 9, dtype="float32") + self.check_operation(a, b, c, '**') + + # scalar(int) ** tensor(float32) + with program_guard(Program()): + a = 3 + b = paddle.full([2, 2, 2], 2, dtype='float32') + c = paddle.full([2, 2, 2], 9, dtype="float32") + self.check_operation(a, b, c, '**') + + # tensor(float32) ** scalar(float) + with program_guard(Program()): + a = 3.0 + b = paddle.full([2, 2, 2], 2, dtype='float32') + c = paddle.full([2, 2, 2], 9, dtype="float32") + self.check_operation(a, b, c, '**') + + # ## TODO: floordiv op kernel doesn't support float + def test_tensor_floordiv_scalar(self): + # tensor(int64) // scalar(int) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 3, dtype='int64') + b = 2 + c = paddle.full([2, 2, 2], 1, dtype="int64") + self.check_operation(a, b, c, '//') + + def test_tensor_mod_scalar(self): + # tensor(int64) % scalar(int) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 3, dtype='int64') + b = 2 + c = paddle.full([2, 2, 2], 1, dtype="int64") + self.check_operation(a, b, c, '%') + + # tensor(int64) % scalar(float) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 3, dtype='int64') + b = 2.0 + c = paddle.full([2, 2, 2], 1, dtype="float32") + self.check_operation(a, b, c, '%') + + # tensor(float32) % scalar(int) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 3, dtype='float32') + b = 2 + c = paddle.full([2, 2, 2], 1, dtype="float32") + self.check_operation(a, b, c, '%') + + # tensor(float32) % scalar(float) + with program_guard(Program()): + a = paddle.full([2, 2, 2], 3, dtype='float32') + b = 2.0 + c = paddle.full([2, 2, 2], 1, dtype="float32") + self.check_operation(a, b, c, '%') + + +if __name__ == '__main__': + unittest.main() -- GitLab