diff --git a/paddlespeech/cli/download.py b/paddlespeech/cli/download.py index 0f09b6fad000f476f3bc38a851f982501a0232ba..ec72587470e8f0e211e453e3b2b2ea3d1f54f25b 100644 --- a/paddlespeech/cli/download.py +++ b/paddlespeech/cli/download.py @@ -86,7 +86,7 @@ def get_path_from_url(url, str: a local path to save downloaded models & weights & datasets. """ - from paddle.fluid.dygraph.parallel import ParallelEnv + from paddle.distributed import ParallelEnv assert _is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir diff --git a/paddlespeech/server/download.py b/paddlespeech/server/download.py deleted file mode 100644 index ea943dd8745c17cacdb0575a8552ba1a75ab4a7c..0000000000000000000000000000000000000000 --- a/paddlespeech/server/download.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import hashlib -import os -import os.path as osp -import shutil -import subprocess -import tarfile -import time -import zipfile - -import requests -from tqdm import tqdm - -from paddlespeech.cli.log import logger - -__all__ = ['get_path_from_url'] - -DOWNLOAD_RETRY_LIMIT = 3 - - -def _is_url(path): - """ - Whether path is URL. - Args: - path (string): URL string or not. - """ - return path.startswith('http://') or path.startswith('https://') - - -def _map_path(url, root_dir): - # parse path after download under root_dir - fname = osp.split(url)[-1] - fpath = fname - return osp.join(root_dir, fpath) - - -def _get_unique_endpoints(trainer_endpoints): - # Sorting is to avoid different environmental variables for each card - trainer_endpoints.sort() - ips = set() - unique_endpoints = set() - for endpoint in trainer_endpoints: - ip = endpoint.split(":")[0] - if ip in ips: - continue - ips.add(ip) - unique_endpoints.add(endpoint) - logger.info("unique_endpoints {}".format(unique_endpoints)) - return unique_endpoints - - -def get_path_from_url(url, - root_dir, - md5sum=None, - check_exist=True, - decompress=True, - method='get'): - """ Download from given url to root_dir. - if file or directory specified by url is exists under - root_dir, return the path directly, otherwise download - from url and decompress it, return the path. - Args: - url (str): download url - root_dir (str): root dir for downloading, it should be - WEIGHTS_HOME or DATASET_HOME - md5sum (str): md5 sum of download package - decompress (bool): decompress zip or tar file. Default is `True` - method (str): which download method to use. Support `wget` and `get`. Default is `get`. - Returns: - str: a local path to save downloaded models & weights & datasets. - """ - - from paddle.fluid.dygraph.parallel import ParallelEnv - - assert _is_url(url), "downloading from {} not a url".format(url) - # parse path after download to decompress under root_dir - fullpath = _map_path(url, root_dir) - # Mainly used to solve the problem of downloading data from different - # machines in the case of multiple machines. Different ips will download - # data, and the same ip will only download data once. - unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) - if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): - logger.info("Found {}".format(fullpath)) - else: - if ParallelEnv().current_endpoint in unique_endpoints: - fullpath = _download(url, root_dir, md5sum, method=method) - else: - while not os.path.exists(fullpath): - time.sleep(1) - - if ParallelEnv().current_endpoint in unique_endpoints: - if decompress and (tarfile.is_tarfile(fullpath) or - zipfile.is_zipfile(fullpath)): - fullpath = _decompress(fullpath) - - return fullpath - - -def _get_download(url, fullname): - # using requests.get method - fname = osp.basename(fullname) - try: - req = requests.get(url, stream=True) - except Exception as e: # requests.exceptions.ConnectionError - logger.info("Downloading {} from {} failed with exception {}".format( - fname, url, str(e))) - return False - - if req.status_code != 200: - raise RuntimeError("Downloading from {} failed with code " - "{}!".format(url, req.status_code)) - - # For protecting download interupted, download to - # tmp_fullname firstly, move tmp_fullname to fullname - # after download finished - tmp_fullname = fullname + "_tmp" - total_size = req.headers.get('content-length') - with open(tmp_fullname, 'wb') as f: - if total_size: - with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: - for chunk in req.iter_content(chunk_size=1024): - f.write(chunk) - pbar.update(1) - else: - for chunk in req.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - shutil.move(tmp_fullname, fullname) - - return fullname - - -def _wget_download(url, fullname): - # using wget to download url - tmp_fullname = fullname + "_tmp" - # –user-agent - command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT, - url) - subprc = subprocess.Popen( - command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - _ = subprc.communicate() - - if subprc.returncode != 0: - raise RuntimeError( - '{} failed. Please make sure `wget` is installed or {} exists'. - format(command, url)) - - shutil.move(tmp_fullname, fullname) - - return fullname - - -_download_methods = { - 'get': _get_download, - 'wget': _wget_download, -} - - -def _download(url, path, md5sum=None, method='get'): - """ - Download from url, save to path. - url (str): download url - path (str): download to given path - md5sum (str): md5 sum of download package - method (str): which download method to use. Support `wget` and `get`. Default is `get`. - """ - assert method in _download_methods, 'make sure `{}` implemented'.format( - method) - - if not osp.exists(path): - os.makedirs(path) - - fname = osp.split(url)[-1] - fullname = osp.join(path, fname) - retry_cnt = 0 - - logger.info("Downloading {} from {}".format(fname, url)) - while not (osp.exists(fullname) and _md5check(fullname, md5sum)): - if retry_cnt < DOWNLOAD_RETRY_LIMIT: - retry_cnt += 1 - else: - raise RuntimeError("Download from {} failed. " - "Retry limit reached".format(url)) - - if not _download_methods[method](url, fullname): - time.sleep(1) - continue - - return fullname - - -def _md5check(fullname, md5sum=None): - if md5sum is None: - return True - - logger.info("File {} md5 checking...".format(fullname)) - md5 = hashlib.md5() - with open(fullname, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5.update(chunk) - calc_md5sum = md5.hexdigest() - - if calc_md5sum != md5sum: - logger.info("File {} md5 check failed, {}(calc) != " - "{}(base)".format(fullname, calc_md5sum, md5sum)) - return False - return True - - -def _decompress(fname): - """ - Decompress for zip and tar file - """ - logger.info("Decompressing {}...".format(fname)) - - # For protecting decompressing interupted, - # decompress to fpath_tmp directory firstly, if decompress - # successed, move decompress files to fpath and delete - # fpath_tmp and remove download compress file. - - if tarfile.is_tarfile(fname): - uncompressed_path = _uncompress_file_tar(fname) - elif zipfile.is_zipfile(fname): - uncompressed_path = _uncompress_file_zip(fname) - else: - raise TypeError("Unsupport compress file type {}".format(fname)) - - return uncompressed_path - - -def _uncompress_file_zip(filepath): - files = zipfile.ZipFile(filepath, 'r') - file_list = files.namelist() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) - - for item in file_list: - files.extract(item, file_dir) - - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] - uncompressed_path = os.path.join(file_dir, rootpath) - - for item in file_list: - files.extract(item, file_dir) - - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) - - files.close() - - return uncompressed_path - - -def _uncompress_file_tar(filepath, mode="r:*"): - files = tarfile.open(filepath, mode) - file_list = files.getnames() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) - - files.close() - - return uncompressed_path - - -def _is_a_single_file(file_list): - if len(file_list) == 1 and file_list[0].find(os.sep) < -1: - return True - return False - - -def _is_a_single_dir(file_list): - new_file_list = [] - for file_path in file_list: - if '/' in file_path: - file_path = file_path.replace('/', os.sep) - elif '\\' in file_path: - file_path = file_path.replace('\\', os.sep) - new_file_list.append(file_path) - - file_name = new_file_list[0].split(os.sep)[0] - for i in range(1, len(new_file_list)): - if file_name != new_file_list[i].split(os.sep)[0]: - return False - return True diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py index ae3e9c6aa913174bc866cd071e6eaa28cd9c8059..13f2ddf6ee00477290d7e66f662c4ab8d2862f5f 100644 --- a/paddlespeech/server/util.py +++ b/paddlespeech/server/util.py @@ -29,9 +29,9 @@ import requests import yaml from paddle.framework import load -from . import download from .entry import client_commands from .entry import server_commands +from paddlespeech.cli import download try: from .. import __version__ except ImportError: diff --git a/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py b/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py index e30a3fe1a5947c7046501ef26fe069656c1fcb31..b20fda1f7445ca05f6c8fe66866bbfd1f4d8ad91 100644 --- a/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py +++ b/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py @@ -16,7 +16,6 @@ from pathlib import Path import paddle from paddle import distributed as dist -from paddle.fluid.layers import huber_loss from paddle.io import DataLoader from paddle.nn import functional as F from paddle.nn import Layer @@ -78,8 +77,11 @@ class SpeedySpeechUpdater(StandardUpdater): target_durations.astype(predicted_durations.dtype), paddle.to_tensor([1.0])) duration_loss = weighted_mean( - huber_loss( - predicted_durations, paddle.log(target_durations), delta=1.0), + F.smooth_l1_loss( + predicted_durations, + paddle.log(target_durations), + delta=1.0, + reduction='none', ), text_mask, ) # ssim loss @@ -146,8 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator): target_durations.astype(predicted_durations.dtype), paddle.to_tensor([1.0])) duration_loss = weighted_mean( - huber_loss( - predicted_durations, paddle.log(target_durations), delta=1.0), + F.smooth_l1_loss( + predicted_durations, + paddle.log(target_durations), + delta=1.0, + reduction='none', ), text_mask, ) # ssim loss diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index db31bcfbb4361281df49d3afeb00dfb97c59d7f9..86dffbe91af8a8008e9d239bbecb533e365d8dd5 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -17,7 +17,6 @@ import librosa import numpy as np import paddle from paddle import nn -from paddle.fluid.layers import sequence_mask from paddle.nn import functional as F from scipy import signal @@ -160,7 +159,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): return x -# Loss for new Tacotron2 +# Loss for Tacotron2 class GuidedAttentionLoss(nn.Layer): """Guided attention loss function module. @@ -428,41 +427,6 @@ class Tacotron2Loss(nn.Layer): return l1_loss, mse_loss, bce_loss -# Loss for Tacotron2 -def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None): - """Build that W matrix. shape(B, T_dec, T_enc) - W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2)) - - See also: - Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969. - """ - dtype = dtype or paddle.get_default_dtype() - dec_pos = paddle.arange(0, N).astype(dtype) / dec_lens.unsqueeze( - -1) # n/N # shape(B, T_dec) - enc_pos = paddle.arange(0, T).astype(dtype) / enc_lens.unsqueeze( - -1) # t/T # shape(B, T_enc) - W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) - enc_pos.unsqueeze(1))**2 / - (2 * g**2)) - - dec_mask = sequence_mask(dec_lens, maxlen=N) - enc_mask = sequence_mask(enc_lens, maxlen=T) - mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1) - mask = paddle.cast(mask, W.dtype) - - W *= mask - return W - - -def guided_attention_loss(attention_weight, dec_lens, enc_lens, g): - """Guided attention loss, masked to excluded padding parts.""" - _, N, T = attention_weight.shape - W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype) - - total_tokens = (dec_lens * enc_lens).astype(W.dtype) - loss = paddle.mean(paddle.sum(W * attention_weight, [1, 2]) / total_tokens) - return loss - - # Losses for GAN Vocoder def stft(x, fft_size, diff --git a/paddlespeech/t2s/utils/profile.py b/paddlespeech/t2s/utils/profile.py deleted file mode 100644 index 5f9b49526c6f1ef036724efdb0deab73ebab9c16..0000000000000000000000000000000000000000 --- a/paddlespeech/t2s/utils/profile.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import contextmanager - -import paddle -from paddle.framework import core -from paddle.framework import CUDAPlace - - -def synchronize(): - """Trigger cuda synchronization for better timing.""" - place = paddle.fluid.framework._current_expected_place() - if isinstance(place, CUDAPlace): - paddle.fluid.core._cuda_synchronize(place) - - -@contextmanager -def nvtx_span(name): - try: - core.nvprof_nvtx_push(name) - yield - finally: - core.nvprof_nvtx_pop() diff --git a/paddlespeech/t2s/utils/timeline.py b/paddlespeech/t2s/utils/timeline.py deleted file mode 100644 index 0a5509dbe4530536df81313f7104afcd528a646f..0000000000000000000000000000000000000000 --- a/paddlespeech/t2s/utils/timeline.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import json - -import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 -import six - -parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - '--profile_path', - type=str, - default='', - help='Input profile file name. If there are multiple file, the format ' - 'should be trainer1=file1,trainer2=file2,ps=file3') -parser.add_argument( - '--timeline_path', type=str, default='', help='Output timeline file name.') -args = parser.parse_args() - - -class _ChromeTraceFormatter(object): - def __init__(self): - self._events = [] - self._metadata = [] - - def _create_event(self, ph, category, name, pid, tid, timestamp): - """Creates a new Chrome Trace event. - - For details of the file format, see: - https://github.com/catapult-project/catapult/blob/master/tracing/README.md - - Args: - ph: The type of event - usually a single character. - category: The event category as a string. - name: The event name as a string. - pid: Identifier of the process generating this event as an integer. - tid: Identifier of the thread generating this event as an integer. - timestamp: The timestamp of this event as a long integer. - - Returns: - A JSON compatible event object. - """ - event = {} - event['ph'] = ph - event['cat'] = category - event['name'] = name.replace("ParallelExecutor::Run/", "") - event['pid'] = pid - event['tid'] = tid - event['ts'] = timestamp - return event - - def emit_pid(self, name, pid): - """Adds a process metadata event to the trace. - - Args: - name: The process name as a string. - pid: Identifier of the process as an integer. - """ - event = {} - event['name'] = 'process_name' - event['ph'] = 'M' - event['pid'] = pid - event['args'] = {'name': name} - self._metadata.append(event) - - def emit_region(self, timestamp, duration, pid, tid, category, name, args): - """Adds a region event to the trace. - - Args: - timestamp: The start timestamp of this region as a long integer. - duration: The duration of this region as a long integer. - pid: Identifier of the process generating this event as an integer. - tid: Identifier of the thread generating this event as an integer. - category: The event category as a string. - name: The event name as a string. - args: A JSON-compatible dictionary of event arguments. - """ - event = self._create_event('X', category, name, pid, tid, timestamp) - event['dur'] = duration - event['args'] = args - self._events.append(event) - - def emit_counter(self, category, name, pid, timestamp, counter, value): - """Emits a record for a single counter. - - Args: - category: The event category as string - name: The event name as string - pid: Identifier of the process generating this event as integer - timestamp: The timestamps of this event as long integer - counter: Name of the counter as string - value: Value of the counter as integer - tid: Thread id of the allocation as integer - """ - event = self._create_event('C', category, name, pid, 0, timestamp) - event['args'] = {counter: value} - self._events.append(event) - - def format_to_string(self, pretty=False): - """Formats the chrome trace to a string. - - Args: - pretty: (Optional.) If True, produce human-readable JSON output. - - Returns: - A JSON-formatted string in Chrome Trace format. - """ - trace = {} - trace['traceEvents'] = self._metadata + self._events - if pretty: - return json.dumps(trace, indent=4, separators=(',', ': ')) - else: - return json.dumps(trace, separators=(',', ':')) - - -class Timeline(object): - def __init__(self, profile_dict): - self._profile_dict = profile_dict - self._pid = 0 - self._devices = dict() - self._mem_devices = dict() - self._chrome_trace = _ChromeTraceFormatter() - - def _allocate_pid(self): - cur_pid = self._pid - self._pid += 1 - return cur_pid - - def _allocate_pids(self): - for k, profile_pb in six.iteritems(self._profile_dict): - for event in profile_pb.events: - if event.type == profiler_pb2.Event.CPU: - if (k, event.device_id, "CPU") not in self._devices: - pid = self._allocate_pid() - self._devices[(k, event.device_id, "CPU")] = pid - # -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy) - if event.device_id == -1: - self._chrome_trace.emit_pid("%s:cuda_api" % k, pid) - else: - self._chrome_trace.emit_pid( - "%s:cpu:block:%d" % (k, event.device_id), pid) - elif event.type == profiler_pb2.Event.GPUKernel: - if (k, event.device_id, "GPUKernel") not in self._devices: - pid = self._allocate_pid() - self._devices[(k, event.device_id, "GPUKernel")] = pid - self._chrome_trace.emit_pid("%s:gpu:%d" % - (k, event.device_id), pid) - if not hasattr(profile_pb, "mem_events"): - continue - for mevent in profile_pb.mem_events: - if mevent.place == profiler_pb2.MemEvent.CUDAPlace: - if (k, mevent.device_id, "GPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, "GPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:gpu:%d" % (k, mevent.device_id), - pid) - elif mevent.place == profiler_pb2.MemEvent.CPUPlace: - if (k, mevent.device_id, "CPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, "CPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cpu:%d" % (k, mevent.device_id), - pid) - elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace: - if (k, mevent.device_id, - "CUDAPinnedPlace") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, - "CUDAPinnedPlace")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cudapinnedplace:%d" % - (k, mevent.device_id), pid) - elif mevent.place == profiler_pb2.MemEvent.NPUPlace: - if (k, mevent.device_id, "NPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, "NPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:npu:%d" % (k, mevent.device_id), - pid) - if (k, 0, "CPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "CPU")] = pid - self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" % - (k, 0), pid) - if (k, 0, "GPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "GPU")] = pid - self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" % - (k, 0), pid) - if (k, 0, "CUDAPinnedPlace") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cudapinnedplace:%d" % (k, 0), pid) - if (k, 0, "NPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "NPU")] = pid - self._chrome_trace.emit_pid("memory usage on %s:npu:%d" % - (k, 0), pid) - - def _allocate_events(self): - for k, profile_pb in six.iteritems(self._profile_dict): - for event in profile_pb.events: - if event.type == profiler_pb2.Event.CPU: - type = "CPU" - elif event.type == profiler_pb2.Event.GPUKernel: - type = "GPUKernel" - pid = self._devices[(k, event.device_id, type)] - args = {'name': event.name} - if event.memcopy.bytes > 0: - args['mem_bytes'] = event.memcopy.bytes - if hasattr(event, "detail_info") and event.detail_info: - args['detail_info'] = event.detail_info - # TODO(panyx0718): Chrome tracing only handles ms. However, some - # ops takes micro-seconds. Hence, we keep the ns here. - self._chrome_trace.emit_region( - event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, - event.sub_device_id, 'Op', event.name, args) - - def _allocate_memory_event(self): - if not hasattr(profiler_pb2, "MemEvent"): - return - place_to_str = { - profiler_pb2.MemEvent.CPUPlace: "CPU", - profiler_pb2.MemEvent.CUDAPlace: "GPU", - profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace", - profiler_pb2.MemEvent.NPUPlace: "NPU" - } - for k, profile_pb in six.iteritems(self._profile_dict): - mem_list = [] - end_profiler = 0 - for mevent in profile_pb.mem_events: - crt_info = dict() - crt_info['time'] = mevent.start_ns - crt_info['size'] = mevent.bytes - if mevent.place in place_to_str: - place = place_to_str[mevent.place] - else: - place = "UnDefine" - crt_info['place'] = place - pid = self._mem_devices[(k, mevent.device_id, place)] - crt_info['pid'] = pid - crt_info['thread_id'] = mevent.thread_id - crt_info['device_id'] = mevent.device_id - mem_list.append(crt_info) - crt_info = dict() - crt_info['place'] = place - crt_info['pid'] = pid - crt_info['thread_id'] = mevent.thread_id - crt_info['device_id'] = mevent.device_id - crt_info['time'] = mevent.end_ns - crt_info['size'] = -mevent.bytes - mem_list.append(crt_info) - end_profiler = max(end_profiler, crt_info['time']) - mem_list.sort(key=lambda tmp: (tmp.get('time', 0))) - i = 0 - total_size = 0 - while i < len(mem_list): - total_size += mem_list[i]['size'] - while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[ - i + 1]['time']: - total_size += mem_list[i + 1]['size'] - i += 1 - - self._chrome_trace.emit_counter( - "Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'], - 0, total_size) - i += 1 - - def generate_chrome_trace(self): - self._allocate_pids() - self._allocate_events() - self._allocate_memory_event() - return self._chrome_trace.format_to_string() - - -profile_path = '/tmp/profile' -if args.profile_path: - profile_path = args.profile_path -timeline_path = '/tmp/timeline' -if args.timeline_path: - timeline_path = args.timeline_path - -profile_paths = profile_path.split(',') -profile_dict = dict() -if len(profile_paths) == 1: - with open(profile_path, 'rb') as f: - profile_s = f.read() - profile_pb = profiler_pb2.Profile() - profile_pb.ParseFromString(profile_s) - profile_dict['trainer'] = profile_pb -else: - for profile_path in profile_paths: - k, v = profile_path.split('=') - with open(v, 'rb') as f: - profile_s = f.read() - profile_pb = profiler_pb2.Profile() - profile_pb.ParseFromString(profile_s) - profile_dict[k] = profile_pb - -tl = Timeline(profile_dict) -with open(timeline_path, 'w') as f: - f.write(tl.generate_chrome_trace())