diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index 231b98ab0032136384379716dafd9f517f7aba13..51c4fc17ec79eb9e16a412748fa42da0825a407a 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -68,9 +68,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo return param; } auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base"); - auto partial = - func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(cast_helper), NewValueNode(dst_type)}); - auto cast = func_graph->NewCNode({NewValueNode(prim::kCompositeHyperMap), partial, param}); + auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); return cast; } diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 8670f4aa7ce307428dd8566d8976d932b160c3c8..4b559d1605b48a4d23a289586164cb797908ab8f 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -307,3 +307,12 @@ def _mixed_precision_cast_helper_2(type_, x): if F.issubclass_(F.dtype(x), mstype.float_): return P.Cast()(x, type_) return x + +@_mp_cast_helper.register("TypeType", "Tuple") +@core +def _mixed_precision_cast_helper_3(type_, x): + """if x is a tuple""" + t = () + for item in x: + t = t + (_mp_cast_helper(type_, item),) + return t diff --git a/tests/ut/python/parameter_feature/test_parameter.py b/tests/ut/python/parameter_feature/test_parameter.py index 696b107f56b75268894fa3c0bf2e40c33ecc9486..1409fef386d364cbece635992b77ce49a1a04c73 100644 --- a/tests/ut/python/parameter_feature/test_parameter.py +++ b/tests/ut/python/parameter_feature/test_parameter.py @@ -19,7 +19,7 @@ from mindspore.nn import Cell from mindspore.ops import operations as P import mindspore.ops.composite as C -context.set_context(mode=context.GRAPH_MODE) +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) def test_parser_three_default_mixed_args_subnet(): @@ -227,3 +227,43 @@ def test_net_vargs_expand(): net.set_train() net(x, y, sens) + + +def test_mixed_precision_const_parameter(): + class NetLoss(Cell): + def __init__(self): + super(NetLoss, self).__init__() + self.shape = P.Shape() + self.up_sample1 = P.ResizeBilinear((14, 14)) + self.up_sample2 = P.ResizeBilinear((28, 28)) + self.up_sample3 = P.ResizeBilinear((36, 36)) + def construct(self, x, y, z, *args): + ret = 0 + if args[0] == self.shape(z)[2]: + if args[0] == 14: + ret = self.up_sample1(y) + x + elif args[0] == 28: + ret = self.up_sample2(y) - x + else: + ret = x / y + else: + ret = x * y + ret = ret * z + return ret + class NetMain(Cell): + def __init__(self, loss_fn): + super(NetMain, self).__init__() + self.loss_fn = loss_fn + self.shape = P.Shape() + def construct(self, x, y, z): + size_x = self.shape(x)[2] + size_y = self.shape(y)[2] + ret = self.loss_fn(x, y, z, size_x, size_y) + return ret + loss_fn = NetLoss() + net = NetMain(loss_fn) + net.add_flags_recursive(fp32=True) + x = Tensor(np.ones((1, 3, 28, 28), np.float32)) + y = Tensor(np.ones((1, 3, 14, 14), np.float32)) + z = Tensor(np.ones((1, 3, 28, 28), np.float32)) + out = net(x, y, z) \ No newline at end of file