提交 96b38f72 编写于 作者: H huangdongrun

add ExpandDims whitelist

add comment for control_depend
上级 f6575616
...@@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { ...@@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimApplyRMSProp, {6, 7, 8}}, {prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}}, {prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}}, {prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}},
{prim::kPrimHistogramSummary, {1}}}); {prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) { for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
......
...@@ -30,6 +30,8 @@ class ControlDepend(Primitive): ...@@ -30,6 +30,8 @@ class ControlDepend(Primitive):
tells the engine that the destination operations should depend on the source operation which means the source tells the engine that the destination operations should depend on the source operation which means the source
operations should be executed before the destination. operations should be executed before the destination.
Note:
This operation does not work in `PYNATIVE_MODE`.
Args: Args:
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0. depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.
......
...@@ -19,6 +19,8 @@ import pytest ...@@ -19,6 +19,8 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as ms
from mindspore.common.api import _executor from mindspore.common.api import _executor
...@@ -116,3 +118,28 @@ def test_parser_map_0002(): ...@@ -116,3 +118,28 @@ def test_parser_map_0002():
net = NetMap0002() net = NetMap0002()
with pytest.raises(TypeError): with pytest.raises(TypeError):
net(input_me_x) net(input_me_x)
def test_fix_expanddims_loss_scale():
class ControlOneIfOneScaleOneScale(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.ExpandDims()
def construct(self, x, y, data):
if x > y:
out = 1
else:
out = 2
if x > y:
out = self.op(data, out)
else:
out = self.op(data, out)
return out
net = ControlOneIfOneScaleOneScale()
x = Tensor(1, ms.float32)
y = Tensor(0, ms.float32)
input_shape = (1024, 512, 7, 7)
input_data = np.random.randn(*input_shape).astype(np.float32)
net = ControlOneIfOneScaleOneScale()
net(x, y, Tensor(input_data))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册