From 736d3acc24bcc8f05f305446c54f94edb728304e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 19 May 2020 16:03:25 +0800 Subject: [PATCH] [Dy2stat] Support lambda and enhance transformation of IfExpr (#24530) * fix bug with `if Tensor` in is_control_flow test=develop * remove continue test=develop * Support lambda and add unittest test=develop --- .../dygraph_to_static/ifelse_transformer.py | 63 +++++++++++ .../dygraph_to_static/test_lambda.py | 102 ++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_lambda.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index d9458007b60..7c31093568e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -101,8 +101,18 @@ class IfElseTransformer(gast.NodeTransformer): self.generic_visit(node) if need_transform: pred_node, new_assign_nodes = if_condition_visitor.transform() + + if len(new_assign_nodes) > 0: + pred_node = merge_multi_assign_nodes(new_assign_nodes) + new_node = create_cond_node(None, pred_node, node.body, node.orelse, True) + # Note: A blank line will be added separately if transform gast.Expr + # into source code. Using gast.Expr.value instead to avoid syntax error + # in python. + if isinstance(new_node, gast.Expr): + new_node = new_node.value + return new_node else: return node @@ -145,6 +155,59 @@ class IfElseTransformer(gast.NodeTransformer): return self.new_func_nodes +def merge_multi_assign_nodes(assign_nodes): + """ + Merges multiple separate assign statements into a single node. + """ + if not isinstance(assign_nodes, (list, tuple)): + assign_nodes = [assign_nodes] + + return MergeAssignTransformer().transform(assign_nodes) + + +class MergeAssignTransformer(gast.NodeTransformer): + """ + Merges multiple separate assign statements into a single node. + Because it cannot be determined the insertion location of new nodes for `IfExpr`, + so replaces original node with merges conditional node. + + Note: This is a very low level api and only used for IfExpr transformation + in control flow. + + For example: + IfExpr: + y = x+1 if mean or x > 0 else x-1 + + assign nodes: + bool_tensor_1 = fluid.layers.cast(x=mean, dtype='bool') + logic_or_0 = fluid.layers.logical_or(x=bool_tensor_1, y=x > 0) + + merged node: + fluid.layers.logical_or(x=fluid.layers.cast(x=mean, dtype='bool'), y=x > 0) + """ + + def __init__(self): + self._name_to_nodes_value = {} + + def transform(self, nodes): + value = None + for node in nodes: + assert isinstance(node, gast.Assign) + # Note: targets of created assign node in control flow `if` + # only contains one element. + assert isinstance(node.targets[0], gast.Name) + target_name = node.targets[0].id + value = self.visit(node.value) + self._name_to_nodes_value[target_name] = value + + return value + + def visit_Name(self, node): + if node.id in self._name_to_nodes_value: + node = self._name_to_nodes_value[node.id] + return node + + class NodeTestTransformer(gast.NodeTransformer): def __init__(self, ast_node, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lambda.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lambda.py new file mode 100644 index 00000000000..4608977ce77 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lambda.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import paddle.fluid as fluid + +from paddle.fluid.dygraph import declarative + + +def call_lambda_as_func(x): + x = fluid.dygraph.to_variable(x) + + add_func = lambda x, y: x + y + mean_func = lambda x: fluid.layers.mean(x) + + y = add_func(x, 1) + y = add_func(y, add_func(y, -1)) + out = mean_func(y) + + return out + + +def call_lambda_directly(x): + x = fluid.dygraph.to_variable(x) + + y = (lambda x, y: x + y)(x, x) + out = (lambda x: fluid.layers.mean(x))(y) + + return out + + +def call_lambda_in_func(x): + x = fluid.dygraph.to_variable(x) + + add_func = lambda x: x + 1 + + y = fluid.layers.mean((lambda x: fluid.layers.relu(x))(x)) + out = add_func(y) if y > 1 and y < 2 else (lambda x: x**2)(y) + + return out + + +def call_lambda_with_ifExpr(x): + x = fluid.dygraph.to_variable(x) + + add_func = lambda x: x + 1 + + y = fluid.layers.mean(x) + out = add_func(y) if y or y < 2 else (lambda x: x**2)(y) + + return out + + +class TestLambda(unittest.TestCase): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.x = np.array([1, 3]).astype('float32') + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.init_func() + + def init_func(self): + self.dyfuncs = [ + call_lambda_as_func, call_lambda_directly, call_lambda_in_func, + call_lambda_with_ifExpr + ] + + def run_static(self, func): + return self.run_dygraph(func, to_static=True) + + def run_dygraph(self, func, to_static=False): + + with fluid.dygraph.guard(self.place): + x_v = fluid.dygraph.to_variable(self.x) + if to_static: + ret = declarative(func)(x_v) + else: + ret = func(x_v) + return ret.numpy() + + def test_ast_to_func(self): + for func in self.dyfuncs: + self.assertTrue((self.run_dygraph(func) == self.run_static(func) + ).all()) + + +if __name__ == '__main__': + unittest.main() -- GitLab