提交 81de331e 编写于 作者: G gaotingquan 提交者: Tingquan Gao

rename to re_parameterize() for re-parameterization nets

上级 d4d3d013
...@@ -216,7 +216,7 @@ class RepDepthwiseSeparable(TheseusLayer): ...@@ -216,7 +216,7 @@ class RepDepthwiseSeparable(TheseusLayer):
x = x + input_x x = x + input_x
return x return x
def rep(self): def re_parameterize(self):
if self.use_rep: if self.use_rep:
self.is_repped = True self.is_repped = True
kernel, bias = self._get_equivalent_kernel_bias() kernel, bias = self._get_equivalent_kernel_bias()
......
...@@ -172,7 +172,7 @@ class RepVGGBlock(nn.Layer): ...@@ -172,7 +172,7 @@ class RepVGGBlock(nn.Layer):
return self.nonlinearity( return self.nonlinearity(
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
def rep(self): def re_parameterize(self):
if not hasattr(self, 'rbr_reparam'): if not hasattr(self, 'rbr_reparam'):
self.rbr_reparam = nn.Conv2D( self.rbr_reparam = nn.Conv2D(
in_channels=self.in_channels, in_channels=self.in_channels,
......
...@@ -49,10 +49,10 @@ def quantize_model(config, model, mode="train"): ...@@ -49,10 +49,10 @@ def quantize_model(config, model, mode="train"):
if mode in ["infer", "export"]: if mode in ["infer", "export"]:
QUANT_CONFIG['activation_preprocess_type'] = None QUANT_CONFIG['activation_preprocess_type'] = None
# for rep nets, convert to reparameterized model first # for re-parameterization nets, convert to reparameterized model first
for layer in model.sublayers(): for layer in model.sublayers():
if hasattr(layer, "rep"): if hasattr(layer, "re_parameterize"):
layer.rep() layer.re_parameterize()
model.quanter = QAT(config=QUANT_CONFIG) model.quanter = QAT(config=QUANT_CONFIG)
model.quanter.quantize(model) model.quanter.quantize(model)
......
...@@ -560,10 +560,11 @@ class Engine(object): ...@@ -560,10 +560,11 @@ class Engine(object):
model.eval() model.eval()
# for rep nets # for re-parameterization nets
for layer in self.model.sublayers(): for layer in self.model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"): if hasattr(layer, "re_parameterize") and not getattr(layer,
layer.rep() "is_repped"):
layer.re_parameterize()
save_path = os.path.join(self.config["Global"]["save_inference_dir"], save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference") "inference")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册