未验证 提交 c3bd88a2 编写于 作者: G gaotingquan

circuitous rep

上级 32adae6c
......@@ -187,12 +187,6 @@ class RepDepthwiseSeparable(TheseusLayer):
def forward(self, x):
if self.use_rep:
if not self.training and not self.is_repped:
self.rep()
self.is_repped = True
if self.training and self.is_repped:
self.is_repped = False
input_x = x
if not self.training:
x = self.act(self.dw_conv(x))
......@@ -215,10 +209,14 @@ class RepDepthwiseSeparable(TheseusLayer):
x = x + input_x
return x
def rep(self):
kernel, bias = self._get_equivalent_kernel_bias()
self.dw_conv.weight.set_value(kernel)
self.dw_conv.bias.set_value(bias)
def eval(self):
if self.use_rep:
kernel, bias = self._get_equivalent_kernel_bias()
self.dw_conv.weight.set_value(kernel)
self.dw_conv.bias.set_value(bias)
self.training = False
for layer in self.sublayers():
layer.eval()
def _get_equivalent_kernel_bias(self):
kernel_sum = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册