Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
6f651d76
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f651d76
编写于
1月 05, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix batch sampler set_epoch when epcoh start
上级
680eac02
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
20 addition
and
15 deletion
+20
-15
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+6
-2
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+8
-5
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+0
-3
paddlespeech/s2t/training/scheduler.py
paddlespeech/s2t/training/scheduler.py
+5
-4
paddlespeech/s2t/training/trainer.py
paddlespeech/s2t/training/trainer.py
+1
-1
未找到文件。
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
6f651d76
...
...
@@ -240,7 +240,9 @@ class U2Trainer(Trainer):
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
,
dist_sampler
=
True
,
shortest_first
=
False
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
...
...
@@ -259,7 +261,9 @@ class U2Trainer(Trainer):
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
,
dist_sampler
=
True
,
shortest_first
=
False
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
6f651d76
...
...
@@ -78,7 +78,8 @@ class BatchDataLoader():
load_aux_input
:
bool
=
False
,
load_aux_output
:
bool
=
False
,
num_encs
:
int
=
1
,
dist_sampler
:
bool
=
False
):
dist_sampler
:
bool
=
False
,
shortest_first
:
bool
=
False
):
self
.
json_file
=
json_file
self
.
train_mode
=
train_mode
self
.
use_sortagrad
=
sortagrad
==
-
1
or
sortagrad
>
0
...
...
@@ -97,6 +98,7 @@ class BatchDataLoader():
self
.
load_aux_input
=
load_aux_input
self
.
load_aux_output
=
load_aux_output
self
.
dist_sampler
=
dist_sampler
self
.
shortest_first
=
shortest_first
# read json data
with
jsonlines
.
open
(
json_file
,
'r'
)
as
reader
:
...
...
@@ -113,7 +115,7 @@ class BatchDataLoader():
maxlen_out
,
minibatches
,
# for debug
min_batch_size
=
mini_batch_size
,
shortest_first
=
self
.
use_sortagrad
,
shortest_first
=
self
.
shortest_first
or
self
.
use_sortagrad
,
count
=
batch_count
,
batch_bins
=
batch_bins
,
batch_frames_in
=
batch_frames_in
,
...
...
@@ -149,13 +151,13 @@ class BatchDataLoader():
self
.
reader
)
if
self
.
dist_sampler
:
self
.
sampler
=
DistributedBatchSampler
(
self
.
batch_
sampler
=
DistributedBatchSampler
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
drop_last
=
False
,
)
else
:
self
.
sampler
=
BatchSampler
(
self
.
batch_
sampler
=
BatchSampler
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
...
...
@@ -163,7 +165,7 @@ class BatchDataLoader():
self
.
dataloader
=
DataLoader
(
dataset
=
self
.
dataset
,
batch_sampler
=
self
.
sampler
,
batch_sampler
=
self
.
batch_
sampler
,
collate_fn
=
batch_collate
,
num_workers
=
self
.
n_iter_processes
,
)
...
...
@@ -194,5 +196,6 @@ class BatchDataLoader():
echo
+=
f
"load_aux_input:
{
self
.
load_aux_input
}
, "
echo
+=
f
"load_aux_output:
{
self
.
load_aux_output
}
, "
echo
+=
f
"dist_sampler:
{
self
.
dist_sampler
}
, "
echo
+=
f
"shortest_first:
{
self
.
shortest_first
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
paddlespeech/s2t/modules/ctc.py
浏览文件 @
6f651d76
...
...
@@ -39,9 +39,6 @@ except ImportError:
except
Exception
as
e
:
logger
.
info
(
"paddlespeech_ctcdecoders not installed!"
)
#try:
#except Exception as e:
# logger.info("ctcdecoder not installed!")
__all__
=
[
'CTCDecoder'
]
...
...
paddlespeech/s2t/training/scheduler.py
浏览文件 @
6f651d76
...
...
@@ -67,18 +67,19 @@ class WarmupLR(LRScheduler):
super
().
__init__
(
learning_rate
,
last_epoch
,
verbose
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(warmup_steps=
{
self
.
warmup_steps
}
)"
return
f
"
{
self
.
__class__
.
__name__
}
(warmup_steps=
{
self
.
warmup_steps
}
, lr=
{
self
.
base_lr
}
, last_epoch=
{
self
.
last_epoch
}
)"
def
get_lr
(
self
):
# self.last_epoch start from zero
step_num
=
self
.
last_epoch
+
1
return
self
.
base_lr
*
self
.
warmup_steps
**
0.5
*
min
(
step_num
**-
0.5
,
step_num
*
self
.
warmup_steps
**-
1.5
)
def
set_step
(
self
,
step
:
int
=
None
):
'''
It will update the learning rate in optimizer according to current ``epoch`` .
It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
...
...
@@ -94,7 +95,7 @@ class ConstantLR(LRScheduler):
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
...
...
paddlespeech/s2t/training/trainer.py
浏览文件 @
6f651d76
...
...
@@ -222,7 +222,7 @@ class Trainer():
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
paddle
.
io
.
DistributedBatchSampler
):
logger
.
debug
(
f
"train_loader.batch_sample
set
epoch:
{
self
.
epoch
}
"
)
f
"train_loader.batch_sample
.set_
epoch:
{
self
.
epoch
}
"
)
batch_sampler
.
set_epoch
(
self
.
epoch
)
def
before_train
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录