未验证 提交 1217a521 编写于 作者: A Aurelius84 提交者: GitHub

Modify the way of inserting newly defined func_nodes (#22837)

* Modify the way of inserting newly defined func_nodes test=develop
上级 c736fef9
...@@ -42,7 +42,7 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -42,7 +42,7 @@ class IfElseTransformer(gast.NodeTransformer):
wrapper_root) wrapper_root)
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.new_func_nodes = [] self.new_func_nodes = {}
def ast_visit(self): def ast_visit(self):
""" """
...@@ -59,10 +59,10 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -59,10 +59,10 @@ class IfElseTransformer(gast.NodeTransformer):
pred_node = node.test pred_node = node.test
true_func_node, false_func_node, return_name_ids = transform_if_else( true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root) node, self.root)
self.new_func_nodes += [true_func_node, false_func_node]
# create layers.cond # create layers.cond
new_node = create_cond_node(return_name_ids, pred_node, new_node = create_cond_node(return_name_ids, pred_node,
true_func_node, false_func_node) true_func_node, false_func_node)
self.new_func_nodes[new_node] = [true_func_node, false_func_node]
return new_node return new_node
else: else:
return node return node
...@@ -82,10 +82,28 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -82,10 +82,28 @@ class IfElseTransformer(gast.NodeTransformer):
It can be used to add the created `true_fn/false_fn` in front of It can be used to add the created `true_fn/false_fn` in front of
the node.body before they are called in cond layer. the node.body before they are called in cond layer.
""" """
assert hasattr(node, 'body') self._insert_func_nodes(node)
# add new ast.funcDef of `if/else`
if self.new_func_nodes: def _insert_func_nodes(self, parent_node):
node.body = self.new_func_nodes + node.body """
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
Because private variables of class or other external scope will be modified.
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if not (self.new_func_nodes and hasattr(parent_node, 'body')):
return
idx = len(parent_node.body) - 1
while idx >= 0:
child_node = parent_node.body[idx]
if child_node in self.new_func_nodes:
parent_node.body[idx:idx] = self.new_func_nodes[child_node]
idx = idx + len(self.new_func_nodes[child_node]) - 1
del self.new_func_nodes[child_node]
else:
self._insert_func_nodes(child_node)
idx = idx - 1
def get_new_func_nodes(self): def get_new_func_nodes(self):
return self.new_func_nodes return self.new_func_nodes
......
...@@ -104,7 +104,7 @@ def get_name_ids(nodes, not_name_set=None, node_black_list=None): ...@@ -104,7 +104,7 @@ def get_name_ids(nodes, not_name_set=None, node_black_list=None):
name_ids = defaultdict(list) name_ids = defaultdict(list)
for node in nodes: for node in nodes:
if node_black_list and node in node_black_list: continue if node_black_list and node in node_black_list: break
if isinstance(node, gast.AST): if isinstance(node, gast.AST):
# In two case, the ast.Name should be filtered. # In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x) # 1. Function name like `my_func` of my_func(x)
......
...@@ -119,5 +119,91 @@ class TestDygraphIfElse3(TestDygraphIfElse): ...@@ -119,5 +119,91 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self.dyfunc = nested_if_else self.dyfunc = nested_if_else
class NetWithControlFlowIf(fluid.dygraph.Layer):
def __init__(self, hidden_dim=16):
super(NetWithControlFlowIf, self).__init__()
self.hidden_dim = hidden_dim
self.fc = fluid.dygraph.Linear(
input_dim=hidden_dim,
output_dim=5,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
self.alpha = 10.
self.constant_vars = {}
@dygraph_to_static_graph
def forward(self, input):
hidden_dim = input.shape[-1]
# Plain `if` statement in Python
if hidden_dim != self.hidden_dim:
raise ValueError(
"hidden_dim {} of input is not equal to FC.weight[0]: {}"
.format(hidden_dim, self.hidden_dim))
self.constant_vars['bias'] = fluid.layers.fill_constant(
[5], dtype='float32', value=1)
# Control flow `if` statement
fc_out = self.fc(input)
if fluid.layers.mean(fc_out).numpy()[0] < 0:
y = fc_out + self.constant_vars['bias']
self.constant_vars['w'] = fluid.layers.fill_constant(
[5], dtype='float32', value=10)
if y.numpy()[0] < self.alpha:
# Create new var, but is not used.
x = 10
tmp = y * self.constant_vars['w']
y = fluid.layers.relu(tmp)
# Nested `if/else`
if y.numpy()[-1] < self.alpha:
# Modify variable of class
self.constant_vars['w'] = fluid.layers.fill_constant(
[hidden_dim], dtype='float32', value=9)
y = fluid.layers.abs(y)
else:
tmp = fluid.layers.fill_constant(
[5], dtype='float32', value=-1)
y = y - tmp
else:
y = fc_out - self.constant_vars['bias']
loss = fluid.layers.mean(y)
return loss
class TestDygraphIfElseNet(unittest.TestCase):
"""
TestCase for the transformation from control flow `if/else`
dependent on tensor in Dygraph into Static `fluid.layers.cond`.
"""
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.Net = NetWithControlFlowIf
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
net = self.Net()
x_v = fluid.layers.assign(self.x)
# Transform into static graph
out = net(x_v)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(main_program, fetch_list=out)
return ret[0]
def _run_dygraph(self):
with fluid.dygraph.guard(place):
net = self.Net()
x_v = fluid.dygraph.to_variable(self.x)
ret = net(x_v)
return ret.numpy()
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册