提交 0f858449 编写于 作者: S sunyanfang01

fix the vis.py

上级 87a1222c
...@@ -13,15 +13,29 @@ ...@@ -13,15 +13,29 @@
#limitations under the License. #limitations under the License.
import os import os
import os.path as osp
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from ..as_data_reader.readers import preprocess_image from ..as_data_reader.readers import preprocess_image
root_path = os.environ['HOME'] def gen_user_home():
root_path = os.path.join(root_path, '.paddlex') if "HOME" in os.environ:
h_pre_models = os.path.join(root_path, "pre_models") home_path = os.environ["HOME"]
h_pre_models_kmeans = os.path.join(h_pre_models, "kmeans_model.pkl") if os.path.exists(home_path) and os.path.isdir(home_path):
return home_path
return os.path.expanduser('~')
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
if not osp.exists(h_pre_models):
if not osp.exists(root_path):
os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path)
h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
def paddle_get_fc_weights(var_name="fc_0.w_0"): def paddle_get_fc_weights(var_name="fc_0.w_0"):
......
...@@ -21,14 +21,7 @@ import paddlex as pdx ...@@ -21,14 +21,7 @@ import paddlex as pdx
from .interpretation_predict import interpretation_predict from .interpretation_predict import interpretation_predict
from .core.interpretation import Interpretation from .core.interpretation import Interpretation
from .core.normlime_base import precompute_normlime_weights from .core.normlime_base import precompute_normlime_weights
from .core._session_preparation import gen_user_home
def gen_user_home():
if "HOME" in os.environ:
home_path = os.environ["HOME"]
if os.path.exists(home_path) and os.path.isdir(home_path):
return home_path
return os.path.expanduser('~')
def visualize(img_file, def visualize(img_file,
model, model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册