From 62f7dc49e534cadeeb4bda42b2255a9ab08c3501 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 1 Jun 2020 16:04:41 +0800 Subject: [PATCH] support convert bool scalar and tensor to number tensor --- mindspore/ccsrc/operator/composite/do_signature.cc | 3 --- tests/ut/python/ops/test_math_ops.py | 6 ++---- tests/ut/python/ops/test_math_ops_check.py | 7 ------- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index dd4d3a87c..305f07584 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -255,9 +255,6 @@ void DoAutoCast(const std::vector &signature, const abstract::Abstrac if (arg_value->isa() && arg_type_id == it->second) { continue; } - if ((arg_type_id == kNumberTypeBool || it->second == kNumberTypeBool) && arg_type_id != it->second) { - continue; - } (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); } } diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index d600ce16b..e280cc109 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -101,10 +101,8 @@ def test_pow(): result = testpow(input_tensor, power) assert np.all(result.asnumpy() == expect) net = PowNet() - with pytest.raises(TypeError): - net(input_tensor, True) - with pytest.raises(TypeError): - net(input_tensor, power2) + net(input_tensor, True) + net(input_tensor, power2) def test_exp(): diff --git a/tests/ut/python/ops/test_math_ops_check.py b/tests/ut/python/ops/test_math_ops_check.py index 355e35f93..9772de82e 100755 --- a/tests/ut/python/ops/test_math_ops_check.py +++ b/tests/ut/python/ops/test_math_ops_check.py @@ -293,13 +293,6 @@ raise_set = [ 'desc_inputs': [5.0], 'skip': ['backward']}), - # input x is Tensor(bool) - ('Pow1', { - 'block': (P.Pow(), - {'exception': TypeError, 'error_keywords': ['Pow']}), - 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0], - 'skip': ['backward']}), - # input is not Tensor ('Exp1', { 'block': (P.Exp(), -- GitLab