提交 4291ea82 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2024 Add MixedPrecisionCast for Dict

Merge pull request !2024 from Kang/master
...@@ -286,6 +286,22 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node ...@@ -286,6 +286,22 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node
++idx; ++idx;
} }
target_node = func_graph->NewCNode(nodes); target_node = func_graph->NewCNode(nodes);
} else if (node_type->isa<AbstractDictionary>()) {
auto x = node_type->cast<AbstractDictionaryPtr>();
auto &items = x->elements();
std::vector<AnfNodePtr> dict_key_nodes;
std::vector<AnfNodePtr> dict_value_nodes;
dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (const auto &item : items) {
AnfNodePtr dict_value_node =
func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)});
AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
dict_key_nodes.emplace_back(NewValueNode(item.first));
dict_value_nodes.emplace_back(node);
}
target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes),
func_graph->NewCNode(dict_value_nodes)});
} }
return target_node; return target_node;
} }
......
...@@ -308,7 +308,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr ...@@ -308,7 +308,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
evaluator = std::make_shared<UnpackGraphEvaluator>(prim); evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
return evaluator; return evaluator;
} }
if (prim->name() == prim::kPrimMixedPrecisionCast->name()) { if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim);
return evaluator; return evaluator;
} }
......
...@@ -25,6 +25,7 @@ from mindspore.nn import Momentum ...@@ -25,6 +25,7 @@ from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from tests.ops_common import convert from tests.ops_common import convert
from ....train_step_wrap import train_step_with_loss_warp from ....train_step_wrap import train_step_with_loss_warp
...@@ -185,3 +186,36 @@ def test_grad_conv_prelu(): ...@@ -185,3 +186,36 @@ def test_grad_conv_prelu():
net = GetParamGrad(net) net = GetParamGrad(net)
net.set_train() net.set_train()
net(*all_inputs) net(*all_inputs)
def test_dict_cast():
class FirstNet(nn.Cell):
def __init__(self):
super(FirstNet, self).__init__()
self.net = SecondNet()
self.sub = P.Sub()
def construct(self, tensor_a, tensor_b):
a = F.mixed_precision_cast(mstype.float16, tensor_a)
b = F.mixed_precision_cast(mstype.float16, tensor_b)
c = self.sub(a, b)
dictionary = {"key": a}
result = self.net(c, key1=a, key2=dictionary)
return result
class SecondNet(nn.Cell):
def __init__(self):
super(SecondNet, self).__init__()
self.add = P.TensorAdd()
def construct(self, tensor_c, **kwargs):
d = F.mixed_precision_cast(mstype.float16, tensor_c)
dict_cast = F.mixed_precision_cast(mstype.float16, kwargs)
e = self.add(d, dict_cast["key1"])
f = self.add(e, dict_cast["key2"]["key"])
return f
x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32)
y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
net = FirstNet()
net(x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册