提交 3fd7a779 编写于 作者: 小湉湉's avatar 小湉湉

add typehit for updater and evaluator, test=tts

上级 9c7f0762
......@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
......@@ -28,13 +32,13 @@ logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater):
def __init__(self,
model,
optimizer,
dataloader,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
use_masking=False,
use_weighted_masking=False,
output_dir=None):
use_masking: bool=False,
use_weighted_masking: bool=False,
output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None)
self.criterion = FastSpeech2Loss(
......@@ -104,11 +108,11 @@ class FastSpeech2Updater(StandardUpdater):
class FastSpeech2Evaluator(StandardEvaluator):
def __init__(self,
model,
dataloader,
use_masking=False,
use_weighted_masking=False,
output_dir=None):
model: Layer,
dataloader: DataLoader,
use_masking: bool=False,
use_weighted_masking: bool=False,
output_dir: Path=None):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import logging
from pathlib import Path
from typing import Dict
from paddle import distributed as dist
from paddle.io import DataLoader
......@@ -34,8 +33,8 @@ logger.setLevel(logging.INFO)
class Tacotron2Updater(StandardUpdater):
def __init__(self,
model: Dict[str, Layer],
optimizer: Dict[str, Optimizer],
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
use_masking: bool=True,
......@@ -126,8 +125,8 @@ class Tacotron2Updater(StandardUpdater):
class Tacotron2Evaluator(StandardEvaluator):
def __init__(self,
model,
dataloader,
model: Layer,
dataloader: DataLoader,
use_masking: bool=True,
use_weighted_masking: bool=False,
bce_pos_weight: float=5.0,
......
......@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import paddle
from paddle import distributed as dist
from paddle.fluid.layers import huber_loss
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddlespeech.t2s.modules.losses import masked_l1_loss
from paddlespeech.t2s.modules.losses import ssim
......@@ -33,11 +37,11 @@ logger.setLevel(logging.INFO)
class SpeedySpeechUpdater(StandardUpdater):
def __init__(self,
model,
optimizer,
dataloader,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
output_dir=None):
output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
......@@ -103,7 +107,10 @@ class SpeedySpeechUpdater(StandardUpdater):
class SpeedySpeechEvaluator(StandardEvaluator):
def __init__(self, model, dataloader, output_dir=None):
def __init__(self,
model: Layer,
dataloader: DataLoader,
output_dir: Path=None):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
......
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Sequence
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddlespeech.t2s.modules.losses import GuidedMultiHeadAttentionLoss
from paddlespeech.t2s.modules.losses import Tacotron2Loss as TransformerTTSLoss
......@@ -32,14 +36,14 @@ logger.setLevel(logging.INFO)
class TransformerTTSUpdater(StandardUpdater):
def __init__(
self,
model,
optimizer,
dataloader,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
use_masking=False,
use_weighted_masking=False,
output_dir=None,
bce_pos_weight=5.0,
use_masking: bool=False,
use_weighted_masking: bool=False,
output_dir: Path=None,
bce_pos_weight: float=5.0,
loss_type: str="L1",
use_guided_attn_loss: bool=True,
modules_applied_guided_attn: Sequence[str]=("encoder-decoder"),
......@@ -185,13 +189,13 @@ class TransformerTTSUpdater(StandardUpdater):
class TransformerTTSEvaluator(StandardEvaluator):
def __init__(
self,
model,
dataloader,
model: Layer,
dataloader: DataLoader,
init_state=None,
use_masking=False,
use_weighted_masking=False,
output_dir=None,
bce_pos_weight=5.0,
use_masking: bool=False,
use_weighted_masking: bool=False,
output_dir: Path=None,
bce_pos_weight: float=5.0,
loss_type: str="L1",
use_guided_attn_loss: bool=True,
modules_applied_guided_attn: Sequence[str]=("encoder-decoder"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册