未验证 提交 b6f86b84 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Using "isinstance" in Loop, test=develop (#28641)

Fix a bug that used in PaddleGAN model which used `isinstance` in a for loop
上级 e4f94153
......@@ -22,6 +22,7 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
......@@ -84,6 +85,9 @@ class NameVisitor(gast.NodeVisitor):
self.condition_vars = defaultdict(set)
self.in_condition = False
# Some names are types, we shouldn't record them as loop var names.
self.type_vars = set()
self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
......@@ -249,6 +253,18 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node)
self.current_loop.pop()
def visit_Call(self, node):
# Store type var names such as "isinstance(x, some_type_names)" and
# Remove them later
if isinstance(node.func, gast.Name) and node.func.id == 'isinstance':
type_node = node.args[1]
if isinstance(type_node, gast.Tuple):
for element in type_node.elts:
self.type_vars.add(ast_to_source_code(element))
else:
self.type_vars.add(ast_to_source_code(type_node))
self.generic_visit(node)
def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
ret = set()
for node in node_set:
......@@ -290,6 +306,7 @@ class NameVisitor(gast.NodeVisitor):
Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
2. Remove vars only in gast.comprehension.
3. Remove vars that are type names, for example: "isinstance(x, var_type_name)"
:param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
......@@ -361,6 +378,12 @@ class NameVisitor(gast.NodeVisitor):
target_vars_of_for_node.add(var)
removed_vars = target_vars_of_for_node | vars_of_list_generator
# 3. Remove var type names which are stored in self.type_vars
for var in loop_vars:
if ast_to_source_code(var) in self.type_vars:
removed_vars.add(var)
return loop_vars - removed_vars
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
import unittest
import paddle
import paddle.nn as nn
class SimpleReturnLayer(nn.Layer):
def forward(self, x):
return x
class AddAttrLayer(nn.Layer):
def __init__(self):
super(AddAttrLayer, self).__init__()
self.attr = None
def forward(self, x):
out = x + self.attr
return out
class IsInstanceLayer(nn.Layer):
def __init__(self, layer):
super(IsInstanceLayer, self).__init__()
self.layer = layer
@paddle.jit.to_static
def forward(self, x):
if isinstance(self.layer, (AddAttrLayer, )):
self.layer.attr = x
res = self.layer(x)
return res
class SequentialLayer(nn.Layer):
def __init__(self, layers):
super(SequentialLayer, self).__init__()
self.layers = nn.LayerList(layers)
@paddle.jit.to_static
def forward(self, x):
res = x
for layer in self.layers:
if isinstance(layer, AddAttrLayer):
layer.attr = x
res = layer(res)
return res
def train(model, to_static):
prog_trans = paddle.jit.ProgramTranslator.get_instance()
prog_trans.enable(to_static)
x = paddle.ones(shape=[2, 3], dtype='int32')
out = model(x)
return out.numpy()
class TestIsinstance(unittest.TestCase):
def test_isinstance_simple_return_layer(self):
model = IsInstanceLayer(SimpleReturnLayer())
self._test_model(model)
def test_isinstance_add_attr_layer(self):
model = IsInstanceLayer(AddAttrLayer())
self._test_model(model)
def test_sequential_layer(self):
layers = []
for i in range(5):
layers.append(SimpleReturnLayer())
layers.append(AddAttrLayer())
model = SequentialLayer(layers)
self._test_model(model)
def _test_model(self, model):
st_out = train(model, to_static=True)
dy_out = train(model, to_static=False)
self.assertTrue(
np.allclose(dy_out, st_out),
msg="dy_out:\n {}\n st_out:\n{}".format(dy_out, st_out))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册