未验证 提交 d8eef4e4 编写于 作者: L LielinJiang 提交者: GitHub

Remove dependence of scipy (#29121)

* lazy import for scipy

* rm unused check
上级 a069e1ca
...@@ -13,9 +13,6 @@ ...@@ -13,9 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
from paddle.check_import_scipy import check_import_scipy
check_import_scipy(os.name)
try: try:
from paddle.version import full_version as __version__ from paddle.version import full_version as __version__
......
...@@ -35,7 +35,6 @@ import itertools ...@@ -35,7 +35,6 @@ import itertools
import functools import functools
from .common import download from .common import download
import tarfile import tarfile
import scipy.io as scio
from paddle.dataset.image import * from paddle.dataset.image import *
from paddle.reader import map_readers, xmap_readers from paddle.reader import map_readers, xmap_readers
from paddle import compat as cpt from paddle import compat as cpt
...@@ -45,6 +44,7 @@ import numpy as np ...@@ -45,6 +44,7 @@ import numpy as np
from multiprocessing import cpu_count from multiprocessing import cpu_count
import six import six
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
from paddle.utils import try_import
__all__ = ['train', 'test', 'valid'] __all__ = ['train', 'test', 'valid']
DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz' DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
...@@ -108,8 +108,11 @@ def reader_creator(data_file, ...@@ -108,8 +108,11 @@ def reader_creator(data_file,
:return: data reader :return: data reader
:rtype: callable :rtype: callable
''' '''
scio = try_import('scipy.io')
labels = scio.loadmat(label_file)['labels'][0] labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[dataset_name][0] indexes = scio.loadmat(setid_file)[dataset_name][0]
img2label = {} img2label = {}
for i in indexes: for i in indexes:
img = "jpg/image_%05d.jpg" % i img = "jpg/image_%05d.jpg" % i
......
...@@ -19,6 +19,10 @@ import importlib ...@@ -19,6 +19,10 @@ import importlib
def try_import(module_name): def try_import(module_name):
"""Try importing a module, with an informative error message on failure.""" """Try importing a module, with an informative error message on failure."""
install_name = module_name install_name = module_name
if module_name.find('.') > -1:
install_name = module_name.split('.')[0]
if module_name == 'cv2': if module_name == 'cv2':
install_name = 'opencv-python' install_name = 'opencv-python'
...@@ -28,7 +32,7 @@ def try_import(module_name): ...@@ -28,7 +32,7 @@ def try_import(module_name):
except ImportError: except ImportError:
err_msg = ( err_msg = (
"Failed importing {}. This likely means that some paddle modules " "Failed importing {}. This likely means that some paddle modules "
"requires additional dependencies that have to be " "require additional dependencies that have to be "
"manually installed (usually with `pip install {}`). ").format( "manually installed (usually with `pip install {}`). ").format(
module_name, install_name) module_name, install_name)
raise ImportError(err_msg) raise ImportError(err_msg)
...@@ -18,11 +18,11 @@ import os ...@@ -18,11 +18,11 @@ import os
import io import io
import tarfile import tarfile
import numpy as np import numpy as np
import scipy.io as scio
from PIL import Image from PIL import Image
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.utils import try_import
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
__all__ = ["Flowers"] __all__ = ["Flowers"]
...@@ -127,6 +127,8 @@ class Flowers(Dataset): ...@@ -127,6 +127,8 @@ class Flowers(Dataset):
for ele in self.data_tar.getmembers(): for ele in self.data_tar.getmembers():
self.name2mem[ele.name] = ele self.name2mem[ele.name] = ele
scio = try_import('scipy.io')
self.labels = scio.loadmat(self.label_file)['labels'][0] self.labels = scio.loadmat(self.label_file)['labels'][0]
self.indexes = scio.loadmat(self.setid_file)[self.flag][0] self.indexes = scio.loadmat(self.setid_file)[self.flag][0]
......
...@@ -4,9 +4,6 @@ numpy>=1.13 ; python_version>="3.5" and platform_system != "Windows" ...@@ -4,9 +4,6 @@ numpy>=1.13 ; python_version>="3.5" and platform_system != "Windows"
numpy>=1.13, <=1.19.3 ; python_version>="3.5" and platform_system == "Windows" numpy>=1.13, <=1.19.3 ; python_version>="3.5" and platform_system == "Windows"
protobuf>=3.1.0 protobuf>=3.1.0
gast==0.3.3 gast==0.3.3
scipy>=0.19.0, <=1.2.1 ; python_version<"3.5"
scipy<=1.3.1 ; python_version=="3.5"
scipy ; python_version>"3.5"
rarfile rarfile
Pillow Pillow
six six
......
...@@ -6,4 +6,7 @@ gym ...@@ -6,4 +6,7 @@ gym
opencv-python<=4.2.0.32 opencv-python<=4.2.0.32
visualdl ; python_version>="3.5" visualdl ; python_version>="3.5"
paddle2onnx>=0.4 paddle2onnx>=0.4
scipy>=0.19.0, <=1.2.1 ; python_version<"3.5"
scipy<=1.3.1 ; python_version=="3.5"
scipy ; python_version>"3.5"
prettytable prettytable
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册