未验证 提交 4086f48e 编写于 作者: C Chen Weihang 提交者: GitHub

Check and fix tensor and scalar type promotion (#28299)

* check and fix tensor and scalar type promotion

* fix else branch error

* fix scalar method error

* fix test_math_op_path unittest

* add future division for unittest

* rm useless bin file
上级 fb1e0c93
...@@ -149,28 +149,46 @@ def monkey_patch_math_varbase(): ...@@ -149,28 +149,46 @@ def monkey_patch_math_varbase():
reverse=False, reverse=False,
scalar_method=None): scalar_method=None):
def __impl__(self, other_var): def __impl__(self, other_var):
# tensor and ComplexVariable opetator # 0. check tensor and ComplexVariable opetator
if isinstance(other_var, ComplexVariable): if isinstance(other_var, ComplexVariable):
# need import paddle in closure # need import paddle in closure
import paddle import paddle
math_op = getattr(paddle.incubate.complex.tensor, op_type) math_op = getattr(paddle.incubate.complex.tensor, op_type)
return math_op(self, other_var) return math_op(self, other_var)
# FIXME(zjl): elementwise_div between integers cannot be converted to scale, # 1. scalar exists cases
# which may lose accuracy. This is a hot fix for release 1.6. # we need combine the tensor.dtype and scalar.dtype, cast correct object
if scalar_method is not None and not ( if isinstance(other_var, float):
op_type == 'elementwise_div' and # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
self.dtype in _supported_int_dtype_): if self.dtype in _supported_int_dtype_:
if isinstance(other_var, float): self = astype(self, 'float32')
if self.dtype in _supported_int_dtype_: # here use `scale` replace `elementwise` to get better performance
assert other_var == int(other_var), \ # but only +, -, *, / can use this method
"float value {} cannot convert to integer".format(other_var) if scalar_method is not None:
return scalar_method(self, other_var) return scalar_method(self, other_var)
elif isinstance(other_var, int): elif isinstance(other_var, int):
return scalar_method(self, float(other_var)) # in all cases(+, -, *, /, **, //, %), we can cast it to float
# because the output tensor.dtype depend on the type of input tensor
other_var = float(other_var)
# division is a special case
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
# the division result can only guarantee the numerical accuracy of 6 digits
# after the decimal point. The result of numpy calculation is of float64 type,
# so the calculation result here and the calculation result of numpy are
# different after 6 decimal point. If necessary, we can also use float64 here.
# torch's behavior here is consistent with ours
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32')
# here use `scale` replace `elementwise` to get better performance
# but only +, -, *, / can use this method
if scalar_method is not None:
return scalar_method(self, other_var)
else:
# do nothing
pass
# 2. create varbase for scalar
lhs_dtype = self.dtype lhs_dtype = self.dtype
if not isinstance(other_var, core.VarBase): if not isinstance(other_var, core.VarBase):
if reverse: if reverse:
other_var = create_tensor( other_var = create_tensor(
...@@ -179,6 +197,7 @@ def monkey_patch_math_varbase(): ...@@ -179,6 +197,7 @@ def monkey_patch_math_varbase():
# add fill_op # add fill_op
other_var = create_scalar(value=other_var, dtype=lhs_dtype) other_var = create_scalar(value=other_var, dtype=lhs_dtype)
# 3. unify right var type to left var
rhs_dtype = other_var.dtype rhs_dtype = other_var.dtype
if lhs_dtype != rhs_dtype: if lhs_dtype != rhs_dtype:
other_var = astype(other_var, lhs_dtype) other_var = astype(other_var, lhs_dtype)
...@@ -187,6 +206,7 @@ def monkey_patch_math_varbase(): ...@@ -187,6 +206,7 @@ def monkey_patch_math_varbase():
self = other_var self = other_var
other_var = tmp other_var = tmp
# 4. calculation
axis = -1 axis = -1
math_op = getattr(core.ops, op_type) math_op = getattr(core.ops, op_type)
return math_op(self, other_var, 'axis', axis) return math_op(self, other_var, 'axis', axis)
......
...@@ -215,21 +215,39 @@ def monkey_patch_variable(): ...@@ -215,21 +215,39 @@ def monkey_patch_variable():
reverse=False, reverse=False,
scalar_method=None): scalar_method=None):
def __impl__(self, other_var): def __impl__(self, other_var):
# FIXME(zjl): elementwise_div between integers cannot be converted to scale, # 1. scalar exists cases
# which may lose accuracy. This is a hot fix for release 1.6. # we need combine the tensor.dtype and scalar.dtype, cast correct object
if scalar_method is not None and not ( if isinstance(other_var, float):
op_type == 'elementwise_div' and # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
self.dtype in _supported_int_dtype_): if self.dtype in _supported_int_dtype_:
if isinstance(other_var, float): self = astype(self, 'float32')
if self.dtype in _supported_int_dtype_: # here use `scale` replace `elementwise` to get better performance
assert other_var == int(other_var), \ # but only +, -, *, / can use this method
"float value {} cannot convert to integer".format(other_var) if scalar_method is not None:
return scalar_method(self, other_var) return scalar_method(self, other_var)
elif isinstance(other_var, int): elif isinstance(other_var, int):
return scalar_method(self, float(other_var)) # in all cases(+, -, *, /, **, //, %), we can cast it to float
# because the output tensor.dtype depend on the type of input tensor
other_var = float(other_var)
# division is a special case
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
# the division result can only guarantee the numerical accuracy of 6 digits
# after the decimal point. The result of numpy calculation is of float64 type,
# so the calculation result here and the calculation result of numpy are
# different after 6 decimal point. If necessary, we can also use float64 here.
# torch's behavior here is consistent with ours
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32')
# here use `scale` replace `elementwise` to get better performance
# but only +, -, *, / can use this method
if scalar_method is not None:
return scalar_method(self, other_var)
else:
# do nothing
pass
# 2. create variable for scalar
lhs_dtype = safe_get_dtype(self) lhs_dtype = safe_get_dtype(self)
if not isinstance(other_var, Variable): if not isinstance(other_var, Variable):
if reverse: if reverse:
has_batch_size = False has_batch_size = False
...@@ -251,6 +269,7 @@ def monkey_patch_variable(): ...@@ -251,6 +269,7 @@ def monkey_patch_variable():
other_var = create_scalar( other_var = create_scalar(
current_block(self), value=other_var, dtype=lhs_dtype) current_block(self), value=other_var, dtype=lhs_dtype)
# 3. unify right var type to left var
rhs_dtype = safe_get_dtype(other_var) rhs_dtype = safe_get_dtype(other_var)
if lhs_dtype != rhs_dtype: if lhs_dtype != rhs_dtype:
other_var = astype(other_var, lhs_dtype) other_var = astype(other_var, lhs_dtype)
......
...@@ -12,15 +12,19 @@ ...@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function, division
import unittest import unittest
from decorator_helper import prog_scope from decorator_helper import prog_scope
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy import numpy
class TestMathOpPatches(unittest.TestCase): class TestMathOpPatches(unittest.TestCase):
def setUp(self):
paddle.enable_static()
@prog_scope() @prog_scope()
def test_add_scalar(self): def test_add_scalar(self):
a = fluid.layers.data(name="a", shape=[1]) a = fluid.layers.data(name="a", shape=[1])
...@@ -197,8 +201,8 @@ class TestMathOpPatches(unittest.TestCase): ...@@ -197,8 +201,8 @@ class TestMathOpPatches(unittest.TestCase):
feed={"a": a_np}, feed={"a": a_np},
fetch_list=[b]) fetch_list=[b])
b_np_actual = (a_np / 7).astype('int64') b_np_actual = (a_np / 7).astype('float32')
self.assertTrue(numpy.array_equal(b_np, b_np_actual)) self.assertTrue(numpy.allclose(b_np, b_np_actual))
@prog_scope() @prog_scope()
def test_equal(self): def test_equal(self):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import unittest
import numpy as np
import paddle
# Support types are ref from `paddle.tensor.math`
# - Related paddle dtypes:
# - int type: int64, (no test here: uint8, int8, int16, int32)
# - float type: float32, (no test here: float64)
# - Python scalar dtypes:
# - int(64)
# - float(64)
class TestTensorScalarTypePromotionDynamic(unittest.TestCase):
def check_operation(self, a, b, c, op):
if op == '+':
c_rlt = a + b
elif op == '-':
c_rlt = a - b
elif op == '*':
c_rlt = a * b
elif op == '/':
c_rlt = a / b
elif op == '**':
c_rlt = a**b
elif op == '//':
c_rlt = a // b
elif op == '%':
c_rlt = a % b
else:
raise ValueError("Unsupported operation.")
self.assertEqual(c_rlt.dtype, c.dtype)
self.assertTrue(np.array_equal(c_rlt.numpy(), c.numpy()))
def test_tensor_add_scalar(self):
# tensor(int64) + scalar(int)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1
c = paddle.full([2, 2, 2], 2, dtype="int64")
self.check_operation(a, b, c, '+')
# tensor(float32) + scalar(int)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '+')
# tensor(int64) + scalar(float, .0)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.0
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '+')
# tensor(int64) + scalar(float, .5)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.5
c = paddle.full([2, 2, 2], 2.5, dtype="float32")
self.check_operation(a, b, c, '+')
# tensor(float32) + scalar(float)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1.5
c = paddle.full([2, 2, 2], 2.5, dtype="float32")
self.check_operation(a, b, c, '+')
def test_tensor_sub_scalar(self):
# tensor(int64) - scalar(int)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1
c = paddle.zeros([2, 2, 2], dtype="int64")
self.check_operation(a, b, c, '-')
# tensor(float32) - scalar(int)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# tensor(int64) - scalar(float, .0)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.0
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# tensor(int64) - scalar(float, .5)
a = paddle.full([2, 2, 2], 2, dtype='int64')
b = 1.5
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '-')
# tensor(float32) - scalar(float)
a = paddle.full([2, 2, 2], 2, dtype='float32')
b = 1.5
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '-')
def test_scalar_sub_tensor(self):
# scalar(int) - tensor(int64)
a = 1
b = paddle.ones([2, 2, 2], dtype='int64')
c = paddle.zeros([2, 2, 2], dtype="int64")
self.check_operation(a, b, c, '-')
# scalar(int) - tensor(float32)
a = 1
b = paddle.ones([2, 2, 2], dtype='float32')
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# scalar(float, .0) - tensor(int64)
a = 1.0
b = paddle.ones([2, 2, 2], dtype='int64')
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# scalar(float, .5) - tensor(int64)
a = 1.5
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], -0.5, dtype="float32")
self.check_operation(a, b, c, '-')
# scalar(float) - tensor(float32)
a = 1.5
b = paddle.full([2, 2, 2], 2, dtype='float32')
c = paddle.full([2, 2, 2], -0.5, dtype="float32")
self.check_operation(a, b, c, '-')
def test_tensor_mul_tensor(self):
# tensor(int64) * scalar(int)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1
c = paddle.ones([2, 2, 2], dtype="int64")
self.check_operation(a, b, c, '*')
# tensor(float32) * scalar(int)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1
c = paddle.ones([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '*')
# tensor(int64) * scalar(float, .0)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.0
c = paddle.ones([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '*')
# tensor(int64) * scalar(float, .5)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.5
c = paddle.full([2, 2, 2], 1.5, dtype="float32")
self.check_operation(a, b, c, '*')
# tensor(float32) * scalar(float)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1.5
c = paddle.full([2, 2, 2], 1.5, dtype="float32")
self.check_operation(a, b, c, '*')
def test_tensor_div_scalar(self):
# tensor(int64) / scalar(int)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 2
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(float32) / scalar(int)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 2
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(int64) / scalar(float, .0)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 2.0
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(int64) / scalar(float, .5)
a = paddle.ones([2, 2, 2], dtype='int64')
b = 0.5
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(float32) / scalar(float)
a = paddle.ones([2, 2, 2], dtype='float32')
b = 0.5
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
def test_scalar_div_tensor(self):
# scalar(int) / tensor(int64)
a = 1
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# scalar(int) / tensor(float32)
a = 1
b = paddle.full([2, 2, 2], 0.5, dtype='float32')
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
# scalar(float) / tensor(int64)
a = 1.0
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# scalar(float) / tensor(float32)
a = 1.0
b = paddle.full([2, 2, 2], 0.5, dtype='float32')
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
def test_tensor_pow_scalar(self):
# tensor(int64) ** scalar(int)
a = paddle.full([2, 2, 2], 2, dtype='int64')
b = 3
c = paddle.full([2, 2, 2], 8, dtype="int64")
self.check_operation(a, b, c, '**')
# tensor(int64) ** scalar(float)
a = paddle.full([2, 2, 2], 2, dtype='int64')
b = 3.0
c = paddle.full([2, 2, 2], 8, dtype="float32")
self.check_operation(a, b, c, '**')
# tensor(float32) ** scalar(int)
a = paddle.full([2, 2, 2], 2, dtype='float32')
b = 3
c = paddle.full([2, 2, 2], 8, dtype="float32")
self.check_operation(a, b, c, '**')
# tensor(float32) ** scalar(float)
a = paddle.full([2, 2, 2], 2, dtype='float32')
b = 3.0
c = paddle.full([2, 2, 2], 8, dtype="float32")
self.check_operation(a, b, c, '**')
def test_scalar_pow_tensor(self):
# scalar(int) ** tensor(int64)
a = 3
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 9, dtype="int64")
self.check_operation(a, b, c, '**')
# scalar(float) ** tensor(int64)
a = 3.0
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 9, dtype="float32")
self.check_operation(a, b, c, '**')
# scalar(int) ** tensor(float32)
a = 3
b = paddle.full([2, 2, 2], 2, dtype='float32')
c = paddle.full([2, 2, 2], 9, dtype="float32")
self.check_operation(a, b, c, '**')
# tensor(float32) ** scalar(float)
a = 3.0
b = paddle.full([2, 2, 2], 2, dtype='float32')
c = paddle.full([2, 2, 2], 9, dtype="float32")
self.check_operation(a, b, c, '**')
## TODO: floordiv op kernel doesn't support float
def test_tensor_floordiv_scalar(self):
# tensor(int64) // scalar(int)
a = paddle.full([2, 2, 2], 3, dtype='int64')
b = 2
c = paddle.full([2, 2, 2], 1, dtype="int64")
self.check_operation(a, b, c, '//')
def test_tensor_mod_scalar(self):
# tensor(int64) % scalar(int)
a = paddle.full([2, 2, 2], 3, dtype='int64')
b = 2
c = paddle.full([2, 2, 2], 1, dtype="int64")
self.check_operation(a, b, c, '%')
# tensor(int64) % scalar(float)
a = paddle.full([2, 2, 2], 3, dtype='int64')
b = 2.0
c = paddle.full([2, 2, 2], 1, dtype="float32")
self.check_operation(a, b, c, '%')
# tensor(float32) % scalar(int)
a = paddle.full([2, 2, 2], 3, dtype='float32')
b = 2
c = paddle.full([2, 2, 2], 1, dtype="float32")
self.check_operation(a, b, c, '%')
# tensor(float32) % scalar(float)
a = paddle.full([2, 2, 2], 3, dtype='float32')
b = 2.0
c = paddle.full([2, 2, 2], 1, dtype="float32")
self.check_operation(a, b, c, '%')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import unittest
import numpy as np
import paddle
from paddle.static import program_guard
from paddle.static import Program
# Support types are ref from `paddle.tensor.math`
# - Related paddle dtypes:
# - int type: int64, (no test here: uint8, int8, int16, int32)
# - float type: float32, (no test here: float64)
# - Python scalar dtypes:
# - int(64)
# - float(64)
class TestTensorScalarTypePromotionStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def check_operation(self, a, b, c, op):
exe = paddle.static.Executor()
if op == '+':
c_rlt = a + b
elif op == '-':
c_rlt = a - b
elif op == '*':
c_rlt = a * b
elif op == '/':
c_rlt = a / b
elif op == '**':
c_rlt = a**b
elif op == '//':
c_rlt = a // b
elif op == '%':
c_rlt = a % b
else:
raise ValueError("Unsupported operation.")
rlt = exe.run(fetch_list=[c_rlt.name, c.name])
self.assertEqual(rlt[0].dtype, rlt[1].dtype)
self.assertTrue(np.array_equal(rlt[0], rlt[1]))
def test_tensor_add_scalar(self):
# tensor(int64) + scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1
c = paddle.full([2, 2, 2], 2, dtype="int64")
self.check_operation(a, b, c, '+')
# tensor(float32) + scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '+')
# tensor(int64) + scalar(float, .0)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.0
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '+')
# tensor(int64) + scalar(float, .5)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.5
c = paddle.full([2, 2, 2], 2.5, dtype="float32")
self.check_operation(a, b, c, '+')
# tensor(float32) + scalar(float)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1.5
c = paddle.full([2, 2, 2], 2.5, dtype="float32")
self.check_operation(a, b, c, '+')
def test_tensor_sub_scalar(self):
# tensor(int64) - scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1
c = paddle.zeros([2, 2, 2], dtype="int64")
self.check_operation(a, b, c, '-')
# tensor(float32) - scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# tensor(int64) - scalar(float, .0)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.0
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# tensor(int64) - scalar(float, .5)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 2, dtype='int64')
b = 1.5
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '-')
# tensor(float32) - scalar(float)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 2, dtype='float32')
b = 1.5
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '-')
def test_scalar_sub_tensor(self):
# scalar(int) - tensor(int64)
with program_guard(Program()):
a = 1
b = paddle.ones([2, 2, 2], dtype='int64')
c = paddle.zeros([2, 2, 2], dtype="int64")
self.check_operation(a, b, c, '-')
# scalar(int) - tensor(float32)
with program_guard(Program()):
a = 1
b = paddle.ones([2, 2, 2], dtype='float32')
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# scalar(float, .0) - tensor(int64)
with program_guard(Program()):
a = 1.0
b = paddle.ones([2, 2, 2], dtype='int64')
c = paddle.zeros([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '-')
# scalar(float, .5) - tensor(int64)
with program_guard(Program()):
a = 1.5
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], -0.5, dtype="float32")
self.check_operation(a, b, c, '-')
# scalar(float) - tensor(float32)
with program_guard(Program()):
a = 1.5
b = paddle.full([2, 2, 2], 2, dtype='float32')
c = paddle.full([2, 2, 2], -0.5, dtype="float32")
self.check_operation(a, b, c, '-')
def test_tensor_mul_tensor(self):
# tensor(int64) * scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1
c = paddle.ones([2, 2, 2], dtype="int64")
self.check_operation(a, b, c, '*')
# tensor(float32) * scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1
c = paddle.ones([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '*')
# tensor(int64) * scalar(float, .0)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.0
c = paddle.ones([2, 2, 2], dtype="float32")
self.check_operation(a, b, c, '*')
# tensor(int64) * scalar(float, .5)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 1.5
c = paddle.full([2, 2, 2], 1.5, dtype="float32")
self.check_operation(a, b, c, '*')
# tensor(float32) * scalar(float)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 1.5
c = paddle.full([2, 2, 2], 1.5, dtype="float32")
self.check_operation(a, b, c, '*')
def test_tensor_div_scalar(self):
# tensor(int64) / scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 2
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(float32) / scalar(int)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 2
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(int64) / scalar(float, .0)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 2.0
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(int64) / scalar(float, .5)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64')
b = 0.5
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(float32) / scalar(float)
with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='float32')
b = 0.5
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
def test_scalar_div_tensor(self):
# scalar(int) / tensor(int64)
with program_guard(Program()):
a = 1
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# scalar(int) / tensor(float32)
with program_guard(Program()):
a = 1
b = paddle.full([2, 2, 2], 0.5, dtype='float32')
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
# scalar(float) / tensor(int64)
with program_guard(Program()):
a = 1.0
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/')
# scalar(float) / tensor(float32)
with program_guard(Program()):
a = 1.0
b = paddle.full([2, 2, 2], 0.5, dtype='float32')
c = paddle.full([2, 2, 2], 2, dtype="float32")
self.check_operation(a, b, c, '/')
def test_tensor_pow_scalar(self):
# tensor(int64) ** scalar(int)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 2, dtype='int64')
b = 3
c = paddle.full([2, 2, 2], 8, dtype="int64")
self.check_operation(a, b, c, '**')
# tensor(int64) ** scalar(float)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 2, dtype='int64')
b = 3.0
c = paddle.full([2, 2, 2], 8, dtype="float32")
self.check_operation(a, b, c, '**')
# tensor(float32) ** scalar(int)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 2, dtype='float32')
b = 3
c = paddle.full([2, 2, 2], 8, dtype="float32")
self.check_operation(a, b, c, '**')
# tensor(float32) ** scalar(float)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 2, dtype='float32')
b = 3.0
c = paddle.full([2, 2, 2], 8, dtype="float32")
self.check_operation(a, b, c, '**')
def test_scalar_pow_tensor(self):
# scalar(int) ** tensor(int64)
with program_guard(Program()):
a = 3
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 9, dtype="int64")
self.check_operation(a, b, c, '**')
# scalar(float) ** tensor(int64)
with program_guard(Program()):
a = 3.0
b = paddle.full([2, 2, 2], 2, dtype='int64')
c = paddle.full([2, 2, 2], 9, dtype="float32")
self.check_operation(a, b, c, '**')
# scalar(int) ** tensor(float32)
with program_guard(Program()):
a = 3
b = paddle.full([2, 2, 2], 2, dtype='float32')
c = paddle.full([2, 2, 2], 9, dtype="float32")
self.check_operation(a, b, c, '**')
# tensor(float32) ** scalar(float)
with program_guard(Program()):
a = 3.0
b = paddle.full([2, 2, 2], 2, dtype='float32')
c = paddle.full([2, 2, 2], 9, dtype="float32")
self.check_operation(a, b, c, '**')
# ## TODO: floordiv op kernel doesn't support float
def test_tensor_floordiv_scalar(self):
# tensor(int64) // scalar(int)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 3, dtype='int64')
b = 2
c = paddle.full([2, 2, 2], 1, dtype="int64")
self.check_operation(a, b, c, '//')
def test_tensor_mod_scalar(self):
# tensor(int64) % scalar(int)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 3, dtype='int64')
b = 2
c = paddle.full([2, 2, 2], 1, dtype="int64")
self.check_operation(a, b, c, '%')
# tensor(int64) % scalar(float)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 3, dtype='int64')
b = 2.0
c = paddle.full([2, 2, 2], 1, dtype="float32")
self.check_operation(a, b, c, '%')
# tensor(float32) % scalar(int)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 3, dtype='float32')
b = 2
c = paddle.full([2, 2, 2], 1, dtype="float32")
self.check_operation(a, b, c, '%')
# tensor(float32) % scalar(float)
with program_guard(Program()):
a = paddle.full([2, 2, 2], 3, dtype='float32')
b = 2.0
c = paddle.full([2, 2, 2], 1, dtype="float32")
self.check_operation(a, b, c, '%')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册