From b0ec6e84d1b45f6cf57fa3a27e6c95b62b92d343 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Wed, 24 Feb 2021 14:07:25 +0800 Subject: [PATCH] [cherry pick]add warning message when dtypes of operator are not same (#31136) (#31175) ATT, cherry-pick #31136 --- python/paddle/fluid/dygraph/math_op_patch.py | 13 +++- .../unittests/test_tensor_type_promotion.py | 59 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 5e26ba2b109..1df3e31ae4b 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -21,6 +21,7 @@ from . import no_grad import numpy as np import six +import warnings _supported_int_dtype_ = [ core.VarDesc.VarType.UINT8, @@ -51,6 +52,11 @@ _supported_promote_complex_types_ = [ '__matmul__', ] +_complex_dtypes = [ + core.VarDesc.VarType.COMPLEX64, + core.VarDesc.VarType.COMPLEX128, +] + _already_patch_varbase = False @@ -214,7 +220,9 @@ def monkey_patch_math_varbase(): # 3. promote types or unify right var type to left var rhs_dtype = other_var.dtype if lhs_dtype != rhs_dtype: - if method_name in _supported_promote_complex_types_: + if method_name in _supported_promote_complex_types_ and ( + lhs_dtype in _complex_dtypes or + rhs_dtype in _complex_dtypes): # only when lhs_dtype or rhs_dtype is complex type, # the dtype will promote, in other cases, directly # use lhs_dtype, this is consistent will original rule @@ -225,6 +233,9 @@ def monkey_patch_math_varbase(): other_var = other_var if rhs_dtype == promote_dtype else astype( other_var, promote_dtype) else: + warnings.warn( + 'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'. + format(lhs_dtype, rhs_dtype, lhs_dtype)) other_var = astype(other_var, lhs_dtype) if reverse: diff --git a/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py b/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py new file mode 100644 index 00000000000..c2543645853 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import unittest +import numpy as np +import warnings +import paddle + + +class TestTensorTypePromotion(unittest.TestCase): + def setUp(self): + self.x = paddle.to_tensor([2, 3]) + self.y = paddle.to_tensor([1.0, 2.0]) + + def test_operator(self): + with warnings.catch_warnings(record=True) as context: + warnings.simplefilter("always") + self.x + self.y + self.assertTrue( + "The dtype of left and right variables are not the same" in + str(context[-1].message)) + + with warnings.catch_warnings(record=True) as context: + warnings.simplefilter("always") + self.x - self.y + self.assertTrue( + "The dtype of left and right variables are not the same" in + str(context[-1].message)) + + with warnings.catch_warnings(record=True) as context: + warnings.simplefilter("always") + self.x * self.y + self.assertTrue( + "The dtype of left and right variables are not the same" in + str(context[-1].message)) + + with warnings.catch_warnings(record=True) as context: + warnings.simplefilter("always") + self.x / self.y + self.assertTrue( + "The dtype of left and right variables are not the same" in + str(context[-1].message)) + + +if __name__ == '__main__': + unittest.main() -- GitLab