提交 2aed2752 编写于 作者: H Hui Zhang

fix test

上级 b15b6c6a
......@@ -357,9 +357,9 @@ if not hasattr(paddle.Tensor, 'tolist'):
########### hcak paddle.nn.functional #############
def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor:
def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor:
"""The gated linear unit (GLU) activation."""
a, b = x.split(2, axis=dim)
a, b = x.split(2, axis=axis)
act_b = F.sigmoid(b)
return a * act_b
......@@ -458,8 +458,8 @@ class ConstantPad2d(nn.Layer):
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
......
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
import os
from deepspeech.exps.u2.config import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
......@@ -53,4 +52,4 @@ if __name__ == "__main__":
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))
pr.dump_stats('test.profile')
......@@ -91,7 +91,7 @@ training:
decoding:
batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册