提交 6b2a5221 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2975 Add MixedPrecisionCast for KeywordArg

Merge pull request !2975 from Kang/add_AbstractKeywordArg
......@@ -321,6 +321,13 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node
}
target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes),
func_graph->NewCNode(dict_value_nodes)});
} else if (node_type->isa<AbstractKeywordArg>()) {
auto x = node_type->cast<AbstractKeywordArgPtr>();
std::string kwarg_key = x->get_key();
AnfNodePtr kwarg_value_node =
func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
}
return target_node;
}
......
......@@ -219,3 +219,31 @@ def test_dict_cast():
y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
net = FirstNet()
net(x, y)
def test_kwarg_cast():
class FirstNet(nn.Cell):
def __init__(self):
super(FirstNet, self).__init__()
self.net = SecondNet().add_flags_recursive(fp16=True)
self.add = P.TensorAdd()
def construct(self, tensor_a, tensor_b):
tensor_c = self.add(tensor_a, tensor_b)
dictionary = {"key": tensor_a}
result = self.net(key1=tensor_c, key2=dictionary)
return result
class SecondNet(nn.Cell):
def __init__(self):
super(SecondNet, self).__init__()
self.add = P.TensorAdd()
def construct(self, key1=1, key2=2):
tensor_d = self.add(key1, key2["key"])
return tensor_d
x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32)
y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
net = FirstNet()
net(x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册