未验证 提交 b94adb0c 编写于 作者: B Bai Yifan 提交者: GitHub

fix hardsigmoid/hardswish (#607)

上级 84d54c2b
...@@ -21,7 +21,6 @@ import paddle ...@@ -21,7 +21,6 @@ import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.functional.activation import hard_sigmoid, hard_swish
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
...@@ -165,7 +164,7 @@ class MobileNetV3(nn.Layer): ...@@ -165,7 +164,7 @@ class MobileNetV3(nn.Layer):
x = self.pool(x) x = self.pool(x)
x = self.last_conv(x) x = self.last_conv(x)
x = hard_swish(x) x = paddle.nn.functional.activation.hardswish(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.out(x) x = self.out(x)
...@@ -214,7 +213,7 @@ class ConvBNLayer(nn.Layer): ...@@ -214,7 +213,7 @@ class ConvBNLayer(nn.Layer):
if self.act == "relu": if self.act == "relu":
x = F.relu(x) x = F.relu(x)
elif self.act == "hard_swish": elif self.act == "hard_swish":
x = hard_swish(x) x = paddle.nn.functional.activation.hardswish(x)
else: else:
print("The activation function is selected incorrectly.") print("The activation function is selected incorrectly.")
exit() exit()
...@@ -303,7 +302,8 @@ class SEModule(nn.Layer): ...@@ -303,7 +302,8 @@ class SEModule(nn.Layer):
outputs = self.conv1(outputs) outputs = self.conv1(outputs)
outputs = F.relu(outputs) outputs = F.relu(outputs)
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = hard_sigmoid(outputs) outputs = paddle.nn.functional.activation.hardsigmoid(
outputs, slope=0.2)
return paddle.multiply(x=inputs, y=outputs) return paddle.multiply(x=inputs, y=outputs)
......
...@@ -235,6 +235,8 @@ class QAT(object): ...@@ -235,6 +235,8 @@ class QAT(object):
quantizable_layer_type=self.config['quantizable_layer_type']) quantizable_layer_type=self.config['quantizable_layer_type'])
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
if hasattr(model, "_layers"):
model = model._layers
model.__init__() model.__init__()
self.imperative_qat.quantize(model) self.imperative_qat.quantize(model)
state_dict = model.state_dict() state_dict = model.state_dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册