提交 3757c1ee 编写于 作者: S shippingwang

Modify test layers, test=develop

上级 942d7cf7
......@@ -9341,8 +9341,8 @@ def shuffle_channel(x, group=1, name=None):
with multiple group convolutional layers.
Args:
x: The input tensor variable..
group: The num of group
x(Variable): The input tensor variable.
group(Integer): The num of group.
Returns:
Variable: channels shuffled tensor variable.
......@@ -9358,8 +9358,7 @@ def shuffle_channel(x, group=1, name=None):
"""
helper = LayerHelper("shuffle_channel", **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('X'))
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if not isinstance(group, int):
raise TypeError("group must be int type")
......
......@@ -1018,8 +1018,8 @@ class TestBook(unittest.TestCase):
def test_shuffle_channel(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[10, 16, 4, 4], dtype="float32")
out = layers.shuffle_channel(x, group=2)
x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
out = layers.shuffle_channel(x, group=4)
self.assertIsNotNone(out)
print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册