提交 4f6e63fc 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4576 Support if by if not inline

Merge pull request !4576 from amongo/SupportIfByIfNotInline
......@@ -20,12 +20,14 @@
#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/tensor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......@@ -153,23 +155,31 @@ class InlinerBase : public AnfVisitor {
return nullptr;
}
std::vector<AnfNodePtr> params;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
std::vector<AnfNodePtr> args;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
// compare size to avoid the case that the function has default value after grad.
// for which after renormalize, the function default value will be an input
if (fg->parameters().size() != params.size()) {
if (fg->parameters().size() != args.size()) {
return nullptr;
}
// Not to inline after block if it has switch call inside, to avoid switch expansion.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
auto has_branch_call = GraphHasBranch(fg);
if (has_branch_call) {
return TransformBranchCall(fg, node, args);
}
}
if (use_move_ && IsUniqueUse(fg, nullptr)) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, params, fg);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
......@@ -197,11 +207,89 @@ class InlinerBase : public AnfVisitor {
is_checked_ = false;
is_recursive_ = false;
}
// For after block which contains branch call, delete the parameters which is not used.
// In most cases, it may be a `Module` or other constant input.
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
auto &fg_params = fg->parameters();
std::vector<int> used_param_index;
auto mng = fg->manager();
for (size_t i = 0; i < fg_params.size(); i++) {
if (mng->node_users()[fg_params[i]].size() != 0) {
used_param_index.emplace_back(i);
}
}
if (used_param_index.size() != fg_params.size()) {
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
// clone a new graph and ignore the not used parameters
FuncGraphPtr new_fg = TransformableClone(fg);
auto &new_fg_params = new_fg->parameters();
std::vector<AnfNodePtr> new_params;
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
[&new_fg_params](size_t i) { return new_fg_params[i]; });
new_fg->set_parameters(new_params);
std::vector<AnfNodePtr> node_inputs;
node_inputs.push_back(NewValueNode(new_fg));
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
[&args](size_t i) { return args[i]; });
return node->func_graph()->NewCNode(node_inputs);
}
return nullptr;
}
// This is a try-best algorithm to find a graph which may generate branch call.
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
bool GraphHasBranch(FuncGraphPtr fg) {
if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
return graph_branch_cache_[fg];
}
bool has_branch = false;
auto nodes = fg->nodes();
for (auto &item : nodes) {
if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
auto sw_inputs = item->cast<CNodePtr>()->inputs();
if (sw_inputs.size() != 4) {
MS_LOG(EXCEPTION) << "switch inputs should be 4";
}
if (!sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1])) {
has_branch = true;
break;
}
} else if (IsCNodeGraph(item)) {
auto cinputs = item->cast<CNodePtr>()->inputs();
if (cinputs.size() < 1) {
MS_LOG(EXCEPTION) << "graph call inputs should greater than 1";
}
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
bool call_fg_has_branch = GraphHasBranch(call_fg);
if (call_fg_has_branch) {
has_branch = true;
break;
}
} else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
auto cinputs = item->cast<CNodePtr>()->inputs();
if (cinputs.size() < 2) {
MS_LOG(EXCEPTION) << "partial call inputs should greater than 2";
}
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
if (call_fg == nullptr) {
continue;
}
bool call_fg_has_branch = GraphHasBranch(call_fg);
if (call_fg_has_branch) {
has_branch = true;
break;
}
}
}
graph_branch_cache_[fg] = has_branch;
return has_branch;
}
private:
bool is_checked_{false}, is_recursive_{false};
bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
};
class Inliner : public InlinerBase {
......
......@@ -1029,6 +1029,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
if (MsContext::GetInstance()->backend_policy() != "ge") {
// for backends excludes 'ge', it can handle multi graph call, use this flag to
// generate call not inline `after_block` graph to reduce if by if switch expansion.
after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
}
// process the if-true branch
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
......
......@@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
......
......@@ -42,7 +42,7 @@ def test_while_forward():
idx = idx + 1
return x
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
......@@ -72,7 +72,7 @@ def test_while_grad():
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -99,7 +99,7 @@ def test_while_with_param_forward():
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
......@@ -124,7 +124,7 @@ def test_while_endless_case():
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
......@@ -159,7 +159,7 @@ def test_while_with_param_grad():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -187,7 +187,7 @@ def test_while_with_param_forward_with_const_branch():
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = while_net
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -224,7 +224,7 @@ def test_while_opt_endless():
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -250,7 +250,7 @@ def test_no_while_call():
out = out + idx + self.param
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = while_net
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -287,7 +287,7 @@ def test_while_with_param_grad_with_const_branch():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -327,7 +327,7 @@ def test_for_while_with_param_grad_with_const_branch():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -364,7 +364,7 @@ def test_for_while_with_param_grad_basic():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -401,7 +401,7 @@ def test_for_while_with_param_grad_normal():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -435,7 +435,7 @@ def test_while_with_param_basic_grad():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -469,7 +469,7 @@ def test_while_with_param_basic_grad_mul():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -504,7 +504,7 @@ def test_while_with_param_basic_grad_two():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -540,7 +540,7 @@ def test_while_with_param_basic_grad_three():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -577,7 +577,7 @@ def test_while_if_with_param_grad():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -610,7 +610,7 @@ def test_while_with_param_grad_not_enter_while():
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(3), dtype=ms.int32)
......@@ -639,7 +639,7 @@ def test_with_param_if_by_if_forward():
out = out + x*2
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -672,7 +672,7 @@ def test_with_param_if_by_if_grad_inputs():
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -706,7 +706,7 @@ def test_with_param_if_by_if_grad_parameter():
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32)
......@@ -738,7 +738,7 @@ def test_with_param_if_by_if_grad_param_excute_null():
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(4), dtype=ms.int32)
......@@ -772,7 +772,7 @@ def test_if_by_if_return_inside_grad():
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(1), dtype=ms.int32)
......@@ -807,10 +807,342 @@ def test_if_by_if_forward():
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(4), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward_control_tuple_switch():
"""tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
class Branch3Net(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
if b == x:
b = self.add(a, b)
else:
b = self.add(a, x)
return a, b, x
class Branch2Net(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
self.net = Branch3Net()
def construct(self, a, b, x):
if a == x:
a = self.mul(a, b)
else:
a = self.div(a, b)
return self.net(a, b, x)
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
self.net = Branch2Net()
def construct(self, a, b, x):
if a < b:
a = self.add(a, b)
else:
a = self.sub(a, b)
a, b, x = self.net(a, b, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward_control_inside_net():
class Branch3Net(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
if b == x:
b = self.add(a, b)
else:
b = self.add(a, x)
a = a * b
out = a + b + x
return out
class Branch2Net(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
self.net = Branch3Net()
def construct(self, a, b, x):
if a == x:
a = self.mul(a, b)
else:
a = self.div(a, b)
return self.net(a, b, x)
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
self.net = Branch2Net()
def construct(self, a, b, x):
if a < b:
a = self.add(a, b)
else:
a = self.sub(a, b)
out = self.net(a, b, x)
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward_use_namespace():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
if a < b:
a = P.TensorAdd()(a, b)
else:
a = P.Sub()(a, b)
if a == x:
a = P.Mul()(a, b)
else:
a = P.RealDiv()(a, b)
if b == x:
b = P.TensorAdd()(a, b)
else:
b = P.TensorAdd()(a, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward_use_global_op():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
add = P.TensorAdd()
sub = P.Sub()
mul = P.Mul()
div = P.RealDiv()
if a < b:
a = add(a, b)
else:
a = sub(a, b)
if a == x:
a = mul(a, b)
else:
a = div(a, b)
if b == x:
b = add(a, b)
else:
b = add(a, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_for_with_if_by_if_forward():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
def construct(self, a, b, x):
for _ in range(0, 4):
if a < b:
a = self.add(a, b)
else:
b = self.sub(b, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_for_with_if_by_if_forward_namespace():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
for _ in range(0, 6):
if a < b:
a = P.TensorAdd()(a, b)
else:
b = P.Sub()(b, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward_const_branch_inner():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
add = P.TensorAdd()
sub = P.Sub()
mul = P.Mul()
div = P.RealDiv()
if a < b:
a = add(a, b)
else:
a = sub(a, b)
if 2 > 1:
a = mul(a, b)
else:
a = div(a, b)
if b == x:
b = add(a, b)
else:
b = add(a, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward_all_const_branch():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
add = P.TensorAdd()
sub = P.Sub()
mul = P.Mul()
div = P.RealDiv()
if 2 < 12:
a = add(a, b)
else:
a = sub(a, b)
if 2 > 1:
a = mul(a, b)
else:
a = div(a, b)
if 2 == 1:
b = add(a, b)
else:
b = add(a, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册