gallery2fc.py 1.0 KB
Newer Older
W
weishengyu 已提交
1
import paddle
W
weishengyu 已提交
2
from ppcls.arch import build_model
W
weishengyu 已提交
3
from ppcls.utils.config import parse_config, parse_args
W
dbg  
weishengyu 已提交
4
from ppcls.utils.save_load import load_dygraph_pretrain
W
weishengyu 已提交
5
from ppcls.utils.logger import init_logger
W
dbg  
weishengyu 已提交
6 7 8 9 10


def load_feature_extractor(configs):
    arch = build_model(configs["Arch"])
    load_dygraph_pretrain(arch, configs["Global"]["pretrained_model"])
W
weishengyu 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23


def build_gallery_feature(feature_extractor):
    pass


def save_fuse_model(fuse_model):
    pass


class FuseModel(paddle.nn.Layer):
    def __init__(self, configs):
        super().__init__()
W
dbg  
weishengyu 已提交
24
        self.feature_extractor = load_feature_extractor(configs)
W
weishengyu 已提交
25 26 27 28 29 30 31 32 33 34 35
        self.gallery_layer = build_gallery_feature(self.feature_extractor)

    def forward(self, x):
        x = self.feature_model(x)
        x = self.gallery_layer(x)
        return x


def main():
    args = parse_args()
    configs = parse_config(args.config)
W
weishengyu 已提交
36
    init_logger(name='gallery2fc')
W
weishengyu 已提交
37 38 39 40 41 42
    fuse_model = FuseModel(configs)
    save_fuse_model(fuse_model)


if __name__ == '__main__':
    main()