提交 760cd682 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3337 add check for implicit conversion when scalar is not number and bool

Merge pull request !3337 from zhangbuxue/add_check_for_implict_conversion_when_scalar_is_string
......@@ -99,7 +99,7 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
}
// Resolve class member, two possible: method, member variable
AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) {
AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
py::object namespace_var =
parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj());
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
......
......@@ -68,7 +68,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); }
bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); }
AnfNodePtr MakeResolveAstOp(const py::object &op);
AnfNodePtr MakeResolveClassMember(std::string attr);
AnfNodePtr MakeResolveClassMember(const std::string &attr);
AnfNodePtr MakeResolveSymbol(const std::string &value);
AnfNodePtr MakeResolveOperation(const std::string &value);
AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
......
......@@ -268,6 +268,12 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
TypeIdToMsTypeStr(it->second));
}
}
if (!py::isinstance<tensor::Tensor>(py_args[i]) && !py::isinstance<py::int_>(py_args[i]) &&
!py::isinstance<py::float_>(py_args[i])) {
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: "
<< py::cast<std::string>(py_args[1].attr("__class__").attr("__name__"))
<< ", and the value is " << py::cast<py::str>(py_args[i]) << ".";
}
py::object cast_output = DoAutoCast(py_args[i], it->second);
(*out_args)[i] = cast_output;
(*out_args_list)[i] = cast_output;
......
......@@ -14,6 +14,7 @@
# ============================================================================
""" test implicit conversion """
import numpy as np
import pytest
from mindspore import Tensor, nn
from mindspore.ops import composite as C
......@@ -90,6 +91,30 @@ def test_float_tensor_and_bool_tensors_add():
assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all()
def test_float_tensor_and_str_add():
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
y = "ok"
with pytest.raises(TypeError) as er:
ret = x + y
assert "For 'TensorAdd', the 1th input is a not support type: str" in str(er.value)
def test_float_tensor_and_tuple_add():
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
y = (1, 2, 3)
with pytest.raises(TypeError) as er:
ret = x + y
assert "For 'TensorAdd', the 1th input is a not support type: tuple" in str(er.value)
def test_float_tensor_and_list_add():
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
y = [1, 2, 3]
with pytest.raises(TypeError) as er:
ret = x + y
assert "For 'TensorAdd', the 1th input is a not support type: list" in str(er.value)
def test_float_tensor_and_bool_tensors_add_grad():
class Net(nn.Cell):
def __init__(self):
......@@ -104,7 +129,6 @@ def test_float_tensor_and_bool_tensors_add_grad():
self.net = net
def construct(self, x, y, sens):
return C.grad_all_with_sens(self.net)(x, y, sens)
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
......@@ -133,7 +157,6 @@ def test_float_tensor_and_int_tensors_sub_grad():
self.net = net
def construct(self, x, y, sens):
return C.grad_all_with_sens(self.net)(x, y, sens)
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
......@@ -163,7 +186,6 @@ def test_float16_tensor_and_float32_tensors_sub_grad():
self.net = net
def construct(self, x, y, sens):
return C.grad_all_with_sens(self.net)(x, y, sens)
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.int32))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册