diff --git a/ppcls/utils/model_zoo.py b/ppcls/utils/model_zoo.py index 8a154a90de99432c44ec004c08f0b5cb2a61d26e..dd65e921f3e258b75c911599ee234359f3fd7c51 100644 --- a/ppcls/utils/model_zoo.py +++ b/ppcls/utils/model_zoo.py @@ -24,7 +24,6 @@ import tqdm import zipfile from ppcls.modeling import similar_architectures -from ppcls.utils.check import check_architecture from ppcls.utils import logger __all__ = ['get'] @@ -168,11 +167,16 @@ def _decompress(fname): os.remove(fname) +def _get_pretrained(): + with open('./ppcls/utils/pretrained.list') as flist: + pretrained = [line.strip() for line in flist] + return pretrained + + def _check_pretrained_name(architecture): assert isinstance(architecture, str), \ - ("the type of architecture({}) should be str". format(architecture)) - with open('./configs/pretrained.list') as flist: - pretrained = [line.strip() for line in flist] + ("the type of architecture({}) should be str". format(architecture)) + pretrained = _get_pretrained() similar_names = similar_architectures(architecture, pretrained) model_list = ', '.join(similar_names) err = "{} is not exist! Maybe you want: [{}]" \ @@ -181,6 +185,14 @@ def _check_pretrained_name(architecture): raise ModelNameError(err) +def list_models(): + pretrained = _get_pretrained() + msg = "All avialable pretrained models are as follows: {}".format( + pretrained) + logger.info(msg) + return + + def get(architecture, path, decompress=True): """ Get the pretrained model. @@ -188,5 +200,6 @@ def get(architecture, path, decompress=True): _check_pretrained_name(architecture) url = _get_url(architecture) fname = _download(url, path) - if decompress: _decompress(fname) + if decompress: + _decompress(fname) logger.info("download {} finished ".format(fname)) diff --git a/configs/pretrained.list b/ppcls/utils/pretrained.list similarity index 100% rename from configs/pretrained.list rename to ppcls/utils/pretrained.list diff --git a/tools/download.py b/tools/download.py index 157bebccdaa26bbb40a46a5bf065b4e092527af7..d9fe1a8ee04a14b31cf4917e60f9348ae51b8d20 100644 --- a/tools/download.py +++ b/tools/download.py @@ -1,18 +1,17 @@ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import sys import argparse from ppcls import model_zoo @@ -26,6 +25,7 @@ def parse_args(): parser.add_argument('-a', '--architecture', type=str, default='ResNet50') parser.add_argument('-p', '--path', type=str, default='./pretrained/') parser.add_argument('-d', '--decompress', type=str2bool, default=True) + parser.add_argument('-l', '--list', type=str2bool, default=False) args = parser.parse_args() return args @@ -33,7 +33,10 @@ def parse_args(): def main(): args = parse_args() - model_zoo.get(args.architecture, args.path, args.decompress) + if args.list: + model_zoo.list_models() + else: + model_zoo.get(args.architecture, args.path, args.decompress) if __name__ == '__main__':