Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
c4aa59ab
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c4aa59ab
编写于
4月 30, 2021
作者:
K
KP
提交者:
GitHub
4月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add sound classification example (#5303)
* Add sound classification example * Add sound classification example
上级
8d17108b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
465 addition
and
4 deletion
+465
-4
PaddleAudio/examples/sound_classification/README.md
PaddleAudio/examples/sound_classification/README.md
+79
-0
PaddleAudio/examples/sound_classification/model.py
PaddleAudio/examples/sound_classification/model.py
+38
-0
PaddleAudio/examples/sound_classification/predict.py
PaddleAudio/examples/sound_classification/predict.py
+61
-0
PaddleAudio/examples/sound_classification/train.py
PaddleAudio/examples/sound_classification/train.py
+140
-0
PaddleAudio/paddleaudio/datasets/__init__.py
PaddleAudio/paddleaudio/datasets/__init__.py
+1
-1
PaddleAudio/paddleaudio/datasets/dataset.py
PaddleAudio/paddleaudio/datasets/dataset.py
+2
-2
PaddleAudio/paddleaudio/datasets/esc50.py
PaddleAudio/paddleaudio/datasets/esc50.py
+57
-0
PaddleAudio/paddleaudio/datasets/urban_sound.py
PaddleAudio/paddleaudio/datasets/urban_sound.py
+2
-0
PaddleAudio/paddleaudio/models/panns.py
PaddleAudio/paddleaudio/models/panns.py
+0
-1
PaddleAudio/paddleaudio/utils/__init__.py
PaddleAudio/paddleaudio/utils/__init__.py
+18
-0
PaddleAudio/paddleaudio/utils/time.py
PaddleAudio/paddleaudio/utils/time.py
+67
-0
未找到文件。
PaddleAudio/examples/sound_classification/README.md
0 → 100644
浏览文件 @
c4aa59ab
# 声音分类
声音分类和检测是声音算法的一个热门研究方向。
对于声音分类任务,传统机器学习的一个常用做法是首先人工提取音频的时域和频域的多种特征并做特征选择、组合、变换等,然后基于SVM或决策树进行分类。而端到端的深度学习则通常利用深度网络如RNN,CNN等直接对声间波形(waveform)或时频特征(time-frequency)进行特征学习(representation learning)和分类预测。
在IEEE ICASSP 2017 大会上,谷歌开放了一个大规模的音频数据集
[
Audioset
](
https://research.google.com/audioset/
)
。该数据集包含了 632 类的音频类别以及 2,084,320 条人工标记的每段 10 秒长度的声音剪辑片段(来源于YouTube视频)。目前该数据集已经有210万个已标注的视频数据,5800小时的音频数据,经过标记的声音样本的标签类别为527。
`PANNs`
(
[
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
](
https://arxiv.org/pdf/1912.10211.pdf
)
)是基于Audioset数据集训练的声音分类/识别的模型。经过预训练后,模型可以用于提取音频的embbedding。本示例将使用
`PANNs`
的预训练模型Finetune完成声音分类的任务。
## 模型简介
PaddleAudio提供了PANNs的CNN14、CNN10和CNN6的预训练模型,可供用户选择使用:
-
CNN14: 该模型主要包含12个卷积层和2个全连接层,模型参数的数量为79.6M,embbedding维度是2048。
-
CNN10: 该模型主要包含8个卷积层和2个全连接层,模型参数的数量为4.9M,embbedding维度是512。
-
CNN6: 该模型主要包含4个卷积层和2个全连接层,模型参数的数量为4.5M,embbedding维度是512。
## 快速开始
### 模型训练
以环境声音分类数据集
`ESC50`
为示例,运行下面的命令,可在训练集上进行模型的finetune,支持单机的单卡训练和多卡分布式训练。关于如何使用
`paddle.distributed.launch`
启动分布式训练,请查看
[
PaddlePaddle2.0分布式训练
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/quick_start/high_level_api/high_level_api.html#danjiduoka
)
。
```
shell
$
unset
CUDA_VISIBLE_DEVICES
$
python
-m
paddle.distributed.launch
--gpus
"0"
train.py
--device
gpu
--epochs
50
--batch_size
16
--num_worker
16
--checkpoint_dir
./checkpoint
--save_freq
10
```
可支持配置的参数:
-
`device`
: 选用什么设备进行训练,可选cpu或gpu,默认为gpu。如使用gpu训练则参数gpus指定GPU卡号。
-
`epochs`
: 训练轮次,默认为50。
-
`learning_rate`
: Fine-tune的学习率;默认为5e-5。
-
`batch_size`
: 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为16。
-
`num_workers`
: Dataloader获取数据的子进程数。默认为0,加载数据的流程在主进程执行。
-
`checkpoint_dir`
: 模型参数文件和optimizer参数文件的保存目录,默认为
`./checkpoint`
。
-
`save_freq`
: 训练过程中的模型保存频率,默认为10。
-
`log_freq`
: 训练过程中的信息打印频率,默认为10。
示例代码中使用的预训练模型为
`CNN14`
,如果想更换为其他预训练模型,可通过以下方式执行:
```
python
from
model
import
SoundClassifier
from
paddleaudio.datasets
import
ESC50
from
paddleaudio.models.panns
import
cnn14
,
cnn10
,
cnn6
# CNN14
backbone
=
cnn14
(
pretrained
=
True
,
extract_embedding
=
True
)
model
=
SoundClassifier
(
backbone
,
num_class
=
len
(
ESC50
.
label_list
))
# CNN10
backbone
=
cnn10
(
pretrained
=
True
,
extract_embedding
=
True
)
model
=
SoundClassifier
(
backbone
,
num_class
=
len
(
ESC50
.
label_list
))
# CNN6
backbone
=
cnn6
(
pretrained
=
True
,
extract_embedding
=
True
)
model
=
SoundClassifier
(
backbone
,
num_class
=
len
(
ESC50
.
label_list
))
```
### 模型预测
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
python
-u
predict.py
--device
gpu
--wav
./dog.wav
--top_k
3
--checkpoint
./checkpoint/epoch_50/model.pdparams
```
可支持配置的参数:
-
`device`
: 选用什么设备进行训练,可选cpu或gpu,默认为gpu。如使用gpu训练则参数gpus指定GPU卡号。
-
`wav`
: 指定预测的音频文件。
-
`top_k`
: 预测显示的top k标签的得分,默认为1。
-
`checkpoint`
: 模型参数checkpoint文件。
输出的预测结果如下:
```
[/audio/dog.wav]
Dog: 0.9999538660049438
Clock tick: 1.3341237718123011e-05
Cat: 6.579841738130199e-06
```
PaddleAudio/examples/sound_classification/model.py
0 → 100644
浏览文件 @
c4aa59ab
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
SoundClassifier
(
nn
.
Layer
):
"""
Model for sound classification which uses panns pretrained models to extract
embeddings from audio files.
"""
def
__init__
(
self
,
backbone
,
num_class
,
dropout
=
0.1
):
super
(
SoundClassifier
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
fc
=
nn
.
Linear
(
self
.
backbone
.
emb_size
,
num_class
)
def
forward
(
self
,
x
):
# x: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
x
=
x
.
unsqueeze
(
1
)
x
=
self
.
backbone
(
x
)
x
=
self
.
dropout
(
x
)
logits
=
self
.
fc
(
x
)
return
logits
PaddleAudio/examples/sound_classification/predict.py
0 → 100644
浏览文件 @
c4aa59ab
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
from
model
import
SoundClassifier
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.datasets
import
ESC50
from
paddleaudio.features
import
mel_spect
from
paddleaudio.models.panns
import
cnn14
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to predict, defaults to gpu."
)
parser
.
add_argument
(
"--wav"
,
type
=
str
,
required
=
True
,
help
=
"Audio file to infer."
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
1
,
help
=
"Show top k predicted results"
)
parser
.
add_argument
(
"--checkpoint"
,
type
=
str
,
required
=
True
,
help
=
"Checkpoint of model."
)
args
=
parser
.
parse_args
()
# yapf: enable
def
extract_features
(
file
:
str
,
**
kwargs
):
waveform
,
sr
=
load_audio
(
args
.
wav
,
sr
=
None
)
feats
=
mel_spect
(
waveform
,
sample_rate
=
sr
,
**
kwargs
).
transpose
()
return
feats
if
__name__
==
'__main__'
:
paddle
.
set_device
(
args
.
device
)
model
=
SoundClassifier
(
backbone
=
cnn14
(
pretrained
=
False
,
extract_embedding
=
True
),
num_class
=
len
(
ESC50
.
label_list
))
model
.
set_state_dict
(
paddle
.
load
(
args
.
checkpoint
))
model
.
eval
()
feats
=
extract_features
(
args
.
wav
)
feats
=
paddle
.
to_tensor
(
np
.
expand_dims
(
feats
,
0
))
logits
=
model
(
feats
)
probs
=
F
.
softmax
(
logits
,
axis
=
1
).
numpy
()
sorted_indices
=
(
-
probs
[
0
]).
argsort
()
msg
=
f
'[
{
args
.
wav
}
]
\n
'
for
idx
in
sorted_indices
[:
args
.
top_k
]:
msg
+=
f
'
{
ESC50
.
label_list
[
idx
]
}
:
{
probs
[
0
][
idx
]
}
\n
'
print
(
msg
)
PaddleAudio/examples/sound_classification/train.py
0 → 100644
浏览文件 @
c4aa59ab
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
paddle
import
paddle.nn.functional
as
F
from
model
import
SoundClassifier
from
paddleaudio.datasets
import
ESC50
from
paddleaudio.models.panns
import
cnn14
from
paddleaudio.utils
import
Timer
,
logger
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
50
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-5
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--num_workers"
,
type
=
int
,
default
=
0
,
help
=
"Number of workers in dataloader."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
'./checkpoint'
,
help
=
"Directory to save model checkpoints."
)
parser
.
add_argument
(
"--save_freq"
,
type
=
int
,
default
=
10
,
help
=
"Save checkpoint every n epoch."
)
parser
.
add_argument
(
"--log_freq"
,
type
=
int
,
default
=
10
,
help
=
"Log the training infomation every n steps."
)
args
=
parser
.
parse_args
()
# yapf: enable
if
__name__
==
"__main__"
:
paddle
.
set_device
(
args
.
device
)
nranks
=
paddle
.
distributed
.
get_world_size
()
local_rank
=
paddle
.
distributed
.
get_rank
()
backbone
=
cnn14
(
pretrained
=
True
,
extract_embedding
=
True
)
model
=
SoundClassifier
(
backbone
,
num_class
=
len
(
ESC50
.
label_list
))
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
,
parameters
=
model
.
parameters
())
criterion
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
()
train_ds
=
ESC50
(
mode
=
'train'
,
feat_type
=
'mel_spect'
)
dev_ds
=
ESC50
(
mode
=
'dev'
,
feat_type
=
'mel_spect'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_ds
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
return_list
=
True
,
use_buffer_reader
=
True
,
)
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
args
.
epochs
)
timer
.
start
()
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
model
.
train
()
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
feats
,
labels
=
batch
logits
=
model
(
feats
)
loss
=
criterion
(
logits
,
labels
)
loss
.
backward
()
optimizer
.
step
()
if
isinstance
(
optimizer
.
_learning_rate
,
paddle
.
optimizer
.
lr
.
LRScheduler
):
optimizer
.
_learning_rate
.
step
()
optimizer
.
clear_grad
()
# Calculate loss
avg_loss
+=
loss
.
numpy
()[
0
]
# Calculate metrics
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
if
(
batch_idx
+
1
)
%
args
.
log_freq
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
avg_loss
/=
args
.
log_freq
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Epoch={}/{}, Step={}/{}'
.
format
(
epoch
,
args
.
epochs
,
batch_idx
+
1
,
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.6f} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
train
(
print_msg
)
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
if
epoch
%
args
.
save_freq
==
0
and
batch_idx
+
1
==
steps_per_epoch
and
local_rank
==
0
:
dev_sampler
=
paddle
.
io
.
BatchSampler
(
dev_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
)
dev_loader
=
paddle
.
io
.
DataLoader
(
dev_ds
,
batch_sampler
=
dev_sampler
,
num_workers
=
args
.
num_workers
,
return_list
=
True
,
)
model
.
eval
()
num_corrects
=
0
num_samples
=
0
with
logger
.
processing
(
'Evaluation on validation dataset'
):
for
batch_idx
,
batch
in
enumerate
(
dev_loader
):
feats
,
labels
=
batch
logits
=
model
(
feats
)
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
print_msg
=
'[Evaluation result]'
print_msg
+=
' dev_acc={:.4f}'
.
format
(
num_corrects
/
num_samples
)
logger
.
eval
(
print_msg
)
# Save model
save_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'epoch_{}'
.
format
(
epoch
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdopt'
))
PaddleAudio/paddleaudio/datasets/__init__.py
浏览文件 @
c4aa59ab
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
...
...
PaddleAudio/paddleaudio/datasets/dataset.py
浏览文件 @
c4aa59ab
...
...
@@ -15,11 +15,11 @@
import
os
from
typing
import
List
,
Tuple
import
librosa
import
numpy
as
np
import
paddle
from
tqdm
import
tqdm
from
..backends
import
load
as
load_audio
from
..features
import
linear_spect
,
log_spect
,
mel_spect
from
..utils.log
import
logger
...
...
@@ -71,7 +71,7 @@ class AudioClassificationDataset(paddle.io.Dataset):
def
_convert_to_record
(
self
,
idx
):
file
,
label
=
self
.
files
[
idx
],
self
.
labels
[
idx
]
waveform
,
_
=
l
ibrosa
.
load
(
file
,
sr
=
self
.
sample_rate
)
waveform
,
_
=
l
oad_audio
(
file
,
sr
=
self
.
sample_rate
)
normal_length
=
self
.
sample_rate
*
self
.
duration
if
len
(
waveform
)
>
normal_length
:
waveform
=
waveform
[:
normal_length
]
...
...
PaddleAudio/paddleaudio/datasets/esc50.py
浏览文件 @
c4aa59ab
...
...
@@ -41,6 +41,63 @@ class ESC50(AudioClassificationDataset):
'md5'
:
'1fdc5dd87626d5eb91be20ed53c9aed9'
,
},
]
label_list
=
[
# Animals
'Dog'
,
'Rooster'
,
'Pig'
,
'Cow'
,
'Frog'
,
'Cat'
,
'Hen'
,
'Insects (flying)'
,
'Sheep'
,
'Crow'
,
# Natural soundscapes & water sounds
'Rain'
,
'Sea waves'
,
'Crackling fire'
,
'Crickets'
,
'Chirping birds'
,
'Water drops'
,
'Wind'
,
'Pouring water'
,
'Toilet flush'
,
'Thunderstorm'
,
# Human, non-speech sounds
'Crying baby'
,
'Sneezing'
,
'Clapping'
,
'Breathing'
,
'Coughing'
,
'Footsteps'
,
'Laughing'
,
'Brushing teeth'
,
'Snoring'
,
'Drinking, sipping'
,
# Interior/domestic sounds
'Door knock'
,
'Mouse click'
,
'Keyboard typing'
,
'Door, wood creaks'
,
'Can opening'
,
'Washing machine'
,
'Vacuum cleaner'
,
'Clock alarm'
,
'Clock tick'
,
'Glass breaking'
,
# Exterior/urban noises
'Helicopter'
,
'Chainsaw'
,
'Siren'
,
'Car horn'
,
'Engine'
,
'Train'
,
'Church bells'
,
'Airplane'
,
'Fireworks'
,
'Hand saw'
,
]
meta
=
os
.
path
.
join
(
'ESC-50-master'
,
'meta'
,
'esc50.csv'
)
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'filename'
,
'fold'
,
'target'
,
'category'
,
'esc10'
,
'src_file'
,
'take'
))
...
...
PaddleAudio/paddleaudio/datasets/urban_sound.py
浏览文件 @
c4aa59ab
...
...
@@ -41,6 +41,8 @@ class UrbanSound8K(AudioClassificationDataset):
'md5'
:
'9aa69802bbf37fb986f71ec1483a196e'
,
},
]
label_list
=
[
"air_conditioner"
,
"car_horn"
,
"children_playing"
,
"dog_bark"
,
"drilling"
,
\
"engine_idling"
,
"gun_shot"
,
"jackhammer"
,
"siren"
,
"street_music"
]
meta
=
os
.
path
.
join
(
'UrbanSound8K'
,
'metadata'
,
'UrbanSound8K.csv'
)
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'filename'
,
'fsid'
,
'start'
,
'end'
,
'salience'
,
'fold'
,
'class_id'
,
'label'
))
...
...
PaddleAudio/paddleaudio/models/panns.py
浏览文件 @
c4aa59ab
...
...
@@ -25,7 +25,6 @@ from ..utils.log import logger
__all__
=
[
'CNN14'
,
'CNN10'
,
'CNN6'
,
'cnn14'
,
'cnn10'
,
'cnn6'
]
pretrained_model_urls
=
{
# TODO: replace test urls
'cnn14'
:
'https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams'
,
'cnn10'
:
'https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams'
,
'cnn6'
:
'https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams'
,
...
...
PaddleAudio/paddleaudio/utils/__init__.py
0 → 100644
浏览文件 @
c4aa59ab
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.download
import
*
from
.env
import
*
from
.log
import
*
from
.time
import
*
PaddleAudio/paddleaudio/utils/time.py
0 → 100644
浏览文件 @
c4aa59ab
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
time
class
Timer
(
object
):
'''Calculate runing speed and estimated time of arrival(ETA)'''
def
__init__
(
self
,
total_step
:
int
):
self
.
total_step
=
total_step
self
.
last_start_step
=
0
self
.
current_step
=
0
self
.
_is_running
=
True
def
start
(
self
):
self
.
last_time
=
time
.
time
()
self
.
start_time
=
time
.
time
()
def
stop
(
self
):
self
.
_is_running
=
False
self
.
end_time
=
time
.
time
()
def
count
(
self
)
->
int
:
if
not
self
.
current_step
>=
self
.
total_step
:
self
.
current_step
+=
1
return
self
.
current_step
@
property
def
timing
(
self
)
->
float
:
run_steps
=
self
.
current_step
-
self
.
last_start_step
self
.
last_start_step
=
self
.
current_step
time_used
=
time
.
time
()
-
self
.
last_time
self
.
last_time
=
time
.
time
()
return
run_steps
/
time_used
@
property
def
is_running
(
self
)
->
bool
:
return
self
.
_is_running
@
property
def
eta
(
self
)
->
str
:
if
not
self
.
is_running
:
return
'00:00:00'
scale
=
self
.
total_step
/
self
.
current_step
remaining_time
=
(
time
.
time
()
-
self
.
start_time
)
*
scale
return
seconds_to_hms
(
remaining_time
)
def
seconds_to_hms
(
seconds
:
int
)
->
str
:
'''Convert the number of seconds to hh:mm:ss'''
h
=
math
.
floor
(
seconds
/
3600
)
m
=
math
.
floor
((
seconds
-
h
*
3600
)
/
60
)
s
=
int
(
seconds
-
h
*
3600
-
m
*
60
)
hms_str
=
'{:0>2}:{:0>2}:{:0>2}'
.
format
(
h
,
m
,
s
)
return
hms_str
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录