提交 6595325b 编写于 作者: R root 提交者: Tingquan Gao

perf: rm sys.path.append() & only in CLI will print_info() be call

上级 87822ba1
......@@ -13,10 +13,6 @@
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ""))
from typing import Union, Generator
import argparse
import shutil
......@@ -32,16 +28,16 @@ from tqdm import tqdm
from prettytable import PrettyTable
import paddle
import ppcls.arch.backbone as backbone
from ppcls.utils import logger
from .ppcls.arch import backbone
from .ppcls.utils import logger
from deploy.python.predict_cls import ClsPredictor
from deploy.utils.get_image_list import get_image_list
from deploy.utils import config
from .deploy.python.predict_cls import ClsPredictor
from .deploy.utils.get_image_list import get_image_list
from .deploy.utils import config
# for the PaddleClas Project to import
import deploy
import ppcls
# for the PaddleClas Project
from . import deploy
from . import ppcls
# for building model with loading pretrained weights from backbone
logger.init_logger()
......@@ -205,6 +201,7 @@ class InputModelError(Exception):
def init_config(model_type, model_name, inference_model_dir, **kwargs):
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
__dir__ = os.path.dirname(__file__)
cfg_path = os.path.join(__dir__, cfg_path)
cfg = config.get_config(cfg_path, show=False)
......@@ -456,10 +453,6 @@ class PaddleClas(object):
"""PaddleClas.
"""
if not os.environ.get('ppcls', False):
os.environ.setdefault('ppcls', 'True')
print_info()
def __init__(self,
model_name: str=None,
inference_model_dir: str=None,
......@@ -474,6 +467,7 @@ class PaddleClas(object):
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
"""
super().__init__()
self.model_type, inference_model_dir = self._check_input_model(
model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name,
......@@ -598,6 +592,7 @@ class PaddleClas(object):
def main():
"""Function API used for commad line.
"""
print_info()
cfg = args_cfg()
clas_engine = PaddleClas(**cfg)
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册