提交 758853ff 编写于 作者: C changzherui

modify case

上级 1f4222ed
...@@ -1342,9 +1342,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node ...@@ -1342,9 +1342,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
} }
for (size_t i = 1; i < input_size; i++) { for (size_t i = 1; i < input_size; i++) {
auto pred = inputs[i]; AnfNodePtr pred = nullptr;
if (case_flag != 0) { if (case_flag != 0) {
pred = case_input_handle_cache_[node.get()]->at(i - 1); pred = case_input_handle_cache_[node.get()]->at(i - 1);
} else {
pred = inputs[i];
} }
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "Depend") { while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "Depend") {
......
...@@ -29,8 +29,7 @@ class Net(nn.Cell): ...@@ -29,8 +29,7 @@ class Net(nn.Cell):
def construct(self, x, index): def construct(self, x, index):
x = self.layers[index](x) x = self.layers[index](x)
y = self.conv1(x) return 2 + x
return x + y
def test_case(): def test_case():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册