提交 69dfc2a5 编写于 作者: H Hui Zhang

fix mask for bool type; fix other

上级 c65ea55b
......@@ -126,7 +126,7 @@ class ConvolutionModule(nn.Layer):
if self.lorder > 0:
if cache is None:
x = nn.functional.pad(
x, (self.lorder, 0), 'constant', 0.0, data_format='NCL')
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
assert cache.shape[0] == x.shape[0] # B
assert cache.shape[1] == x.shape[1] # C
......
......@@ -209,7 +209,9 @@ class BaseEncoder(nn.Layer):
"""
assert xs.size(0) == 1 # batch size must be one
# tmp_masks is just for interface compatibility
tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
# TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None:
......
......@@ -121,7 +121,7 @@ def subsequent_chunk_mask(
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros([size, size], dtype=paddle.bool)
ret = paddle.zeros([size, size], dtype=paddle.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
......@@ -186,13 +186,15 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
......
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 32
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 180
accum_grad: 1
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-6
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
decoding:
batch_size: 1
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: True # simulate streaming inference. Defaults to False.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册