提交 3757c1ee 编写于 作者: S shippingwang

Modify test layers, test=develop

上级 942d7cf7
...@@ -9341,8 +9341,8 @@ def shuffle_channel(x, group=1, name=None): ...@@ -9341,8 +9341,8 @@ def shuffle_channel(x, group=1, name=None):
with multiple group convolutional layers. with multiple group convolutional layers.
Args: Args:
x: The input tensor variable.. x(Variable): The input tensor variable.
group: The num of group group(Integer): The num of group.
Returns: Returns:
Variable: channels shuffled tensor variable. Variable: channels shuffled tensor variable.
...@@ -9358,8 +9358,7 @@ def shuffle_channel(x, group=1, name=None): ...@@ -9358,8 +9358,7 @@ def shuffle_channel(x, group=1, name=None):
""" """
helper = LayerHelper("shuffle_channel", **locals()) helper = LayerHelper("shuffle_channel", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(dtype=x.dtype)
dtype=helper.input_dtype('X'))
if not isinstance(group, int): if not isinstance(group, int):
raise TypeError("group must be int type") raise TypeError("group must be int type")
......
...@@ -1018,8 +1018,8 @@ class TestBook(unittest.TestCase): ...@@ -1018,8 +1018,8 @@ class TestBook(unittest.TestCase):
def test_shuffle_channel(self): def test_shuffle_channel(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data(name="X", shape=[10, 16, 4, 4], dtype="float32") x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
out = layers.shuffle_channel(x, group=2) out = layers.shuffle_channel(x, group=4)
self.assertIsNotNone(out) self.assertIsNotNone(out)
print(str(program)) print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册