未验证 提交 4955c97e 编写于 作者: A Aurelius84 提交者: GitHub

Add unitTest for `Tensor==constant` for ifElse in dygraph2static (#23407)

* Add unitTest for `Tensor==constant` for ifElse in dygraph2static test=develop
上级 9676ac1c
......@@ -85,6 +85,7 @@ class IfElseTransformer(gast.NodeTransformer):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
self.generic_visit(node)
return node
def visit_IfExp(self, node):
......@@ -292,12 +293,12 @@ class NodeTestTransformer(gast.NodeTransformer):
return self.visit(self.ast_root)
def visit_Call(self, node):
# self.generic_visit(node)
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
self.generic_visit(node)
return node
def visit_UnaryOp(self, node):
......
......@@ -116,6 +116,30 @@ class TestDygraphIfElse6(TestDygraphIfElse):
self.dyfunc = dyfunc_ifExp_with_while
def dyfunc_ifExp_with_while2(x):
y = [x]
def add_fn(x):
x = x + 1
return x
def map_func(func, tensor_list):
return [func(x) for x in tensor_list]
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
# It will be converted into `layers.cond` as followed.
# map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y)
# `i (Tensor) == 0` is supported in dygraph.
y = map_func(lambda x: x if i == 0 else add_fn(x), y)
return y[0]
class TestDygraphIfElse7(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_ifExp_with_while2
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册