提交 e9ee59c7 编写于 作者: C chenzupeng

add perchannel quant train

上级 4bbd4414
......@@ -47,7 +47,6 @@ Dataset used: imagenet
├── eval.py
```
Notation: Current hyperparameters only test on 4 cards while training, if want to use 8 cards for training, should change parameters like learning rate in 'src/config.py'.
## Training process
......
......@@ -22,10 +22,10 @@ config_ascend = ed({
"image_height": 224,
"image_width": 224,
"batch_size": 192,
"epoch_size": 40,
"epoch_size": 60,
"start_epoch": 200,
"warmup_epochs": 1,
"lr": 0.15,
"lr": 0.3,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
......
......@@ -20,7 +20,8 @@ from mindspore.ops.operations import TensorAdd
__all__ = ['mobilenet_v2_quant']
_ema_decay = 0.999
_symmetric = False
_symmetric = True
_per_channel = True
def _make_divisible(v, divisor, min_value=None):
......@@ -77,10 +78,10 @@ class ConvBNReLU(nn.Cell):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = nn.Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups)
group=groups, per_channel=_per_channel, symmetric=_symmetric)
layers = [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers)
self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric, min_init=0)
self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, min_init=0)
def construct(self, x):
output = self.features(x)
......@@ -119,12 +120,13 @@ class InvertedResidual(nn.Cell):
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1),
nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1,
per_channel=_per_channel, symmetric=_symmetric),
nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
])
self.conv = nn.SequentialCell(layers)
self.add = TensorAdd()
self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
def construct(self, x):
identity = x
......@@ -175,7 +177,7 @@ class MobileNetV2Quant(nn.Cell):
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in self.cfgs:
......@@ -189,8 +191,12 @@ class MobileNetV2Quant(nn.Cell):
# make it nn.CellList
self.features = nn.SequentialCell(features)
# mobilenet head
head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)])
head = ([GlobalAvgPooling(),
nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2),
nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)])
self.head = nn.SequentialCell(head)
def construct(self, x):
......
......@@ -51,7 +51,7 @@ Parameters for both training and inference can be set in config.py.
"loss_scale": 1024, # loss scale
"momentum": 0.9, # momentum optimizer
"weight_decay": 1e-4, # weight decay
"epoch_size": 110, # only valid for taining, which is always 1 for inference
"epoch_size": 120, # only valid for taining, which is always 1 for inference
"pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
......@@ -65,7 +65,7 @@ Parameters for both training and inference can be set in config.py.
"label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0, # initial learning rate
"lr_max": 0.1, # maximum learning rate
"lr_max": 0.005, # maximum learning rate
```
## Running the example
......
......@@ -22,6 +22,7 @@ from mindspore.nn import FakeQuantWithMinMax, Conv2dBatchNormQuant
_ema_decay = 0.999
_symmetric = False
_fake = True
_per_channel = True
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
......@@ -85,7 +86,7 @@ class ConvBNReLU(nn.Cell):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake)
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
layers = [conv, nn.ReLUQuant()] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers)
......@@ -119,10 +120,13 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1, stride=1,
pad_mode='same', padding=0)
......@@ -134,18 +138,22 @@ class ResidualBlock(nn.Cell):
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1, stride=stride,
pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
kernel_size=1,
stride=stride,
pad_mode='same',
padding=0)
self.add = P.TensorAdd()
self.fake = FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
self.relu = nn.ReLUQuant() if _fake else P.ReLU()
def construct(self, x):
identity = x
......@@ -157,9 +165,7 @@ class ResidualBlock(nn.Cell):
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = P.ReLU()(out)
if _fake:
out = self.fake(out)
out = self.relu(out)
return out
......
......@@ -23,7 +23,7 @@ config = ed({
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 110,
"epoch_size": 120,
"pretrained_epoch_size": 90,
"buffer_size": 1000,
"image_height": 224,
......@@ -37,6 +37,6 @@ config = ed({
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.1
"lr_max": 0.005
})
......@@ -91,11 +91,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
......@@ -122,7 +126,7 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
......@@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
......
......@@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
......@@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册