From 9628e76c79bbf2125a092b815a8b7a90a55f5f1f Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Mon, 23 Aug 2021 19:06:34 +0800 Subject: [PATCH] Create vgg_variant.py --- .../backbone/variant_models/vgg_variant.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 ppcls/arch/backbone/variant_models/vgg_variant.py diff --git a/ppcls/arch/backbone/variant_models/vgg_variant.py b/ppcls/arch/backbone/variant_models/vgg_variant.py new file mode 100644 index 00000000..ac6c0138 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/vgg_variant.py @@ -0,0 +1,28 @@ +import paddle +from paddle.nn import Sigmoid +from ppcls.arch.backbone.legendary_models.vgg import VGG19 + +__all__ = ["VGG19Sigmoid"] + + +class SigmoidSuffix(paddle.nn.Layer): + def __init__(self, origin_layer): + super(SigmoidSuffix, self).__init__() + self.origin_layer = origin_layer + self.sigmoid = Sigmoid() + + def forward(self, *input, res_dict=None, **kwargs): + x = self.origin_layer(input) + x = self.sigmoid(x) + return x + + +def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs): + def replace_function(origin_layer): + new_layer = SigmoidSuffix(origin_layer) + return new_layer + + match_re = "linear_2" + model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs) + model.replace_sub(match_re, replace_function, True) + return model -- GitLab