diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index f33904b6e40348256816d7ab3be5fd72c3b4f061..eaf0a782ff454139ef427d481e733fff93841039 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -23,6 +23,13 @@ from .core.interpretation import Interpretation from .core.normlime_base import precompute_normlime_weights +def gen_user_home(): + if "HOME" in os.environ: + home_path = os.environ["HOME"] + if os.path.exists(home_path) and os.path.isdir(home_path): + return home_path + return os.path.expanduser('~') + def visualize(img_file, model, dataset=None, @@ -109,7 +116,7 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5 labels_name = None if dataset is not None: labels_name = dataset.labels - root_path = os.environ['HOME'] + root_path = gen_user_home() root_path = osp.join(root_path, '.paddlex') pre_models_path = osp.join(root_path, "pre_models") if not osp.exists(pre_models_path):