提交 46700bec 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2500 add output activation quant in mobilenetv2 and resnet50

Merge pull request !2500 from chenzupeng/r0.3
...@@ -193,10 +193,12 @@ class MobileNetV2Quant(nn.Cell): ...@@ -193,10 +193,12 @@ class MobileNetV2Quant(nn.Cell):
# mobilenet head # mobilenet head
head = ([GlobalAvgPooling(), head = ([GlobalAvgPooling(),
nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel, nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)] if not has_dropout else symmetric=_symmetric),
nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), [GlobalAvgPooling(), nn.Dropout(0.2),
nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel, nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)]) symmetric=_symmetric),
nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)])
self.head = nn.SequentialCell(head) self.head = nn.SequentialCell(head)
def construct(self, x): def construct(self, x):
......
...@@ -24,6 +24,7 @@ _symmetric = False ...@@ -24,6 +24,7 @@ _symmetric = False
_fake = True _fake = True
_per_channel = True _per_channel = True
def _weight_variable(shape, factor=0.01): def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value) return Tensor(init_value)
...@@ -65,6 +66,7 @@ def _fc(in_channel, out_channel): ...@@ -65,6 +66,7 @@ def _fc(in_channel, out_channel):
weight = _weight_variable(weight_shape) weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ConvBNReLU(nn.Cell): class ConvBNReLU(nn.Cell):
""" """
Convolution/Depthwise fused with Batchnorm and ReLU block definition. Convolution/Depthwise fused with Batchnorm and ReLU block definition.
...@@ -82,6 +84,7 @@ class ConvBNReLU(nn.Cell): ...@@ -82,6 +84,7 @@ class ConvBNReLU(nn.Cell):
Examples: Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
""" """
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
...@@ -94,6 +97,7 @@ class ConvBNReLU(nn.Cell): ...@@ -94,6 +97,7 @@ class ConvBNReLU(nn.Cell):
output = self.features(x) output = self.features(x)
return output return output
class ResidualBlock(nn.Cell): class ResidualBlock(nn.Cell):
""" """
ResNet V1 residual block definition. ResNet V1 residual block definition.
...@@ -152,8 +156,8 @@ class ResidualBlock(nn.Cell): ...@@ -152,8 +156,8 @@ class ResidualBlock(nn.Cell):
stride=stride, stride=stride,
pad_mode='same', pad_mode='same',
padding=0) padding=0)
self.add = P.TensorAdd() self.add = nn.TensorAddQuant()
self.relu = nn.ReLUQuant() if _fake else P.ReLU() self.relu = P.ReLU()
def construct(self, x): def construct(self, x):
identity = x identity = x
...@@ -231,7 +235,9 @@ class ResNet(nn.Cell): ...@@ -231,7 +235,9 @@ class ResNet(nn.Cell):
self.mean = P.ReduceMean(keep_dims=True) self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.end_point = nn.Dense(out_channels[3], num_classes, has_bias=True) self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)
self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride): def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
""" """
...@@ -273,7 +279,7 @@ class ResNet(nn.Cell): ...@@ -273,7 +279,7 @@ class ResNet(nn.Cell):
out = self.mean(c5, (2, 3)) out = self.mean(c5, (2, 3))
out = self.flatten(out) out = self.flatten(out)
out = self.end_point(out) out = self.end_point(out)
out = self.output_fake(out)
return out return out
...@@ -297,6 +303,7 @@ def resnet50_quant(class_num=10): ...@@ -297,6 +303,7 @@ def resnet50_quant(class_num=10):
[1, 2, 2, 2], [1, 2, 2, 2],
class_num) class_num)
def resnet101_quant(class_num=1001): def resnet101_quant(class_num=1001):
""" """
Get ResNet101 neural network. Get ResNet101 neural network.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册