提交 f04b5234 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: unable to export rep net

上级 03142ea3
...@@ -188,7 +188,7 @@ class RepDepthwiseSeparable(TheseusLayer): ...@@ -188,7 +188,7 @@ class RepDepthwiseSeparable(TheseusLayer):
def forward(self, x): def forward(self, x):
if self.use_rep: if self.use_rep:
input_x = x input_x = x
if not self.training: if self.is_repped:
x = self.act(self.dw_conv(x)) x = self.act(self.dw_conv(x))
else: else:
y = self.dw_conv_list[0](x) y = self.dw_conv_list[0](x)
...@@ -209,14 +209,12 @@ class RepDepthwiseSeparable(TheseusLayer): ...@@ -209,14 +209,12 @@ class RepDepthwiseSeparable(TheseusLayer):
x = x + input_x x = x + input_x
return x return x
def eval(self): def rep(self):
if self.use_rep: if self.use_rep:
self.is_repped = True
kernel, bias = self._get_equivalent_kernel_bias() kernel, bias = self._get_equivalent_kernel_bias()
self.dw_conv.weight.set_value(kernel) self.dw_conv.weight.set_value(kernel)
self.dw_conv.bias.set_value(bias) self.dw_conv.bias.set_value(bias)
self.training = False
for layer in self.sublayers():
layer.eval()
def _get_equivalent_kernel_bias(self): def _get_equivalent_kernel_bias(self):
kernel_sum = 0 kernel_sum = 0
......
...@@ -124,13 +124,7 @@ class RepVGGBlock(nn.Layer): ...@@ -124,13 +124,7 @@ class RepVGGBlock(nn.Layer):
groups=groups) groups=groups)
def forward(self, inputs): def forward(self, inputs):
if not self.training and not self.is_repped: if self.is_repped:
self.rep()
self.is_repped = True
if self.training and self.is_repped:
self.is_repped = False
if not self.training:
return self.nonlinearity(self.rbr_reparam(inputs)) return self.nonlinearity(self.rbr_reparam(inputs))
if self.rbr_identity is None: if self.rbr_identity is None:
...@@ -154,6 +148,7 @@ class RepVGGBlock(nn.Layer): ...@@ -154,6 +148,7 @@ class RepVGGBlock(nn.Layer):
kernel, bias = self.get_equivalent_kernel_bias() kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam.weight.set_value(kernel) self.rbr_reparam.weight.set_value(kernel)
self.rbr_reparam.bias.set_value(bias) self.rbr_reparam.bias.set_value(bias)
self.is_repped = True
def get_equivalent_kernel_bias(self): def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
......
...@@ -452,6 +452,12 @@ class Engine(object): ...@@ -452,6 +452,12 @@ class Engine(object):
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
model.eval() model.eval()
# for rep nets
for layer in self.model.sublayers():
if hasattr(layer, "rep"):
layer.rep()
save_path = os.path.join(self.config["Global"]["save_inference_dir"], save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference") "inference")
if model.quanter: if model.quanter:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册