未验证 提交 ddf317ed 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][py2] fix a decode error caused by 47036 (#47097)

* [CodeStyle][py2] fix an decode error caused by 47036

* add a comment

* add an unittest for Block._rename_var

* add test_block_rename_var to static_mode_white_list
上级 d817d896
...@@ -3658,8 +3658,8 @@ class Block(object): ...@@ -3658,8 +3658,8 @@ class Block(object):
Rename variable in vars and ops' inputs and outputs Rename variable in vars and ops' inputs and outputs
Args: Args:
name(bytes): the name that need to be renamed. name(str|bytes): the name that need to be renamed.
new_name(bytes): the name that need to rename to. new_name(str|bytes): the name that need to rename to.
Raises: Raises:
ValueError: If this block doesn't have this the giving name, ValueError: If this block doesn't have this the giving name,
...@@ -3669,8 +3669,10 @@ class Block(object): ...@@ -3669,8 +3669,10 @@ class Block(object):
Returns: Returns:
Variable: the Variable with the giving name. Variable: the Variable with the giving name.
""" """
name = name.decode() # Ensure the type of name and new_name is str
new_name = new_name.decode() name = name.decode() if isinstance(name, bytes) else name
new_name = new_name.decode() if isinstance(new_name,
bytes) else new_name
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current block" % name) raise ValueError("var %s is not in current block" % name)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
class TestBlockRenameVar(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.block = self.program.current_block()
self.var = self.block.create_var(name="X",
shape=[-1, 23, 48],
dtype='float32')
self.op = self.block.append_op(type="abs",
inputs={"X": [self.var]},
outputs={"Out": [self.var]})
self.new_var_name = self.get_new_var_name()
def get_new_var_name(self):
return "Y"
def test_rename_var(self):
self.block._rename_var(self.var.name, self.new_var_name)
new_var_name_str = self.new_var_name if isinstance(
self.new_var_name, str) else self.new_var_name.decode()
self.assertTrue(new_var_name_str in self.block.vars)
class TestBlockRenameVarStrCase2(TestBlockRenameVar):
def get_new_var_name(self):
return "ABC"
class TestBlockRenameVarBytes(TestBlockRenameVar):
def get_new_var_name(self):
return b"Y"
if __name__ == "__main__":
unittest.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
STATIC_MODE_TESTING_LIST = [ STATIC_MODE_TESTING_LIST = [
'test_affine_channel_op', 'test_affine_channel_op',
'test_block_rename_var',
'test_transfer_dtype_op', 'test_transfer_dtype_op',
'test_transfer_layout_op', 'test_transfer_layout_op',
'test_concat_op', 'test_concat_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册