Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
bc893c19
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bc893c19
编写于
9月 20, 2022
作者:
H
Hui Zhang
提交者:
GitHub
9月 20, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2417 from SmileGoat/update_audio_api_in_apps
format audio dir
上级
79f14319
d94996f2
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
523 addition
and
315 deletion
+523
-315
audio/paddleaudio/__init__.py
audio/paddleaudio/__init__.py
+2
-1
audio/paddleaudio/backends/__init__.py
audio/paddleaudio/backends/__init__.py
+3
-4
audio/paddleaudio/backends/soundfile_backend.py
audio/paddleaudio/backends/soundfile_backend.py
+49
-33
audio/paddleaudio/backends/sox_io_backend.py
audio/paddleaudio/backends/sox_io_backend.py
+41
-36
audio/paddleaudio/backends/utils.py
audio/paddleaudio/backends/utils.py
+4
-2
audio/paddleaudio/utils/__init__.py
audio/paddleaudio/utils/__init__.py
+2
-2
audio/paddleaudio/utils/tensor_utils.py
audio/paddleaudio/utils/tensor_utils.py
+192
-0
audio/tests/backends/soundfile/common.py
audio/tests/backends/soundfile/common.py
+3
-3
audio/tests/backends/soundfile/info_test.py
audio/tests/backends/soundfile/info_test.py
+40
-39
audio/tests/backends/soundfile/load_test.py
audio/tests/backends/soundfile/load_test.py
+90
-95
audio/tests/backends/soundfile/save_test.py
audio/tests/backends/soundfile/save_test.py
+69
-67
audio/tests/common_utils/__init__.py
audio/tests/common_utils/__init__.py
+9
-14
audio/tests/common_utils/wav_utils.py
audio/tests/common_utils/wav_utils.py
+19
-19
未找到文件。
audio/paddleaudio/__init__.py
浏览文件 @
bc893c19
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.
import
backends
from
.
import
compliance
from
.
import
compliance
from
.
import
datasets
from
.
import
datasets
from
.
import
features
from
.
import
features
...
@@ -18,4 +19,4 @@ from . import functional
...
@@ -18,4 +19,4 @@ from . import functional
from
.
import
io
from
.
import
io
from
.
import
metric
from
.
import
metric
from
.
import
sox_effects
from
.
import
sox_effects
from
.
import
backend
s
from
.
import
util
s
audio/paddleaudio/backends/__init__.py
浏览文件 @
bc893c19
...
@@ -11,16 +11,15 @@
...
@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.
import
utils
from
.soundfile_backend
import
depth_convert
from
.soundfile_backend
import
depth_convert
from
.soundfile_backend
import
soundfile_load
from
.soundfile_backend
import
normalize
from
.soundfile_backend
import
normalize
from
.soundfile_backend
import
resample
from
.soundfile_backend
import
resample
from
.soundfile_backend
import
soundfile_load
from
.soundfile_backend
import
soundfile_save
from
.soundfile_backend
import
soundfile_save
from
.soundfile_backend
import
to_mono
from
.soundfile_backend
import
to_mono
from
.
import
utils
from
.utils
import
get_audio_backend
from
.utils
import
get_audio_backend
from
.utils
import
list_audio_backends
from
.utils
import
list_audio_backends
from
.utils
import
set_audio_backend
from
.utils
import
set_audio_backend
utils
.
_init_audio_backend
()
utils
.
_init_audio_backend
()
\ No newline at end of file
audio/paddleaudio/backends/soundfile_backend.py
浏览文件 @
bc893c19
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
warnings
import
warnings
from
typing
import
Optional
from
typing
import
Optional
...
@@ -204,6 +203,7 @@ def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
...
@@ -204,6 +203,7 @@ def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
wavfile
.
write
(
file
,
sr
,
y_out
)
wavfile
.
write
(
file
,
sr
,
y_out
)
def
soundfile_load
(
def
soundfile_load
(
file
:
os
.
PathLike
,
file
:
os
.
PathLike
,
sr
:
Optional
[
int
]
=
None
,
sr
:
Optional
[
int
]
=
None
,
...
@@ -256,9 +256,13 @@ def soundfile_load(
...
@@ -256,9 +256,13 @@ def soundfile_load(
y
=
depth_convert
(
y
,
dtype
)
y
=
depth_convert
(
y
,
dtype
)
return
y
,
r
return
y
,
r
#the code below token form: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py with modificaion.
#the code below token form: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py with modificaion.
def
_get_subtype_for_wav
(
dtype
:
paddle
.
dtype
,
encoding
:
str
,
bits_per_sample
:
int
):
def
_get_subtype_for_wav
(
dtype
:
paddle
.
dtype
,
encoding
:
str
,
bits_per_sample
:
int
):
if
not
encoding
:
if
not
encoding
:
if
not
bits_per_sample
:
if
not
bits_per_sample
:
subtype
=
{
subtype
=
{
...
@@ -315,7 +319,10 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
...
@@ -315,7 +319,10 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
raise
ValueError
(
f
"sph does not support
{
encoding
}
."
)
raise
ValueError
(
f
"sph does not support
{
encoding
}
."
)
def
_get_subtype
(
dtype
:
paddle
.
dtype
,
format
:
str
,
encoding
:
str
,
bits_per_sample
:
int
):
def
_get_subtype
(
dtype
:
paddle
.
dtype
,
format
:
str
,
encoding
:
str
,
bits_per_sample
:
int
):
if
format
==
"wav"
:
if
format
==
"wav"
:
return
_get_subtype_for_wav
(
dtype
,
encoding
,
bits_per_sample
)
return
_get_subtype_for_wav
(
dtype
,
encoding
,
bits_per_sample
)
if
format
==
"flac"
:
if
format
==
"flac"
:
...
@@ -328,7 +335,8 @@ def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sampl
...
@@ -328,7 +335,8 @@ def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sampl
return
"PCM_S8"
if
bits_per_sample
==
8
else
f
"PCM_
{
bits_per_sample
}
"
return
"PCM_S8"
if
bits_per_sample
==
8
else
f
"PCM_
{
bits_per_sample
}
"
if
format
in
(
"ogg"
,
"vorbis"
):
if
format
in
(
"ogg"
,
"vorbis"
):
if
encoding
or
bits_per_sample
:
if
encoding
or
bits_per_sample
:
raise
ValueError
(
"ogg/vorbis does not support encoding/bits_per_sample."
)
raise
ValueError
(
"ogg/vorbis does not support encoding/bits_per_sample."
)
return
"VORBIS"
return
"VORBIS"
if
format
==
"sph"
:
if
format
==
"sph"
:
return
_get_subtype_for_sphere
(
encoding
,
bits_per_sample
)
return
_get_subtype_for_sphere
(
encoding
,
bits_per_sample
)
...
@@ -336,16 +344,16 @@ def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sampl
...
@@ -336,16 +344,16 @@ def _get_subtype(dtype: paddle.dtype, format: str, encoding: str, bits_per_sampl
return
"PCM_16"
return
"PCM_16"
raise
ValueError
(
f
"Unsupported format:
{
format
}
"
)
raise
ValueError
(
f
"Unsupported format:
{
format
}
"
)
def
save
(
def
save
(
filepath
:
str
,
filepath
:
str
,
src
:
paddle
.
Tensor
,
src
:
paddle
.
Tensor
,
sample_rate
:
int
,
sample_rate
:
int
,
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
):
"""Save audio data to file.
"""Save audio data to file.
Note:
Note:
...
@@ -441,11 +449,11 @@ def save(
...
@@ -441,11 +449,11 @@ def save(
if
compression
is
not
None
:
if
compression
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
'`save` function of "soundfile" backend does not support "compression" parameter. '
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
"The argument is silently ignored."
)
)
if
hasattr
(
filepath
,
"write"
):
if
hasattr
(
filepath
,
"write"
):
if
format
is
None
:
if
format
is
None
:
raise
RuntimeError
(
"`format` is required when saving to file object."
)
raise
RuntimeError
(
"`format` is required when saving to file object."
)
ext
=
format
.
lower
()
ext
=
format
.
lower
()
else
:
else
:
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
...
@@ -455,8 +463,7 @@ def save(
...
@@ -455,8 +463,7 @@ def save(
if
bits_per_sample
==
24
:
if
bits_per_sample
==
24
:
warnings
.
warn
(
warnings
.
warn
(
"Saving audio with 24 bits per sample might warp samples near -1. "
"Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this."
"Using 16 bits per sample might be able to avoid this."
)
)
subtype
=
_get_subtype
(
src
.
dtype
,
ext
,
encoding
,
bits_per_sample
)
subtype
=
_get_subtype
(
src
.
dtype
,
ext
,
encoding
,
bits_per_sample
)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
...
@@ -467,7 +474,13 @@ def save(
...
@@ -467,7 +474,13 @@ def save(
if
channels_first
:
if
channels_first
:
src
=
src
.
t
()
src
=
src
.
t
()
soundfile
.
write
(
file
=
filepath
,
data
=
src
,
samplerate
=
sample_rate
,
subtype
=
subtype
,
format
=
format
)
soundfile
.
write
(
file
=
filepath
,
data
=
src
,
samplerate
=
sample_rate
,
subtype
=
subtype
,
format
=
format
)
_SUBTYPE2DTYPE
=
{
_SUBTYPE2DTYPE
=
{
"PCM_S8"
:
"int8"
,
"PCM_S8"
:
"int8"
,
...
@@ -478,14 +491,14 @@ _SUBTYPE2DTYPE = {
...
@@ -478,14 +491,14 @@ _SUBTYPE2DTYPE = {
"DOUBLE"
:
"float64"
,
"DOUBLE"
:
"float64"
,
}
}
def
load
(
def
load
(
filepath
:
str
,
filepath
:
str
,
frame_offset
:
int
=
0
,
frame_offset
:
int
=
0
,
num_frames
:
int
=
-
1
,
num_frames
:
int
=-
1
,
normalize
:
bool
=
True
,
normalize
:
bool
=
True
,
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
int
]:
)
->
Tuple
[
paddle
.
Tensor
,
int
]:
"""Load audio data from file.
"""Load audio data from file.
Note:
Note:
...
@@ -564,7 +577,7 @@ def load(
...
@@ -564,7 +577,7 @@ def load(
waveform
=
paddle
.
to_tensor
(
waveform
)
waveform
=
paddle
.
to_tensor
(
waveform
)
if
channels_first
:
if
channels_first
:
waveform
=
paddle
.
transpose
(
waveform
,
perm
=
[
1
,
0
])
waveform
=
paddle
.
transpose
(
waveform
,
perm
=
[
1
,
0
])
return
waveform
,
sample_rate
return
waveform
,
sample_rate
...
@@ -588,7 +601,8 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
...
@@ -588,7 +601,8 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
"ALAW"
:
8
,
# A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"ALAW"
:
8
,
# A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"IMA_ADPCM"
:
0
,
# IMA ADPCM.
"IMA_ADPCM"
:
0
,
# IMA ADPCM.
"MS_ADPCM"
:
0
,
# Microsoft ADPCM.
"MS_ADPCM"
:
0
,
# Microsoft ADPCM.
"GSM610"
:
0
,
# GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
"GSM610"
:
0
,
# GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
"VOX_ADPCM"
:
0
,
# OKI / Dialogix ADPCM
"VOX_ADPCM"
:
0
,
# OKI / Dialogix ADPCM
"G721_32"
:
0
,
# 32kbs G721 ADPCM encoding.
"G721_32"
:
0
,
# 32kbs G721 ADPCM encoding.
"G723_24"
:
0
,
# 24kbs G723 ADPCM encoding.
"G723_24"
:
0
,
# 24kbs G723 ADPCM encoding.
...
@@ -606,16 +620,17 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
...
@@ -606,16 +620,17 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
"ALAC_32"
:
32
,
# Apple Lossless Audio Codec (32 bit).
"ALAC_32"
:
32
,
# Apple Lossless Audio Codec (32 bit).
}
}
def
_get_bit_depth
(
subtype
):
def
_get_bit_depth
(
subtype
):
if
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
if
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
warnings
.
warn
(
warnings
.
warn
(
f
"The
{
subtype
}
subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
f
"The
{
subtype
}
subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
"You may otherwise ignore this warning."
)
)
return
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
subtype
,
0
)
return
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
subtype
,
0
)
_SUBTYPE_TO_ENCODING
=
{
_SUBTYPE_TO_ENCODING
=
{
"PCM_S8"
:
"PCM_S"
,
"PCM_S8"
:
"PCM_S"
,
"PCM_16"
:
"PCM_S"
,
"PCM_16"
:
"PCM_S"
,
...
@@ -629,12 +644,14 @@ _SUBTYPE_TO_ENCODING = {
...
@@ -629,12 +644,14 @@ _SUBTYPE_TO_ENCODING = {
"VORBIS"
:
"VORBIS"
,
"VORBIS"
:
"VORBIS"
,
}
}
def
_get_encoding
(
format
:
str
,
subtype
:
str
):
def
_get_encoding
(
format
:
str
,
subtype
:
str
):
if
format
==
"FLAC"
:
if
format
==
"FLAC"
:
return
"FLAC"
return
"FLAC"
return
_SUBTYPE_TO_ENCODING
.
get
(
subtype
,
"UNKNOWN"
)
return
_SUBTYPE_TO_ENCODING
.
get
(
subtype
,
"UNKNOWN"
)
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioInfo
:
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioInfo
:
"""Get signal information of an audio file.
"""Get signal information of an audio file.
Note:
Note:
...
@@ -657,5 +674,4 @@ def info(filepath: str, format: Optional[str] = None) -> AudioInfo:
...
@@ -657,5 +674,4 @@ def info(filepath: str, format: Optional[str] = None) -> AudioInfo:
sinfo
.
frames
,
sinfo
.
frames
,
sinfo
.
channels
,
sinfo
.
channels
,
bits_per_sample
=
_get_bit_depth
(
sinfo
.
subtype
),
bits_per_sample
=
_get_bit_depth
(
sinfo
.
subtype
),
encoding
=
_get_encoding
(
sinfo
.
format
,
sinfo
.
subtype
),
encoding
=
_get_encoding
(
sinfo
.
format
,
sinfo
.
subtype
),
)
)
\ No newline at end of file
audio/paddleaudio/backends/sox_io_backend.py
浏览文件 @
bc893c19
from
pathlib
import
Path
import
os
from
typing
import
Callable
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Tuple
import
paddle
import
paddle
import
paddleaudio
import
paddleaudio
from
paddle
import
Tensor
from
paddle
import
Tensor
from
.common
import
AudioInfo
from
paddleaudio._internal
import
module_utils
as
_mod_utils
import
os
from
paddleaudio._internal
import
module_utils
as
_mod_utils
from
.common
import
AudioInfo
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py
def
_fail_info
(
filepath
:
str
,
format
:
Optional
[
str
])
->
AudioInfo
:
def
_fail_info
(
filepath
:
str
,
format
:
Optional
[
str
])
->
AudioInfo
:
raise
RuntimeError
(
"Failed to fetch metadata from {}"
.
format
(
filepath
))
raise
RuntimeError
(
"Failed to fetch metadata from {}"
.
format
(
filepath
))
...
@@ -22,73 +22,78 @@ def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioInfo:
...
@@ -22,73 +22,78 @@ def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioInfo:
# Note: need to comply TorchScript syntax -- need annotation and no f-string
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def
_fail_load
(
def
_fail_load
(
filepath
:
str
,
filepath
:
str
,
frame_offset
:
int
=
0
,
frame_offset
:
int
=
0
,
num_frames
:
int
=
-
1
,
num_frames
:
int
=-
1
,
normalize
:
bool
=
True
,
normalize
:
bool
=
True
,
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Tensor
,
int
]:
)
->
Tuple
[
Tensor
,
int
]:
raise
RuntimeError
(
"Failed to load audio from {}"
.
format
(
filepath
))
raise
RuntimeError
(
"Failed to load audio from {}"
.
format
(
filepath
))
def
_fail_load_fileobj
(
fileobj
,
*
args
,
**
kwargs
):
def
_fail_load_fileobj
(
fileobj
,
*
args
,
**
kwargs
):
raise
RuntimeError
(
f
"Failed to load audio from
{
fileobj
}
"
)
raise
RuntimeError
(
f
"Failed to load audio from
{
fileobj
}
"
)
_fallback_info
=
_fail_info
_fallback_info
=
_fail_info
_fallback_info_fileobj
=
_fail_info_fileobj
_fallback_info_fileobj
=
_fail_info_fileobj
_fallback_load
=
_fail_load
_fallback_load
=
_fail_load
_fallback_load_filebj
=
_fail_load_fileobj
_fallback_load_filebj
=
_fail_load_fileobj
@
_mod_utils
.
requires_sox
()
@
_mod_utils
.
requires_sox
()
def
load
(
def
load
(
filepath
:
str
,
filepath
:
str
,
frame_offset
:
int
=
0
,
frame_offset
:
int
=
0
,
num_frames
:
int
=-
1
,
num_frames
:
int
=-
1
,
normalize
:
bool
=
True
,
normalize
:
bool
=
True
,
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
format
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Tensor
,
int
]:
format
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
Tensor
,
int
]:
if
hasattr
(
filepath
,
"read"
):
if
hasattr
(
filepath
,
"read"
):
ret
=
paddleaudio
.
_paddleaudio
.
load_audio_fileobj
(
ret
=
paddleaudio
.
_paddleaudio
.
load_audio_fileobj
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
)
format
)
if
ret
is
not
None
:
if
ret
is
not
None
:
audio_tensor
=
paddle
.
to_tensor
(
ret
[
0
])
audio_tensor
=
paddle
.
to_tensor
(
ret
[
0
])
return
(
audio_tensor
,
ret
[
1
])
return
(
audio_tensor
,
ret
[
1
])
return
_fallback_load_fileobj
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
return
_fallback_load_fileobj
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
filepath
=
os
.
fspath
(
filepath
)
filepath
=
os
.
fspath
(
filepath
)
ret
=
paddleaudio
.
_paddleaudio
.
sox_io_load_audio_file
(
ret
=
paddleaudio
.
_paddleaudio
.
sox_io_load_audio_file
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
)
if
ret
is
not
None
:
if
ret
is
not
None
:
audio_tensor
=
paddle
.
to_tensor
(
ret
[
0
])
audio_tensor
=
paddle
.
to_tensor
(
ret
[
0
])
return
(
audio_tensor
,
ret
[
1
])
return
(
audio_tensor
,
ret
[
1
])
return
_fallback_load
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
return
_fallback_load
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
@
_mod_utils
.
requires_sox
()
@
_mod_utils
.
requires_sox
()
def
save
(
filepath
:
str
,
def
save
(
src
:
Tenso
r
,
filepath
:
st
r
,
sample_rate
:
int
,
src
:
Tensor
,
channels_first
:
bool
=
True
,
sample_rate
:
int
,
compression
:
Optional
[
float
]
=
Non
e
,
channels_first
:
bool
=
Tru
e
,
format
:
Optional
[
str
]
=
None
,
compression
:
Optional
[
float
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
):
bits_per_sample
:
Optional
[
int
]
=
None
,
):
src_arr
=
src
.
numpy
()
src_arr
=
src
.
numpy
()
if
hasattr
(
filepath
,
"write"
):
if
hasattr
(
filepath
,
"write"
):
paddleaudio
.
_paddleaudio
.
save_audio_fileobj
(
paddleaudio
.
_paddleaudio
.
save_audio_fileobj
(
filepath
,
src_arr
,
sample_rate
,
channels_first
,
compression
,
format
,
encoding
,
bits_per_sample
filepath
,
src_arr
,
sample_rate
,
channels_first
,
compression
,
format
,
)
encoding
,
bits_per_sample
)
return
return
filepath
=
os
.
fspath
(
filepath
)
filepath
=
os
.
fspath
(
filepath
)
paddleaudio
.
_paddleaudio
.
sox_io_save_audio_file
(
paddleaudio
.
_paddleaudio
.
sox_io_save_audio_file
(
filepath
,
src_arr
,
sample_rate
,
channels_first
,
compression
,
format
,
encoding
,
bits_per_sample
filepath
,
src_arr
,
sample_rate
,
channels_first
,
compression
,
format
,
)
encoding
,
bits_per_sample
)
@
_mod_utils
.
requires_sox
()
@
_mod_utils
.
requires_sox
()
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
,)
->
AudioInfo
:
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
,
)
->
AudioInfo
:
if
hasattr
(
filepath
,
"read"
):
if
hasattr
(
filepath
,
"read"
):
sinfo
=
paddleaudio
.
_paddleaudio
.
get_info_fileobj
(
filepath
,
format
)
sinfo
=
paddleaudio
.
_paddleaudio
.
get_info_fileobj
(
filepath
,
format
)
if
sinfo
is
not
None
:
if
sinfo
is
not
None
:
...
...
audio/paddleaudio/backends/utils.py
浏览文件 @
bc893c19
"""Defines utilities for switching audio backends"""
"""Defines utilities for switching audio backends"""
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/utils.py
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/utils.py
import
warnings
import
warnings
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
...
@@ -8,7 +7,9 @@ from typing import Optional
...
@@ -8,7 +7,9 @@ from typing import Optional
import
paddleaudio
import
paddleaudio
from
paddleaudio._internal
import
module_utils
as
_mod_utils
from
paddleaudio._internal
import
module_utils
as
_mod_utils
from
.
import
no_backend
,
soundfile_backend
,
sox_io_backend
from
.
import
no_backend
from
.
import
soundfile_backend
from
.
import
sox_io_backend
__all__
=
[
__all__
=
[
"list_audio_backends"
,
"list_audio_backends"
,
...
@@ -55,6 +56,7 @@ def set_audio_backend(backend: Optional[str]):
...
@@ -55,6 +56,7 @@ def set_audio_backend(backend: Optional[str]):
for
func
in
[
"save"
,
"load"
,
"info"
]:
for
func
in
[
"save"
,
"load"
,
"info"
]:
setattr
(
paddleaudio
,
func
,
getattr
(
module
,
func
))
setattr
(
paddleaudio
,
func
,
getattr
(
module
,
func
))
def
_init_audio_backend
():
def
_init_audio_backend
():
backends
=
list_audio_backends
()
backends
=
list_audio_backends
()
if
"soundfile"
in
backends
:
if
"soundfile"
in
backends
:
...
...
audio/paddleaudio/utils/__init__.py
浏览文件 @
bc893c19
...
@@ -21,7 +21,7 @@ from .env import USER_HOME
...
@@ -21,7 +21,7 @@ from .env import USER_HOME
from
.error
import
ParameterError
from
.error
import
ParameterError
from
.log
import
Logger
from
.log
import
Logger
from
.log
import
logger
from
.log
import
logger
from
.time
import
seconds_to_hms
from
.time
import
Timer
from
.numeric
import
depth_convert
from
.numeric
import
depth_convert
from
.numeric
import
pcm16to32
from
.numeric
import
pcm16to32
from
.time
import
seconds_to_hms
from
.time
import
Timer
audio/paddleaudio/utils/tensor_utils.py
0 → 100644
浏览文件 @
bc893c19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
from
typing
import
List
from
typing
import
Tuple
import
paddle
from
.log
import
Logger
__all__
=
[
"pad_sequence"
,
"add_sos_eos"
,
"th_accuracy"
,
"has_tensor"
]
logger
=
Logger
(
__name__
)
def
has_tensor
(
val
):
if
isinstance
(
val
,
(
list
,
tuple
)):
for
item
in
val
:
if
has_tensor
(
item
):
return
True
elif
isinstance
(
val
,
dict
):
for
k
,
v
in
val
.
items
():
print
(
k
)
if
has_tensor
(
v
):
return
True
else
:
return
paddle
.
is_tensor
(
val
)
def
pad_sequence
(
sequences
:
List
[
paddle
.
Tensor
],
batch_first
:
bool
=
False
,
padding_value
:
float
=
0.0
)
->
paddle
.
Tensor
:
r
"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
paddle
.
shape
(
sequences
[
0
])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims
=
tuple
(
max_size
[
1
:].
numpy
().
tolist
())
if
sequences
[
0
].
ndim
>=
2
else
()
max_len
=
max
([
s
.
shape
[
0
]
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
else
:
out_dims
=
(
max_len
,
len
(
sequences
))
+
trailing_dims
out_tensor
=
paddle
.
full
(
out_dims
,
padding_value
,
sequences
[
0
].
dtype
)
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
if
batch_first
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if
length
!=
0
:
out_tensor
[
i
,
:
length
]
=
tensor
else
:
out_tensor
[
i
,
length
]
=
tensor
else
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if
length
!=
0
:
out_tensor
[:
length
,
i
]
=
tensor
else
:
out_tensor
[
length
,
i
]
=
tensor
return
out_tensor
def
add_sos_eos
(
ys_pad
:
paddle
.
Tensor
,
sos
:
int
,
eos
:
int
,
ignore_id
:
int
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
ys_in
=
paddle
.
cat
([
_sos
,
ys_pad
],
dim
=
1
)
mask_pad
=
(
ys_in
==
ignore_id
)
ys_in
=
ys_in
.
masked_fill
(
mask_pad
,
eos
)
ys_out
=
paddle
.
cat
([
ys_pad
,
_eos
],
dim
=
1
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
eos
)
mask_eos
=
(
ys_out
==
ignore_id
)
ys_out
=
ys_out
.
masked_fill
(
mask_eos
,
eos
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
ignore_id
)
return
ys_in
,
ys_out
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
pad_targets
:
paddle
.
Tensor
,
ignore_label
:
int
)
->
float
:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
shape
[
0
],
pad_targets
.
shape
[
1
],
pad_outputs
.
shape
[
1
]).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator
=
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
numerator
=
paddle
.
sum
(
numerator
.
type_as
(
pad_targets
))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
audio/tests/backends/soundfile/common.py
浏览文件 @
bc893c19
import
itertools
import
itertools
from
unittest
import
skipIf
from
unittest
import
skipIf
from
parameterized
import
parameterized
from
paddleaudio._internal.module_utils
import
is_module_available
from
paddleaudio._internal.module_utils
import
is_module_available
from
parameterized
import
parameterized
def
name_func
(
func
,
_
,
params
):
def
name_func
(
func
,
_
,
params
):
...
@@ -31,7 +31,8 @@ def skipIfFormatNotSupported(fmt):
...
@@ -31,7 +31,8 @@ def skipIfFormatNotSupported(fmt):
def
parameterize
(
*
params
):
def
parameterize
(
*
params
):
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
def
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
):
def
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
):
...
@@ -54,4 +55,3 @@ def fetch_wav_subtype(dtype, encoding, bits_per_sample):
...
@@ -54,4 +55,3 @@ def fetch_wav_subtype(dtype, encoding, bits_per_sample):
if
subtype
:
if
subtype
:
return
subtype
return
subtype
raise
ValueError
(
f
"wav does not support (
{
encoding
}
,
{
bits_per_sample
}
)."
)
raise
ValueError
(
f
"wav does not support (
{
encoding
}
,
{
bits_per_sample
}
)."
)
audio/tests/backends/soundfile/info_test.py
浏览文件 @
bc893c19
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/info_test.py
import
tarfile
import
tarfile
import
warnings
import
unittest
import
unittest
import
warnings
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
paddle
import
paddle
from
paddleaudio._internal
import
module_utils
as
_mod_utils
import
soundfile
from
common
import
parameterize
from
common
import
skipIfFormatNotSupported
from
paddleaudio.backends
import
soundfile_backend
from
paddleaudio.backends
import
soundfile_backend
from
tests.backends.common
import
get_bits_per_sample
,
get_encoding
from
tests.common_utils
import
(
get_wav_data
,
nested_params
,
save_wav
,
TempDirMixin
,
)
from
common
import
parameterize
,
skipIfFormatNotSupported
from
tests.backends.common
import
get_bits_per_sample
from
tests.backends.common
import
get_encoding
import
soundfile
from
tests.common_utils
import
get_wav_data
from
tests.common_utils
import
nested_params
from
tests.common_utils
import
save_wav
from
tests.common_utils
import
TempDirMixin
class
TestInfo
(
TempDirMixin
,
unittest
.
TestCase
):
class
TestInfo
(
TempDirMixin
,
unittest
.
TestCase
):
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
],
[
"float32"
,
"int32"
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file correctly"""
"""`soundfile_backend.info` can check wav file correctly"""
duration
=
1
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
soundfile_backend
.
info
(
path
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
...
@@ -62,32 +62,31 @@ class TestInfo(TempDirMixin, unittest.TestCase):
...
@@ -62,32 +62,31 @@ class TestInfo(TempDirMixin, unittest.TestCase):
#@parameterize([8000, 16000], [1, 2])
#@parameterize([8000, 16000], [1, 2])
#@skipIfFormatNotSupported("OGG")
#@skipIfFormatNotSupported("OGG")
#def test_ogg(self, sample_rate, num_channels):
#def test_ogg(self, sample_rate, num_channels):
#"""`soundfile_backend.info` can check ogg file correctly"""
#"""`soundfile_backend.info` can check ogg file correctly"""
#duration = 1
#duration = 1
#num_frames = sample_rate * duration
#num_frames = sample_rate * duration
##data = torch.randn(num_frames, num_channels).numpy()
##data = torch.randn(num_frames, num_channels).numpy()
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
#data = paddle.randn(shape=[num_frames, num_channels]).numpy()
#print(len(data))
#print(len(data))
#path = self.get_temp_path("data.ogg")
#path = self.get_temp_path("data.ogg")
#soundfile.write(path, data, sample_rate)
#soundfile.write(path, data, sample_rate)
#info = soundfile_backend.info(path)
#info = soundfile_backend.info(path)
#print(info)
#print(info)
#assert info.sample_rate == sample_rate
#assert info.sample_rate == sample_rate
#print("info")
#print("info")
#print(info.num_frames)
#print(info.num_frames)
#print("jiji")
#print("jiji")
#print(sample_rate*duration)
#print(sample_rate*duration)
##assert info.num_frames == sample_rate * duration
##assert info.num_frames == sample_rate * duration
#assert info.num_channels == num_channels
#assert info.num_channels == num_channels
#assert info.bits_per_sample == 0
#assert info.bits_per_sample == 0
#assert info.encoding == "VORBIS"
#assert info.encoding == "VORBIS"
@
nested_params
(
@
nested_params
(
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[(
"PCM_24"
,
24
),
(
"PCM_32"
,
32
)],
[(
"PCM_24"
,
24
),
(
"PCM_32"
,
32
)],
)
)
@
skipIfFormatNotSupported
(
"NIST"
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
subtype_and_bit_depth
):
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
subtype_and_bit_depth
):
"""`soundfile_backend.info` can check sph file correctly"""
"""`soundfile_backend.info` can check sph file correctly"""
...
@@ -127,7 +126,8 @@ class TestInfo(TempDirMixin, unittest.TestCase):
...
@@ -127,7 +126,8 @@ class TestInfo(TempDirMixin, unittest.TestCase):
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
info
=
soundfile_backend
.
info
(
"foo"
)
info
=
soundfile_backend
.
info
(
"foo"
)
assert
len
(
w
)
==
1
assert
len
(
w
)
==
1
assert
"UNSEEN_SUBTYPE subtype is unknown to PaddleAudio"
in
str
(
w
[
-
1
].
message
)
assert
"UNSEEN_SUBTYPE subtype is unknown to PaddleAudio"
in
str
(
w
[
-
1
].
message
)
assert
info
.
bits_per_sample
==
0
assert
info
.
bits_per_sample
==
0
...
@@ -195,5 +195,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
...
@@ -195,5 +195,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
"""Query compressed audio via file-like object works"""
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
"flac"
,
"PCM_16"
,
16
)
self
.
_test_tarobj
(
"flac"
,
"PCM_16"
,
16
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
audio/tests/backends/soundfile/load_test.py
浏览文件 @
bc893c19
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py
#this code is from: https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/backend/soundfile/load_test.py
import
os
import
os
import
tarfile
import
tarfile
import
unittest
import
unittest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
numpy
as
np
from
parameterized
import
parameterized
import
numpy
as
np
import
paddle
import
paddle
from
paddleaudio._internal
import
module_utils
as
_mod_utils
import
soundfile
from
common
import
dtype2subtype
from
common
import
parameterize
from
common
import
skipIfFormatNotSupported
from
paddleaudio.backends
import
soundfile_backend
from
paddleaudio.backends
import
soundfile_backend
from
tests.backends.common
import
get_bits_per_sample
,
get_encoding
from
parameterized
import
parameterized
from
tests.common_utils
import
(
get_wav_data
,
load_wav
,
nested_params
,
normalize_wav
,
save_wav
,
TempDirMixin
,
)
from
common
import
dtype2subtype
,
parameterize
,
skipIfFormatNotSupported
import
soundfile
from
tests.common_utils
import
get_wav_data
from
tests.common_utils
import
load_wav
from
tests.common_utils
import
normalize_wav
from
tests.common_utils
import
save_wav
from
tests.common_utils
import
TempDirMixin
def
_get_mock_path
(
def
_get_mock_path
(
ext
:
str
,
ext
:
str
,
dtype
:
str
,
dtype
:
str
,
sample_rate
:
int
,
sample_rate
:
int
,
num_channels
:
int
,
num_channels
:
int
,
num_frames
:
int
,
num_frames
:
int
,
):
):
return
f
"
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
num_frames
}
.
{
ext
}
"
return
f
"
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
num_frames
}
.
{
ext
}
"
...
@@ -87,9 +81,8 @@ class SoundFileMock:
...
@@ -87,9 +81,8 @@ class SoundFileMock:
self
.
_params
[
"num_channels"
],
self
.
_params
[
"num_channels"
],
normalize
=
False
,
normalize
=
False
,
num_frames
=
self
.
_params
[
"num_frames"
],
num_frames
=
self
.
_params
[
"num_frames"
],
channels_first
=
False
,
channels_first
=
False
,
).
numpy
()
).
numpy
()
return
data
[
self
.
_start
:
self
.
_start
+
frames
]
return
data
[
self
.
_start
:
self
.
_start
+
frames
]
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
@@ -99,13 +92,17 @@ class SoundFileMock:
...
@@ -99,13 +92,17 @@ class SoundFileMock:
class
MockedLoadTest
(
unittest
.
TestCase
):
class
MockedLoadTest
(
unittest
.
TestCase
):
def
assert_dtype
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
def
assert_dtype
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames
=
3
*
sample_rate
num_frames
=
3
*
sample_rate
path
=
_get_mock_path
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
path
=
_get_mock_path
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
expected_dtype
=
paddle
.
float32
if
normalize
or
ext
not
in
[
"wav"
,
"nist"
]
else
getattr
(
paddle
,
dtype
)
expected_dtype
=
paddle
.
float32
if
normalize
or
ext
not
in
[
"wav"
,
"nist"
]
else
getattr
(
paddle
,
dtype
)
with
patch
(
"soundfile.SoundFile"
,
SoundFileMock
):
with
patch
(
"soundfile.SoundFile"
,
SoundFileMock
):
found
,
sr
=
soundfile_backend
.
load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
found
,
sr
=
soundfile_backend
.
load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
assert
found
.
dtype
==
expected_dtype
assert
found
.
dtype
==
expected_dtype
assert
sample_rate
==
sr
assert
sample_rate
==
sr
...
@@ -114,44 +111,47 @@ class MockedLoadTest(unittest.TestCase):
...
@@ -114,44 +111,47 @@ class MockedLoadTest(unittest.TestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
],
[
True
,
False
],
[
True
,
False
],
)
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
channels_first
):
"""Returns native dtype when normalize=False else float32"""
"""Returns native dtype when normalize=False else float32"""
self
.
assert_dtype
(
"wav"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
self
.
assert_dtype
(
"wav"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
(
@
parameterize
(
[
"int32"
],
[
"int32"
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
],
[
True
,
False
],
[
True
,
False
],
)
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
channels_first
):
"""Returns float32 always"""
"""Returns float32 always"""
self
.
assert_dtype
(
"sph"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
self
.
assert_dtype
(
"sph"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
def
test_ogg
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
def
test_ogg
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns float32 always"""
"""Returns float32 always"""
self
.
assert_dtype
(
"ogg"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
self
.
assert_dtype
(
"ogg"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
def
test_flac
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`soundfile_backend.load` can load ogg format."""
"""`soundfile_backend.load` can load ogg format."""
self
.
assert_dtype
(
"flac"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
self
.
assert_dtype
(
"flac"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
class
LoadTestBase
(
TempDirMixin
,
unittest
.
TestCase
):
class
LoadTestBase
(
TempDirMixin
,
unittest
.
TestCase
):
def
assert_wav
(
def
assert_wav
(
self
,
self
,
dtype
,
dtype
,
sample_rate
,
sample_rate
,
num_channels
,
num_channels
,
normalize
,
normalize
,
channels_first
=
True
,
channels_first
=
True
,
duration
=
1
,
duration
=
1
,
):
):
"""`soundfile_backend.load` can load wav format correctly.
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
Wav data loaded with soundfile backend should match those with scipy
...
@@ -163,22 +163,22 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
...
@@ -163,22 +163,22 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
num_channels
,
num_channels
,
normalize
=
normalize
,
normalize
=
normalize
,
num_frames
=
num_frames
,
num_frames
=
num_frames
,
channels_first
=
channels_first
,
channels_first
=
channels_first
,
)
)
save_wav
(
path
,
data
,
sample_rate
,
channels_first
=
channels_first
)
save_wav
(
path
,
data
,
sample_rate
,
channels_first
=
channels_first
)
expected
=
load_wav
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)[
0
]
expected
=
load_wav
(
data
,
sr
=
soundfile_backend
.
load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)[
0
]
data
,
sr
=
soundfile_backend
.
load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
assert
sr
==
sample_rate
np
.
testing
.
assert_array_almost_equal
(
data
.
numpy
(),
expected
.
numpy
())
np
.
testing
.
assert_array_almost_equal
(
data
.
numpy
(),
expected
.
numpy
())
def
assert_sphere
(
def
assert_sphere
(
self
,
self
,
dtype
,
dtype
,
sample_rate
,
sample_rate
,
num_channels
,
num_channels
,
channels_first
=
True
,
channels_first
=
True
,
duration
=
1
,
duration
=
1
,
):
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
"""`soundfile_backend.load` can load SPHERE format correctly."""
path
=
self
.
get_temp_path
(
"reference.sph"
)
path
=
self
.
get_temp_path
(
"reference.sph"
)
num_frames
=
duration
*
sample_rate
num_frames
=
duration
*
sample_rate
...
@@ -187,9 +187,9 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
...
@@ -187,9 +187,9 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
num_channels
,
num_channels
,
num_frames
=
num_frames
,
num_frames
=
num_frames
,
normalize
=
False
,
normalize
=
False
,
channels_first
=
False
,
channels_first
=
False
,
)
)
soundfile
.
write
(
soundfile
.
write
(
path
,
raw
,
sample_rate
,
subtype
=
dtype2subtype
(
dtype
),
format
=
"NIST"
)
path
,
raw
,
sample_rate
,
subtype
=
dtype2subtype
(
dtype
),
format
=
"NIST"
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
data
,
sr
=
soundfile_backend
.
load
(
path
,
channels_first
=
channels_first
)
data
,
sr
=
soundfile_backend
.
load
(
path
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
assert
sr
==
sample_rate
...
@@ -197,13 +197,12 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
...
@@ -197,13 +197,12 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
np
.
testing
.
assert_array_almost_equal
(
data
.
numpy
(),
expected
.
numpy
())
np
.
testing
.
assert_array_almost_equal
(
data
.
numpy
(),
expected
.
numpy
())
def
assert_flac
(
def
assert_flac
(
self
,
self
,
dtype
,
dtype
,
sample_rate
,
sample_rate
,
num_channels
,
num_channels
,
channels_first
=
True
,
channels_first
=
True
,
duration
=
1
,
duration
=
1
,
):
):
"""`soundfile_backend.load` can load FLAC format correctly."""
"""`soundfile_backend.load` can load FLAC format correctly."""
path
=
self
.
get_temp_path
(
"reference.flac"
)
path
=
self
.
get_temp_path
(
"reference.flac"
)
num_frames
=
duration
*
sample_rate
num_frames
=
duration
*
sample_rate
...
@@ -212,15 +211,13 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
...
@@ -212,15 +211,13 @@ class LoadTestBase(TempDirMixin, unittest.TestCase):
num_channels
,
num_channels
,
num_frames
=
num_frames
,
num_frames
=
num_frames
,
normalize
=
False
,
normalize
=
False
,
channels_first
=
False
,
channels_first
=
False
,
)
)
soundfile
.
write
(
path
,
raw
,
sample_rate
)
soundfile
.
write
(
path
,
raw
,
sample_rate
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
data
,
sr
=
soundfile_backend
.
load
(
path
,
channels_first
=
channels_first
)
data
,
sr
=
soundfile_backend
.
load
(
path
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
assert
sr
==
sample_rate
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
#self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
np
.
testing
.
assert_array_almost_equal
(
data
.
numpy
(),
expected
.
numpy
())
np
.
testing
.
assert_array_almost_equal
(
data
.
numpy
(),
expected
.
numpy
())
class
TestLoad
(
LoadTestBase
):
class
TestLoad
(
LoadTestBase
):
...
@@ -231,41 +228,43 @@ class TestLoad(LoadTestBase):
...
@@ -231,41 +228,43 @@ class TestLoad(LoadTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
)
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
channels_first
):
"""`soundfile_backend.load` can load wav format correctly."""
"""`soundfile_backend.load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
(
@
parameterize
(
[
"int32"
],
[
"int32"
],
[
16000
],
[
16000
],
[
2
],
[
2
],
[
False
],
[
False
],
)
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`soundfile_backend.load` can load large wav file correctly."""
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
two_hours
)
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
two_hours
)
@
parameterize
([
"float32"
,
"int32"
],
[
4
,
8
,
16
,
32
],
[
False
,
True
])
@
parameterize
([
"float32"
,
"int32"
],
[
4
,
8
,
16
,
32
],
[
False
,
True
])
def
test_multiple_channels
(
self
,
dtype
,
num_channels
,
channels_first
):
def
test_multiple_channels
(
self
,
dtype
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate
=
8000
sample_rate
=
8000
normalize
=
False
normalize
=
False
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@skipIfFormatNotSupported("NIST")
#@skipIfFormatNotSupported("NIST")
#def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
#def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
#"""`soundfile_backend.load` can load sphere format correctly."""
#"""`soundfile_backend.load` can load sphere format correctly."""
#self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
#self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@parameterize(["int32"], [8000, 16000], [1, 2], [False, True])
#@skipIfFormatNotSupported("FLAC")
#@skipIfFormatNotSupported("FLAC")
#def test_flac(self, dtype, sample_rate, num_channels, channels_first):
#def test_flac(self, dtype, sample_rate, num_channels, channels_first):
#"""`soundfile_backend.load` can load flac format correctly."""
#"""`soundfile_backend.load` can load flac format correctly."""
#self.assert_flac(dtype, sample_rate, num_channels, channels_first)
#self.assert_flac(dtype, sample_rate, num_channels, channels_first)
class
TestLoadFormat
(
TempDirMixin
,
unittest
.
TestCase
):
class
TestLoadFormat
(
TempDirMixin
,
unittest
.
TestCase
):
...
@@ -291,21 +290,17 @@ class TestLoadFormat(TempDirMixin, unittest.TestCase):
...
@@ -291,21 +290,17 @@ class TestLoadFormat(TempDirMixin, unittest.TestCase):
#self.assertEqual(found, expected)
#self.assertEqual(found, expected)
np
.
testing
.
assert_array_almost_equal
(
found
,
expected
)
np
.
testing
.
assert_array_almost_equal
(
found
,
expected
)
@
parameterized
.
expand
(
@
parameterized
.
expand
([
[
(
"WAV"
,
),
(
"WAV"
,),
(
"wav"
,
),
(
"wav"
,),
])
]
)
def
test_wav
(
self
,
format_
):
def
test_wav
(
self
,
format_
):
self
.
_test_format
(
format_
)
self
.
_test_format
(
format_
)
@
parameterized
.
expand
(
@
parameterized
.
expand
([
[
(
"FLAC"
,
),
(
"FLAC"
,),
(
"flac"
,
),
(
"flac"
,),
])
]
)
@
skipIfFormatNotSupported
(
"FLAC"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
format_
):
def
test_flac
(
self
,
format_
):
self
.
_test_format
(
format_
)
self
.
_test_format
(
format_
)
...
@@ -356,7 +351,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
...
@@ -356,7 +351,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
#self.assertEqual(expected, found)
#self.assertEqual(expected, found)
np
.
testing
.
assert_array_almost_equal
(
found
.
numpy
(),
expected
)
np
.
testing
.
assert_array_almost_equal
(
found
.
numpy
(),
expected
)
def
test_tarfile_wav
(
self
):
def
test_tarfile_wav
(
self
):
"""Loading audio via file-like object works"""
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
"wav"
)
self
.
_test_tarfile
(
"wav"
)
...
@@ -365,5 +359,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
...
@@ -365,5 +359,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
"""Loading audio via file-like object works"""
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
"flac"
)
self
.
_test_tarfile
(
"flac"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
audio/tests/backends/soundfile/save_test.py
浏览文件 @
bc893c19
...
@@ -2,23 +2,18 @@ import io
...
@@ -2,23 +2,18 @@ import io
import
unittest
import
unittest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
paddleaudio._internal
import
module_utils
as
_mod_utils
from
paddleaudio.backends
import
soundfile_backend
from
tests.common_utils
import
(
get_wav_data
,
load_wav
,
nested_params
,
normalize_wav
,
save_wav
,
TempDirMixin
,
)
from
common
import
fetch_wav_subtype
,
parameterize
,
skipIfFormatNotSupported
import
paddle
import
numpy
as
np
import
numpy
as
np
import
paddle
import
soundfile
import
soundfile
from
common
import
fetch_wav_subtype
from
common
import
parameterize
from
common
import
skipIfFormatNotSupported
from
paddleaudio.backends
import
soundfile_backend
from
tests.common_utils
import
get_wav_data
from
tests.common_utils
import
load_wav
from
tests.common_utils
import
nested_params
from
tests.common_utils
import
TempDirMixin
class
MockedSaveTest
(
unittest
.
TestCase
):
class
MockedSaveTest
(
unittest
.
TestCase
):
...
@@ -41,10 +36,10 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -41,10 +36,10 @@ class MockedSaveTest(unittest.TestCase):
(
"ULAW"
,
8
),
(
"ULAW"
,
8
),
(
"ALAW"
,
None
),
(
"ALAW"
,
None
),
(
"ALAW"
,
8
),
(
"ALAW"
,
8
),
],
],
)
)
@
patch
(
"soundfile.write"
)
@
patch
(
"soundfile.write"
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath
=
"foo.wav"
filepath
=
"foo.wav"
input_tensor
=
get_wav_data
(
input_tensor
=
get_wav_data
(
...
@@ -52,8 +47,7 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -52,8 +47,7 @@ class MockedSaveTest(unittest.TestCase):
num_channels
,
num_channels
,
num_frames
=
3
*
sample_rate
,
num_frames
=
3
*
sample_rate
,
normalize
=
dtype
==
"float32"
,
normalize
=
dtype
==
"float32"
,
channels_first
=
channels_first
,
channels_first
=
channels_first
,
)
)
input_tensor
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
input_tensor
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
encoding
,
bits_per_sample
=
enc_params
encoding
,
bits_per_sample
=
enc_params
...
@@ -63,33 +57,32 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -63,33 +57,32 @@ class MockedSaveTest(unittest.TestCase):
sample_rate
,
sample_rate
,
channels_first
=
channels_first
,
channels_first
=
channels_first
,
encoding
=
encoding
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
bits_per_sample
=
bits_per_sample
,
)
)
# on +Py3.8 call_args.kwargs is more descreptive
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"format"
]
is
None
assert
args
[
"format"
]
is
None
tensor_result
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
if
channels_first
else
input_tensor
tensor_result
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
if
channels_first
else
input_tensor
#self.assertEqual(args["data"], tensor_result.numpy())
#self.assertEqual(args["data"], tensor_result.numpy())
np
.
testing
.
assert_array_almost_equal
(
args
[
"data"
].
numpy
(),
tensor_result
.
numpy
())
np
.
testing
.
assert_array_almost_equal
(
args
[
"data"
].
numpy
(),
tensor_result
.
numpy
())
@
patch
(
"soundfile.write"
)
@
patch
(
"soundfile.write"
)
def
assert_non_wav
(
def
assert_non_wav
(
self
,
self
,
fmt
,
fmt
,
dtype
,
dtype
,
sample_rate
,
sample_rate
,
num_channels
,
num_channels
,
channels_first
,
channels_first
,
mocked_write
,
mocked_write
,
encoding
=
None
,
encoding
=
None
,
bits_per_sample
=
None
,
bits_per_sample
=
None
,
):
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath
=
f
"foo.
{
fmt
}
"
filepath
=
f
"foo.
{
fmt
}
"
input_tensor
=
get_wav_data
(
input_tensor
=
get_wav_data
(
...
@@ -97,11 +90,11 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -97,11 +90,11 @@ class MockedSaveTest(unittest.TestCase):
num_channels
,
num_channels
,
num_frames
=
3
*
sample_rate
,
num_frames
=
3
*
sample_rate
,
normalize
=
False
,
normalize
=
False
,
channels_first
=
channels_first
,
channels_first
=
channels_first
,
)
)
input_tensor
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
input_tensor
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
expected_data
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
if
channels_first
else
input_tensor
expected_data
=
paddle
.
transpose
(
input_tensor
,
[
1
,
0
])
if
channels_first
else
input_tensor
soundfile_backend
.
save
(
soundfile_backend
.
save
(
filepath
,
filepath
,
...
@@ -109,8 +102,7 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -109,8 +102,7 @@ class MockedSaveTest(unittest.TestCase):
sample_rate
,
sample_rate
,
channels_first
,
channels_first
,
encoding
=
encoding
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
bits_per_sample
=
bits_per_sample
,
)
)
# on +Py3.8 call_args.kwargs is more descreptive
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
args
=
mocked_write
.
call_args
[
1
]
...
@@ -120,7 +112,8 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -120,7 +112,8 @@ class MockedSaveTest(unittest.TestCase):
assert
args
[
"format"
]
==
"NIST"
assert
args
[
"format"
]
==
"NIST"
else
:
else
:
assert
args
[
"format"
]
is
None
assert
args
[
"format"
]
is
None
np
.
testing
.
assert_array_almost_equal
(
args
[
"data"
].
numpy
(),
expected_data
.
numpy
())
np
.
testing
.
assert_array_almost_equal
(
args
[
"data"
].
numpy
(),
expected_data
.
numpy
())
#self.assertEqual(args["data"], expected_data)
#self.assertEqual(args["data"], expected_data)
@
nested_params
(
@
nested_params
(
...
@@ -139,45 +132,57 @@ class MockedSaveTest(unittest.TestCase):
...
@@ -139,45 +132,57 @@ class MockedSaveTest(unittest.TestCase):
(
"ALAW"
,
16
),
(
"ALAW"
,
16
),
(
"ALAW"
,
24
),
(
"ALAW"
,
24
),
(
"ALAW"
,
32
),
(
"ALAW"
,
32
),
],
],
)
)
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
):
enc_params
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
encoding
,
bits_per_sample
=
enc_params
encoding
,
bits_per_sample
=
enc_params
self
.
assert_non_wav
(
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
fmt
,
)
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
@
parameterize
(
[
"int32"
],
[
"int32"
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
[
8
,
16
,
24
],
[
8
,
16
,
24
],
)
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
):
bits_per_sample
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
@
parameterize
(
[
"int32"
],
[
"int32"
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
)
)
def
test_ogg
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_ogg
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"ogg"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
self
.
assert_non_wav
(
"ogg"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
class
SaveTestBase
(
TempDirMixin
,
unittest
.
TestCase
):
class
SaveTestBase
(
TempDirMixin
,
unittest
.
TestCase
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
"""`soundfile_backend.save` can save wav format."""
"""`soundfile_backend.save` can save wav format."""
path
=
self
.
get_temp_path
(
"data.wav"
)
path
=
self
.
get_temp_path
(
"data.wav"
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
soundfile_backend
.
save
(
path
,
expected
,
sample_rate
)
soundfile_backend
.
save
(
path
,
expected
,
sample_rate
)
found
,
sr
=
load_wav
(
path
,
normalize
=
False
)
found
,
sr
=
load_wav
(
path
,
normalize
=
False
)
assert
sample_rate
==
sr
assert
sample_rate
==
sr
...
@@ -192,7 +197,8 @@ class SaveTestBase(TempDirMixin, unittest.TestCase):
...
@@ -192,7 +197,8 @@ class SaveTestBase(TempDirMixin, unittest.TestCase):
"""
"""
num_frames
=
sample_rate
*
3
num_frames
=
sample_rate
*
3
path
=
self
.
get_temp_path
(
f
"data.
{
fmt
}
"
)
path
=
self
.
get_temp_path
(
f
"data.
{
fmt
}
"
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
soundfile_backend
.
save
(
path
,
expected
,
sample_rate
)
soundfile_backend
.
save
(
path
,
expected
,
sample_rate
)
sinfo
=
soundfile
.
info
(
path
)
sinfo
=
soundfile
.
info
(
path
)
assert
sinfo
.
format
==
fmt
.
upper
()
assert
sinfo
.
format
==
fmt
.
upper
()
...
@@ -220,16 +226,14 @@ class TestSave(SaveTestBase):
...
@@ -220,16 +226,14 @@ class TestSave(SaveTestBase):
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
],
[
"float32"
,
"int32"
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save wav format."""
"""`soundfile_backend.save` can save wav format."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
],
[
"float32"
,
"int32"
],
[
4
,
8
,
16
,
32
],
[
4
,
8
,
16
,
32
],
)
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate
=
8000
sample_rate
=
8000
...
@@ -238,8 +242,7 @@ class TestSave(SaveTestBase):
...
@@ -238,8 +242,7 @@ class TestSave(SaveTestBase):
@
parameterize
(
@
parameterize
(
[
"int32"
],
[
"int32"
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)
)
@
skipIfFormatNotSupported
(
"NIST"
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save sph format."""
"""`soundfile_backend.save` can save sph format."""
...
@@ -247,8 +250,7 @@ class TestSave(SaveTestBase):
...
@@ -247,8 +250,7 @@ class TestSave(SaveTestBase):
@
parameterize
(
@
parameterize
(
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)
)
@
skipIfFormatNotSupported
(
"FLAC"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save flac format."""
"""`soundfile_backend.save` can save flac format."""
...
@@ -256,8 +258,7 @@ class TestSave(SaveTestBase):
...
@@ -256,8 +258,7 @@ class TestSave(SaveTestBase):
@
parameterize
(
@
parameterize
(
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)
)
@
skipIfFormatNotSupported
(
"OGG"
)
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_ogg
(
self
,
sample_rate
,
num_channels
):
def
test_ogg
(
self
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save ogg/vorbis format."""
"""`soundfile_backend.save` can save ogg/vorbis format."""
...
@@ -318,5 +319,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
...
@@ -318,5 +319,6 @@ class TestFileObject(TempDirMixin, unittest.TestCase):
"""Saving audio via file-like object works"""
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
"OGG"
)
self
.
_test_fileobj
(
"OGG"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
audio/tests/common_utils/__init__.py
浏览文件 @
bc893c19
from
.wav_utils
import
get_wav_data
,
load_wav
,
save_wav
,
normalize_wav
from
.case_utils
import
name_func
from
.parameterized_utils
import
nested_params
from
.case_utils
import
TempDirMixin
from
.case_utils
import
(
from
.parameterized_utils
import
nested_params
TempDirMixin
,
from
.wav_utils
import
get_wav_data
name_func
from
.wav_utils
import
load_wav
)
from
.wav_utils
import
normalize_wav
from
.wav_utils
import
save_wav
__all__
=
[
__all__
=
[
"get_wav_data"
,
"get_wav_data"
,
"load_wav"
,
"save_wav"
,
"normalize_wav"
,
"get_sinusoid"
,
"load_wav"
,
"name_func"
,
"nested_params"
,
"TempDirMixin"
"save_wav"
,
"normalize_wav"
,
"get_sinusoid"
,
"name_func"
,
"nested_params"
,
"TempDirMixin"
]
]
audio/tests/common_utils/wav_utils.py
浏览文件 @
bc893c19
from
typing
import
Optional
from
typing
import
Optional
import
scipy.io.wavfile
import
paddle
import
paddle
import
numpy
as
np
import
scipy.io.wavfile
def
normalize_wav
(
tensor
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
normalize_wav
(
tensor
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
if
tensor
.
dtype
==
paddle
.
float32
:
if
tensor
.
dtype
==
paddle
.
float32
:
...
@@ -23,13 +23,12 @@ def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor:
...
@@ -23,13 +23,12 @@ def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor:
def
get_wav_data
(
def
get_wav_data
(
dtype
:
str
,
dtype
:
str
,
num_channels
:
int
,
num_channels
:
int
,
*
,
*
,
num_frames
:
Optional
[
int
]
=
None
,
num_frames
:
Optional
[
int
]
=
None
,
normalize
:
bool
=
True
,
normalize
:
bool
=
True
,
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
):
):
"""Generate linear signal of the given dtype and num_channels
"""Generate linear signal of the given dtype and num_channels
Data range is
Data range is
...
@@ -53,25 +52,26 @@ def get_wav_data(
...
@@ -53,25 +52,26 @@ def get_wav_data(
# paddle linspace not support uint8, int8, int16
# paddle linspace not support uint8, int8, int16
#if dtype == "uint8":
#if dtype == "uint8":
# base = paddle.linspace(0, 255, num_frames, dtype=dtype_)
# base = paddle.linspace(0, 255, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(0, 255, num_frames, dtype_np)
#base_np = np.linspace(0, 255, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#elif dtype == "int8":
#elif dtype == "int8":
# base = paddle.linspace(-128, 127, num_frames, dtype=dtype_)
# base = paddle.linspace(-128, 127, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-128, 127, num_frames, dtype_np)
#base_np = np.linspace(-128, 127, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#base = paddle.to_tensor(base_np, dtype=dtype_)
if
dtype
==
"float32"
:
if
dtype
==
"float32"
:
base
=
paddle
.
linspace
(
-
1.0
,
1.0
,
num_frames
,
dtype
=
dtype_
)
base
=
paddle
.
linspace
(
-
1.0
,
1.0
,
num_frames
,
dtype
=
dtype_
)
elif
dtype
==
"float64"
:
elif
dtype
==
"float64"
:
base
=
paddle
.
linspace
(
-
1.0
,
1.0
,
num_frames
,
dtype
=
dtype_
)
base
=
paddle
.
linspace
(
-
1.0
,
1.0
,
num_frames
,
dtype
=
dtype_
)
elif
dtype
==
"int32"
:
elif
dtype
==
"int32"
:
base
=
paddle
.
linspace
(
-
2147483648
,
2147483647
,
num_frames
,
dtype
=
dtype_
)
base
=
paddle
.
linspace
(
-
2147483648
,
2147483647
,
num_frames
,
dtype
=
dtype_
)
#elif dtype == "int16":
#elif dtype == "int16":
# base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_)
# base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_)
#dtype_np = getattr(np, dtype)
#dtype_np = getattr(np, dtype)
#base_np = np.linspace(-32768, 32767, num_frames, dtype_np)
#base_np = np.linspace(-32768, 32767, num_frames, dtype_np)
#base = paddle.to_tensor(base_np, dtype=dtype_)
#base = paddle.to_tensor(base_np, dtype=dtype_)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported dtype
{
dtype
}
"
)
raise
NotImplementedError
(
f
"Unsupported dtype
{
dtype
}
"
)
data
=
base
.
tile
([
num_channels
,
1
])
data
=
base
.
tile
([
num_channels
,
1
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录