From 0de05b39491f0655ce0149d8a6b3567809331937 Mon Sep 17 00:00:00 2001 From: candanzg Date: Sat, 18 Apr 2020 16:20:31 +0800 Subject: [PATCH] [bug] fixed bool check for cast op Signed-off-by: candanzg --- .../ccsrc/operator/composite/do_signature.cc | 13 +++++++++++++ mindspore/ops/operations/array_ops.py | 2 +- tests/ut/python/ops/test_math_ops.py | 16 ++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index a4a26377f..70fc0f591 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -137,6 +137,19 @@ void DoAutoCast(const std::vector& signature, const abstract::Abstrac if (it == dst_type.end() || it->second == i || !arg_value->isa()) { continue; } + // When scalar is of bool type, the type of tensor must also be of bool type, + // otherwise the cast operator will not be added. + auto scalar = arg_value->cast(); + auto scalar_type = scalar->BuildType(); + MS_EXCEPTION_IF_NULL(scalar_type); + if (scalar_type->type_id() == kNumberTypeBool) { + auto tensor = args_spec_list[it->second]->cast(); + auto tensor_type = tensor->element()->BuildType(); + MS_EXCEPTION_IF_NULL(tensor_type); + if (tensor_type->type_id() != kNumberTypeBool) { + continue; + } + } // get source node for cast AnfNodePtr source_node = (*op_inputs)[it->second + 1]; (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], source_node, graph); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2e03676a4..b4c4796d5 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -745,7 +745,7 @@ class Fill(PrimitiveWithInfer): out = { 'value': Tensor(ret), 'shape': dims['value'], - 'dtype': x_nptype, + 'dtype': x_dtype, } return out diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 8b7f627e8..7f8717d4e 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -30,6 +30,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config from ....mindspore_test_framework.pipeline.forward.verify_exception \ import pipeline_for_verify_exception_for_case_by_case_config +import pytest # pylint: disable=W0613 @@ -81,14 +82,29 @@ def test_sqrt(): assert np.all(result.asnumpy() == expect) +class PowNet(nn.Cell): + def __init__(self): + super(PowNet, self).__init__() + self.pow = P.Pow() + + def construct(self, x, y): + return self.pow(x, y) + + def test_pow(): """ test_pow """ input_tensor = Tensor(np.array([[2, 2], [3, 3]])) power = Tensor(np.array(3.0, np.int64)) + power2 = Tensor(np.array(True, np.bool)) testpow = P.Pow() expect = np.array([[8, 8], [27, 27]]) result = testpow(input_tensor, power) assert np.all(result.asnumpy() == expect) + net = PowNet() + with pytest.raises(TypeError): + net(input_tensor, True) + with pytest.raises(TypeError): + net(input_tensor, power2) def test_exp(): -- GitLab