diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index 702cb8464adb8547458905bd434c241822af7e8a..d5c0310d16579e27714eaf44242bf1a366295ae8 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -17,7 +17,7 @@ import collections import copy import six import numpy as np -from ..framework import Variable, in_dygraph_mode +from ..framework import Block, Variable, in_dygraph_mode from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..layer_helper import LayerHelper from sys import version_info @@ -429,3 +429,31 @@ def try_get_constant_shape_from_tensor(shape_tensor): return None return None + + +def get_inputs_outputs_in_block(block): + """ + Returns the inputs and outputs variable used in this block but not + created in this block. + """ + assert isinstance( + block, + Block), "input non-Block argument for get_inputs_outputs_in_block." + assert block.parent_idx != -1, "input block should be a sub-block, not main block." + + # Find input/output var names of all ops in block + inner_inputs = set() + inner_outputs = set() + for op in block.ops: + for iname in op.input_names: + for in_var_name in op.input(iname): + if not block.has_var(in_var_name): + # variable not created in this block + inner_inputs.add(in_var_name) + for oname in op.output_names: + for out_var_name in op.output(oname): + if not block.has_var(out_var_name): + # variable not created in this block + inner_outputs.add(out_var_name) + + return inner_inputs, inner_outputs diff --git a/python/paddle/fluid/tests/unittests/test_get_inputs_outputs_in_block.py b/python/paddle/fluid/tests/unittests/test_get_inputs_outputs_in_block.py new file mode 100644 index 0000000000000000000000000000000000000000..9e82057959408002fe7c26d8c10be37a27c56d6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_get_inputs_outputs_in_block.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +from paddle.fluid.layers import utils + +paddle.enable_static() + + +class TestGetInputsOutputsInBlock(unittest.TestCase): + def test_ordered(self): + # Program variable names may be different when test order is different + # This helper makes the test ordered. + self._test_while_loop() + self._test_cond() + + def _test_while_loop(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + i = paddle.assign(np.array([1])) + ten = paddle.assign(np.array([10])) + + def while_cond(i): + # use ten in parent block without passing it + return i < ten + + def while_body(i): + # variable created in sub block + one = paddle.assign(np.array([1])) + i = i + one + return [i] + + i = paddle.static.nn.while_loop(while_cond, while_body, [i]) + + sub_block = main_program.block(1) + inner_inputs, inner_outputs = utils.get_inputs_outputs_in_block( + sub_block) + # 'assign_0.tmp_0', 'assign_1.tmp_0' are name of i and ten in program + self.assertTrue(inner_inputs == {'assign_0.tmp_0', 'assign_1.tmp_0'}) + # 'tmp_0', 'assign_0.tmp_0' are name of i < ten and i in program + self.assertTrue(inner_outputs == {'tmp_0', 'assign_0.tmp_0'}) + + def _test_cond(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + a = paddle.zeros((1, 1)) + b = paddle.zeros((1, 1)) + c = a * b + out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b) + + sub_block = main_program.block(1) + inner_inputs, inner_outputs = utils.get_inputs_outputs_in_block( + sub_block) + #'fill_constant_1.tmp_0', 'tmp_3' are names of a, c + self.assertTrue(inner_inputs == {'fill_constant_1.tmp_0', 'tmp_3'}) + #'_generated_var_1', is name of a + c + self.assertTrue(inner_outputs == {'_generated_var_1'}) + + +if __name__ == "__main__": + unittest.main()