未验证 提交 c9211e2a 编写于 作者: C chenjian 提交者: GitHub

add environment version filter for module download (#2122)

上级 141396e9
...@@ -12,16 +12,19 @@ ...@@ -12,16 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import requests
from typing import List from typing import List
import requests
import paddlehub import paddlehub
from paddlehub.utils import platform from paddlehub.utils import platform
from paddlehub.utils.utils import convert_version
from paddlehub.utils.utils import Version
class ServerConnectionError(Exception): class ServerConnectionError(Exception):
def __init__(self, url: str): def __init__(self, url: str):
self.url = url self.url = url
...@@ -76,26 +79,28 @@ class ServerSource(object): ...@@ -76,26 +79,28 @@ class ServerSource(object):
result = self.request(path='search', params=params) result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0: if result['status'] == 0 and len(result['data']) > 0:
return result['data'] results = []
for module_info in result['data']:
should_skip = False
if module_info['paddle_version']:
paddle_version_intervals = convert_version(module_info['paddle_version'])
for module_paddle_version in paddle_version_intervals:
if not Version(params['paddle_version']).match(module_paddle_version):
should_skip = True
if module_info['hub_version']:
hub_version_intervals = convert_version(module_info['hub_version'])
for module_hub_version in hub_version_intervals:
if not Version(params['hub_version']).match(module_hub_version):
should_skip = True
if should_skip:
continue
results.append(module_info)
if results:
return results
return None return None
def get_module_compat_info(self, name: str) -> dict: def get_module_compat_info(self, name: str) -> dict:
'''Get the version compatibility information of the model.''' '''Get the version compatibility information of the model.'''
def _convert_version(version: str) -> List:
result = []
# from [1.5.4, 2.0.0] -> 1.5.4,2.0.0
version = version.replace(' ', '')[1:-1]
version = version.split(',')
if version[0] != '-1.0.0':
result.append('>={}'.format(version[0]))
if len(version) > 1:
if version[1] != '99.0.0':
result.append('<={}'.format(version[1]))
return result
params = {'name': name} params = {'name': name}
result = self.request(path='info', params=params) result = self.request(path='info', params=params)
if result['status'] == 0 and len(result['data']) > 0: if result['status'] == 0 and len(result['data']) > 0:
...@@ -103,8 +108,8 @@ class ServerSource(object): ...@@ -103,8 +108,8 @@ class ServerSource(object):
for _info in result['data']['info']: for _info in result['data']['info']:
infos[_info['version']] = { infos[_info['version']] = {
'url': _info['url'], 'url': _info['url'],
'paddle_version': _convert_version(_info['paddle_version']), 'paddle_version': convert_version(_info['paddle_version']),
'hub_version': _convert_version(_info['hub_version']) 'hub_version': convert_version(_info['hub_version'])
} }
return infos return infos
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64 import base64
import contextlib import contextlib
import hashlib import hashlib
...@@ -25,7 +24,8 @@ import tempfile ...@@ -25,7 +24,8 @@ import tempfile
import time import time
import traceback import traceback
import types import types
from typing import Generator, List from typing import Generator
from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
import cv2 import cv2
...@@ -410,14 +410,13 @@ def extract_melspectrogram(y, ...@@ -410,14 +410,13 @@ def extract_melspectrogram(y,
logger.error('Failed to import librosa. Please check that librosa and numba are correctly installed.') logger.error('Failed to import librosa. Please check that librosa and numba are correctly installed.')
raise raise
s = librosa.stft( s = librosa.stft(y,
y, n_fft=window_size,
n_fft=window_size, hop_length=hop_size,
hop_length=hop_size, win_length=window_size,
win_length=window_size, window=window,
window=window, center=center,
center=center, pad_mode=pad_mode)
pad_mode=pad_mode)
power = np.abs(s)**2 power = np.abs(s)**2
melW = librosa.filters.mel(sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax) melW = librosa.filters.mel(sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax)
...@@ -425,3 +424,23 @@ def extract_melspectrogram(y, ...@@ -425,3 +424,23 @@ def extract_melspectrogram(y,
db = librosa.power_to_db(mel, ref=ref, amin=amin, top_db=None) db = librosa.power_to_db(mel, ref=ref, amin=amin, top_db=None)
db = db.transpose() db = db.transpose()
return db return db
def convert_version(version: str) -> List:
'''
Convert version string in modules dataset such as [1.5.4, 2.0.0] to >=1.5.4 and <=2.0.0
'''
result = []
# from [1.5.4, 2.0.0] -> 1.5.4,2.0.0
version = version.replace(' ', '')[1:-1]
version = version.split(',')
# Although -1.0.0 represents no least version limited,
# we should also consider when users write another minus number
if not version[0].startswith('-'):
result.append('>={}'.format(version[0]))
if len(version) > 1:
if version[1] != '99.0.0':
result.append('<={}'.format(version[1]))
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册