From 4df1edf5b386c7f3a54de22fe55971a98bb82cd8 Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 9 Jul 2020 17:26:20 +0800 Subject: [PATCH] Improving implicit type conversion --- mindspore/ccsrc/ir/dtype/type_id.h | 6 +- .../ccsrc/operator/composite/do_signature.cc | 15 ++- .../ccsrc/operator/composite/do_signature.h | 3 + mindspore/ccsrc/pynative/pynative_execute.cc | 44 ++---- .../pynative_mode/test_implicit_conversion.py | 125 +++++++++++++++++- 5 files changed, 151 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/ir/dtype/type_id.h b/mindspore/ccsrc/ir/dtype/type_id.h index a711779e9..6fb2a354c 100644 --- a/mindspore/ccsrc/ir/dtype/type_id.h +++ b/mindspore/ccsrc/ir/dtype/type_id.h @@ -86,8 +86,8 @@ enum TypeId : int { // TypeId name map // const std::unordered_map type_name_map = { - {kNumberTypeBool, "Bool"}, {kNumberTypeInt8, "Int8"}, {kNumberTypeUInt8, "UInt8"}, - {kNumberTypeInt16, "Int16"}, {kNumberTypeInt32, "Int32"}, {kNumberTypeInt64, "Int64"}, - {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat32, "Float32"}, {kNumberTypeFloat64, "Float64"}}; + {kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"}, + {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"}, + {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}}; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index 0b619eecc..c70cfe5d4 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -223,11 +223,7 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign if (it_name_map == type_name_map.end()) { continue; } - MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" - << "the type of writable argument is '" << it_map->second << "', " - << "but the largest type in the same SignatureEumDtype is '" << it_name_map->second - << "'. The writable arg type is not equal to the largest type, " - << "so can not cast automatically."; + RaiseExceptionForConvertRefDtype(func_name, it_map->second, it_name_map->second); } continue; } @@ -311,5 +307,14 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); return func_graph; } + +void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, + const std::string &target_type) { + MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" + << "the type of writable argument is '" << ref_type << "', " + << "but the largest type in the same SignatureEumDtype is '" << target_type + << "'. The writable arg type is not equal to the largest type, " + << "so can not cast automatically."; +} } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/do_signature.h b/mindspore/ccsrc/operator/composite/do_signature.h index 6905a7835..97f6d7e7a 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.h +++ b/mindspore/ccsrc/operator/composite/do_signature.h @@ -58,6 +58,9 @@ using RWSignaturePtr = std::shared_ptr; extern const std::map type_map; +void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, + const std::string &target_type); + AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); } // namespace prim diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index d62ec1895..ed7ff38ae 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -184,6 +184,9 @@ std::map GetDstType(const py::tuple &py_args, auto arg = py::cast(py_args[index]); TypeId arg_type_id = arg->data_type(); auto type_priority = prim::type_map.find(arg_type_id); + if (type_priority == prim::type_map.end()) { + continue; + } if (type_priority->second > priority) { max_type = type_priority->first; priority = type_priority->second; @@ -204,36 +207,14 @@ std::map GetDstType(const py::tuple &py_args, } std::string TypeIdToMsTypeStr(const TypeId &type_id) { - switch (type_id) { - case kNumberTypeFloat16: - return "float16"; - case kNumberTypeFloat32: - return "float32"; - case kNumberTypeFloat64: - return "float64"; - case kNumberTypeInt8: - return "int8"; - case kNumberTypeInt16: - return "int16"; - case kNumberTypeInt32: - return "int32"; - case kNumberTypeInt64: - return "int64"; - case kNumberTypeUInt8: - return "uint8"; - case kNumberTypeUInt16: - return "uint16"; - case kNumberTypeUInt32: - return "uint32"; - case kNumberTypeUInt64: - return "uint64"; - case kNumberTypeBool: - return "bool_"; - default: - MS_LOG(EXCEPTION) << "For implicit type conversion, not support the type: " << TypeIdToType(type_id); + auto type_name = type_name_map.find(type_id); + if (type_name == type_name_map.end()) { + MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id); } + return type_name->second; } -py::object DoAutoCast(const py::object arg, const TypeId &type_id) { + +py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { py::tuple args(3); std::string module_name = "mindspore.ops.functional"; std::string op_name = "cast"; @@ -283,11 +264,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu continue; } if (signature[i].rw == SignatureEnumRW::kRWWrite) { - MS_LOG(EXCEPTION) << "In op '" << prim->name() << "', \n" - << "the type of writable argument is '" << TypeIdToMsTypeStr(arg->data_type()) << "', " - << "but the largest type in the same SignatureEumDtype is '" << TypeIdToMsTypeStr(it->second) - << "'. The writable arg type is not equal to the largest type, " - << "so can not cast automatically."; + prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()), + TypeIdToMsTypeStr(it->second)); } } py::object cast_output = DoAutoCast(py_args[i], it->second); diff --git a/tests/ut/python/pynative_mode/test_implicit_conversion.py b/tests/ut/python/pynative_mode/test_implicit_conversion.py index 093b095b7..ecaffd87f 100644 --- a/tests/ut/python/pynative_mode/test_implicit_conversion.py +++ b/tests/ut/python/pynative_mode/test_implicit_conversion.py @@ -15,7 +15,8 @@ """ test implicit conversion """ import numpy as np -from mindspore import Tensor +from mindspore import Tensor, nn +from mindspore.ops import composite as C def test_float_tensor_and_int_add(): @@ -23,6 +24,7 @@ def test_float_tensor_and_int_add(): y = 2 ret_actual = x + y ret_expect = Tensor(np.array([[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() @@ -31,6 +33,7 @@ def test_bool_tensor_and_float_add(): y = 3.3 ret_actual = x + y ret_expect = Tensor(np.array([[4.3, 3.3], [3.3, 4.3]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() @@ -39,6 +42,7 @@ def test_bool_tensor_and_int_add(): y = 3 ret_actual = x + y ret_expect = Tensor(np.array([[4, 3], [3, 4]], dtype=np.int32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() @@ -47,13 +51,16 @@ def test_bool_and_int_tensor_add(): y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) ret_actual = x + y ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + def test_float_tensor_and_int_tensor_add(): x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) ret_actual = x + y ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() @@ -62,6 +69,7 @@ def test_float_tensor_and_float_tensor_add(): y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float16)) ret_actual = x + y ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() @@ -70,6 +78,7 @@ def test_int_tensor_and_int_tensor_add(): y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) ret_actual = x + y ret_expect = Tensor(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.int32)) + assert ret_actual.dtype == ret_expect.dtype assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() @@ -79,3 +88,117 @@ def test_float_tensor_and_bool_tensors_add(): ret_actual = x + y ret_expect = Tensor(np.array([[1.1, 1.2, 1.3], [0.4, 0.5, 0.6]], dtype=np.float32)) assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_bool_tensors_add_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return x + y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, y, sens): + + return C.grad_all_with_sens(self.net)(x, y, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, y, sens) + assert ret[0].dtype == x.dtype + assert ret[1].dtype == y.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() + assert (ret[1].asnumpy() == sens.asnumpy().astype(np.bool_)).all() + + +def test_float_tensor_and_int_tensors_sub_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return x - y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, y, sens): + + return C.grad_all_with_sens(self.net)(x, y, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, y, sens) + print(ret) + assert ret[0].dtype == x.dtype + assert ret[1].dtype == y.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() + assert (ret[1].asnumpy() == sens.asnumpy() * -1).all() + + +def test_float16_tensor_and_float32_tensors_sub_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return x - y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, y, sens): + + return C.grad_all_with_sens(self.net)(x, y, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.int32)) + y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, y, sens) + print(ret) + assert ret[0].dtype == x.dtype + assert ret[1].dtype == y.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() + assert (ret[1].asnumpy() == sens.asnumpy() * -1).all() + + +def test_float_tensor_and_int_add_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return x + 2 + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, sens): + return C.grad_all_with_sens(self.net)(x, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, sens) + assert ret[0].dtype == x.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() -- GitLab