未验证 提交 c9211e2a 编写于 作者: C chenjian 提交者: GitHub

add environment version filter for module download (#2122)

上级 141396e9
......@@ -12,16 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import requests
from typing import List
import requests
import paddlehub
from paddlehub.utils import platform
from paddlehub.utils.utils import convert_version
from paddlehub.utils.utils import Version
class ServerConnectionError(Exception):
def __init__(self, url: str):
self.url = url
......@@ -76,26 +79,28 @@ class ServerSource(object):
result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0:
return result['data']
results = []
for module_info in result['data']:
should_skip = False
if module_info['paddle_version']:
paddle_version_intervals = convert_version(module_info['paddle_version'])
for module_paddle_version in paddle_version_intervals:
if not Version(params['paddle_version']).match(module_paddle_version):
should_skip = True
if module_info['hub_version']:
hub_version_intervals = convert_version(module_info['hub_version'])
for module_hub_version in hub_version_intervals:
if not Version(params['hub_version']).match(module_hub_version):
should_skip = True
if should_skip:
continue
results.append(module_info)
if results:
return results
return None
def get_module_compat_info(self, name: str) -> dict:
'''Get the version compatibility information of the model.'''
def _convert_version(version: str) -> List:
result = []
# from [1.5.4, 2.0.0] -> 1.5.4,2.0.0
version = version.replace(' ', '')[1:-1]
version = version.split(',')
if version[0] != '-1.0.0':
result.append('>={}'.format(version[0]))
if len(version) > 1:
if version[1] != '99.0.0':
result.append('<={}'.format(version[1]))
return result
params = {'name': name}
result = self.request(path='info', params=params)
if result['status'] == 0 and len(result['data']) > 0:
......@@ -103,8 +108,8 @@ class ServerSource(object):
for _info in result['data']['info']:
infos[_info['version']] = {
'url': _info['url'],
'paddle_version': _convert_version(_info['paddle_version']),
'hub_version': _convert_version(_info['hub_version'])
'paddle_version': convert_version(_info['paddle_version']),
'hub_version': convert_version(_info['hub_version'])
}
return infos
......
......@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import contextlib
import hashlib
......@@ -25,7 +24,8 @@ import tempfile
import time
import traceback
import types
from typing import Generator, List
from typing import Generator
from typing import List
from urllib.parse import urlparse
import cv2
......@@ -410,8 +410,7 @@ def extract_melspectrogram(y,
logger.error('Failed to import librosa. Please check that librosa and numba are correctly installed.')
raise
s = librosa.stft(
y,
s = librosa.stft(y,
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
......@@ -425,3 +424,23 @@ def extract_melspectrogram(y,
db = librosa.power_to_db(mel, ref=ref, amin=amin, top_db=None)
db = db.transpose()
return db
def convert_version(version: str) -> List:
'''
Convert version string in modules dataset such as [1.5.4, 2.0.0] to >=1.5.4 and <=2.0.0
'''
result = []
# from [1.5.4, 2.0.0] -> 1.5.4,2.0.0
version = version.replace(' ', '')[1:-1]
version = version.split(',')
# Although -1.0.0 represents no least version limited,
# we should also consider when users write another minus number
if not version[0].startswith('-'):
result.append('>={}'.format(version[0]))
if len(version) > 1:
if version[1] != '99.0.0':
result.append('<={}'.format(version[1]))
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册