diff --git a/hubconf.py b/hubconf.py index 2ebfd23de571b12b88b534a162bdfec7f87a0563..6e5807619621ec65903d68645bb75e3eafdf4147 100644 --- a/hubconf.py +++ b/hubconf.py @@ -20,17 +20,20 @@ import sys class _SysPathG(object): + def __init__(self, path): + self.path = path + def __enter__(self, ): - sys.path.insert(0, - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - 'ppcls', 'modeling')) + sys.path.insert(0, self.path) def __exit__(self, type, value, traceback): - sys.path.pop(0) + _p = sys.path.pop(0) + assert _p == self.path, 'make sure pop {} correctly.'.format(self.path) -with _SysPathG(): +with _SysPathG( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'ppcls', 'modeling')): import architectures def _load_pretrained_parameters(model, name):