未验证 提交 ed19e243 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #930 from Jackwaterveg/join_ctc

Join ctc
...@@ -24,6 +24,7 @@ from .utils import add_results_to_json ...@@ -24,6 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load # from espnet.asr.asr_utils import torch_load
...@@ -78,12 +79,18 @@ def recog_v2(args): ...@@ -78,12 +79,18 @@ def recog_v2(args):
preprocess_args={"train": False}, ) preprocess_args={"train": False}, )
if args.rnnlm: if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) lm_path = args.rnnlm
# NOTE: for a compatibility with less than 0.5.0 version models lm = TransformerLM(
lm_model_module = getattr(lm_args, "model_module", "default") n_vocab=5002,
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) pos_enc=None,
lm = lm_class(len(char_list), lm_args) embed_unit=128,
torch_load(args.rnnlm, lm) att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
model_dict = paddle.load(lm_path)
lm.set_state_dict(model_dict)
lm.eval() lm.eval()
else: else:
lm = None lm = None
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Tuple from typing import Tuple
...@@ -150,7 +151,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): ...@@ -150,7 +151,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _, cache = self.encoder.forward_one_step( h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state) emb, self._target_mask(y), cache=state)
h = self.decoder(h[:, -1]) h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0) logp = F.log_softmax(h).squeeze(0)
return logp, cache return logp, cache
# batch beam search API (see BatchScorerInterface) # batch beam search API (see BatchScorerInterface)
...@@ -193,7 +194,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): ...@@ -193,7 +194,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _, states = self.encoder.forward_one_step( h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state) emb, self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1]) h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1) logp = F.log_softmax(h)
# transpose state of [layer, batch] into [batch, layer] # transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] state_list = [[states[i][b] for i in range(n_layers)]
...@@ -219,7 +220,7 @@ if __name__ == "__main__": ...@@ -219,7 +220,7 @@ if __name__ == "__main__":
# head: int=2, # head: int=2,
# unit: int=1024, # unit: int=1024,
# layer: int=4, # layer: int=4,
# dropout_rate: float=0.5, # dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0, # emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0, # att_dropout_rate: float = 0.0,
# tie_weights: bool = False,): # tie_weights: bool = False,):
...@@ -231,14 +232,14 @@ if __name__ == "__main__": ...@@ -231,14 +232,14 @@ if __name__ == "__main__":
#Test the score #Test the score
input2 = np.array([5]) input2 = np.array([5])
input2 = paddle.to_tensor(input2) input2 = paddle.to_tensor(input2)
state = (None, None, 0) state = None
output, state = tlm.score(input2, state, None) output, state = tlm.score(input2, state, None)
input3 = np.array([10]) input3 = np.array([5, 10])
input3 = paddle.to_tensor(input3) input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None) output, state = tlm.score(input3, state, None)
input4 = np.array([0]) input4 = np.array([5, 10, 0])
input4 = paddle.to_tensor(input4) input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None) output, state = tlm.score(input4, state, None)
print("output", output) print("output", output)
......
...@@ -399,7 +399,8 @@ class TransformerEncoder(BaseEncoder): ...@@ -399,7 +399,8 @@ class TransformerEncoder(BaseEncoder):
xs, pos_emb, masks = self.embed( xs, pos_emb, masks = self.embed(
xs, masks.astype(xs.dtype), offset=0) xs, masks.astype(xs.dtype), offset=0)
else: else:
xs = self.embed(xs) xs, pos_emb, masks = self.embed(
xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool) masks = masks.astype(paddle.bool)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册