__init__.py 891 字节
Newer Older
C
update  
ceci3 已提交
1 2
"""Based on https://github.com/mit-han-lab/gan-compression """

C
ceci3 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
import importlib
from .base_model import BaseModel


def find_model_using_name(model_name):
    model_filename = "models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)
    target_model_name = model_name.replace('_', '')
    model = None
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() and issubclass(cls,
                                                                    BaseModel):
            model = cls
    assert model is not None, "model {} is not right, please check it!".format(
        model_name)

    return model


def get_special_cfg(model):
    model_cls = find_model_using_name(model)
    return model_cls.add_special_cfgs


def create_model(cfg):
    model_cls = find_model_using_name(cfg.model)
    return model_cls(cfg)