未验证 提交 03ba5b74 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2static] Add for enumerate Variable support (#24398)

* initial test

* for enumerate basic implement, test=develop

* update unittests, test=develop

* refine unittests to adapt new training mode, test=develop

* refactor for node stmts parsing code, test=develop

* self-review & polish details, test=develop
上级 d980d251
......@@ -19,6 +19,7 @@ import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer']
......@@ -61,87 +62,26 @@ class ForToWhileTransformer(gast.NodeTransformer):
raise ValueError(
"parent_node doesn't contain the loop_node in ForToWhileTransformer")
def get_for_range_node(self, node):
if not isinstance(node.iter, gast.Call):
return None
if not isinstance(node.iter.func, gast.Name):
return None
if node.iter.func.id != "range":
return None
return node.iter
def get_for_args_stmts(self, iter_name, args_list):
'''
Returns 3 gast stmt nodes for argument.
1. Initailize of iterate variable
2. Condition for the loop
3. Statement for changing of iterate variable during the loop
'''
len_range_args = len(args_list)
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
if len_range_args == 1:
init_stmt = get_constant_variable_node(iter_name, 0)
else:
init_stmt = gast.Assign(
targets=[
gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=args_list[0])
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
value=1, kind=None)
old_cond_stmt = gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[range_max_node])
cond_stmt = gast.BoolOp(
op=gast.And(), values=[old_cond_stmt, self.condition_node])
change_stmt = gast.AugAssign(
target=gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
return init_stmt, cond_stmt, change_stmt
def get_for_stmt_nodes(self, node):
assert isinstance(
node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes"
# TODO: support non-range case
range_call_node = self.get_for_range_node(node)
if range_call_node is None:
return [node]
if not isinstance(node.target, gast.Name):
# 1. parse current gast.For node
current_for_node_parser = ForNodeParser(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
iter_var_name = node.target.id
init_stmts, cond_stmt, body_stmts = stmts_tuple
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
iter_var_name, range_call_node.args)
# 2. append break statement
new_cond_stmt = gast.BoolOp(
op=gast.And(), values=[cond_stmt, self.condition_node])
new_body = node.body
new_body.append(change_stmt)
# 3. construct gast.While node
while_node = gast.While(
test=cond_stmt, body=new_body, orelse=node.orelse)
return [init_stmt, while_node]
test=new_cond_stmt, body=body_stmts, orelse=node.orelse)
init_stmts.append(while_node)
return init_stmts
class BreakContinueTransformer(gast.NodeTransformer):
......
......@@ -23,10 +23,12 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
......@@ -321,7 +323,7 @@ class LoopTransformer(gast.NodeTransformer):
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of WhileTransformer."
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_visitor = NameVisitor(self.root)
......@@ -355,86 +357,45 @@ class LoopTransformer(gast.NodeTransformer):
else:
i += 1
def get_for_range_node(self, node):
if not isinstance(node.iter, gast.Call):
return None
if not isinstance(node.iter.func, gast.Name):
return None
if node.iter.func.id != "range":
return None
return node.iter
def get_for_args_stmts(self, iter_name, args_list):
'''
Returns 3 gast stmt nodes for argument.
1. Initailize of iterate variable
2. Condition for the loop
3. Statement for changing of iterate variable during the loop
NOTE(TODO): Python allows to access iteration variable after loop, such
as "for i in range(10)" will create i = 9 after the loop. But using
current conversion will make i = 10. We should find a way to change it
'''
len_range_args = len(args_list)
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
if len_range_args == 1:
init_stmt = get_constant_variable_node(iter_name, 0)
else:
init_stmt = gast.Assign(
targets=[
gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=args_list[0])
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
value=1, kind=None)
cond_stmt = gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[range_max_node])
change_stmt = gast.AugAssign(
target=gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
return init_stmt, cond_stmt, change_stmt
def get_for_stmt_nodes(self, node):
# TODO: consider for - else in python
if not self.name_visitor.is_control_flow_loop(node):
return [node]
# TODO: support non-range case
range_call_node = self.get_for_range_node(node)
if range_call_node is None:
# 1. check whether need to transform
# NOTE: Current need transform cases:
# 1). for x in range(VarBase.numpy()[0])
# 2). for x in VarBase.numpy()
# 3). for i, x in enumerate(VarBase.numpy())
if not self.name_visitor.is_control_flow_loop(node):
return [node]
if not isinstance(node.target, gast.Name):
# 2. get key statements for different cases
# NOTE: three key statements:
# 1). init_stmts: list[node], prepare nodes of for loop, may not only one
# 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement
current_for_node_parser = ForNodeParser(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
iter_var_name = node.target.id
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
iter_var_name, range_call_node.args)
init_stmts, cond_stmt, body_stmts = stmts_tuple
# 3. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
# we need append new loop var & remove useless loop var
# 1. for x in var -> x is no need
# 2. for i, x in enumerate(var) -> x is no need
if current_for_node_parser.is_for_iter(
) or current_for_node_parser.is_for_enumerate_iter():
iter_var_name = current_for_node_parser.iter_var_name
iter_idx_name = current_for_node_parser.iter_idx_name
loop_var_names.add(iter_idx_name)
if iter_var_name not in create_var_names:
loop_var_names.remove(iter_var_name)
# 4. prepare result statement list
new_stmts = []
# Python can create variable in loop and use it out of loop, E.g.
#
......@@ -447,12 +408,13 @@ class LoopTransformer(gast.NodeTransformer):
if "." not in name:
new_stmts.append(create_static_variable_gast_node(name))
new_stmts.append(init_stmt)
# 5. append init statements
new_stmts.extend(init_stmts)
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
# 6. create & append condition function node
condition_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(
......@@ -480,9 +442,9 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node)
new_body = node.body
new_body.append(change_stmt)
new_body.append(
# 7. create & append loop body function node
# append return values for loop body
body_stmts.append(
gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load())))
body_func_node = gast.FunctionDef(
......@@ -501,7 +463,7 @@ class LoopTransformer(gast.NodeTransformer):
kw_defaults=None,
kwarg=None,
defaults=[]),
body=new_body,
body=body_stmts,
decorator_list=[],
returns=None,
type_comment=None)
......@@ -512,6 +474,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node)
# 8. create & append while loop node
while_loop_node = create_while_node(condition_func_node.name,
body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node)
......
......@@ -25,6 +25,8 @@ import os
import six
import tempfile
from paddle.fluid import unique_name
dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
"ExponentialDecay": "exponential_decay",
......@@ -35,6 +37,9 @@ dygraph_class_to_static_api = {
"PolynomialDecay": "polynomial_decay",
}
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_VAR_SHAPE_PREFIX = '__for_loop_var_shape'
def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj)
......@@ -504,12 +509,22 @@ class IsControlFlowVisitor(gast.NodeVisitor):
assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call):
return
if not isinstance(node.iter.func, gast.Name):
return
if node.iter.func.id != "range":
return
# for in range(v.numpy()) or for in enumerate(v.numpy())
if isinstance(node.iter.func, gast.Name):
if node.iter.func.id == "range" or node.iter.func.id == "enumerate":
for arg in node.iter.args:
self.visit(arg)
else:
return
# for in v.numpy()
elif isinstance(node.iter.func, gast.Attribute):
if node.iter.func.attr == 'numpy':
self._visit_Call(node.iter)
else:
return
else:
return
for child_node in gast.walk(node):
if isinstance(child_node, (gast.Continue, gast.Break)):
......@@ -609,3 +624,308 @@ class IsControlFlowVisitor(gast.NodeVisitor):
def get_compare_nodes_with_tensor(self):
return self._compare_node_tenor_set
class NameNodeReplaceTransformer(gast.NodeTransformer):
"""
This class transform specfice gast.Name node to replace node
"""
def __init__(self, root_node, target_name, replace_node):
assert isinstance(target_name, str)
self.target_name = target_name
self.replace_node = replace_node
self.visit(root_node)
def visit_Name(self, node):
if node.id == self.target_name:
return self.replace_node
return node
class ForNodeParser(object):
"""
This class parse python for statement, get transformed 3 statement components of for node
three key statements:
1). init_stmts: list[node], prepare nodes of for loop, may not only one
2). cond_stmt: node, condition node to judge whether continue loop
3). body_stmts: list[node], updated loop body, sometimes we should change
the original statement in body, not just append new statement
In this process, the semantics of for does not change.
Now only can parse 3 type statements:
1). for x in range(***)
2). for x in var.numpy()
3). for i, x enumerate(var.numpy())
"""
def __init__(self, for_node):
assert isinstance(
for_node, gast.For
), "Input node for the initialization of ForNodeParser is not gast.For node."
# 1. original for node
self.node = for_node
# 2. gast.For node main parts
self.target = for_node.target
# NOTE: type may be Node or list[Node]
self.iter_args = for_node.iter if self.is_for_iter(
) else for_node.iter.args
self.body = for_node.body
# 3. key shared node or names
# - x:
# - for x in range(***)
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
self.iter_var_name = self._get_iter_var_name()
# - created index var to slice Variable: __for_loop_var_index_0
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX)
# - created shape var to build loop condition: __for_loop_var_shape_0
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
self.iter_var_shape_name = unique_name.generate(
FOR_ITER_VAR_SHAPE_PREFIX)
# - var.numpy()
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
self.iter_node = self._get_iter_node()
# - enumeate i:
# - for i, x enumerate(var.numpy())
self.enum_idx_name = self._get_enum_idx_name()
# - range/enumerate args length
self.args_length = None
def parse(self):
self._args_check()
if self.is_for_range_iter():
return self._parse_for_range_stmts()
elif self.is_for_iter():
return self._parse_for_stmts()
elif self.is_for_enumerate_iter():
return self._parse_for_enumerate_stmts()
else:
raise None
def is_for_range_iter(self):
return isinstance(self.node.iter.func,
gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self):
return isinstance(
self.node.iter.func,
gast.Attribute) and self.node.iter.func.attr == 'numpy'
def is_for_enumerate_iter(self):
return isinstance(self.node.iter.func,
gast.Name) and self.node.iter.func.id == "enumerate"
def _args_check(self):
if self.is_for_range_iter():
self.args_length = len(self.iter_args)
assert self.args_length >= 1 and self.args_length <= 3, "range() function takes 1 to 3 arguments"
elif self.is_for_enumerate_iter():
self.args_length = len(self.iter_args)
assert self.args_length >= 1 and self.args_length <= 2, "enumerate() function takes 1 to 2 arguments"
else:
self.args_length = None
def _parse_for_range_stmts(self):
init_stmts = []
init_stmts.append(self._build_index_init_node())
compare_node = self._build_compare_node()
step_node = self._build_step_node()
cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body
body_stmts.append(self._build_index_increase_node(step_node))
return init_stmts, cond_stmt, body_stmts
def _parse_for_stmts(self):
init_stmts = []
init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_shape_assign_node())
compare_node = self._build_compare_node()
step_node = self._build_step_node()
cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body
var_slice_node = self._build_var_slice_node()
for body_node in body_stmts:
NameNodeReplaceTransformer(body_node, self.iter_var_name,
var_slice_node)
body_stmts.append(self._build_index_increase_node(step_node))
return init_stmts, cond_stmt, body_stmts
def _parse_for_enumerate_stmts(self):
init_stmts = []
init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_shape_assign_node())
init_stmts.append(self._build_enum_init_node())
compare_node = self._build_compare_node()
step_node = self._build_step_node()
cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body
var_slice_node = self._build_var_slice_node()
for body_node in body_stmts:
NameNodeReplaceTransformer(body_node, self.iter_var_name,
var_slice_node)
body_stmts.append(self._build_index_increase_node(step_node))
body_stmts.append(self._build_enum_increase_node())
return init_stmts, cond_stmt, body_stmts
def _build_index_init_node(self):
if self.is_for_range_iter():
if self.args_length == 1:
index_init_node = get_constant_variable_node(self.iter_var_name,
0)
else:
index_init_node = gast.Assign(
targets=[
gast.Name(
id=self.iter_var_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=self.iter_args[0])
else:
# TODO: slice bug, only support int32 index
index_init_node = get_constant_variable_node(
self.iter_idx_name, 0, dtype='int32')
return index_init_node
def _build_var_shape_assign_node(self):
# get variable shape as iter length
return gast.Assign(
targets=[
gast.Name(
id=self.iter_var_shape_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
value=create_api_shape_node(self.iter_node.func))
def _build_enum_init_node(self):
enum_init_node = get_constant_variable_node(
name=self.enum_idx_name, value=0)
if self.is_for_enumerate_iter() and self.args_length != 1:
enum_init_node = gast.Assign(
targets=[
gast.Name(
id=self.enum_idx_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=self.iter_args[1])
return enum_init_node
def _build_compare_node(self):
if self.is_for_range_iter():
compare_node = self.iter_args[
0] if self.args_length == 1 else self.iter_args[1]
else:
compare_node = gast.Subscript(
value=gast.Name(
id=self.iter_var_shape_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
slice=gast.Index(value=gast.Constant(
value=0, kind=None)),
ctx=gast.Load())
return compare_node
def _build_step_node(self):
if self.is_for_range_iter():
step_node = self.iter_args[
2] if self.args_length == 3 else gast.Constant(
value=1, kind=None)
else:
step_node = gast.Constant(value=1, kind=None)
return step_node
def _build_cond_stmt(self, step_node, compare_node):
return gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=self.iter_var_name
if self.is_for_range_iter() else self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[compare_node])
def _build_index_increase_node(self, step_node):
return gast.AugAssign(
target=gast.Name(
id=self.iter_var_name
if self.is_for_range_iter() else self.iter_idx_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
def _build_var_slice_node(self):
return gast.Subscript(
value=self.iter_node,
slice=gast.Index(value=gast.Name(
id=self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)),
ctx=gast.Load())
def _build_enum_increase_node(self):
return gast.AugAssign(
target=gast.Name(
id=self.enum_idx_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=gast.Constant(
value=1, kind=None))
def _get_iter_var_name(self):
if self.is_for_range_iter():
return self.target.id
elif self.is_for_iter():
return self.target.id
elif self.is_for_enumerate_iter():
return self.target.elts[1].id
return None
def _get_iter_node(self):
if self.is_for_iter():
return self.iter_args
elif self.is_for_enumerate_iter():
return self.iter_args[0]
return None
def _get_enum_idx_name(self):
if self.is_for_enumerate_iter():
return self.target.elts[0].id
return None
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
program_translator = ProgramTranslator()
# 0. for in range with var case
@declarative
def dygraph_for_in_range(x):
z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]):
z = z + i
return z
# 1. for iter list
@declarative
def dygraph_for_iter_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in x_array:
z = z + x
return z
# 2. for enumerate list
@declarative
def dygraph_for_enumerate_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(x_array):
z = z + x + i
return z
# 3. for iter var.numpy()
@declarative
def dygraph_for_iter_var_numpy(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array.numpy():
z = z + x
return z
# 4. for enumerate var.numpy()
@declarative
def dygraph_for_enumerate_var_numpy(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array.numpy()):
y = y + i
z = z + x
return y, z
# 5. for enumerate var.numpy() with start
@declarative
def dygraph_for_enumerate_var_numpy_with_start(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array.numpy(), 1):
y = y + i
z = z + x
return y, z
# 6. for in range with break
@declarative
def dygraph_for_in_range_with_break(x):
z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]):
z = z + i
if i > 2:
break
return z
# 7. for enumerate var.numpy() with break
@declarative
def dygraph_for_enumerate_var_numpy_with_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array.numpy()):
y = y + i
z = z + x
if i > 2:
break
return y, z
# 8. for enumerate var.numpy() with continue
@declarative
def dygraph_for_enumerate_var_numpy_with_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array.numpy()):
y = y + i
if i > 2:
continue
z = z + x
return y, z
# 9. for enumerate var.numpy() with start & break
@declarative
def dygraph_for_enumerate_var_numpy_with_start_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array.numpy(), 1):
y = y + i
z = z + x
if i > 2:
break
return y, z
# 10. for enumerate var.numpy() with start & continue
@declarative
def dygraph_for_enumerate_var_numpy_with_start_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array.numpy(), 1):
y = y + i
if i > 2:
continue
z = z + x
return y, z
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.set_input()
self.set_test_func()
def set_input(self):
self.input = [1, 2, 3]
def set_test_func(self):
raise NotImplementedError(
"For Enumerate test should implement set_test_func")
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard():
return self.dygraph_func(self.input)
def get_dygraph_output(self):
return self._run(to_static=False)
def get_static_output(self):
return self._run(to_static=True)
class TestTransform(TestTransformBase):
def transformed_result_compare(self):
dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, tuple):
dy_outs = (dy_outs, )
# NOTE: return type is difference
st_outs = self.get_static_output()
if not isinstance(st_outs, list):
st_outs = (st_outs, )
else:
st_outs = tuple(st_outs)
for x, y in zip(dy_outs, st_outs):
self.assertTrue(np.allclose(x.numpy(), y.numpy()))
class TestTransformError(TestTransformBase):
def transformed_error(self, etype):
with self.assertRaises(etype):
dy_out = self.get_dygraph_output()
st_out = self.get_static_output()
class TestForInRange(TestTransform):
def set_input(self):
self.input = np.array([5])
def set_test_func(self):
self.dygraph_func = dygraph_for_in_range
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForIterList(TestTransform):
def set_test_func(self):
self.dygraph_func = dygraph_for_iter_list
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForEnumerateSimple(TestForIterList):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_list
class TestForInRangeWithBreak(TestForInRange):
def set_test_func(self):
self.dygraph_func = dygraph_for_in_range_with_break
class TestForIterVarNumpy(TestTransform):
def set_input(self):
self.input = np.array([1, 2, 3, 4, 5])
def set_test_func(self):
self.dygraph_func = dygraph_for_iter_var_numpy
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForEnumerateVarNumpy(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy
class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_break
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_continue
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_break
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_continue
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册