Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
2e77c3c3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2e77c3c3
编写于
8月 25, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/DeepSpeech
into ds2_online_export
上级
0d0b5811
44c84e26
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
888 addition
and
0 deletion
+888
-0
deepspeech/training/extensions/__init__.py
deepspeech/training/extensions/__init__.py
+28
-0
deepspeech/training/extensions/evaluator.py
deepspeech/training/extensions/evaluator.py
+58
-0
deepspeech/training/extensions/extension.py
deepspeech/training/extensions/extension.py
+41
-0
deepspeech/training/extensions/snapshot.py
deepspeech/training/extensions/snapshot.py
+102
-0
deepspeech/training/extensions/visualizer.py
deepspeech/training/extensions/visualizer.py
+24
-0
deepspeech/training/reporter.py
deepspeech/training/reporter.py
+131
-0
deepspeech/training/triggers/__init__.py
deepspeech/training/triggers/__init__.py
+13
-0
deepspeech/training/triggers/interval_trigger.py
deepspeech/training/triggers/interval_trigger.py
+24
-0
deepspeech/training/triggers/limit_trigger.py
deepspeech/training/triggers/limit_trigger.py
+17
-0
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+17
-0
deepspeech/training/updaters/__init__.py
deepspeech/training/updaters/__init__.py
+0
-0
deepspeech/training/updaters/standard_updater.py
deepspeech/training/updaters/standard_updater.py
+179
-0
deepspeech/training/updaters/trainer.py
deepspeech/training/updaters/trainer.py
+171
-0
deepspeech/training/updaters/updater.py
deepspeech/training/updaters/updater.py
+82
-0
requirements.txt
requirements.txt
+1
-0
未找到文件。
deepspeech/training/extensions/__init__.py
0 → 100644
浏览文件 @
2e77c3c3
from
typing
import
Callable
from
.extension
import
Extension
def
make_extension
(
trigger
:
Callable
=
None
,
default_name
:
str
=
None
,
priority
:
int
=
None
,
finalizer
:
Callable
=
None
,
initializer
:
Callable
=
None
,
on_error
:
Callable
=
None
):
"""Make an Extension-like object by injecting required attributes to it.
"""
if
trigger
is
None
:
trigger
=
Extension
.
trigger
if
priority
is
None
:
priority
=
Extension
.
priority
def
decorator
(
ext
):
ext
.
trigger
=
trigger
ext
.
default_name
=
default_name
or
ext
.
__name__
ext
.
priority
=
priority
ext
.
finalize
=
finalizer
ext
.
on_error
=
on_error
ext
.
initialize
=
initializer
return
ext
return
decorator
\ No newline at end of file
deepspeech/training/extensions/evaluator.py
0 → 100644
浏览文件 @
2e77c3c3
from
typing
import
Dict
import
paddle
from
paddle.io
import
DataLoader
from
paddle.nn
import
Layer
import
extension
from
..reporter
import
DictSummary
from
..reporter
import
report
from
..reporter
import
scope
class
StandardEvaluator
(
extension
.
Extension
):
trigger
=
(
1
,
'epoch'
)
default_name
=
'validation'
priority
=
extension
.
PRIORITY_WRITER
name
=
None
def
__init__
(
self
,
model
:
Layer
,
dataloader
:
DataLoader
):
# it is designed to hold multiple models
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
self
.
model
=
model
# dataloaders
self
.
dataloader
=
dataloader
def
evaluate_core
(
self
,
batch
):
# compute
self
.
model
(
batch
)
# you may report here
def
evaluate
(
self
):
# switch to eval mode
for
model
in
self
.
models
.
values
():
model
.
eval
()
# to average evaluation metrics
summary
=
DictSummary
()
for
batch
in
self
.
dataloader
:
observation
=
{}
with
scope
(
observation
):
# main evaluation computation here.
with
paddle
.
no_grad
():
self
.
evaluate_core
(
batch
)
summary
.
add
(
observation
)
summary
=
summary
.
compute_mean
()
return
summary
def
__call__
(
self
,
trainer
=
None
):
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary
=
self
.
evaluate
()
for
k
,
v
in
summary
.
items
():
report
(
k
,
v
)
\ No newline at end of file
deepspeech/training/extensions/extension.py
0 → 100644
浏览文件 @
2e77c3c3
from
typing
import
Callable
PRIORITY_WRITER
=
300
PRIORITY_EDITOR
=
200
PRIORITY_READER
=
100
class
Extension
():
"""Extension to customize the behavior of Trainer."""
trigger
=
(
1
,
'iteration'
)
priority
=
PRIORITY_READER
name
=
None
@
property
def
default_name
(
self
):
"""Default name of the extension, class name by default."""
return
type
(
self
).
__name__
def
__call__
(
self
,
trainer
):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise
NotImplementedError
(
'Extension implementation must override __call__.'
)
def
initialize
(
self
,
trainer
):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.
"""
pass
def
on_error
(
self
,
trainer
,
exc
,
tb
):
"""Handles the error raised during training before finalization.
"""
pass
def
finalize
(
self
,
trainer
):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass
\ No newline at end of file
deepspeech/training/extensions/snapshot.py
0 → 100644
浏览文件 @
2e77c3c3
import
os
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
import
jsonlines
from
deepspeech.training.updaters.trainer
import
Trainer
from
deepspeech.training.extensions
import
extension
from
deepspeech.utils.mp_tools
import
rank_zero_only
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
def
load_records
(
records_fp
):
"""Load record files (json lines.)"""
with
jsonlines
.
open
(
records_fp
,
'r'
)
as
reader
:
records
=
list
(
reader
)
return
records
class
Snapshot
(
extension
.
Extension
):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
An Updater save its state_dict by default, which contains the
updater state, (i.e. epoch and iteration) and all the model
parameters and optimizer states. If the updater inside the trainer
subclasses StandardUpdater, everything is good to go.
Parameters
----------
checkpoint_dir : Union[str, Path]
The directory to save checkpoints into.
"""
trigger
=
(
1
,
'epoch'
)
priority
=
-
100
default_name
=
"snapshot"
def
__init__
(
self
,
max_size
:
int
=
5
,
snapshot_on_error
:
bool
=
False
):
self
.
records
:
List
[
Dict
[
str
,
Any
]]
=
[]
self
.
max_size
=
max_size
self
.
_snapshot_on_error
=
snapshot_on_error
self
.
_save_all
=
(
max_size
==
-
1
)
self
.
checkpoint_dir
=
None
def
initialize
(
self
,
trainer
:
Trainer
):
"""Setting up this extention."""
self
.
checkpoint_dir
=
trainer
.
out
/
"checkpoints"
# load existing records
record_path
:
Path
=
self
.
checkpoint_dir
/
"records.jsonl"
if
record_path
.
exists
():
logger
.
debug
(
"Loading from an existing checkpoint dir"
)
self
.
records
=
load_records
(
record_path
)
trainer
.
updater
.
load
(
self
.
records
[
-
1
][
'path'
])
def
on_error
(
self
,
trainer
,
exc
,
tb
):
if
self
.
_snapshot_on_error
:
self
.
save_checkpoint_and_update
(
trainer
)
def
__call__
(
self
,
trainer
:
Trainer
):
self
.
save_checkpoint_and_update
(
trainer
)
def
full
(
self
):
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return
(
not
self
.
_save_all
)
and
len
(
self
.
records
)
>
self
.
max_size
@
rank_zero_only
def
save_checkpoint_and_update
(
self
,
trainer
:
Trainer
):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration
=
trainer
.
updater
.
state
.
iteration
epoch
=
trainer
.
updater
.
state
.
epoch
num
=
epoch
if
self
.
trigger
[
1
]
is
'epoch'
else
iteration
path
=
self
.
checkpoint_dir
/
f
"
{
num
}
.pdz"
# add the new one
trainer
.
updater
.
save
(
path
)
record
=
{
"time"
:
str
(
datetime
.
now
()),
'path'
:
str
(
path
.
resolve
()),
# use absolute path
'iteration'
:
iteration
,
'epoch'
:
epoch
,
}
self
.
records
.
append
(
record
)
# remove the earist
if
self
.
full
():
eariest_record
=
self
.
records
[
0
]
os
.
remove
(
eariest_record
[
"path"
])
self
.
records
.
pop
(
0
)
# update the record file
record_path
=
self
.
checkpoint_dir
/
"records.jsonl"
with
jsonlines
.
open
(
record_path
,
'w'
)
as
writer
:
for
record
in
self
.
records
:
# jsonlines.open may return a Writer or a Reader
writer
.
write
(
record
)
# pylint: disable=no-member
\ No newline at end of file
deepspeech/training/extensions/visualizer.py
0 → 100644
浏览文件 @
2e77c3c3
from
deepspeech.training.extensions
import
extension
from
deepspeech.training.updaters.trainer
import
Trainer
class
VisualDL
(
extension
.
Extension
):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger
=
(
1
,
'iteration'
)
default_name
=
'visualdl'
priority
=
extension
.
PRIORITY_READER
def
__init__
(
self
,
writer
):
self
.
writer
=
writer
def
__call__
(
self
,
trainer
:
Trainer
):
for
k
,
v
in
trainer
.
observation
.
items
():
self
.
writer
.
add_scalar
(
k
,
v
,
step
=
trainer
.
updater
.
state
.
iteration
)
def
finalize
(
self
,
trainer
):
self
.
writer
.
close
()
\ No newline at end of file
deepspeech/training/reporter.py
0 → 100644
浏览文件 @
2e77c3c3
import
contextlib
import
math
from
collections
import
defaultdict
OBSERVATIONS
=
None
@
contextlib
.
contextmanager
def
scope
(
observations
):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global
OBSERVATIONS
old
=
OBSERVATIONS
OBSERVATIONS
=
observations
try
:
yield
finally
:
OBSERVATIONS
=
old
def
get_observations
():
global
OBSERVATIONS
return
OBSERVATIONS
def
report
(
name
,
value
):
# a simple function to report named value
# you can use it everywhere, it will get the default target and writ to it
# you can think of it as std.out
observations
=
get_observations
()
if
observations
is
None
:
return
else
:
observations
[
name
]
=
value
class
Summary
():
"""Online summarization of a sequence of scalars.
Summary computes the statistics of given scalars online.
"""
def
__init__
(
self
):
self
.
_x
=
0.0
self
.
_x2
=
0.0
self
.
_n
=
0
def
add
(
self
,
value
,
weight
=
1
):
"""Adds a scalar value.
Args:
value: Scalar value to accumulate. It is either a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
weight: An optional weight for the value. It is a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
Default is 1 (integer).
"""
self
.
_x
+=
weight
*
value
self
.
_x2
+=
weight
*
value
*
value
self
.
_n
+=
weight
def
compute_mean
(
self
):
"""Computes the mean."""
x
,
n
=
self
.
_x
,
self
.
_n
return
x
/
n
def
make_statistics
(
self
):
"""Computes and returns the mean and standard deviation values.
Returns:
tuple: Mean and standard deviation values.
"""
x
,
n
=
self
.
_x
,
self
.
_n
mean
=
x
/
n
var
=
self
.
_x2
/
n
-
mean
*
mean
std
=
math
.
sqrt
(
var
)
return
mean
,
std
class
DictSummary
():
"""Online summarization of a sequence of dictionaries.
``DictSummary`` computes the statistics of a given set of scalars online.
It only computes the statistics for scalar values and variables of scalar
values in the dictionaries.
"""
def
__init__
(
self
):
self
.
_summaries
=
defaultdict
(
Summary
)
def
add
(
self
,
d
):
"""Adds a dictionary of scalars.
Args:
d (dict): Dictionary of scalars to accumulate. Only elements of
scalars, zero-dimensional arrays, and variables of
zero-dimensional arrays are accumulated. When the value
is a tuple, the second element is interpreted as a weight.
"""
summaries
=
self
.
_summaries
for
k
,
v
in
d
.
items
():
w
=
1
if
isinstance
(
v
,
tuple
):
v
=
v
[
0
]
w
=
v
[
1
]
summaries
[
k
].
add
(
v
,
weight
=
w
)
def
compute_mean
(
self
):
"""Creates a dictionary of mean values.
It returns a single dictionary that holds a mean value for each entry
added to the summary.
Returns:
dict: Dictionary of mean values.
"""
return
{
name
:
summary
.
compute_mean
()
for
name
,
summary
in
self
.
_summaries
.
items
()
}
def
make_statistics
(
self
):
"""Creates a dictionary of statistics.
It returns a single dictionary that holds mean and standard deviation
values for every entry added to the summary. For an entry of name
``'key'``, these values are added to the dictionary by names ``'key'``
and ``'key.std'``, respectively.
Returns:
dict: Dictionary of statistics of all entries.
"""
stats
=
{}
for
name
,
summary
in
self
.
_summaries
.
items
():
mean
,
std
=
summary
.
make_statistics
()
stats
[
name
]
=
mean
stats
[
name
+
'.std'
]
=
std
return
stats
\ No newline at end of file
deepspeech/training/triggers/__init__.py
0 → 100644
浏览文件 @
2e77c3c3
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
\ No newline at end of file
deepspeech/training/triggers/interval_trigger.py
0 → 100644
浏览文件 @
2e77c3c3
class
IntervalTrigger
():
"""A Predicate to do something every N cycle."""
def
__init__
(
self
,
period
:
int
,
unit
:
str
):
if
unit
not
in
(
"iteration"
,
"epoch"
):
raise
ValueError
(
"unit should be 'iteration' or 'epoch'"
)
if
period
<=
0
:
raise
ValueError
(
"period should be a positive integer."
)
self
.
period
=
period
self
.
unit
=
unit
self
.
last_index
=
None
def
__call__
(
self
,
trainer
):
if
self
.
last_index
is
None
:
last_index
=
getattr
(
trainer
.
updater
.
state
,
self
.
unit
)
self
.
last_index
=
last_index
last_index
=
self
.
last_index
index
=
getattr
(
trainer
.
updater
.
state
,
self
.
unit
)
fire
=
index
//
self
.
period
!=
last_index
//
self
.
period
self
.
last_index
=
index
return
fire
\ No newline at end of file
deepspeech/training/triggers/limit_trigger.py
0 → 100644
浏览文件 @
2e77c3c3
class
LimitTrigger
():
"""A Predicate to decide whether to stop."""
def
__init__
(
self
,
limit
:
int
,
unit
:
str
):
if
unit
not
in
(
"iteration"
,
"epoch"
):
raise
ValueError
(
"unit should be 'iteration' or 'epoch'"
)
if
limit
<=
0
:
raise
ValueError
(
"limit should be a positive integer."
)
self
.
limit
=
limit
self
.
unit
=
unit
def
__call__
(
self
,
trainer
):
state
=
trainer
.
updater
.
state
index
=
getattr
(
state
,
self
.
unit
)
fire
=
index
>=
self
.
limit
return
fire
\ No newline at end of file
deepspeech/training/triggers/time_trigger.py
0 → 100644
浏览文件 @
2e77c3c3
class
TimeTrigger
():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
Args:
period (float): Interval time. It is given in seconds.
"""
def
__init__
(
self
,
period
):
self
.
_period
=
period
self
.
_next_time
=
self
.
_period
def
__call__
(
self
,
trainer
):
if
self
.
_next_time
<
trainer
.
elapsed_time
:
self
.
_next_time
+=
self
.
_period
return
True
else
:
return
False
\ No newline at end of file
deepspeech/training/updaters/__init__.py
0 → 100644
浏览文件 @
2e77c3c3
deepspeech/training/updaters/standard_updater.py
0 → 100644
浏览文件 @
2e77c3c3
from
typing
import
Dict
from
typing
import
Optional
from
paddle
import
Tensor
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.nn
import
Layer
from
paddle.optimizer
import
Optimizer
from
timer
import
timer
from
deepspeech.training.reporter
import
report
from
deepspeech.training.updaters.updater
import
UpdaterBase
from
deepspeech.training.updaters.updater
import
UpdaterState
from
deepspeech.utils.log
import
Log
__all__
=
[
"StandardUpdater"
]
logger
=
Log
(
__name__
).
getlog
()
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def
__init__
(
self
,
model
:
Layer
,
optimizer
:
Optimizer
,
dataloader
:
DataLoader
,
init_state
:
Optional
[
UpdaterState
]
=
None
):
# it is designed to hold multiple models
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
self
.
model
=
model
# it is designed to hold multiple optimizers
optimizers
=
{
"main"
:
optimizer
}
self
.
optimizer
=
optimizer
self
.
optimizers
:
Dict
[
str
,
Optimizer
]
=
optimizers
# dataloaders
self
.
dataloader
=
dataloader
# init state
if
init_state
is
None
:
self
.
state
=
UpdaterState
()
else
:
self
.
state
=
init_state
self
.
train_iterator
=
iter
(
dataloader
)
def
update
(
self
):
# We increase the iteration index after updating and before extension.
# Here are the reasons.
# 0. Snapshotting(as well as other extensions, like visualizer) is
# executed after a step of updating;
# 1. We decide to increase the iteration index after updating and
# before any all extension is executed.
# 3. We do not increase the iteration after extension because we
# prefer a consistent resume behavior, when load from a
# `snapshot_iter_100.pdz` then the next step to train is `101`,
# naturally. But if iteration is increased increased after
# extension(including snapshot), then, a `snapshot_iter_99` is
# loaded. You would need a extra increasing of the iteration idex
# before training to avoid another iteration `99`, which has been
# done before snapshotting.
# 4. Thus iteration index represrnts "currently how mant epochs has
# been done."
# NOTE: use report to capture the correctly value. If you want to
# report the learning rate used for a step, you must report it before
# the learning rate scheduler's step() has been called. In paddle's
# convention, we do not use an extension to change the learning rate.
# so if you want to report it, do it in the updater.
# Then here comes the next question. When is the proper time to
# increase the epoch index? Since all extensions are executed after
# updating, it is the time that after updating is the proper time to
# increase epoch index.
# 1. If we increase the epoch index before updating, then an extension
# based ot epoch would miss the correct timing. It could only be
# triggerd after an extra updating.
# 2. Theoretically, when an epoch is done, the epoch index should be
# increased. So it would be increase after updating.
# 3. Thus, eppoch index represents "currently how many epochs has been
# done." So it starts from 0.
# switch to training mode
for
model
in
self
.
models
.
values
():
model
.
train
()
# training for a step is implemented here
batch
=
self
.
read_batch
()
self
.
update_core
(
batch
)
self
.
state
.
iteration
+=
1
if
self
.
updates_per_epoch
is
not
None
:
if
self
.
state
.
iteration
%
self
.
updates_per_epoch
==
0
:
self
.
state
.
epoch
+=
1
def
update_core
(
self
,
batch
):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss
=
self
.
model
(
*
batch
)
if
isinstance
(
loss
,
Tensor
):
loss_dict
=
{
"main"
:
loss
}
else
:
# Dict[str, Tensor]
loss_dict
=
loss
if
"main"
not
in
loss_dict
:
main_loss
=
0
for
loss_item
in
loss
.
values
():
main_loss
+=
loss_item
loss_dict
[
"main"
]
=
main_loss
for
name
,
loss_item
in
loss_dict
.
items
():
report
(
name
,
float
(
loss_item
))
self
.
optimizer
.
clear_gradient
()
loss_dict
[
"main"
].
backward
()
self
.
optimizer
.
update
()
@
property
def
updates_per_epoch
(
self
):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader
=
None
try
:
length_of_dataloader
=
len
(
self
.
dataloader
)
except
TypeError
:
logger
.
debug
(
"This dataloader has no __len__."
)
finally
:
return
length_of_dataloader
def
new_epoch
(
self
):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if
hasattr
(
self
.
dataloader
,
"batch_sampler"
)
batch_sampler
=
self
.
dataloader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
state
.
epoch
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
def
read_batch
(
self
):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with
timer
()
as
t
:
try
:
batch
=
next
(
self
.
train_iterator
)
except
StopIteration
:
self
.
new_epoch
()
batch
=
next
(
self
.
train_iterator
)
logger
.
debug
(
f
"Read a batch takes
{
t
.
elapse
}
s."
)
# replace it with logger
return
batch
def
state_dict
(
self
):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict
=
super
().
state_dict
()
for
name
,
model
in
self
.
models
.
items
():
state_dict
[
f
"
{
name
}
_params"
]
=
model
.
state_dict
()
for
name
,
optim
in
self
.
optimizers
.
items
():
state_dict
[
f
"
{
name
}
_optimizer"
]
=
optim
.
state_dict
()
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for
name
,
model
in
self
.
models
.
items
():
model
.
set_state_dict
(
state_dict
[
f
"
{
name
}
_params"
])
for
name
,
optim
in
self
.
optimizers
.
items
():
optim
.
set_state_dict
(
state_dict
[
f
"
{
name
}
_optimizer"
])
super
().
set_state_dict
(
state_dict
)
\ No newline at end of file
deepspeech/training/updaters/trainer.py
0 → 100644
浏览文件 @
2e77c3c3
import
sys
import
traceback
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Callable
from
typing
import
List
from
typing
import
Union
import
six
import
tqdm
from
deepspeech.training.extensions.extension
import
Extension
from
deepspeech.training.extensions.extension
import
PRIORITY_READER
from
deepspeech.training.reporter
import
scope
from
deepspeech.training.triggers
import
get_trigger
from
deepspeech.training.triggers.limit_trigger
import
LimitTrigger
from
deepspeech.training.updaters.updater
import
UpdaterBase
class
_ExtensionEntry
():
def
__init__
(
self
,
extension
,
trigger
,
priority
):
self
.
extension
=
extension
self
.
trigger
=
trigger
self
.
priority
=
priority
class
Trainer
():
def
__init__
(
self
,
updater
:
UpdaterBase
,
stop_trigger
:
Callable
=
None
,
out
:
Union
[
str
,
Path
]
=
'result'
,
extensions
:
List
[
Extension
]
=
None
):
self
.
updater
=
updater
self
.
extensions
=
OrderedDict
()
self
.
stop_trigger
=
LimitTrigger
(
*
stop_trigger
)
self
.
out
=
Path
(
out
)
self
.
observation
=
None
self
.
_done
=
False
if
extensions
:
for
ext
in
extensions
:
self
.
extend
(
ext
)
@
property
def
is_before_training
(
self
):
return
self
.
updater
.
state
.
iteration
==
0
def
extend
(
self
,
extension
,
name
=
None
,
trigger
=
None
,
priority
=
None
):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if
name
is
None
:
name
=
getattr
(
extension
,
'name'
,
None
)
if
name
is
None
:
name
=
getattr
(
extension
,
'default_name'
,
None
)
if
name
is
None
:
name
=
getattr
(
extension
,
'__name__'
,
None
)
if
name
is
None
:
raise
ValueError
(
"Name is not given for the extension."
)
if
name
==
'training'
:
raise
ValueError
(
"training is a reserved name."
)
if
trigger
is
None
:
trigger
=
getattr
(
extension
,
'trigger'
,
(
1
,
'iteration'
))
trigger
=
get_trigger
(
trigger
)
if
priority
is
None
:
priority
=
getattr
(
extension
,
'priority'
,
PRIORITY_READER
)
# add suffix to avoid nameing conflict
ordinal
=
0
modified_name
=
name
while
modified_name
in
self
.
extensions
:
ordinal
+=
1
modified_name
=
f
"
{
name
}
_
{
ordinal
}
"
extension
.
name
=
modified_name
self
.
extensions
[
modified_name
]
=
_ExtensionEntry
(
extension
,
trigger
,
priority
)
def
get_extension
(
self
,
name
):
"""get extension by name."""
extensions
=
self
.
extensions
if
name
in
extensions
:
return
extensions
[
name
].
extension
else
:
raise
ValueError
(
f
'extension
{
name
}
not found'
)
def
run
(
self
):
if
self
.
_done
:
raise
RuntimeError
(
"Training is already done!."
)
self
.
out
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# sort extensions by priorities once
extension_order
=
sorted
(
self
.
extensions
.
keys
(),
key
=
lambda
name
:
self
.
extensions
[
name
].
priority
,
reverse
=
True
)
extensions
=
[(
name
,
self
.
extensions
[
name
])
for
name
in
extension_order
]
# initializing all extensions
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"initialize"
):
entry
.
extension
.
initialize
(
self
)
update
=
self
.
updater
.
update
# training step
stop_trigger
=
self
.
stop_trigger
# display only one progress bar
max_iteration
=
None
if
isinstance
(
stop_trigger
,
LimitTrigger
):
if
stop_trigger
.
unit
==
'epoch'
:
max_epoch
=
self
.
stop_trigger
.
limit
updates_per_epoch
=
getattr
(
self
.
updater
,
"updates_per_epoch"
,
None
)
max_iteration
=
max_epoch
*
updates_per_epoch
if
updates_per_epoch
else
None
else
:
max_iteration
=
self
.
stop_trigger
.
limit
p
=
tqdm
.
tqdm
(
initial
=
self
.
updater
.
state
.
iteration
,
total
=
max_iteration
)
try
:
while
not
stop_trigger
(
self
):
self
.
observation
=
{}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with
scope
(
self
.
observation
):
update
()
p
.
update
()
# execute extension when necessary
for
name
,
entry
in
extensions
:
if
entry
.
trigger
(
self
):
entry
.
extension
(
self
)
# print("###", self.observation)
except
Exception
as
e
:
f
=
sys
.
stderr
f
.
write
(
f
"Exception in main training loop:
{
e
}
\n
"
)
f
.
write
(
"Traceback (most recent call last):
\n
"
)
traceback
.
print_tb
(
sys
.
exc_info
()[
2
])
f
.
write
(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info
=
sys
.
exc_info
()
# try to handle it
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"on_error"
):
try
:
entry
.
extension
.
on_error
(
self
,
e
,
sys
.
exc_info
()[
2
])
except
Exception
as
ee
:
f
.
write
(
f
"Exception in error handler:
{
ee
}
\n
"
)
f
.
write
(
'Traceback (most recent call last):
\n
'
)
traceback
.
print_tb
(
sys
.
exc_info
()[
2
])
# raise exception in main training loop
six
.
reraise
(
*
exc_info
)
finally
:
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"finalize"
):
entry
.
extension
.
finalize
(
self
)
\ No newline at end of file
deepspeech/training/updaters/updater.py
0 → 100644
浏览文件 @
2e77c3c3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
import
paddle
from
deepspeech.utils.log
import
Log
__all__
=
[
"UpdaterBase"
,
"UpdaterState"
]
logger
=
Log
(
__name__
).
getlog
()
@
dataclass
class
UpdaterState
:
iteration
:
int
=
0
epoch
:
int
=
0
class
UpdaterBase
():
"""An updater is the abstraction of how a model is trained given the
dataloader and the optimizer.
The `update_core` method is a step in the training loop with only necessary
operations (get a batch, forward and backward, update the parameters).
Other stuffs are made extensions. Visualization, saving, loading and
periodical validation and evaluation are not considered here.
But even in such simplist case, things are not that simple. There is an
attempt to standardize this process and requires only the model and
dataset and do all the stuffs automatically. But this may hurt flexibility.
If we assume a batch yield from the dataloader is just the input to the
model, we will find that some model requires more arguments, or just some
keyword arguments. But this prevents us from over-simplifying it.
From another perspective, the batch may includes not just the input, but
also the target. But the model's forward method may just need the input.
We can pass a dict or a super-long tuple to the model and let it pick what
it really needs. But this is an abuse of lazy interface.
After all, we care about how a model is trained. But just how the model is
used for inference. We want to control how a model is trained. We just
don't want to be messed up with other auxiliary code.
So the best practice is to define a model and define a updater for it.
"""
def
__init__
(
self
,
init_state
=
None
):
if
init_state
is
None
:
self
.
state
=
UpdaterState
()
else
:
self
.
state
=
init_state
def
update
(
self
,
batch
):
raise
NotImplementedError
(
"Implement your own `update` method for training a step."
)
def
state_dict
(
self
):
state_dict
=
{
"epoch"
:
self
.
state
.
epoch
,
"iteration"
:
self
.
state
.
iteration
,
}
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
state
.
epoch
=
state_dict
[
"epoch"
]
self
.
state
.
iteration
=
state_dict
[
"iteration"
]
def
save
(
self
,
path
):
logger
.
debug
(
f
"Saving to
{
path
}
."
)
archive
=
self
.
state_dict
()
paddle
.
save
(
archive
,
str
(
path
))
def
load
(
self
,
path
):
logger
.
debug
(
f
"Loading from
{
path
}
."
)
archive
=
paddle
.
load
(
str
(
path
))
self
.
set_state_dict
(
archive
)
\ No newline at end of file
requirements.txt
浏览文件 @
2e77c3c3
...
...
@@ -15,3 +15,4 @@ tensorboardX
textgrid
typeguard
yacs
jsonlines
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录