Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
e5a6c243
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“e21e8d19896ce9d2f3ecec3957d757387f1afbca”上不存在“tools/git@gitcode.net:qq_37101384/mace.git”
提交
e5a6c243
编写于
8月 01, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix jit save for conformer
上级
4e7106d9
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
62 addition
and
205 deletion
+62
-205
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+37
-168
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+25
-37
未找到文件。
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
e5a6c243
...
@@ -25,8 +25,6 @@ import paddle
...
@@ -25,8 +25,6 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
...
@@ -109,7 +107,8 @@ class U2Trainer(Trainer):
...
@@ -109,7 +107,8 @@ class U2Trainer(Trainer):
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -136,7 +135,8 @@ class U2Trainer(Trainer):
...
@@ -136,7 +135,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -157,7 +157,8 @@ class U2Trainer(Trainer):
...
@@ -157,7 +157,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -225,14 +226,18 @@ class U2Trainer(Trainer):
...
@@ -225,14 +226,18 @@ class U2Trainer(Trainer):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
'decode_batch_size'
,
1
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
...
@@ -470,166 +475,30 @@ class U2Tester(U2Trainer):
...
@@ -470,166 +475,30 @@ class U2Tester(U2Trainer):
def
export
(
self
):
def
export
(
self
):
infer_model
,
input_spec
=
self
.
load_inferspec
()
infer_model
,
input_spec
=
self
.
load_inferspec
()
assert
isinstance
(
input_spec
,
list
),
type
(
input_spec
)
assert
isinstance
(
input_spec
,
list
),
type
(
input_spec
)
del
input_spec
infer_model
.
eval
()
infer_model
.
eval
()
# static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
# logger.info(f"Export code: {static_model.forward.code}")
# paddle.jit.save(static_model, self.args.export_path)
# # to check outputs
# def flatten(out):
# if isinstance(out, paddle.Tensor):
# return [out]
# flatten_out = []
# for var in out:
# if isinstance(var, (list, tuple)):
# flatten_out.extend(flatten(var))
# else:
# flatten_out.append(var)
# return flatten_out
# ######################### infer_model.forward_attention_decoder ########################
# a = paddle.full(shape=[10, 8], fill_value=10, dtype='int64')
# b = paddle.full(shape=[10], fill_value=8, dtype='int64')
# # c = paddle.rand(shape=[1, 20, 512], dtype='float32')
# c = paddle.full(shape=[1, 20, 512], fill_value=1, dtype='float32')
# out1 = infer_model.forward_attention_decoder(a, b, c)
# print(out1)
# input_spec = [paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')]
# static_model = paddle.jit.to_static(infer_model.forward_attention_decoder, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path)
# static_model = paddle.jit.load(self.args.export_path)
# out2 = static_model(a, b, c)
# # print(out2)
# out1 = flatten(out1)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
# ######################### infer_model.forward_encoder_chunk ########################
# xs = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([80], dtype='int32')
# required_cache_size = -16
# att_cache = paddle.randn(shape=[12, 8, 80, 128], dtype='float32')
# cnn_cache = paddle.randn(shape=[12, 1, 512, 14], dtype='float32')
# # out1 = infer_model.forward_encoder_chunk(xs, offset, required_cache_size, att_cache, cnn_cache)
# # print(out1)
# zero_out1 = infer_model.forward_encoder_chunk(xs, offset, required_cache_size, att_cache=paddle.zeros([0, 0, 0, 0]), cnn_cache=paddle.zeros([0, 0, 0, 0]))
# # print(zero_out1)
# input_spec = [
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -16,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]
# static_model = paddle.jit.to_static(infer_model.forward_encoder_chunk, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path)
# static_model = paddle.jit.load(self.args.export_path)
# # out2 = static_model(xs, offset, att_cache, cnn_cache)
# # print(out2)
# zero_out2 = static_model(xs, offset, paddle.zeros([0, 0, 0, 0]), paddle.zeros([0, 0, 0, 0]))
# # out1 = flatten(out1)
# # out2 = flatten(out2)
# # for i in range(len(out1)):
# # print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
# zero_out1 = flatten(zero_out1)
# zero_out2 = flatten(zero_out2)
# for i in range(len(zero_out1)):
# print(np.equal(zero_out1[i].numpy(), zero_out2[i].numpy()).all())
# ######################### infer_model.forward_encoder_chunk zero Tensor online ########################
# xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([0], dtype='int32')
# required_cache_size = -16
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
######################### infer_model.forward_encoder_chunk zero Tensor online ########################
# xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32')
input_spec
=
[
# offset = paddle.to_tensor([16], dtype='int32')
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
None
,
80
],
dtype
=
'float32'
),
# out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache)
paddle
.
static
.
InputSpec
(
shape
=
[
1
],
dtype
=
'int32'
),
-
1
,
# # print(out1)
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
],
# input_spec = [
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
shape
=
[
None
,
None
,
None
,
None
],
dtype
=
'float32'
)
# paddle.static.InputSpec(shape=[1], dtype='int32'),
]
# -16,
infer_model
.
forward_encoder_chunk
=
paddle
.
jit
.
to_static
(
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
infer_model
.
forward_encoder_chunk
,
input_spec
=
input_spec
)
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
# static_model = paddle.jit.to_static(infer_model.forward_encoder_chunk, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path)
# static_model = paddle.jit.load(self.args.export_path)
# offset = paddle.to_tensor([0], dtype='int32')
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = static_model(xs1, offset, att_cache, cnn_cache)
# xs = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([16], dtype='int32')
# out2 = static_model(xs2, offset, att_cache, cnn_cache)
# # print(out2)
# out1 = flatten(out1)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
###################### save/load combine ########################
paddle
.
jit
.
save
(
infer_model
,
'/workspace/conformer/PaddleSpeech-conformer/conformer/conformer'
,
combine_params
=
True
)
# xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([0], dtype='int32')
# required_cache_size = -16
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
# xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([16], dtype='int32')
# out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache)
# # print(out1)
# from paddle.jit.layer import Layer
# layer = Layer()
# layer.load('/workspace/conformer/PaddleSpeech-conformer/conformer/conformer', paddle.CUDAPlace(0))
# offset = paddle.to_tensor([0], dtype='int32')
######################### infer_model.forward_attention_decoder ########################
# att_cache = paddle.zeros([0, 0, 0, 0])
input_spec
=
[
# cnn_cache=paddle.zeros([0, 0, 0, 0])
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
],
dtype
=
'int64'
),
# xs, att_cache, cnn_cache = layer.forward_encoder_chunk(xs1, offset, att_cache, cnn_cache)
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# offset = paddle.to_tensor([16], dtype='int32')
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
None
,
512
],
dtype
=
'float32'
)
# out2 = layer.forward_encoder_chunk(xs2, offset, att_cache, cnn_cache)
]
# # print(out2)
infer_model
.
forward_attention_decoder
=
paddle
.
jit
.
to_static
(
infer_model
.
forward_attention_decoder
,
input_spec
=
input_spec
)
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
# out1 = flatten(out1)
paddle
.
jit
.
save
(
infer_model
,
'./export.jit'
,
combine_params
=
True
)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
\ No newline at end of file
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
e5a6c243
...
@@ -29,6 +29,9 @@ import paddle
...
@@ -29,6 +29,9 @@ import paddle
from
paddle
import
jit
from
paddle
import
jit
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
pad_sequence
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.decoders.scorers.ctc
import
CTCPrefixScorer
from
paddlespeech.s2t.decoders.scorers.ctc
import
CTCPrefixScorer
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
...
@@ -48,9 +51,6 @@ from paddlespeech.s2t.utils import checkpoint
...
@@ -48,9 +51,6 @@ from paddlespeech.s2t.utils import checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
pad_sequence
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
...
@@ -59,20 +59,6 @@ __all__ = ["U2Model", "U2InferModel"]
...
@@ -59,20 +59,6 @@ __all__ = ["U2Model", "U2InferModel"]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
# input_spec1 = [paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')]
# input_spec2 = [
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -16,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]
# input_spec3 = [paddle.static.InputSpec(shape=[1, 1, 1], dtype='int64'),
# paddle.static.InputSpec(shape=[1], dtype='int64')]
class
U2BaseModel
(
ASRInterface
,
nn
.
Layer
):
class
U2BaseModel
(
ASRInterface
,
nn
.
Layer
):
"""CTC-Attention hybrid Encoder-Decoder model"""
"""CTC-Attention hybrid Encoder-Decoder model"""
...
@@ -588,44 +574,44 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -588,44 +574,44 @@ class U2BaseModel(ASRInterface, nn.Layer):
best_index
=
i
best_index
=
i
return
hyps
[
best_index
][
0
]
return
hyps
[
best_index
][
0
]
#@jit.to_static
@
jit
.
to_static
(
property
=
True
)
def
subsampling_rate
(
self
)
->
int
:
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
""" Export interface for c++ call, return subsampling_rate of the
model
model
"""
"""
return
self
.
encoder
.
embed
.
subsampling_rate
return
self
.
encoder
.
embed
.
subsampling_rate
#@jit.to_static
@
jit
.
to_static
(
property
=
True
)
def
right_context
(
self
)
->
int
:
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
""" Export interface for c++ call, return right_context of the model
"""
"""
return
self
.
encoder
.
embed
.
right_context
return
self
.
encoder
.
embed
.
right_context
#@jit.to_static
@
jit
.
to_static
(
property
=
True
)
def
sos_symbol
(
self
)
->
int
:
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
""" Export interface for c++ call, return sos symbol id of the model
"""
"""
return
self
.
sos
return
self
.
sos
#@jit.to_static
@
jit
.
to_static
(
property
=
True
)
def
eos_symbol
(
self
)
->
int
:
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
""" Export interface for c++ call, return eos symbol id of the model
"""
"""
return
self
.
eos
return
self
.
eos
@
jit
.
to_static
(
input_spec
=
[
#
@jit.to_static(input_spec=[
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
None
,
80
],
dtype
=
'float32'
),
#
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
paddle
.
static
.
InputSpec
(
shape
=
[
1
],
dtype
=
'int32'
),
#
paddle.static.InputSpec(shape=[1], dtype='int32'),
-
16
,
# -1
,
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
],
dtype
=
'float32'
),
#
paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
],
dtype
=
'float32'
)])
#
paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')])
def
forward_encoder_chunk
(
def
forward_encoder_chunk
(
self
,
self
,
xs
:
paddle
.
Tensor
,
xs
:
paddle
.
Tensor
,
offset
:
int
,
offset
:
int
,
required_cache_size
:
int
,
required_cache_size
:
int
,
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
""" Export interface for c++ call, give input chunk xs, and return
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
output from time 0 to current chunk.
...
@@ -660,8 +646,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -660,8 +646,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
same shape as the original cnn_cache.
"""
"""
return
self
.
encoder
.
forward_chunk
(
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
att_cache
,
cnn_cache
)
# @jit.to_static
# @jit.to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
...
@@ -674,10 +660,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -674,10 +660,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
"""
return
self
.
ctc
.
log_softmax
(
xs
)
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
to_static
(
input_spec
=
[
#
@jit.to_static(input_spec=[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
],
dtype
=
'int64'
),
#
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
#
paddle.static.InputSpec(shape=[None], dtype='int64'),
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
None
,
512
],
dtype
=
'float32'
)])
#
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')])
def
forward_attention_decoder
(
def
forward_attention_decoder
(
self
,
self
,
hyps
:
paddle
.
Tensor
,
hyps
:
paddle
.
Tensor
,
...
@@ -942,7 +928,8 @@ class U2InferModel(U2Model):
...
@@ -942,7 +928,8 @@ class U2InferModel(U2Model):
@
jit
.
to_static
(
input_spec
=
[
@
jit
.
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
1
,
1
],
dtype
=
'int64'
),
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
1
,
1
],
dtype
=
'int64'
),
paddle
.
static
.
InputSpec
(
shape
=
[
1
],
dtype
=
'int64'
)])
paddle
.
static
.
InputSpec
(
shape
=
[
1
],
dtype
=
'int64'
)
])
def
forward
(
self
,
def
forward
(
self
,
feats
,
feats
,
feats_lengths
,
feats_lengths
,
...
@@ -958,6 +945,7 @@ class U2InferModel(U2Model):
...
@@ -958,6 +945,7 @@ class U2InferModel(U2Model):
Returns:
Returns:
List[List[int]]: best path result
List[List[int]]: best path result
"""
"""
# dummy code for dy2st
# return self.ctc_greedy_search(
# return self.ctc_greedy_search(
# feats,
# feats,
# feats_lengths,
# feats_lengths,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录