提交 04c522d0 编写于 作者: Z zhaojichen

Add Group Normalization

上级 0b7de696
......@@ -57,6 +57,7 @@ def test_compile():
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
_executor.compile(net, input_data)
class GroupNet(nn.Cell):
def __init__(self):
super(GroupNet, self).__init__()
......@@ -64,6 +65,7 @@ class GroupNet(nn.Cell):
def construct(self, x):
return self.group_bn(x)
def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册