提交 776d094c 编写于 作者: W Wei Luning

quant export fix up for atc tools

上级 a2b32356
......@@ -270,6 +270,7 @@ class Parameter(MetaTensor):
"Update the parameter by a Tensor."
if isinstance(self, Tensor):
# for Tensor same shape:
self.init_flag = False
return self.assign_value(data)
# create a new tensor
return Parameter(data, self.name, self.requires_grad)
......
......@@ -29,6 +29,7 @@ from ...common import dtype as mstype
from ...common.api import _executor
from ...nn.layer import quant
from ...ops import functional as F
from ...ops import operations as P
from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
......@@ -366,8 +367,6 @@ class ExportToQuantInferNetwork:
sqrt_mode = True
dequant_op = inner.Dequant(sqrt_mode)
# get op
op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv
if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell
elif hasattr(activation, "get_origin"):
......@@ -383,10 +382,17 @@ class ExportToQuantInferNetwork:
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
# apply the quant
weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type)
weight = quant_utils.weight2int(weight, scale_w, zp_w)
if bias is not None:
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32)
scale_deq = Tensor(scale_deq, mstype.float16)
# get op
if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul()
weight = np.transpose(weight)
else:
op_core = cell_core.conv
weight = Tensor(weight, self.data_type)
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
......
......@@ -50,7 +50,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
power=optimizer_cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
......
......@@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
power=optimizer_cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
......
......@@ -116,7 +116,7 @@ def run_pretrain():
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params},
{'order_params': params}]
......@@ -132,7 +132,7 @@ def run_pretrain():
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
......
......@@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
power=optimizer_cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
......
......@@ -137,7 +137,7 @@ def run_pretrain():
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params},
{'order_params': params}]
......@@ -153,7 +153,7 @@ def run_pretrain():
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
......
......@@ -99,7 +99,7 @@ def run_general_distill():
power=common_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
......
......@@ -107,7 +107,7 @@ def run_predistill():
power=optimizer_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
......@@ -165,7 +165,7 @@ def run_task_distill(ckpt_file):
power=optimizer_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册