提交 6595325b 编写于 作者: R root 提交者: Tingquan Gao

perf: rm sys.path.append() & only in CLI will print_info() be call

上级 87822ba1
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ""))
from typing import Union, Generator from typing import Union, Generator
import argparse import argparse
import shutil import shutil
...@@ -32,16 +28,16 @@ from tqdm import tqdm ...@@ -32,16 +28,16 @@ from tqdm import tqdm
from prettytable import PrettyTable from prettytable import PrettyTable
import paddle import paddle
import ppcls.arch.backbone as backbone from .ppcls.arch import backbone
from ppcls.utils import logger from .ppcls.utils import logger
from deploy.python.predict_cls import ClsPredictor from .deploy.python.predict_cls import ClsPredictor
from deploy.utils.get_image_list import get_image_list from .deploy.utils.get_image_list import get_image_list
from deploy.utils import config from .deploy.utils import config
# for the PaddleClas Project to import # for the PaddleClas Project
import deploy from . import deploy
import ppcls from . import ppcls
# for building model with loading pretrained weights from backbone # for building model with loading pretrained weights from backbone
logger.init_logger() logger.init_logger()
...@@ -205,6 +201,7 @@ class InputModelError(Exception): ...@@ -205,6 +201,7 @@ class InputModelError(Exception):
def init_config(model_type, model_name, inference_model_dir, **kwargs): def init_config(model_type, model_name, inference_model_dir, **kwargs):
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml" cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
__dir__ = os.path.dirname(__file__)
cfg_path = os.path.join(__dir__, cfg_path) cfg_path = os.path.join(__dir__, cfg_path)
cfg = config.get_config(cfg_path, show=False) cfg = config.get_config(cfg_path, show=False)
...@@ -456,10 +453,6 @@ class PaddleClas(object): ...@@ -456,10 +453,6 @@ class PaddleClas(object):
"""PaddleClas. """PaddleClas.
""" """
if not os.environ.get('ppcls', False):
os.environ.setdefault('ppcls', 'True')
print_info()
def __init__(self, def __init__(self,
model_name: str=None, model_name: str=None,
inference_model_dir: str=None, inference_model_dir: str=None,
...@@ -474,6 +467,7 @@ class PaddleClas(object): ...@@ -474,6 +467,7 @@ class PaddleClas(object):
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5. topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
""" """
super().__init__() super().__init__()
self.model_type, inference_model_dir = self._check_input_model( self.model_type, inference_model_dir = self._check_input_model(
model_name, inference_model_dir) model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name, self._config = init_config(self.model_type, model_name,
...@@ -598,6 +592,7 @@ class PaddleClas(object): ...@@ -598,6 +592,7 @@ class PaddleClas(object):
def main(): def main():
"""Function API used for commad line. """Function API used for commad line.
""" """
print_info()
cfg = args_cfg() cfg = args_cfg()
clas_engine = PaddleClas(**cfg) clas_engine = PaddleClas(**cfg)
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True) res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册