未验证 提交 d7a7c5f0 编写于 作者: H Huihuang Zheng 提交者: GitHub

Support Simple For Range Loop in Dygraph to Static (#22867)

1. Add basic support for `for in range` loop
2. Move `test_dygraph_to_static_*` to `dygraph_to_static` dir and rename them
3. Add test case for dict in while_loop
上级 f70f1cf1
......@@ -21,6 +21,8 @@ from collections import defaultdict
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
......@@ -29,6 +31,9 @@ __all__ = ['LoopTransformer', 'NameVisitor']
WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
def create_while_node(condition_name, body_name, loop_var_names):
while_args = []
......@@ -63,13 +68,16 @@ class NameVisitor(gast.NodeVisitor):
'''
def __init__(self, root_node):
# Set of gast.Name
# Set of gast.Name or gast.Attribute for variables
self.current_seen_vars = set()
# list of nodes of current visit node
self.ancestor_nodes = []
# List of gast.While/gast.For nodes
self.current_loop = []
# Mapping from gast.While/gast.For to string name of vars
self.before_loop_vars = defaultdict(set)
# Mapping from gast.While/gast.For to variable nodes
self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set)
self.visit(root_node)
......@@ -86,13 +94,12 @@ class NameVisitor(gast.NodeVisitor):
read_context = {type(gast.Load()), type(gast.AugLoad())}
in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = set(name.id for name in in_loop_vars)
before_loop_vars = self.before_loop_vars[node]
before_loop_name_strs = set(name.id for name in before_loop_vars)
after_loop_vars = self.current_seen_vars - before_loop_vars - in_loop_vars
after_loop_name_strs = set(
name.id for name in after_loop_vars
if type(name.ctx) in read_context)
in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
before_loop_body_vars = self.before_loop_body_vars[node]
before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
read_context)
for name in in_loop_name_strs:
if name in before_loop_name_strs:
# If a variable is used in loop and created before loop, it
......@@ -106,23 +113,65 @@ class NameVisitor(gast.NodeVisitor):
return loop_var_names, create_var_names
def visit_Name(self, node):
if self._is_call_func_name_node(node):
self.generic_visit(node)
return
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
self.generic_visit(node)
def visit(self, node):
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_Attribute(self, node):
if self._is_call_func_name_node(node):
return
attr_full_name = get_attribute_full_name(node)
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
# sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again
def visit_For(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.visit(node.target)
self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
def visit_While(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.visit(node.test)
self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
ret = set()
for node in node_set:
if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
if isinstance(node, gast.Name):
ret.add(node.id)
elif isinstance(node, gast.Attribute):
ret.add(get_attribute_full_name(node))
return ret
def _is_call_func_name_node(self, node):
if self.ancestor_nodes:
parent_node = self.ancestor_nodes[-1]
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
return False
class LoopTransformer(gast.NodeTransformer):
"""
......@@ -140,11 +189,6 @@ class LoopTransformer(gast.NodeTransformer):
def transform(self):
self.visit(self.root)
def get_for_stmt_nodes(self, node):
self.generic_visit(node)
# TODO
return node
def visit(self, node):
self.generic_visit(node)
# All parent nodes that may contain gast.While/gast.For
......@@ -165,15 +209,166 @@ class LoopTransformer(gast.NodeTransformer):
body_list[i:i + 1] = new_stmts
i += len(new_stmts)
elif isinstance(body_list[i], gast.For):
# TODO
i += 1
new_stmts = self.get_for_stmt_nodes(body_list[i])
body_list[i:i + 1] = new_stmts
i += len(new_stmts)
else:
i += 1
def get_for_range_node(self, node):
if not isinstance(node.iter, gast.Call):
return None
if not isinstance(node.iter.func, gast.Name):
return None
if node.iter.func.id != "range":
return None
return node.iter
def get_for_args_stmts(self, iter_name, args_list):
'''
Returns 3 gast stmt nodes for argument.
1. Initailize of iterate variable
2. Condition for the loop
3. Statement for changing of iterate variable during the loop
NOTE(TODO): Python allows to access iteration variable after loop, such
as "for i in range(10)" will create i = 9 after the loop. But using
current conversion will make i = 10. We should find a way to change it
'''
len_range_args = len(args_list)
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
if len_range_args == 1:
init_stmt = get_constant_variable_node(iter_name, 0)
else:
init_stmt = gast.Assign(
targets=[
gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=args_list[0])
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
value=1, kind=None)
cond_stmt = gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[range_max_node])
change_stmt = gast.AugAssign(
target=gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
return init_stmt, cond_stmt, change_stmt
def get_for_stmt_nodes(self, node):
# TODO: consider for - else in python
if not self.name_visitor.is_control_flow_loop(node):
return [node]
# TODO: support non-range case
range_call_node = self.get_for_range_node(node)
if range_call_node is None:
return [node]
if not isinstance(node.target, gast.Name):
return [node]
iter_var_name = node.target.id
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
iter_var_name, range_call_node.args)
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
new_stmts = []
# Python can create variable in loop and use it out of loop, E.g.
#
# for x in range(10):
# y += x
# print(x) # x = 10
#
# We need to create static variable for those variables
for name in create_var_names:
new_stmts.append(create_static_variable_gast_node(name))
new_stmts.append(init_stmt)
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
condition_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(
args=[
gast.Name(
id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=[gast.Return(value=cond_stmt)],
decorator_list=[],
returns=None,
type_comment=None)
new_stmts.append(condition_func_node)
new_body = node.body
new_body.append(change_stmt)
new_body.append(
gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load())))
body_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_BODY_PREFIX),
args=gast.arguments(
args=[
gast.Name(
id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=new_body,
decorator_list=[],
returns=None,
type_comment=None)
new_stmts.append(body_func_node)
while_loop_node = create_while_node(condition_func_node.name,
body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node)
return new_stmts
def get_while_stmt_nodes(self, node):
# TODO: consider while - else in python
# self.generic_visit(node)
if not self.name_visitor.is_control_flow_loop(node):
return [node]
......
......@@ -215,6 +215,18 @@ def create_api_shape_node(tensor_shape_node):
return api_shape_node
def get_constant_variable_node(name, value, shape=[1], dtype='int64'):
return gast.parse('%s = fluid.layers.fill_constant(%s, "%s", %s)' %
(name, str(shape), dtype, str(value)))
def get_attribute_full_name(node):
assert isinstance(
node,
gast.Attribute), "Input non-Attribute node to get attribute full name"
return astor.to_source(gast.gast_to_ast(node)).strip()
def generate_name_node(name_ids, ctx=gast.Load()):
"""
Generate list or gast.Tuple of ast.Name for Return statement.
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import six
import gast
from paddle.fluid.layers import fill_constant
......@@ -39,8 +40,15 @@ def to_static_variable(x):
'''
if isinstance(x, bool):
return fill_constant(shape=[1], dtype='bool', value=x)
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x)
if isinstance(x, float):
return fill_constant(shape=[1], dtype='float64', value=x)
if six.PY2:
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int32', value=x)
if isinstance(x, long):
return fill_constant(shape=[1], dtype='int64', value=x)
else:
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x)
return x
......@@ -59,7 +59,17 @@ class SubNetWithDict(fluid.dygraph.Layer):
cache_k, cache_v = cache["k"], cache["v"]
k = 0.1 * cache_k + k
v = 0.2 * cache_v + v
cache["k"], cache["v"] = k, v
# TODO: currently while_loop can have a dict as loop_vars, but
# to change the value in a dict, you have to use layers.assign
# because cache["k"] = k is putting k in dict without building
# network. So we cannot write:
#
# cache["k"], cache["v"] = k, v
#
# we have to support this kind of dict in loop in the future.
# For example, automatically change = to assign in AutoTracer
fluid.layers.assign(k, cache["k"])
fluid.layers.assign(v, cache["v"])
weight = fluid.layers.matmul(x=q, y=k, transpose_y=True)
weight = fluid.layers.softmax(weight)
......@@ -94,12 +104,20 @@ class MainNetWithDict(fluid.dygraph.Layer):
for i in range(max_len):
out = self.sub_net(out, cache)
cache = self.update_cache(cache)
return out
def update_cache(self, cache):
for k, val in six.iteritems(cache):
cache[k] = fluid.layers.softmax(val)
# TODO: currently while_loop can have a dict as loop_vars, but
# to change the value in a dict, you have to use layers.assign
# because cache["k"] = k is putting k in dict without building
# network. So we cannot write:
#
# cache[k] = fluid.layers.softmax(val)
#
# we have to support this kind of dict in loop in the future.
# For example, automatically change = to assign in AutoTracer
fluid.layers.assign(fluid.layers.softmax(val), cache[k])
return cache
......
......@@ -35,20 +35,34 @@ def while_loop_dyfunc(x):
return i
def for_loop_dyfunc(max_len):
for i in range(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc]
self.loop_var_names = [set(["i", "x"]), set(["i", "ret", "max_len"])]
self.create_var_names = [set(), set(["ret"])]
def test_loop_vars(self):
test_func = inspect.getsource(while_loop_dyfunc)
gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
for node in gast.walk(gast_root):
if isinstance(node, gast.While):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
self.assertEqual(loop_var_names, set(["i", "x"]))
self.assertEqual(create_var_names, set())
class TestTransformWhile(unittest.TestCase):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
for node in gast.walk(gast_root):
if isinstance(node, (gast.While, gast.For)):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
class TestTransformWhileLoop(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
......@@ -83,5 +97,35 @@ class TestTransformWhile(unittest.TestCase):
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformForLoop(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.len = 100
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_func = dygraph_to_static_graph(for_loop_dyfunc)
out = static_func(self.len)
exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out)
return ret
def _run_dygraph(self):
with fluid.dygraph.guard(self.place):
ret = for_loop_dyfunc(self.len)
return ret.numpy()
def test_ast_to_func(self):
static_numpy = self._run_static()
self.assertTrue(
np.allclose(
np.full(
shape=(1), fill_value=2, dtype=np.int32), static_numpy))
self._run_dygraph()
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
if __name__ == '__main__':
unittest.main()
......@@ -77,6 +77,34 @@ class TestApiWhileLoop(unittest.TestCase):
data = np.add(data, data_one)
self.assertTrue(np.allclose(np.asarray(res[1]), data))
def test_var_dict(self):
def cond(i, ten, test_dict):
return layers.less_than(i, ten)
def body(i, ten, test_dict):
layers.assign(i, test_dict["test_key"])
i = layers.increment(i)
return [i, ten, test_dict]
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
i = layers.zeros(shape=[1], dtype='int64')
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
test_data = layers.fill_constant(shape=[1], dtype='int64', value=0)
test_dict = {"test_key": test_data}
i, ten, test_dict = layers.while_loop(cond, body,
[i, ten, test_dict])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[test_dict["test_key"]])
self.assertTrue(
np.allclose(
np.asarray(res[0]),
np.full(
shape=(1), fill_value=9, dtype=np.int64)))
class TestApiWhileLoop_Nested(unittest.TestCase):
def test_nested_net(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册