vgg_variant.py 778 字节
Newer Older
B
Bin Lu 已提交
1 2
import paddle
from paddle.nn import Sigmoid
R
root 已提交
3
from ..legendary_models.vgg import VGG19
G
gaotingquan 已提交
4

B
Bin Lu 已提交
5
__all__ = ["VGG19Sigmoid"]
G
gaotingquan 已提交
6 7


B
Bin Lu 已提交
8 9
class SigmoidSuffix(paddle.nn.Layer):
    def __init__(self, origin_layer):
G
gaotingquan 已提交
10
        super().__init__()
B
Bin Lu 已提交
11 12
        self.origin_layer = origin_layer
        self.sigmoid = Sigmoid()
G
gaotingquan 已提交
13

B
Bin Lu 已提交
14
    def forward(self, input, res_dict=None, **kwargs):
B
Bin Lu 已提交
15 16 17
        x = self.origin_layer(input)
        x = self.sigmoid(x)
        return x
G
gaotingquan 已提交
18 19


B
Bin Lu 已提交
20
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
G
gaotingquan 已提交
21
    def replace_function(origin_layer, pattern):
B
Bin Lu 已提交
22 23
        new_layer = SigmoidSuffix(origin_layer)
        return new_layer
G
gaotingquan 已提交
24 25

    pattern = "fc2"
B
Bin Lu 已提交
26
    model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
G
gaotingquan 已提交
27
    model.upgrade_sublayer(pattern, replace_function)
B
Bin Lu 已提交
28
    return model