提交 2933eb7e 编写于 作者: 走神的阿圆's avatar 走神的阿圆

replace add_scalar to add_scalars

上级 bf6d9ef0
......@@ -29,41 +29,6 @@ from parakeet.models.transformer_tts import TransformerTTS
from parakeet.utils import io
def add_scalars(self, main_tag, tag_scalar_dict, step, walltime=None):
"""Add scalars to vdl record file.
Args:
main_tag (string): The parent name for the tags
tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
step (int): Step of scalars
walltime (float): Wall time of scalars.
Example:
for index in range(1, 101):
writer.add_scalar(tag="train/loss", value=index*0.2, step=index)
writer.add_scalar(tag="train/lr", value=index*0.5, step=index)
"""
import time
from visualdl.writer.record_writer import RecordFileWriter
from visualdl.component.base_component import scalar
fw_logdir = self.logdir
walltime = round(time.time()) if walltime is None else walltime
for tag, value in tag_scalar_dict.items():
tag = os.path.join(fw_logdir, main_tag, tag)
if '%' in tag:
raise RuntimeError("% can't appear in tag!")
if tag in self._all_writers:
fw = self._all_writers[tag]
else:
fw = RecordFileWriter(
logdir=tag,
max_queue_size=self._max_queue,
flush_secs=self._flush_secs,
filename_suffix=self._filename_suffix)
self._all_writers.update({tag: fw})
fw.add_record(
scalar(tag=main_tag, value=value, step=step, walltime=walltime))
def add_config_options_to_parser(parser):
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
......@@ -99,7 +64,6 @@ def main(args):
writer = LogWriter(os.path.join(args.output,
'log')) if local_rank == 0 else None
writer.add_scalars = add_scalars
fluid.enable_dygraph(place)
network_cfg = cfg['network']
......@@ -167,23 +131,28 @@ def main(args):
loss = loss + stop_loss
if local_rank == 0:
writer.add_scalars('training_loss', {
'mel_loss': mel_loss.numpy(),
'post_mel_loss': post_mel_loss.numpy()
}, global_step)
writer.add_scalar('training_loss/mel_loss',
mel_loss.numpy(),
global_step)
writer.add_scalar('training_loss/post_mel_loss',
post_mel_loss.numpy(),
global_step)
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
if parallel:
writer.add_scalars('alphas', {
'encoder_alpha': model._layers.encoder.alpha.numpy(),
'decoder_alpha': model._layers.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('alphas/encoder_alpha',
model._layers.encoder.alpha.numpy(),
global_step)
writer.add_scalar('alphas/decoder_alpha',
model._layers.decoder.alpha.numpy(),
global_step)
else:
writer.add_scalars('alphas', {
'encoder_alpha': model.encoder.alpha.numpy(),
'decoder_alpha': model.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('alphas/encoder_alpha',
model.encoder.alpha.numpy(),
global_step)
writer.add_scalar('alphas/decoder_alpha',
model.decoder.alpha.numpy(),
global_step)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
......
......@@ -121,7 +121,7 @@ def main(args):
model.clear_gradients()
if local_rank == 0:
writer.add_scalars('training_loss', {'loss': loss.numpy(), },
writer.add_scalar('training_loss/loss', loss.numpy(),
global_step)
# save checkpoint
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册