提交 fe4b4710 编写于 作者: C chenfeiyu

use g2p in deepvoice3

上级 b7e74f78
......@@ -14,7 +14,7 @@ from hparams import hparams, hparams_debug_string
from data.data import TextDataSource, MelSpecDataSource
from nnmnkwii.datasets import FileSourceDataset
from tqdm import trange
from modules import frontend
import g2p as frontend
def build_parser():
......
......@@ -25,7 +25,7 @@ import random
# import global hyper parameters
from hparams import hparams
from modules import frontend
import g2p as frontend
import builder
_frontend = getattr(frontend, hparams.frontend)
......
......@@ -17,7 +17,7 @@ from paddle import fluid
import paddle.fluid.dygraph as dg
from hparams import hparams, hparams_debug_string
from modules import frontend
import g2p as frontend
from deepvoice3 import DeepVoiceTTS
......
......@@ -37,7 +37,7 @@ from tensorboardX import SummaryWriter
# import global hyper parameters
from hparams import hparams
from modules import frontend
import g2p as frontend
_frontend = getattr(frontend, hparams.frontend)
......
......@@ -30,7 +30,7 @@ import paddle.fluid.dygraph as dg
sys.path.append("../")
import audio
from modules import frontend
import g2p as frontend
import dry_run
from hparams import hparams
......
......@@ -32,7 +32,7 @@ from data import (TextDataSource, MelSpecDataSource,
LinearSpecDataSource,
PartialyRandomizedSimilarTimeLengthSampler,
Dataset, make_loader, create_batch)
from modules import frontend
import g2p as frontend
from builder import deepvoice3, WindowRange
from dry_run import dry_run
from train_model import train_model
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
from weight_norm import Conv2D, Conv2DTranspose
class Conv1D(dg.Layer):
"""
A convolution 1D block implemented with Conv2D. Form simplicity and
ensuring the output has the same length as the input, it does not allow
stride > 1.
"""
def __init__(self,
name_scope,
in_cahnnels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1D, self).__init__(name_scope, dtype=dtype)
if causal:
padding = dilation * (filter_size - 1)
else:
padding = (dilation * (filter_size - 1)) // 2
self.in_channels = in_cahnnels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.padding = padding
self.act = act
self.conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=(1, filter_size),
stride=(1, 1),
dilation=(1, dilation),
padding=(0, padding),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
output channels (num_filters).
"""
x = self.conv(x)
if self.filter_size > 1:
if self.causal:
x = fluid.layers.slice(
x, axes=[3], starts=[0], ends=[-self.padding])
elif self.filter_size % 2 == 0:
x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1])
return x
def start_new_sequence(self):
self.temp_weight = None
self.input_buffer = None
def add_input(self, x):
"""
Adding input for a time step and compute an output for a time step.
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels, and T = 1.
Returns:
out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out
means output channels (num_filters), and T = 1.
"""
if self.temp_weight is None:
self.temp_weight = self._reshaped_weight()
window_size = 1 + (self.filter_size - 1) * self.dilation
batch_size = x.shape[0]
in_channels = x.shape[1]
if self.filter_size > 1:
if self.input_buffer is None:
self.input_buffer = fluid.layers.fill_constant(
[batch_size, in_channels, 1, window_size - 1],
dtype=x.dtype,
value=0.0)
else:
self.input_buffer = self.input_buffer[:, :, :, 1:]
self.input_buffer = fluid.layers.concat(
[self.input_buffer, x], axis=3)
x = self.input_buffer
if self.dilation > 1:
if not hasattr(self, "indices"):
self.indices = dg.to_variable(
np.arange(0, window_size, self.dilation))
tmp = fluid.layers.transpose(
self.input_buffer, perm=[3, 1, 2, 0])
tmp = fluid.layers.gather(tmp, index=self.indices)
tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0])
x = tmp
inputs = fluid.layers.reshape(
x, shape=[batch_size, in_channels * 1 * self.filter_size])
out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True)
out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1)
out = fluid.layers.reshape(out, out.shape + [1, 1])
out = self._helper.append_activation(out, act=self.act)
return out
def _reshaped_weight(self):
"""
Get the linearized weight of convolution filter, cause it is by nature
a matmul weight. And because the model uses weight norm, compute the
weight by weight_v * weight_g to make it faster.
Returns:
weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size)
"""
shape = self.conv._filter_param_v.shape
matrix_shape = [shape[0], np.prod(shape[1:])]
weight_matrix = fluid.layers.reshape(
self.conv._filter_param_v, shape=matrix_shape)
weight_matrix = fluid.layers.elementwise_mul(
fluid.layers.l2_normalize(
weight_matrix, axis=1),
self.conv._filter_param_g,
axis=0)
return weight_matrix
class Conv1DTranspose(dg.Layer):
"""
A convolutional transpose 1D block implemented with convolutional transpose
2D. It does not ensure that the output is exactly expanded stride times in
time dimension.
"""
def __init__(self,
name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1DTranspose, self).__init__(name_scope, dtype=dtype)
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
self.conv_transpose = Conv2DTranspose(
self.full_name(),
num_filters,
filter_size=(1, filter_size),
padding=(0, padding),
stride=(1, stride),
dilation=(1, dilation),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Argss:
x (Variable): Shape(B, C_in, 1, T_in), where C_in means the input
channels and T_in means the number of time steps of input.
Returns:
out (Variable): shape(B, C_out, 1, T_out), where C_out means the
output channels and T_out means the number of time steps of
input.
"""
return self.conv_transpose(x)
This package is adapted from https://github.com/r9y9/deepvoice3_pytorch/tree/master/deepvoice3_pytorch/frontend, Copyright (c) 2017: Ryuichi Yamamoto, whose license applies.
# coding: utf-8
"""Text processing frontend
All frontend module should have the following functions:
- text_to_sequence(text, p)
- sequence_to_text(sequence)
and the property:
- n_vocab
"""
from . import en
# optinoal Japanese frontend
try:
from . import jp
except ImportError:
jp = None
try:
from . import ko
except ImportError:
ko = None
# if you are going to use the frontend, you need to modify _characters in
# symbol.py:
# _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + '¡¿ñáéíóúÁÉÍÓÚÑ'
try:
from . import es
except ImportError:
es = None
# coding: utf-8
from modules.frontend.text.symbols import symbols
import nltk
from random import random
n_vocab = len(symbols)
_arpabet = nltk.corpus.cmudict.dict()
def _maybe_get_arpabet(word, p):
try:
phonemes = _arpabet[word][0]
phonemes = " ".join(phonemes)
except KeyError:
return word
return '{%s}' % phonemes if random() < p else word
def mix_pronunciation(text, p):
text = ' '.join(_maybe_get_arpabet(word, p) for word in text.split(' '))
return text
def text_to_sequence(text, p=0.0):
if p >= 0:
text = mix_pronunciation(text, p)
from modules.frontend.text import text_to_sequence
text = text_to_sequence(text, ["english_cleaners"])
return text
from modules.frontend.text import sequence_to_text
# coding: utf-8
from deepvoice3_paddle.frontend.text.symbols import symbols
import nltk
from random import random
n_vocab = len(symbols)
def text_to_sequence(text, p=0.0):
from deepvoice3_paddle.frontend.text import text_to_sequence
text = text_to_sequence(text, ["basic_cleaners"])
return text
from deepvoice3_paddle.frontend.text import sequence_to_text
# coding: utf-8
import MeCab
import jaconv
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def _yomi(mecab_result):
tokens = []
yomis = []
for line in mecab_result.split("\n")[:-1]:
s = line.split("\t")
if len(s) == 1:
break
token, rest = s
rest = rest.split(",")
tokens.append(token)
yomi = rest[7] if len(rest) > 7 else None
yomi = None if yomi == "*" else yomi
yomis.append(yomi)
return tokens, yomis
def _mix_pronunciation(tokens, yomis, p):
return "".join(yomis[idx]
if yomis[idx] is not None and random() < p else tokens[idx]
for idx in range(len(tokens)))
def mix_pronunciation(text, p):
global _tagger
if _tagger is None:
_tagger = MeCab.Tagger("")
tokens, yomis = _yomi(_tagger.parse(text))
return _mix_pronunciation(tokens, yomis, p)
def add_punctuation(text):
last = text[-1]
if last not in [".", ",", "、", "。", "!", "?", "!", "?"]:
text = text + "。"
return text
def normalize_delimitor(text):
text = text.replace(",", "、")
text = text.replace(".", "。")
text = text.replace(",", "、")
text = text.replace(".", "。")
return text
def text_to_sequence(text, p=0.0):
for c in [" ", " ", "「", "」", "『", "』", "・", "【", "】", "(", ")", "(", ")"]:
text = text.replace(c, "")
text = text.replace("!", "!")
text = text.replace("?", "?")
text = normalize_delimitor(text)
text = jaconv.normalize(text)
if p > 0:
text = mix_pronunciation(text, p)
text = jaconv.hira2kata(text)
text = add_punctuation(text)
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)
# coding: utf-8
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def text_to_sequence(text, p=0.0):
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)
import re
from . import cleaners
from .symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'
'''
Cleaners are transformations that run over the input text at both training and
eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as
the "cleaners" hyperparameter. Some cleaners are English-specific. You'll
typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated
to ASCII using the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you
should also update the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def add_punctuation(text):
if len(text) == 0:
return text
if text[-1] not in '!,.:;?':
text = text + '.' # without this decoder is confused when to output EOS
return text
def basic_cleaners(text):
'''
Basic pipeline that lowercases and collapses whitespace without
transliteration.
'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''
Pipeline for English text, including number and abbreviation expansion.
'''
text = convert_to_ascii(text)
text = add_punctuation(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1',
'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T',
'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y',
'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''
Thin wrapper around CMUDict data.
http://www.speech.cs.cmu.edu/cgi-bin/cmudict
'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {
word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)
# -*- coding: utf-8 -*-
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(
num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
'''
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text
that has been run through Unidecode. For other data, you can modify _characters.
See TRAINING_DATA.md for details.
'''
from .cmudict import valid_symbols
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from numba import jit
from paddle import fluid
import paddle.fluid.dygraph as dg
def masked_mean(inputs, mask):
"""
Args:
inputs (Variable): Shape(B, C, 1, T), the input, where B means
batch size, C means channels of input, T means timesteps of
the input.
mask (Variable): Shape(B, T), a mask.
Returns:
loss (Variable): Shape(1, ), masked mean.
"""
channels = inputs.shape[1]
reshaped_mask = fluid.layers.reshape(
mask, shape=[mask.shape[0], 1, 1, mask.shape[-1]])
expanded_mask = fluid.layers.expand(
reshaped_mask, expand_times=[1, channels, 1, 1])
expanded_mask.stop_gradient = True
valid_cnt = fluid.layers.reduce_sum(expanded_mask)
valid_cnt.stop_gradient = True
masked_inputs = inputs * expanded_mask
loss = fluid.layers.reduce_sum(masked_inputs) / valid_cnt
return loss
@jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
W = np.zeros((max_N, max_T), dtype=np.float32)
for n in range(N):
for t in range(T):
W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
return W
def guided_attentions(input_lengths, target_lengths, max_target_len, g=0.2):
B = len(input_lengths)
max_input_len = input_lengths.max()
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
for b in range(B):
W[b] = guided_attention(input_lengths[b], max_input_len,
target_lengths[b], max_target_len, g).T
return W
class TTSLoss(object):
def __init__(self,
masked_weight=0.0,
priority_weight=0.0,
binary_divergence_weight=0.0,
guided_attention_sigma=0.2):
self.masked_weight = masked_weight
self.priority_weight = priority_weight
self.binary_divergence_weight = binary_divergence_weight
self.guided_attention_sigma = guided_attention_sigma
def l1_loss(self, prediction, target, mask, priority_bin=None):
abs_diff = fluid.layers.abs(prediction - target)
# basic mask-weighted l1 loss
w = self.masked_weight
if w > 0 and mask is not None:
base_l1_loss = w * masked_mean(abs_diff, mask) + (
1 - w) * fluid.layers.reduce_mean(abs_diff)
else:
base_l1_loss = fluid.layers.reduce_mean(abs_diff)
if self.priority_weight > 0 and priority_bin is not None:
# mask-weighted priority channels' l1-loss
priority_abs_diff = fluid.layers.slice(
abs_diff, axes=[1], starts=[0], ends=[priority_bin])
if w > 0 and mask is not None:
priority_loss = w * masked_mean(priority_abs_diff, mask) + (
1 - w) * fluid.layers.reduce_mean(priority_abs_diff)
else:
priority_loss = fluid.layers.reduce_mean(priority_abs_diff)
# priority weighted sum
p = self.priority_weight
loss = p * priority_loss + (1 - p) * base_l1_loss
else:
loss = base_l1_loss
return loss
def binary_divergence(self, prediction, target, mask):
flattened_prediction = fluid.layers.reshape(prediction, [-1, 1])
flattened_target = fluid.layers.reshape(target, [-1, 1])
flattened_loss = fluid.layers.log_loss(
flattened_prediction, flattened_target, epsilon=1e-8)
bin_div = fluid.layers.reshape(flattened_loss, prediction.shape)
w = self.masked_weight
if w > 0 and mask is not None:
loss = w * masked_mean(bin_div, mask) + (
1 - w) * fluid.layers.reduce_mean(bin_div)
else:
loss = fluid.layers.reduce_mean(bin_div)
return loss
@staticmethod
def done_loss(done_hat, done):
flat_done_hat = fluid.layers.reshape(done_hat, [-1, 1])
flat_done = fluid.layers.reshape(done, [-1, 1])
loss = fluid.layers.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
loss = fluid.layers.reduce_mean(loss)
return loss
def attention_loss(self, predicted_attention, input_lengths,
target_lengths):
"""
Given valid encoder_lengths and decoder_lengths, compute a diagonal
guide, and compute loss from the predicted attention and the guide.
Args:
predicted_attention (Variable): Shape(*, B, T_dec, T_enc), the
alignment tensor, where B means batch size, T_dec means number
of time steps of the decoder, T_enc means the number of time
steps of the encoder, * means other possible dimensions.
input_lengths (numpy.ndarray): Shape(B,), dtype:int64, valid lengths
(time steps) of encoder outputs.
target_lengths (numpy.ndarray): Shape(batch_size,), dtype:int64,
valid lengths (time steps) of decoder outputs.
Returns:
loss (Variable): Shape(1, ) attention loss.
"""
n_attention, batch_size, max_target_len, max_input_len = (
predicted_attention.shape)
soft_mask = guided_attentions(input_lengths, target_lengths,
max_target_len,
self.guided_attention_sigma)
soft_mask_ = dg.to_variable(soft_mask)
loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
return loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
import numpy as np
import conv
import weight_norm as weight_norm
def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
dropout=0.0,
epsilon=1e-30,
act=None,
is_test=False,
dtype="float32"):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if isinstance(in_features, int):
in_features = [in_features]
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.FC(name_scope,
size,
num_flatten_dims=num_flatten_dims,
param_attr=weight_attrs,
bias_attr=bias_attr,
act=act,
dtype=dtype)
return layer
def Conv1D(name_scope,
in_channels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features)))
"""
# std
std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels))
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1D(
name_scope,
in_channels,
num_filters,
filter_size,
dilation,
groups=groups,
causal=causal,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def Embedding(name_scope,
num_embeddings,
embed_dim,
is_sparse=False,
is_distributed=False,
padding_idx=None,
std=0.01,
dtype="float32"):
# param attrs
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=std))
layer = dg.Embedding(
name_scope, (num_embeddings, embed_dim),
padding_idx=padding_idx,
param_attr=weight_attr,
dtype=dtype)
return layer
class Conv1DGLU(dg.Layer):
"""
A Convolution 1D block with GLU activation. It also applys dropout for the
input x. It fuses speaker embeddings through a FC activated by softsign. It
has residual connection from the input x, and scale the output by
np.sqrt(0.5).
"""
def __init__(self,
name_scope,
n_speakers,
speaker_dim,
in_channels,
num_filters,
filter_size,
dilation,
std_mul=4.0,
dropout=0.0,
causal=False,
residual=True,
dtype="float32"):
super(Conv1DGLU, self).__init__(name_scope, dtype=dtype)
# conv spec
self.in_channels = in_channels
self.n_speakers = n_speakers
self.speaker_dim = speaker_dim
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
# weight init and dropout
self.std_mul = std_mul
self.dropout = dropout
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channes should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
std_mul=std_mul,
dropout=dropout,
dtype=dtype)
if n_speakers > 1:
assert (speaker_dim is not None
), "speaker embed should not be null in multi-speaker case"
self.fc = Conv1D(
self.full_name(),
speaker_dim,
num_filters,
filter_size=1,
dilation=1,
causal=False,
act="softsign",
dtype=dtype)
def forward(self, x, speaker_embed_bc1t=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = fluid.layers.dropout(
x, self.dropout, dropout_implementation="upscale_in_train")
x = self.conv(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc1t is not None:
sp = self.fc(speaker_embed_bc1t)
content = content + sp
# glu
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def add_input(self, x, speaker_embed_bc11=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual = x
# add step input and produce step output
x = fluid.layers.dropout(
x, self.dropout, dropout_implementation="upscale_in_train")
x = self.conv.add_input(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc11 is not None:
sp = self.fc(speaker_embed_bc11)
content = content + sp
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def Conv1DTranspose(name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size))
weight_init = fluid.initializer.NormalInitializer(scale=std)
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_init = fluid.initializer.ConstantInitializer(0.0)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1DTranspose(
name_scope,
in_channels,
num_filters,
filter_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def compute_position_embedding(rad):
# rad is a transposed radius, shape(embed_dim, n_vocab)
embed_dim, n_vocab = rad.shape
even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32"))
odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32"))
even_rads = fluid.layers.gather(rad, even_dims)
odd_rads = fluid.layers.gather(rad, odd_dims)
sines = fluid.layers.sin(even_rads)
cosines = fluid.layers.cos(odd_rads)
temp = fluid.layers.scatter(rad, even_dims, sines)
out = fluid.layers.scatter(temp, odd_dims, cosines)
out = fluid.layers.transpose(out, perm=[1, 0])
return out
def position_encoding_init(n_position,
d_pos_vec,
position_rate=1.0,
sinusoidal=True):
""" Init the sinusoid position encoding table """
# keep idx 0 for padding token position encoding zero vector
position_enc = np.array([[
position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec)
for i in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
if sinusoidal:
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return position_enc
class PositionEmbedding(dg.Layer):
def __init__(self,
name_scope,
n_position,
d_pos_vec,
position_rate=1.0,
is_sparse=False,
is_distributed=False,
param_attr=None,
max_norm=None,
padding_idx=None,
dtype="float32"):
super(PositionEmbedding, self).__init__(name_scope, dtype=dtype)
self.embed = dg.Embedding(
self.full_name(),
size=(n_position, d_pos_vec),
is_sparse=is_sparse,
is_distributed=is_distributed,
padding_idx=None,
param_attr=param_attr,
dtype=dtype)
self.set_weight(
position_encoding_init(
n_position,
d_pos_vec,
position_rate=position_rate,
sinusoidal=False).astype(dtype))
self._is_sparse = is_sparse
self._is_distributed = is_distributed
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
if self._remote_prefetch:
assert self._is_sparse is True and self._is_distributed is False
self._padding_idx = (-1 if padding_idx is None else padding_idx if
padding_idx >= 0 else (n_position + padding_idx))
self._position_rate = position_rate
self._max_norm = max_norm
self._dtype = dtype
def set_weight(self, array):
assert self.embed._w.shape == list(array.shape), "shape does not match"
self.embed._w._ivar.value().get_tensor().set(
array, fluid.framework._current_expected_place())
def forward(self, indices, speaker_position_rate=None):
"""
Args:
indices (Variable): Shape (B, T, 1), dtype: int64, position
indices, where B means the batch size, T means the time steps.
speaker_position_rate (Variable | float, optional), position
rate. It can be a float point number or a Variable with
shape (1,), then this speaker_position_rate is used for every
example. It can also be a Variable with shape (B, 1), which
contains a speaker position rate for each speaker.
Returns:
out (Variable): Shape(B, C_pos), position embedding, where C_pos
means position embedding size.
"""
rad = fluid.layers.transpose(self.embed._w, perm=[1, 0])
batch_size = indices.shape[0]
if speaker_position_rate is None:
weight = compute_position_embedding(rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif (np.isscalar(speaker_position_rate) or
isinstance(speaker_position_rate, fluid.framework.Variable) and
speaker_position_rate.shape == [1, 1]):
# # make a weight
# scale the weight (the operand for sin & cos)
if np.isscalar(speaker_position_rate):
scaled_rad = fluid.layers.scale(rad, speaker_position_rate)
else:
scaled_rad = fluid.layers.elementwise_mul(
rad, speaker_position_rate[0])
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif np.prod(speaker_position_rate.shape) > 1:
assert speaker_position_rate.shape == [batch_size, 1]
outputs = []
for i in range(batch_size):
rate = speaker_position_rate[i] # rate has shape [1]
scaled_rad = fluid.layers.elementwise_mul(rad, rate)
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(
self._dtype)
sequence = indices[i]
self._helper.append_op(
type="lookup_table",
inputs={"Ids": sequence,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx": -1,
})
outputs.append(out)
out = fluid.layers.stack(outputs)
return out
else:
raise Exception("Then you can just use position rate at init")
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册