提交 258956c9 编写于 作者: Z zhangyinhui

Added some paddle.jit.save debug code cases

上级 43b52082
......@@ -599,30 +599,30 @@ class U2BaseModel(nn.Module):
best_index = i
return hyps[best_index][0]
@jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
# @jit.export
# def subsampling_rate(self) -> int:
# """ Export interface for c++ call, return subsampling_rate of the
# model
# """
# return self.encoder.embed.subsampling_rate
# @jit.export
# def right_context(self) -> int:
# """ Export interface for c++ call, return right_context of the model
# """
# return self.encoder.embed.right_context
# @jit.export
# def sos_symbol(self) -> int:
# """ Export interface for c++ call, return sos symbol id of the model
# """
# return self.sos
# @jit.export
# def eos_symbol(self) -> int:
# """ Export interface for c++ call, return eos symbol id of the model
# """
# return self.eos
@jit.export
def forward_encoder_chunk(
......@@ -654,16 +654,16 @@ class U2BaseModel(nn.Module):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
Returns:
paddle.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
# @jit.export
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
@jit.export
def forward_attention_decoder(
......@@ -878,12 +878,10 @@ class U2Model(U2BaseModel):
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
......
......@@ -70,10 +70,11 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
# n_batch = query.size(0)
n_batch = query.shape[0]
q = self.linear_q(query).reshape([n_batch, -1, self.h, self.d_k])
k = self.linear_k(key).reshape([n_batch, -1, self.h, self.d_k])
v = self.linear_v(value).reshape([n_batch, -1, self.h, self.d_k])
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
......@@ -96,7 +97,8 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
"""
n_batch = value.size(0)
# n_batch = value.size(0)
n_batch = value.shape[0]
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
......@@ -205,8 +207,10 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
#n_batch_pos = pos_emb.size(0)
n_batch_pos = pos_emb.shape[0]
# p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
......
......@@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + x.size(1) < self.max_len
assert offset + x.shape[1] < self.max_len
x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]
......
......@@ -159,7 +159,7 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
......
......@@ -128,7 +128,8 @@ class Conv2dSubsampling4(BaseSubsampling):
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
#import pdb;pdb.set_trace()
b, c, t, f = x.shape
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册