diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 290e2a1a98cb593fe19dc15270d80da125cbf802..aec366ab9bd321ed8e2b42ee7e50901da738a67c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -20,12 +20,14 @@ #include #include #include +#include #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/anf_visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/tensor.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -153,23 +155,31 @@ class InlinerBase : public AnfVisitor { return nullptr; } - std::vector params; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); + std::vector args; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); // compare size to avoid the case that the function has default value after grad. // for which after renormalize, the function default value will be an input - if (fg->parameters().size() != params.size()) { + if (fg->parameters().size() != args.size()) { return nullptr; } + // Not to inline after block if it has switch call inside, to avoid switch expansion. + if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { + auto has_branch_call = GraphHasBranch(fg); + if (has_branch_call) { + return TransformBranchCall(fg, node, args); + } + } + if (use_move_ && IsUniqueUse(fg, nullptr)) { auto mng = fg->manager(); MS_EXCEPTION_IF_NULL(mng); - ReplaceParams(mng, params, fg); + ReplaceParams(mng, args, fg); auto out_node = fg->output(); mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); return out_node; } - return InlineClone(fg, node->func_graph(), params, inputs[0]->scope()); + return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); } void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector &new_params, @@ -197,11 +207,89 @@ class InlinerBase : public AnfVisitor { is_checked_ = false; is_recursive_ = false; } + // For after block which contains branch call, delete the parameters which is not used. + // In most cases, it may be a `Module` or other constant input. + AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector &args) { + auto &fg_params = fg->parameters(); + std::vector used_param_index; + auto mng = fg->manager(); + for (size_t i = 0; i < fg_params.size(); i++) { + if (mng->node_users()[fg_params[i]].size() != 0) { + used_param_index.emplace_back(i); + } + } + if (used_param_index.size() != fg_params.size()) { + MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); + // clone a new graph and ignore the not used parameters + FuncGraphPtr new_fg = TransformableClone(fg); + auto &new_fg_params = new_fg->parameters(); + std::vector new_params; + std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), + [&new_fg_params](size_t i) { return new_fg_params[i]; }); + new_fg->set_parameters(new_params); + std::vector node_inputs; + node_inputs.push_back(NewValueNode(new_fg)); + std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs), + [&args](size_t i) { return args[i]; }); + return node->func_graph()->NewCNode(node_inputs); + } + return nullptr; + } + + // This is a try-best algorithm to find a graph which may generate branch call. + // It does not handle high-order function call. For high-orderer call branch, it still may be inlined. + bool GraphHasBranch(FuncGraphPtr fg) { + if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) { + return graph_branch_cache_[fg]; + } + bool has_branch = false; + auto nodes = fg->nodes(); + for (auto &item : nodes) { + if (IsPrimitiveCNode(item, prim::kPrimSwitch)) { + auto sw_inputs = item->cast()->inputs(); + if (sw_inputs.size() != 4) { + MS_LOG(EXCEPTION) << "switch inputs should be 4"; + } + if (!sw_inputs[1]->isa() || IsValueNode(sw_inputs[1])) { + has_branch = true; + break; + } + } else if (IsCNodeGraph(item)) { + auto cinputs = item->cast()->inputs(); + if (cinputs.size() < 1) { + MS_LOG(EXCEPTION) << "graph call inputs should greater than 1"; + } + FuncGraphPtr call_fg = GetValueNode(cinputs[0]); + bool call_fg_has_branch = GraphHasBranch(call_fg); + if (call_fg_has_branch) { + has_branch = true; + break; + } + } else if (IsPrimitiveCNode(item, prim::kPrimPartial)) { + auto cinputs = item->cast()->inputs(); + if (cinputs.size() < 2) { + MS_LOG(EXCEPTION) << "partial call inputs should greater than 2"; + } + FuncGraphPtr call_fg = GetValueNode(cinputs[1]); + if (call_fg == nullptr) { + continue; + } + bool call_fg_has_branch = GraphHasBranch(call_fg); + if (call_fg_has_branch) { + has_branch = true; + break; + } + } + } + graph_branch_cache_[fg] = has_branch; + return has_branch; + } private: bool is_checked_{false}, is_recursive_{false}; bool use_move_; std::vector> criterions_; + std::unordered_map graph_branch_cache_; }; class Inliner : public InlinerBase { diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 486f8181ce78e0f13b8ec3917fc85941cee382a4..87c93ac8b15f4d47217db328ca7e32edb158befd 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1029,6 +1029,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object FunctionBlockPtr after_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); + if (MsContext::GetInstance()->backend_policy() != "ge") { + // for backends excludes 'ge', it can handle multi graph call, use this flag to + // generate call not inline `after_block` graph to reduce if by if switch expansion. + after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true); + } + // process the if-true branch py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index e9694c659657db58902eb7973510fecb2138b0cb..01808ca4c609a9b269abcd645adb1af4ab6e2d55 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; +const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block"; const char FUNC_GRAPH_FLAG_CORE[] = "core"; const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py index c3baae1fb5b3dc4b1c9dc4cb7855e49534dfc732..6a71ecda2c066ea5b7ad61e860d14fcb5b76c8f6 100644 --- a/tests/st/control/test_cont_grad.py +++ b/tests/st/control/test_cont_grad.py @@ -42,7 +42,7 @@ def test_while_forward(): idx = idx + 1 return x - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = MyWhileNet() idx = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32) @@ -72,7 +72,7 @@ def test_while_grad(): def construct(self, *inputs): return C.grad_all(self.net)(*inputs) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -99,7 +99,7 @@ def test_while_with_param_forward(): idx = idx + 1 return out - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = MyWhileNet() idx = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32) @@ -124,7 +124,7 @@ def test_while_endless_case(): idx = idx + 1 return out - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = MyWhileNet() idx = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32) @@ -159,7 +159,7 @@ def test_while_with_param_grad(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -187,7 +187,7 @@ def test_while_with_param_forward_with_const_branch(): idx = idx + 1 return out - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = while_net idx = Tensor(np.array(0), dtype=ms.int32) @@ -224,7 +224,7 @@ def test_while_opt_endless(): def construct(self, *inputs): return C.grad_all(self.net)(*inputs) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -250,7 +250,7 @@ def test_no_while_call(): out = out + idx + self.param return out - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = while_net idx = Tensor(np.array(0), dtype=ms.int32) @@ -287,7 +287,7 @@ def test_while_with_param_grad_with_const_branch(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -327,7 +327,7 @@ def test_for_while_with_param_grad_with_const_branch(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -364,7 +364,7 @@ def test_for_while_with_param_grad_basic(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -401,7 +401,7 @@ def test_for_while_with_param_grad_normal(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -435,7 +435,7 @@ def test_while_with_param_basic_grad(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -469,7 +469,7 @@ def test_while_with_param_basic_grad_mul(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -504,7 +504,7 @@ def test_while_with_param_basic_grad_two(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -540,7 +540,7 @@ def test_while_with_param_basic_grad_three(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -577,7 +577,7 @@ def test_while_if_with_param_grad(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -610,7 +610,7 @@ def test_while_with_param_grad_not_enter_while(): def construct(self, a, b, c): return C.grad_by_list(self.net, self.weights)(a, b, c) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) idx = Tensor(np.array(3), dtype=ms.int32) @@ -639,7 +639,7 @@ def test_with_param_if_by_if_forward(): out = out + x*2 return out - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") if_net = MyIfByIfNet() net = if_net idx = Tensor(np.array(0), dtype=ms.int32) @@ -672,7 +672,7 @@ def test_with_param_if_by_if_grad_inputs(): def construct(self, *inputs): return C.grad_all(self.net)(*inputs) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") if_net = MyIfByIfNet() net = GradNet(if_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -706,7 +706,7 @@ def test_with_param_if_by_if_grad_parameter(): def construct(self, *inputs): return C.grad_by_list(self.net, self.weights)(*inputs) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") if_net = MyIfByIfNet() net = GradNet(if_net) idx = Tensor(np.array(0), dtype=ms.int32) @@ -738,7 +738,7 @@ def test_with_param_if_by_if_grad_param_excute_null(): def construct(self, *inputs): return C.grad_by_list(self.net, self.weights)(*inputs) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") if_net = MyIfByIfNet() net = GradNet(if_net) idx = Tensor(np.array(4), dtype=ms.int32) @@ -772,7 +772,7 @@ def test_if_by_if_return_inside_grad(): def construct(self, *inputs): return C.grad_by_list(self.net, self.weights)(*inputs) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") if_net = MyIfByIfNet() net = GradNet(if_net) idx = Tensor(np.array(1), dtype=ms.int32) @@ -807,10 +807,342 @@ def test_if_by_if_forward(): out = a + b + x return out - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") if_net = MyIfByIfNet() net = if_net idx = Tensor(np.array(2), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32) x = Tensor(np.array(4), dtype=ms.float32) net(idx, end, x) + + +def test_if_by_if_forward_control_tuple_switch(): + """tuple_get from swtich op will generate new switch inside to eliminate tuple_get""" + class Branch3Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + return a, b, x + class Branch2Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch3Net() + + def construct(self, a, b, x): + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + return self.net(a, b, x) + + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch2Net() + + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + a, b, x = self.net(a, b, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + + + +def test_if_by_if_forward_control_inside_net(): + class Branch3Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + a = a * b + out = a + b + x + return out + class Branch2Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch3Net() + + def construct(self, a, b, x): + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + return self.net(a, b, x) + + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch2Net() + + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + out = self.net(a, b, x) + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + + +def test_if_by_if_forward_use_namespace(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if a < b: + a = P.TensorAdd()(a, b) + else: + a = P.Sub()(a, b) + if a == x: + a = P.Mul()(a, b) + else: + a = P.RealDiv()(a, b) + if b == x: + b = P.TensorAdd()(a, b) + else: + b = P.TensorAdd()(a, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_use_global_op(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + add = P.TensorAdd() + sub = P.Sub() + mul = P.Mul() + div = P.RealDiv() + if a < b: + a = add(a, b) + else: + a = sub(a, b) + if a == x: + a = mul(a, b) + else: + a = div(a, b) + if b == x: + b = add(a, b) + else: + b = add(a, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_for_with_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + + def construct(self, a, b, x): + for _ in range(0, 4): + if a < b: + a = self.add(a, b) + else: + b = self.sub(b, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + + +def test_for_with_if_by_if_forward_namespace(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + for _ in range(0, 6): + if a < b: + a = P.TensorAdd()(a, b) + else: + b = P.Sub()(b, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + + +def test_if_by_if_forward_const_branch_inner(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + add = P.TensorAdd() + sub = P.Sub() + mul = P.Mul() + div = P.RealDiv() + if a < b: + a = add(a, b) + else: + a = sub(a, b) + if 2 > 1: + a = mul(a, b) + else: + a = div(a, b) + if b == x: + b = add(a, b) + else: + b = add(a, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + + + +def test_if_by_if_forward_all_const_branch(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + add = P.TensorAdd() + sub = P.Sub() + mul = P.Mul() + div = P.RealDiv() + if 2 < 12: + a = add(a, b) + else: + a = sub(a, b) + if 2 > 1: + a = mul(a, b) + else: + a = div(a, b) + if 2 == 1: + b = add(a, b) + else: + b = add(a, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) diff --git a/tests/ut/python/pynative_mode/test_cont_cases.py b/tests/ut/python/pynative_mode/test_cont_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d9f9c72af2b8af965d9232d1685356cc64567b --- /dev/null +++ b/tests/ut/python/pynative_mode/test_cont_cases.py @@ -0,0 +1,1006 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test control ops """ +import numpy as np +from mindspore import dtype as ms +from mindspore import Tensor +from mindspore import context +from mindspore import nn +from mindspore import ms_function +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.ops import composite as C +from mindspore.ops import operations as P +# from tests.vm_impl.math_ops_vm_impl import * +# from tests.vm_impl.vm_interface import * +# from tests.vm_impl import * + + +def setup_module(): + context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) + + +def test_while_with_param_forward_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + @ms_function + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + while_net = MyWhileNet() + net = while_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_opt_endless(): + """endless during optimization case""" + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.addn = P.AddN() + + def construct(self, idx, end, x): + addn1 = self.addn((x, x, x)) + out = addn1 + while idx < end: + out = self.addn((out, addn1)) + idx = idx + 1 + out = self.addn((out, x)) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + @ms_function + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) + net(idx, end, x) + + +def test_no_while_call(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + @ms_function + def construct(self, idx, end, x): + out = self.zero + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + return out + + while_net = MyWhileNet() + net = while_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = self.zero + for _ in range(0, 2): + idx = self.start + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_basic(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = self.zero + for _ in range(0, 2): + idx = self.start + while idx < end: + out = out + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_normal(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = x + for _ in range(0, 2): + idx = self.start + while idx < end: + out = out + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_mul(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out * self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_two(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + self.weight + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_three(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") + self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + self.weight + self.key + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_if_with_param_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if self.max(out) < self.max(x): + out = out + self.param * 2 + else: + out = out + self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad_not_enter_while(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param * 3 + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(3), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + @ms_function + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param + else: + out = out + x + if a == b: + out = out + x*3 + self.param + else: + out = out + x*2 + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_inputs(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 4 + if a == b: + out = out + x*3 + self.param * 3 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + @ms_function + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_parameter(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 2 + if a == b: + out = out + x*3 + self.param + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_param_excute_null(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 2 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(4), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_return_inside_grad(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + return out + x + self.param + if a == b: + return out + self.param * 2 + return out + self.param * 3 + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + @ms_function + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(1), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + @ms_function + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(4), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_control_tuple_switch(): + """tuple_get from swtich op will generate new switch inside to eliminate tuple_get""" + class Branch3Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + return a, b, x + + class Branch2Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch3Net() + + def construct(self, a, b, x): + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + return self.net(a, b, x) + + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch2Net() + + @ms_function + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + a, b, x = self.net(a, b, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_control_inside_net(): + class Branch3Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + a = a * b + out = a + b + x + return out + + class Branch2Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch3Net() + + def construct(self, a, b, x): + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + return self.net(a, b, x) + + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.net = Branch2Net() + + @ms_function + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + out = self.net(a, b, x) + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_use_namespace(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + @ms_function + def construct(self, a, b, x): + if a < b: + a = P.TensorAdd()(a, b) + else: + a = P.Sub()(a, b) + if a == x: + a = P.Mul()(a, b) + else: + a = P.RealDiv()(a, b) + if b == x: + b = P.TensorAdd()(a, b) + else: + b = P.TensorAdd()(a, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_use_global_op(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + @ms_function + def construct(self, a, b, x): + add = P.TensorAdd() + sub = P.Sub() + mul = P.Mul() + div = P.RealDiv() + if a < b: + a = add(a, b) + else: + a = sub(a, b) + if a == x: + a = mul(a, b) + else: + a = div(a, b) + if b == x: + b = add(a, b) + else: + b = add(a, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_for_with_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + + @ms_function + def construct(self, a, b, x): + for _ in range(0, 4): + if a < b: + a = self.add(a, b) + else: + b = self.sub(b, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_for_with_if_by_if_forward_namespace(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + @ms_function + def construct(self, a, b, x): + for _ in range(0, 6): + if a < b: + a = P.TensorAdd()(a, b) + else: + b = P.Sub()(b, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_const_branch_inner(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + @ms_function + def construct(self, a, b, x): + add = P.TensorAdd() + sub = P.Sub() + mul = P.Mul() + div = P.RealDiv() + if a < b: + a = add(a, b) + else: + a = sub(a, b) + if 2 > 1: + a = mul(a, b) + else: + a = div(a, b) + if b == x: + b = add(a, b) + else: + b = add(a, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward_all_const_branch(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + @ms_function + def construct(self, a, b, x): + add = P.TensorAdd() + sub = P.Sub() + mul = P.Mul() + div = P.RealDiv() + if 2 < 12: + a = add(a, b) + else: + a = sub(a, b) + if 2 > 1: + a = mul(a, b) + else: + a = div(a, b) + if 2 == 1: + b = add(a, b) + else: + b = add(a, x) + a = a * b + out = a + b + x + return out + + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(0), dtype=ms.float32) + net(idx, end, x)