提交 c1b512c5 编写于 作者: 小湉湉's avatar 小湉湉

rm fluid in tts, test=tts

上级 e6ddb0cc
......@@ -86,7 +86,7 @@ def get_path_from_url(url,
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import ParallelEnv
assert _is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import os
import os.path as osp
import shutil
import subprocess
import tarfile
import time
import zipfile
import requests
from tqdm import tqdm
from paddlespeech.cli.log import logger
__all__ = ['get_path_from_url']
DOWNLOAD_RETRY_LIMIT = 3
def _is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True,
method='get'):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert _is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
......@@ -29,9 +29,9 @@ import requests
import yaml
from paddle.framework import load
from . import download
from .entry import client_commands
from .entry import server_commands
from paddlespeech.cli import download
try:
from .. import __version__
except ImportError:
......
......@@ -16,7 +16,6 @@ from pathlib import Path
import paddle
from paddle import distributed as dist
from paddle.fluid.layers import huber_loss
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle.nn import Layer
......@@ -78,8 +77,11 @@ class SpeedySpeechUpdater(StandardUpdater):
target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0]))
duration_loss = weighted_mean(
huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0),
F.smooth_l1_loss(
predicted_durations,
paddle.log(target_durations),
delta=1.0,
reduction='none', ),
text_mask, )
# ssim loss
......@@ -146,8 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator):
target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0]))
duration_loss = weighted_mean(
huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0),
F.smooth_l1_loss(
predicted_durations,
paddle.log(target_durations),
delta=1.0,
reduction='none', ),
text_mask, )
# ssim loss
......
......@@ -17,7 +17,6 @@ import librosa
import numpy as np
import paddle
from paddle import nn
from paddle.fluid.layers import sequence_mask
from paddle.nn import functional as F
from scipy import signal
......@@ -160,7 +159,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
return x
# Loss for new Tacotron2
# Loss for Tacotron2
class GuidedAttentionLoss(nn.Layer):
"""Guided attention loss function module.
......@@ -428,41 +427,6 @@ class Tacotron2Loss(nn.Layer):
return l1_loss, mse_loss, bce_loss
# Loss for Tacotron2
def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None):
"""Build that W matrix. shape(B, T_dec, T_enc)
W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2))
See also:
Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969.
"""
dtype = dtype or paddle.get_default_dtype()
dec_pos = paddle.arange(0, N).astype(dtype) / dec_lens.unsqueeze(
-1) # n/N # shape(B, T_dec)
enc_pos = paddle.arange(0, T).astype(dtype) / enc_lens.unsqueeze(
-1) # t/T # shape(B, T_enc)
W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) - enc_pos.unsqueeze(1))**2 /
(2 * g**2))
dec_mask = sequence_mask(dec_lens, maxlen=N)
enc_mask = sequence_mask(enc_lens, maxlen=T)
mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1)
mask = paddle.cast(mask, W.dtype)
W *= mask
return W
def guided_attention_loss(attention_weight, dec_lens, enc_lens, g):
"""Guided attention loss, masked to excluded padding parts."""
_, N, T = attention_weight.shape
W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype)
total_tokens = (dec_lens * enc_lens).astype(W.dtype)
loss = paddle.mean(paddle.sum(W * attention_weight, [1, 2]) / total_tokens)
return loss
# Losses for GAN Vocoder
def stft(x,
fft_size,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import paddle
from paddle.framework import core
from paddle.framework import CUDAPlace
def synchronize():
"""Trigger cuda synchronization for better timing."""
place = paddle.fluid.framework._current_expected_place()
if isinstance(place, CUDAPlace):
paddle.fluid.core._cuda_synchronize(place)
@contextmanager
def nvtx_span(name):
try:
core.nvprof_nvtx_push(name)
yield
finally:
core.nvprof_nvtx_pop()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
import six
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--profile_path',
type=str,
default='',
help='Input profile file name. If there are multiple file, the format '
'should be trainer1=file1,trainer2=file2,ps=file3')
parser.add_argument(
'--timeline_path', type=str, default='', help='Output timeline file name.')
args = parser.parse_args()
class _ChromeTraceFormatter(object):
def __init__(self):
self._events = []
self._metadata = []
def _create_event(self, ph, category, name, pid, tid, timestamp):
"""Creates a new Chrome Trace event.
For details of the file format, see:
https://github.com/catapult-project/catapult/blob/master/tracing/README.md
Args:
ph: The type of event - usually a single character.
category: The event category as a string.
name: The event name as a string.
pid: Identifier of the process generating this event as an integer.
tid: Identifier of the thread generating this event as an integer.
timestamp: The timestamp of this event as a long integer.
Returns:
A JSON compatible event object.
"""
event = {}
event['ph'] = ph
event['cat'] = category
event['name'] = name.replace("ParallelExecutor::Run/", "")
event['pid'] = pid
event['tid'] = tid
event['ts'] = timestamp
return event
def emit_pid(self, name, pid):
"""Adds a process metadata event to the trace.
Args:
name: The process name as a string.
pid: Identifier of the process as an integer.
"""
event = {}
event['name'] = 'process_name'
event['ph'] = 'M'
event['pid'] = pid
event['args'] = {'name': name}
self._metadata.append(event)
def emit_region(self, timestamp, duration, pid, tid, category, name, args):
"""Adds a region event to the trace.
Args:
timestamp: The start timestamp of this region as a long integer.
duration: The duration of this region as a long integer.
pid: Identifier of the process generating this event as an integer.
tid: Identifier of the thread generating this event as an integer.
category: The event category as a string.
name: The event name as a string.
args: A JSON-compatible dictionary of event arguments.
"""
event = self._create_event('X', category, name, pid, tid, timestamp)
event['dur'] = duration
event['args'] = args
self._events.append(event)
def emit_counter(self, category, name, pid, timestamp, counter, value):
"""Emits a record for a single counter.
Args:
category: The event category as string
name: The event name as string
pid: Identifier of the process generating this event as integer
timestamp: The timestamps of this event as long integer
counter: Name of the counter as string
value: Value of the counter as integer
tid: Thread id of the allocation as integer
"""
event = self._create_event('C', category, name, pid, 0, timestamp)
event['args'] = {counter: value}
self._events.append(event)
def format_to_string(self, pretty=False):
"""Formats the chrome trace to a string.
Args:
pretty: (Optional.) If True, produce human-readable JSON output.
Returns:
A JSON-formatted string in Chrome Trace format.
"""
trace = {}
trace['traceEvents'] = self._metadata + self._events
if pretty:
return json.dumps(trace, indent=4, separators=(',', ': '))
else:
return json.dumps(trace, separators=(',', ':'))
class Timeline(object):
def __init__(self, profile_dict):
self._profile_dict = profile_dict
self._pid = 0
self._devices = dict()
self._mem_devices = dict()
self._chrome_trace = _ChromeTraceFormatter()
def _allocate_pid(self):
cur_pid = self._pid
self._pid += 1
return cur_pid
def _allocate_pids(self):
for k, profile_pb in six.iteritems(self._profile_dict):
for event in profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
if (k, event.device_id, "CPU") not in self._devices:
pid = self._allocate_pid()
self._devices[(k, event.device_id, "CPU")] = pid
# -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy)
if event.device_id == -1:
self._chrome_trace.emit_pid("%s:cuda_api" % k, pid)
else:
self._chrome_trace.emit_pid(
"%s:cpu:block:%d" % (k, event.device_id), pid)
elif event.type == profiler_pb2.Event.GPUKernel:
if (k, event.device_id, "GPUKernel") not in self._devices:
pid = self._allocate_pid()
self._devices[(k, event.device_id, "GPUKernel")] = pid
self._chrome_trace.emit_pid("%s:gpu:%d" %
(k, event.device_id), pid)
if not hasattr(profile_pb, "mem_events"):
continue
for mevent in profile_pb.mem_events:
if mevent.place == profiler_pb2.MemEvent.CUDAPlace:
if (k, mevent.device_id, "GPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id, "GPU")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:gpu:%d" % (k, mevent.device_id),
pid)
elif mevent.place == profiler_pb2.MemEvent.CPUPlace:
if (k, mevent.device_id, "CPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id, "CPU")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:cpu:%d" % (k, mevent.device_id),
pid)
elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace:
if (k, mevent.device_id,
"CUDAPinnedPlace") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id,
"CUDAPinnedPlace")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:cudapinnedplace:%d" %
(k, mevent.device_id), pid)
elif mevent.place == profiler_pb2.MemEvent.NPUPlace:
if (k, mevent.device_id, "NPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id, "NPU")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:npu:%d" % (k, mevent.device_id),
pid)
if (k, 0, "CPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "CPU")] = pid
self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" %
(k, 0), pid)
if (k, 0, "GPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "GPU")] = pid
self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" %
(k, 0), pid)
if (k, 0, "CUDAPinnedPlace") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:cudapinnedplace:%d" % (k, 0), pid)
if (k, 0, "NPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "NPU")] = pid
self._chrome_trace.emit_pid("memory usage on %s:npu:%d" %
(k, 0), pid)
def _allocate_events(self):
for k, profile_pb in six.iteritems(self._profile_dict):
for event in profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
type = "CPU"
elif event.type == profiler_pb2.Event.GPUKernel:
type = "GPUKernel"
pid = self._devices[(k, event.device_id, type)]
args = {'name': event.name}
if event.memcopy.bytes > 0:
args['mem_bytes'] = event.memcopy.bytes
if hasattr(event, "detail_info") and event.detail_info:
args['detail_info'] = event.detail_info
# TODO(panyx0718): Chrome tracing only handles ms. However, some
# ops takes micro-seconds. Hence, we keep the ns here.
self._chrome_trace.emit_region(
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
event.sub_device_id, 'Op', event.name, args)
def _allocate_memory_event(self):
if not hasattr(profiler_pb2, "MemEvent"):
return
place_to_str = {
profiler_pb2.MemEvent.CPUPlace: "CPU",
profiler_pb2.MemEvent.CUDAPlace: "GPU",
profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace",
profiler_pb2.MemEvent.NPUPlace: "NPU"
}
for k, profile_pb in six.iteritems(self._profile_dict):
mem_list = []
end_profiler = 0
for mevent in profile_pb.mem_events:
crt_info = dict()
crt_info['time'] = mevent.start_ns
crt_info['size'] = mevent.bytes
if mevent.place in place_to_str:
place = place_to_str[mevent.place]
else:
place = "UnDefine"
crt_info['place'] = place
pid = self._mem_devices[(k, mevent.device_id, place)]
crt_info['pid'] = pid
crt_info['thread_id'] = mevent.thread_id
crt_info['device_id'] = mevent.device_id
mem_list.append(crt_info)
crt_info = dict()
crt_info['place'] = place
crt_info['pid'] = pid
crt_info['thread_id'] = mevent.thread_id
crt_info['device_id'] = mevent.device_id
crt_info['time'] = mevent.end_ns
crt_info['size'] = -mevent.bytes
mem_list.append(crt_info)
end_profiler = max(end_profiler, crt_info['time'])
mem_list.sort(key=lambda tmp: (tmp.get('time', 0)))
i = 0
total_size = 0
while i < len(mem_list):
total_size += mem_list[i]['size']
while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[
i + 1]['time']:
total_size += mem_list[i + 1]['size']
i += 1
self._chrome_trace.emit_counter(
"Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'],
0, total_size)
i += 1
def generate_chrome_trace(self):
self._allocate_pids()
self._allocate_events()
self._allocate_memory_event()
return self._chrome_trace.format_to_string()
profile_path = '/tmp/profile'
if args.profile_path:
profile_path = args.profile_path
timeline_path = '/tmp/timeline'
if args.timeline_path:
timeline_path = args.timeline_path
profile_paths = profile_path.split(',')
profile_dict = dict()
if len(profile_paths) == 1:
with open(profile_path, 'rb') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict['trainer'] = profile_pb
else:
for profile_path in profile_paths:
k, v = profile_path.split('=')
with open(v, 'rb') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict[k] = profile_pb
tl = Timeline(profile_dict)
with open(timeline_path, 'w') as f:
f.write(tl.generate_chrome_trace())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册