提交 d9522982 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5117 Fix test case interface

Merge pull request !5117 from amongo/FixTestCaseInterface
...@@ -26,6 +26,9 @@ from mindspore.ops import operations as P ...@@ -26,6 +26,9 @@ from mindspore.ops import operations as P
# from tests.vm_impl.vm_interface import * # from tests.vm_impl.vm_interface import *
# from tests.vm_impl import * # from tests.vm_impl import *
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
grad_all = C.GradOperation('get_all', get_all=True)
def setup_module(): def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
...@@ -86,7 +89,7 @@ def test_while_opt_endless(): ...@@ -86,7 +89,7 @@ def test_while_opt_endless():
@ms_function @ms_function
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -149,7 +152,7 @@ def test_while_with_param_grad_with_const_branch(): ...@@ -149,7 +152,7 @@ def test_while_with_param_grad_with_const_branch():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -189,7 +192,7 @@ def test_for_while_with_param_grad_with_const_branch(): ...@@ -189,7 +192,7 @@ def test_for_while_with_param_grad_with_const_branch():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -226,7 +229,7 @@ def test_for_while_with_param_grad_basic(): ...@@ -226,7 +229,7 @@ def test_for_while_with_param_grad_basic():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -263,7 +266,7 @@ def test_for_while_with_param_grad_normal(): ...@@ -263,7 +266,7 @@ def test_for_while_with_param_grad_normal():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -297,7 +300,7 @@ def test_while_with_param_basic_grad(): ...@@ -297,7 +300,7 @@ def test_while_with_param_basic_grad():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -331,7 +334,7 @@ def test_while_with_param_basic_grad_mul(): ...@@ -331,7 +334,7 @@ def test_while_with_param_basic_grad_mul():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -366,7 +369,7 @@ def test_while_with_param_basic_grad_two(): ...@@ -366,7 +369,7 @@ def test_while_with_param_basic_grad_two():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -402,7 +405,7 @@ def test_while_with_param_basic_grad_three(): ...@@ -402,7 +405,7 @@ def test_while_with_param_basic_grad_three():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -439,7 +442,7 @@ def test_while_if_with_param_grad(): ...@@ -439,7 +442,7 @@ def test_while_if_with_param_grad():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -472,7 +475,7 @@ def test_while_with_param_grad_not_enter_while(): ...@@ -472,7 +475,7 @@ def test_while_with_param_grad_not_enter_while():
@ms_function @ms_function
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)
...@@ -534,7 +537,7 @@ def test_with_param_if_by_if_grad_inputs(): ...@@ -534,7 +537,7 @@ def test_with_param_if_by_if_grad_inputs():
@ms_function @ms_function
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
...@@ -568,7 +571,7 @@ def test_with_param_if_by_if_grad_parameter(): ...@@ -568,7 +571,7 @@ def test_with_param_if_by_if_grad_parameter():
@ms_function @ms_function
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
...@@ -600,7 +603,7 @@ def test_with_param_if_by_if_grad_param_excute_null(): ...@@ -600,7 +603,7 @@ def test_with_param_if_by_if_grad_param_excute_null():
@ms_function @ms_function
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
...@@ -634,7 +637,7 @@ def test_if_by_if_return_inside_grad(): ...@@ -634,7 +637,7 @@ def test_if_by_if_return_inside_grad():
@ms_function @ms_function
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
net = GradNet(if_net) net = GradNet(if_net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册