diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index d2c03a18e013ca722dcd8c99cc23d7719293e78f..233986a9a141c98638166cc29c81a92dc971c3a3 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -21,8 +21,8 @@ from yacs.config import CfgNode from deepspeech.modules.conv import ConvStack from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.rnn import RNNStack -from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -222,7 +222,7 @@ class DeepSpeech2Model(nn.Layer): rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) - infos = checkpoint.load_parameters( + infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") layer_tools.summary(model)