未验证 提交 4af491c2 编写于 作者: L liym27 提交者: GitHub

Tensor.shape support control flow if/for/while and bugfix (#22866)

* Support Tensor.shape in control flow if/for/while and separate TensorShapeTransformer from BasicApiTransformer. test=develop
上级 714b0076
......@@ -27,13 +27,14 @@ import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst', 'convert_to_static']
......@@ -62,16 +63,21 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Generic transformation
self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph
basic_api_trans = BasicApiTransformer(node_wrapper,
self.static_analysis_visitor)
basic_api_trans.ast_visit()
# Transform basic api of dygraph to static graph and get feed_name_to_arg_name
basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans.transform()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform Tensor.shape into fluid.layers.shape(Tensor)
TensorShapeTransformer(node_wrapper).transform()
# Transform list used in control flow
ListTransformer(node_wrapper).transform()
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform()
# Transform for loop and while loop
LoopTransformer(node_wrapper).transform()
def visit_FunctionDef(self, node):
......@@ -110,7 +116,7 @@ class BasicApiTransformer(gast.NodeTransformer):
Class to transform basic API from dygraph to static graph.
"""
def __init__(self, wrapper_root, static_analysis_visitor):
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
......@@ -123,20 +129,7 @@ class BasicApiTransformer(gast.NodeTransformer):
self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {}
# Used for transformation of Tensor.shape
self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
self.scope_var_type_dict = {}
self._run_static_visitor()
def _run_static_visitor(self):
var_env = copy.deepcopy(self.static_analysis_visitor.get_var_env())
# TODO: Consider that Tensor.shape is used in sub function and sub_scopes is empty
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def ast_visit(self):
def transform(self):
self.visit(self.root)
return self.wrapper_root
......@@ -153,9 +146,6 @@ class BasicApiTransformer(gast.NodeTransformer):
if self._update_class_node_dict(node):
return None
if self._update_name_to_tensor_shape(node):
return node
for child_node in gast.walk(node.value):
if isinstance(child_node, gast.Call):
self._visit_Call(child_node)
......@@ -171,25 +161,6 @@ class BasicApiTransformer(gast.NodeTransformer):
self._visit_Call(child_node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
if self.is_tensor_shape(node):
return create_api_shape_node(node)
return node
def visit_Name(self, node):
if node.id in self.name_to_tensor_shape:
if self._used_by_paddle_api(node):
tensor_shape_node = self.name_to_tensor_shape[node.id]
if isinstance(tensor_shape_node, gast.Attribute):
return create_api_shape_node(tensor_shape_node)
elif isinstance(tensor_shape_node, gast.Subscript):
result_node = copy.deepcopy(tensor_shape_node)
result_node.value = create_api_shape_node(
tensor_shape_node.value)
return result_node
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
......@@ -198,10 +169,6 @@ class BasicApiTransformer(gast.NodeTransformer):
node = to_assign_node(node)
return node
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary
self.generic_visit(node)
func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name):
......@@ -211,53 +178,6 @@ class BasicApiTransformer(gast.NodeTransformer):
else:
return node
def is_tensor_shape(self, node):
"""
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
"""
assert isinstance(node, gast.Attribute)
if node.attr != 'shape':
return False
try:
value_id = node.value.id
except AttributeError:
return False
if value_id in self.name_to_tensor_shape:
return True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[value_id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
return True
def _used_by_paddle_api(self, node):
assert isinstance(node, (gast.Attribute, gast.Name))
wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node:
# Transformed node is not in node_to_wrapper_map
return False
while wrapper_node.parent:
parent_node = wrapper_node.parent.node
if isinstance(parent_node, gast.Call):
if is_paddle_api(parent_node):
return True
else:
return False
wrapper_node = wrapper_node.parent
return False
def _is_dygraph_forward(self, func_id):
return func_id in self.class_node_dict
......@@ -304,32 +224,6 @@ class BasicApiTransformer(gast.NodeTransformer):
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
def _update_name_to_tensor_shape(self, node):
assert isinstance(node, gast.Assign)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_tensor_shape:
self.name_to_tensor_shape[
target_id] = self.name_to_tensor_shape[value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self.is_tensor_shape(value_node): # eg: x.shape
self.name_to_tensor_shape[target_id] = value_node
return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self.is_tensor_shape(value_node.value): # eg: x.shape[0]
self.name_to_tensor_shape[target_id] = value_node
return True
return False
def convert_to_static(dyfunc):
"""
......
......@@ -20,7 +20,7 @@ import gast
from collections import defaultdict
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
......@@ -70,8 +70,6 @@ class NameVisitor(gast.NodeVisitor):
def __init__(self, root_node):
# Set of gast.Name or gast.Attribute for variables
self.current_seen_vars = set()
# list of nodes of current visit node
self.ancestor_nodes = []
# List of gast.While/gast.For nodes
self.current_loop = []
......@@ -80,6 +78,10 @@ class NameVisitor(gast.NodeVisitor):
self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set)
self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
self.visit(root_node)
def is_control_flow_loop(self, node):
......@@ -123,11 +125,9 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node)
def visit(self, node):
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_Attribute(self, node):
......@@ -166,10 +166,9 @@ class NameVisitor(gast.NodeVisitor):
return ret
def _is_call_func_name_node(self, node):
if self.ancestor_nodes:
parent_node = self.ancestor_nodes[-1]
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
parent_node = self.node_to_wrapper_map[node].parent.node
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
return False
......
......@@ -313,6 +313,10 @@ class StaticAnalysisVisitor(object):
return self.var_env.get_var_type(node.id)
if isinstance(node, gast.Return):
# If return nothing:
if node.value is None:
return {NodeVarType.NONE}
return_type = self.node_to_wrapper_map[node.value].node_var_type
assert self.var_env.cur_scope.scope_type == AstVarScope.SCOPE_TYPE_FUNCTION, "Return at non-function scope"
func_name = self.var_env.cur_scope.scope_name
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import astor
import copy
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
class TensorShapeTransformer(gast.NodeTransformer):
"""
This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_to_tensor_shape = {}
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
var_env = self.static_analysis_visitor.get_var_env()
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self):
self.visit(self.root)
def visit_Assign(self, node):
if self._update_name_to_tensor_shape(node):
return node
self.generic_visit(node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
if self.is_tensor_shape(node):
return create_api_shape_node(node)
return node
def visit_Name(self, node):
if node.id in self.name_to_tensor_shape:
if self._used_by_paddle_api(node):
tensor_shape_node = self.name_to_tensor_shape[node.id]
return create_api_shape_node(tensor_shape_node)
return node
def visit_Call(self, node):
assert isinstance(node, gast.Call)
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary.
self.generic_visit(node)
return node
def visit_If(self, node):
# Call generic_visit first to transform Tensor.shape that is used in Paddle Api.
self.generic_visit(node)
cond = node.test
self._transform_tensor_shape_if_necessary(cond)
return node
def visit_While(self, node):
self.generic_visit(node)
cond = node.test
self._transform_tensor_shape_if_necessary(cond)
return node
def visit_For(self, node):
self.generic_visit(node)
iter = node.iter
self._transform_tensor_shape_if_necessary(iter)
# If tensor.shape is a gast.Name and it is used in range function, transform it
self._transform_tensor_shape_in_range(node)
return node
def _transform_tensor_shape_in_range(self, node):
assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call):
return False
if not isinstance(node.iter.func, gast.Name):
return False
if node.iter.func.id != "range":
return False
args = node.iter.args
for idx, arg in enumerate(args):
if isinstance(arg,
gast.Name) and arg.id in self.name_to_tensor_shape:
args[idx] = create_api_shape_node(self.name_to_tensor_shape[
arg.id])
return True
def _transform_tensor_shape_if_necessary(self, cond):
for child_node in gast.walk(cond):
tensor_shape_node = None
if isinstance(child_node, (gast.Attribute)):
if self.is_tensor_shape(child_node):
tensor_shape_node = child_node
elif isinstance(child_node, (gast.Name)):
if child_node.id in self.name_to_tensor_shape:
tensor_shape_node = self.name_to_tensor_shape[child_node.id]
if tensor_shape_node:
wrapper_node = self.node_to_wrapper_map.get(child_node)
parent_node = wrapper_node.parent.node
for field, value in gast.iter_fields(parent_node):
if child_node is value:
setattr(parent_node, field,
create_api_shape_node(tensor_shape_node))
break
def _used_by_paddle_api(self, node):
assert isinstance(node, (gast.Attribute, gast.Name))
wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node:
# Transformed node is not in node_to_wrapper_map
return False
while wrapper_node.parent:
parent_node = wrapper_node.parent.node
if isinstance(parent_node, gast.Call):
if is_paddle_api(parent_node):
return True
else:
return False
wrapper_node = wrapper_node.parent
return False
def is_tensor_shape(self, node):
"""
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
"""
assert isinstance(node, gast.Attribute)
if node.attr != 'shape':
return False
try:
value_id = node.value.id
except AttributeError:
return False
if value_id in self.name_to_tensor_shape:
return True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[value_id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
return True
def _update_name_to_tensor_shape(self, node):
assert isinstance(node, gast.Assign)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_tensor_shape:
self.name_to_tensor_shape[
target_id] = self.name_to_tensor_shape[value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self.is_tensor_shape(value_node): # eg: x.shape
self.name_to_tensor_shape[target_id] = value_node
return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self.is_tensor_shape(value_node.value): # eg: x.shape[0]
self.name_to_tensor_shape[target_id] = value_node
return True
return False
......@@ -230,12 +230,19 @@ def update_args_of_func(node, dygraph_node, method_name):
def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node, gast.Attribute)
api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value,
args=[tensor_shape_node.value],
keywords=[])
return api_shape_node
assert isinstance(tensor_shape_node, (gast.Attribute, gast.Subscript))
if isinstance(tensor_shape_node, gast.Attribute):
api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value,
args=[tensor_shape_node.value],
keywords=[])
return api_shape_node
if isinstance(tensor_shape_node, gast.Subscript):
result_node = copy.deepcopy(tensor_shape_node)
result_node.value = create_api_shape_node(result_node.value)
return result_node
def get_constant_variable_node(name, value, shape=[1], dtype='int64'):
......@@ -280,6 +287,8 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
# add return statement
if return_name_ids:
nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
else:
nodes.append(gast.Return(value=None))
func_def_node = gast.FunctionDef(
name=name,
args=input_args,
......
......@@ -58,17 +58,118 @@ def dyfunc_tensor_shape_5(x):
return res
test_funcs = [
dyfunc_tensor_shape_1, dyfunc_tensor_shape_2, dyfunc_tensor_shape_3,
dyfunc_tensor_shape_4, dyfunc_tensor_shape_5
]
def dyfunc_with_if_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, [-1, 1])
x_shape_0 = x.shape[0]
if x_shape_0 < 1:
# `res.shape[0] > 1` is transformed into `if fluid.layers.shape(res)[0] > 1`
if res.shape[0] > 1:
res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32")
else:
res = fluid.layers.fill_constant(
value=3, shape=x.shape, dtype="int32")
return res
def dyfunc_with_if_2(x):
x = fluid.dygraph.to_variable(x)
# `len(x.shape)` will not be transformed.
if len(x.shape) < 1:
res = x
else:
res = fluid.layers.fill_constant(value=8, shape=x.shape, dtype="int32")
return res
def dyfunc_with_for_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]`
for i in range(x.shape[0]):
res += 1
return res
def dyfunc_with_for_2(x):
x = fluid.dygraph.to_variable(x)
x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x_shape_0` is transformed into `fluid.layers.shape(x)[0]`
for i in range(x_shape_0):
res += 1
return res
def dyfunc_with_for_3(x):
# TODO(liym27):
# It will fail to run because `for i in range(len(x.shape))` will be transformed into Paddle while_loop.
# Here the python list x.shape will be added to loop_vars. However, loop_vars doesn't support python list.
# And the condition of `for i in range(len(x.shape))` only uses the length of x.shape, so it doesn't have to be transformed into Paddle while_loop.
# After the AST tranformation of for loop is improved, add TestTensorShapeInFor3.
x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `len(x.shape)` is not transformed.
for i in range(len(x.shape)):
res += 1
return res
def dyfunc_with_while_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]`
i = 1
while i < x.shape[0]:
res += 1
i = i + 2
return res
def dyfunc_with_while_2(x):
x = fluid.dygraph.to_variable(x)
x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
i = 1
# `x_shape_0` is transformed into `fluid.layers.shape(x)[0]`
# TODO(liym27): If `x_shape_0` is at right like `while i < x_shape_0`, it will not be transformed.
# Fix this bug next PR.
while x_shape_0 > i:
res += 1
i = i + 2
return res
class TestTensorShape(unittest.TestCase):
def dyfunc_with_while_3(x):
# TODO(liym27):
# It will fail to run because the same problem as `dyfunc_with_for_3`.
# After the AST tranformation of for loop is improved, add TestTensorShapeInWhile3.
x = fluid.dygraph.to_variable(x)
x_shape = x.shape
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
i = 1
# `len(x.shape)` is not transformed.
while len(x_shape) > i:
res += 1
i += 1
return res
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_test_func()
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_1
def get_dygraph_output(self):
with fluid.dygraph.guard():
......@@ -86,14 +187,65 @@ class TestTensorShape(unittest.TestCase):
return static_res[0]
def test_transformed_static_result(self):
for func in test_funcs:
self.dygraph_func = func
static_res = self.get_static_output()
dygraph_res = self.get_dygraph_output()
self.assertTrue(
numpy.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
static_res = self.get_static_output()
dygraph_res = self.get_dygraph_output()
self.assertTrue(
numpy.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
class TestTensorShapeBasic2(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_2
class TestTensorShapeBasic3(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_3
class TestTensorShapeBasic4(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_4
class TestTensorShapeBasic5(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_5
# 2. Tests with control flow if
class TestTensorShapeInIf1(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_if_1
class TestTensorShapeInIf2(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_if_2
# 3. Tests with control flow for loop
class TestTensorShapeInFor1(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_for_1
class TestTensorShapeInFor2(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_for_2
# 4. Tests with control flow while loop
class TestTensorShapeInWhile1(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_1
class TestTensorShapeInWhile2(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_2
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册