diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 99d359f4f5b59f9cf6e887c03514baf4e08d852e..40223d4bcc2a9d8a33a1b7093f28e6f5a081165f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3658,8 +3658,8 @@ class Block(object): Rename variable in vars and ops' inputs and outputs Args: - name(bytes): the name that need to be renamed. - new_name(bytes): the name that need to rename to. + name(str|bytes): the name that need to be renamed. + new_name(str|bytes): the name that need to rename to. Raises: ValueError: If this block doesn't have this the giving name, @@ -3669,8 +3669,10 @@ class Block(object): Returns: Variable: the Variable with the giving name. """ - name = name.decode() - new_name = new_name.decode() + # Ensure the type of name and new_name is str + name = name.decode() if isinstance(name, bytes) else name + new_name = new_name.decode() if isinstance(new_name, + bytes) else new_name if not self.has_var(name): raise ValueError("var %s is not in current block" % name) diff --git a/python/paddle/fluid/tests/unittests/test_block_rename_var.py b/python/paddle/fluid/tests/unittests/test_block_rename_var.py new file mode 100644 index 0000000000000000000000000000000000000000..322cb8bc4471f591377bdaa4006039880609e347 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_block_rename_var.py @@ -0,0 +1,56 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + + +class TestBlockRenameVar(unittest.TestCase): + + def setUp(self): + self.program = paddle.static.Program() + self.block = self.program.current_block() + self.var = self.block.create_var(name="X", + shape=[-1, 23, 48], + dtype='float32') + self.op = self.block.append_op(type="abs", + inputs={"X": [self.var]}, + outputs={"Out": [self.var]}) + self.new_var_name = self.get_new_var_name() + + def get_new_var_name(self): + return "Y" + + def test_rename_var(self): + self.block._rename_var(self.var.name, self.new_var_name) + new_var_name_str = self.new_var_name if isinstance( + self.new_var_name, str) else self.new_var_name.decode() + self.assertTrue(new_var_name_str in self.block.vars) + + +class TestBlockRenameVarStrCase2(TestBlockRenameVar): + + def get_new_var_name(self): + return "ABC" + + +class TestBlockRenameVarBytes(TestBlockRenameVar): + + def get_new_var_name(self): + return b"Y" + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 7e92b6b9b7afcfd20873a081c857025839f2a6fb..aec18068c9a5f61ba0fa39483b86690fd5f8c261 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -15,6 +15,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_affine_channel_op', + 'test_block_rename_var', 'test_transfer_dtype_op', 'test_transfer_layout_op', 'test_concat_op',