未验证 提交 7bba9f8d 编写于 作者: L LielinJiang 提交者: GitHub

Fix import error (#127)

* fix some bug
上级 a92ade4a
from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add
\ No newline at end of file
...@@ -29,7 +29,7 @@ class PerceptualVGG(nn.Layer): ...@@ -29,7 +29,7 @@ class PerceptualVGG(nn.Layer):
layer_name_list, layer_name_list,
vgg_type='vgg19', vgg_type='vgg19',
use_input_norm=True, use_input_norm=True,
pretrained_url='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams' pretrained_url='https://paddlegan.bj.bcebos.com/models/vgg19.pdparams'
): ):
super(PerceptualVGG, self).__init__() super(PerceptualVGG, self).__init__()
......
...@@ -18,20 +18,17 @@ from __future__ import print_function ...@@ -18,20 +18,17 @@ from __future__ import print_function
import os import os
import sys import sys
import os.path as osp import time
import shutil import shutil
import requests
import hashlib import hashlib
import tarfile import tarfile
import zipfile import zipfile
import time import requests
import os.path as osp
from tqdm import tqdm from tqdm import tqdm
import logging
from .logger import get_logger from .logger import get_logger
logger = get_logger('ppgan')
PPGAN_HOME = os.path.expanduser("~/.cache/ppgan/") PPGAN_HOME = os.path.expanduser("~/.cache/ppgan/")
DOWNLOAD_RETRY_LIMIT = 3 DOWNLOAD_RETRY_LIMIT = 3
...@@ -75,6 +72,7 @@ def get_path_from_url(url, md5sum=None, check_exist=True): ...@@ -75,6 +72,7 @@ def get_path_from_url(url, md5sum=None, check_exist=True):
fullpath = _map_path(url, root_dir) fullpath = _map_path(url, root_dir)
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger = get_logger('ppgan')
logger.info("Found {}".format(fullpath)) logger.info("Found {}".format(fullpath))
else: else:
if ParallelEnv().local_rank == 0: if ParallelEnv().local_rank == 0:
...@@ -111,6 +109,7 @@ def _download(url, path, md5sum=None): ...@@ -111,6 +109,7 @@ def _download(url, path, md5sum=None):
raise RuntimeError("Download from {} failed. " raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url)) "Retry limit reached".format(url))
logger = get_logger('ppgan')
logger.info("Downloading {} from {} to {}".format(fname, url, fullname)) logger.info("Downloading {} from {} to {}".format(fname, url, fullname))
req = requests.get(url, stream=True) req = requests.get(url, stream=True)
...@@ -141,6 +140,7 @@ def _md5check(fullname, md5sum=None): ...@@ -141,6 +140,7 @@ def _md5check(fullname, md5sum=None):
if md5sum is None: if md5sum is None:
return True return True
logger = get_logger('ppgan')
logger.info("File {} md5 checking...".format(fullname)) logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5() md5 = hashlib.md5()
with open(fullname, 'rb') as f: with open(fullname, 'rb') as f:
...@@ -159,6 +159,8 @@ def _decompress(fname): ...@@ -159,6 +159,8 @@ def _decompress(fname):
""" """
Decompress for zip and tar file Decompress for zip and tar file
""" """
logger = get_logger('ppgan')
logger.info("Decompressing {}...".format(fname)) logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted, # For protecting decompressing interupted,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册