From 365975fdbb38670336096a07870b589307ab009f Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 11 Apr 2022 20:23:33 +0800 Subject: [PATCH] [cherry-pick] Refine statistic table and bug fix (#41581) * Refine statistic table (#41524) * Add get profiler from config (#41532) * no * maintain old profiler * add get profiler from serialization config * add unit test * improve coverage * fix * Revert "improve coverage" This reverts commit 4a980bfda48adadee551d0e1c5740bc5b7389200. * fix unit * fix * fix --- .../fluid/tests/unittests/test_newprofiler.py | 141 ++++++++++ .../unittests/test_profiler_statistic.py | 88 ++++--- python/paddle/profiler/profiler.py | 72 ++++++ python/paddle/profiler/profiler_statistic.py | 241 ++++++++++++------ 4 files changed, 428 insertions(+), 114 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_newprofiler.py b/python/paddle/fluid/tests/unittests/test_newprofiler.py index 0088687b12..ac2b205e61 100755 --- a/python/paddle/fluid/tests/unittests/test_newprofiler.py +++ b/python/paddle/fluid/tests/unittests/test_newprofiler.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import tempfile import paddle import paddle.profiler as profiler @@ -138,6 +139,146 @@ class TestNvprof(unittest.TestCase): y = x / 2.0 +class TestGetProfiler(unittest.TestCase): + def test_getprofiler(self): + config_content = ''' + { + "targets": ["CPU"], + "scheduler": [3,4], + "on_trace_ready": { + "export_chrome_tracing":{ + "module": "paddle.profiler", + "use_direct": false, + "args": [], + "kwargs": { + "dir_name": "testdebug/" + } + } + }, + "timer_only": false + } + ''' + filehandle = tempfile.NamedTemporaryFile(mode='w') + filehandle.write(config_content) + filehandle.flush() + import paddle.profiler.profiler as profiler + profiler = profiler.get_profiler(filehandle.name) + x_value = np.random.randn(2, 3, 3) + x = paddle.to_tensor( + x_value, stop_gradient=False, place=paddle.CPUPlace()) + with profiler: + for i in range(5): + y = x / 2.0 + ones_like_y = paddle.ones_like(y) + profiler.step() + + # below tests are just for coverage, wrong config + # test use_direct + config_content = ''' + { + "targets": ["Cpu", "Gpu"], + "scheduler": { + "make_scheduler":{ + "module": "paddle.profiler", + "use_direct": true, + "args": [], + "kwargs": {} + } + }, + "on_trace_ready": { + "export_chrome_tracing":{ + "module": "paddle.profiler1", + "use_direct": true, + "args": [], + "kwargs": { + } + } + }, + "timer_only": false + } + ''' + filehandle = tempfile.NamedTemporaryFile(mode='w') + filehandle.write(config_content) + filehandle.flush() + import paddle.profiler.profiler as profiler + try: + profiler = profiler.get_profiler(filehandle.name) + except: + pass + + # test scheduler + config_content = ''' + { + "targets": ["Cpu", "Gpu"], + "scheduler": { + "make_scheduler":{ + "module": "paddle.profiler", + "use_direct": false, + "args": [], + "kwargs": { + "closed": 1, + "ready": 1, + "record": 2 + } + } + }, + "on_trace_ready": { + "export_chrome_tracing":{ + "module": "paddle.profiler", + "use_direct": true, + "args": [], + "kwargs": { + } + } + }, + "timer_only": false + } + ''' + filehandle = tempfile.NamedTemporaryFile(mode='w') + filehandle.write(config_content) + filehandle.flush() + import paddle.profiler.profiler as profiler + profiler = profiler.get_profiler(filehandle.name) + + # test exception + config_content = ''' + { + "targets": [1], + "scheduler": { + "make_scheduler1":{ + "module": "paddle.profiler", + "use_direct": false, + "args": [], + "kwargs": { + "closed": 1, + "ready": 1, + "record": 2 + } + } + }, + "on_trace_ready": { + "export_chrome_tracing1":{ + "module": "paddle.profiler", + "use_direct": false, + "args": [], + "kwargs": { + "dir_name": "testdebug/" + } + } + }, + "timer_only": 1 + } + ''' + filehandle = tempfile.NamedTemporaryFile(mode='w') + filehandle.write(config_content) + filehandle.flush() + import paddle.profiler.profiler as profiler + profiler = profiler.get_profiler(filehandle.name) + # test path error + import paddle.profiler.profiler as profiler + profiler = profiler.get_profiler('nopath.json') + + class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples diff --git a/python/paddle/fluid/tests/unittests/test_profiler_statistic.py b/python/paddle/fluid/tests/unittests/test_profiler_statistic.py index adc42d0447..dc944e68c7 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler_statistic.py +++ b/python/paddle/fluid/tests/unittests/test_profiler_statistic.py @@ -185,20 +185,22 @@ class TestProfilerStatistic(unittest.TestCase): profiler.TracerEventType.Communication), 5) self.assertEqual(len(event_summary.items), 2) self.assertEqual(len(event_summary.userdefined_items), 1) - self.assertEqual(len(event_summary.model_perspective_items), 3) + self.assertEqual(len(event_summary.model_perspective_items), 4) self.assertEqual(len(event_summary.memory_manipulation_items), 1) self.assertEqual(event_summary.items['conv2d'].cpu_time, 15) - self.assertEqual(event_summary.items['conv2d'].gpu_time, 25) + self.assertEqual(event_summary.items['conv2d'].general_gpu_time, 25) self.assertEqual( event_summary.model_perspective_items['Forward'].cpu_time, 100) self.assertEqual( - event_summary.model_perspective_items['Forward'].gpu_time, 135) + event_summary.model_perspective_items['Forward'].general_gpu_time, + 135) self.assertEqual( - event_summary.model_perspective_items['Backward'].gpu_time, 0) + event_summary.model_perspective_items['Backward'].general_gpu_time, + 0) self.assertEqual( event_summary.memory_manipulation_items['AsyncMemcpy'].cpu_time, 15) - self.assertEqual( - event_summary.memory_manipulation_items['AsyncMemcpy'].gpu_time, 60) + self.assertEqual(event_summary.memory_manipulation_items['AsyncMemcpy'] + .general_gpu_time, 60) print( profiler.profiler_statistic._build_table( statistic_data, @@ -226,31 +228,31 @@ class TestProfilerStatistic(unittest.TestCase): userdefined_node = HostPythonNode('Communication Time', profiler.TracerEventType.UserDefined, 100, 110, 1000, 1001) - reduce_all_launchkernel0 = HostPythonNode( + allreduce_launchkernel0 = HostPythonNode( 'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 102, 104, 1000, 1001) - nccl_reduce_all_kernel0 = DevicePythonNode( - 'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 105, 120, + nccl_allreduce_kernel0 = DevicePythonNode( + 'nccl_allreduce_kernel', profiler.TracerEventType.Kernel, 105, 120, 0, 0, 2) communication_node = HostPythonNode( 'Communication', profiler.TracerEventType.Communication, 105, 110, 1000, 1001) - reduce_all_op1 = HostPythonNode('reduce_all_op1', - profiler.TracerEventType.Operator, 105, - 108, 1000, 1001) - reduce_all_op1_infershape = HostPythonNode( - 'reduce_all_op1::infershape', - profiler.TracerEventType.OperatorInner, 105, 106, 1000, 1001) + allreduce_op1 = HostPythonNode('allreduce_op1', + profiler.TracerEventType.Operator, 105, + 108, 1000, 1001) + allreduce_op1_infershape = HostPythonNode( + 'allreduce_op1::infershape', profiler.TracerEventType.OperatorInner, + 105, 106, 1000, 1001) - reduce_all_launchkernel1 = HostPythonNode( + allreduce_launchkernel1 = HostPythonNode( 'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 106, 107, 1000, 1001) - nccl_reduce_all_kernel1 = DevicePythonNode( - 'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 130, 150, + nccl_allreduce_kernel1 = DevicePythonNode( + 'nccl_allreduce_kernel', profiler.TracerEventType.Kernel, 130, 150, 0, 0, 2) backward_node = HostPythonNode('Gradient Backward', @@ -305,19 +307,19 @@ class TestProfilerStatistic(unittest.TestCase): 'sync_batch_norm_memcpy', profiler.TracerEventType.Memcpy, 150, 200, 0, 0, 1) - reduce_all_node2 = HostPythonNode('reduce_all', - profiler.TracerEventType.Operator, - 230, 250, 1000, 1001) + allreduce_node2 = HostPythonNode('allreduce', + profiler.TracerEventType.Operator, 230, + 250, 1000, 1001) - reduce_all_node2_infershape = HostPythonNode( - 'reduce_all_node2::infershape', + allreduce_node2_infershape = HostPythonNode( + 'allreduce_node2::infershape', profiler.TracerEventType.OperatorInner, 231, 232, 1000, 1001) - reduce_all_launchkernel2 = HostPythonNode( + allreduce_launchkernel2 = HostPythonNode( 'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 235, 240, 1000, 1001) - nccl_reduce_all_kernel2 = DevicePythonNode( - 'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 250, 280, + nccl_allreduce_kernel2 = DevicePythonNode( + 'nccl_allreduce_kernel', profiler.TracerEventType.Kernel, 250, 280, 0, 0, 2) root_node.children_node.append(profilerstep_node) @@ -329,12 +331,12 @@ class TestProfilerStatistic(unittest.TestCase): yolonet_node.children_node.extend( [sync_batch_norm_node, userdefined_node]) userdefined_node.children_node.append(communication_node) - userdefined_node.runtime_node.append(reduce_all_launchkernel0) - reduce_all_launchkernel0.device_node.append(nccl_reduce_all_kernel0) - communication_node.children_node.append(reduce_all_op1) - reduce_all_op1.children_node.append(reduce_all_op1_infershape) - reduce_all_op1.runtime_node.append(reduce_all_launchkernel1) - reduce_all_launchkernel1.device_node.append(nccl_reduce_all_kernel1) + userdefined_node.runtime_node.append(allreduce_launchkernel0) + allreduce_launchkernel0.device_node.append(nccl_allreduce_kernel0) + communication_node.children_node.append(allreduce_op1) + allreduce_op1.children_node.append(allreduce_op1_infershape) + allreduce_op1.runtime_node.append(allreduce_launchkernel1) + allreduce_launchkernel1.device_node.append(nccl_allreduce_kernel1) conv2d_node.children_node.extend( [conv2d_infer_shape, conv2d_compute, conv2d_MemCpy]) conv2d_compute.runtime_node.append(conv2d_launchkernel) @@ -350,10 +352,10 @@ class TestProfilerStatistic(unittest.TestCase): sync_batch_norm_MemCpy.runtime_node.append(sync_batch_norm_cudaMemCpy) sync_batch_norm_launchkernel.device_node.append(sync_batch_norm_kernel) sync_batch_norm_cudaMemCpy.device_node.append(sync_batch_norm_memcpy) - optimization_node.children_node.append(reduce_all_node2) - reduce_all_node2.children_node.append(reduce_all_node2_infershape) - reduce_all_node2.runtime_node.append(reduce_all_launchkernel2) - reduce_all_launchkernel2.device_node.append(nccl_reduce_all_kernel2) + optimization_node.children_node.append(allreduce_node2) + allreduce_node2.children_node.append(allreduce_node2_infershape) + allreduce_node2.runtime_node.append(allreduce_launchkernel2) + allreduce_launchkernel2.device_node.append(nccl_allreduce_kernel2) thread_tree = {'thread1001': root_node} extra_info = { 'Process Cpu Utilization': '1.02', @@ -415,20 +417,22 @@ class TestProfilerStatistic(unittest.TestCase): distributed_summary.overlap_range), 85) self.assertEqual(len(event_summary.items), 4) self.assertEqual(len(event_summary.userdefined_items), 1) - self.assertEqual(len(event_summary.model_perspective_items), 3) + self.assertEqual(len(event_summary.model_perspective_items), 4) self.assertEqual(len(event_summary.memory_manipulation_items), 1) self.assertEqual(event_summary.items['conv2d'].cpu_time, 15) - self.assertEqual(event_summary.items['conv2d'].gpu_time, 25) + self.assertEqual(event_summary.items['conv2d'].general_gpu_time, 25) self.assertEqual( event_summary.model_perspective_items['Forward'].cpu_time, 100) self.assertEqual( - event_summary.model_perspective_items['Forward'].gpu_time, 315) + event_summary.model_perspective_items['Forward'].general_gpu_time, + 315) self.assertEqual( - event_summary.model_perspective_items['Backward'].gpu_time, 0) + event_summary.model_perspective_items['Backward'].general_gpu_time, + 0) self.assertEqual( event_summary.memory_manipulation_items['AsyncMemcpy'].cpu_time, 15) - self.assertEqual( - event_summary.memory_manipulation_items['AsyncMemcpy'].gpu_time, 60) + self.assertEqual(event_summary.memory_manipulation_items['AsyncMemcpy'] + .general_gpu_time, 60) print( profiler.profiler_statistic._build_table( statistic_data, diff --git a/python/paddle/profiler/profiler.py b/python/paddle/profiler/profiler.py index c1c4f4ff8c..2fae583397 100644 --- a/python/paddle/profiler/profiler.py +++ b/python/paddle/profiler/profiler.py @@ -18,6 +18,8 @@ import datetime from enum import Enum from typing import Any, Callable, Iterable, Optional, Union from warnings import warn +import importlib +import json import paddle from paddle.fluid.core import (_Profiler, _ProfilerResult, ProfilerOptions, @@ -741,3 +743,73 @@ class Profiler: op_detail=op_detail, thread_sep=thread_sep, time_unit=time_unit)) + + +def get_profiler(config_path): + try: + with open(config_path, 'r') as filehandle: + config_dict = json.load(filehandle) + except Exception as e: + print('Load config file for profiler error: {}'.format(e)) + print('Use default parameters instead.') + return Profiler() + translated_config_dict = {} + if "targets" in config_dict: + try: + translated_config_dict['targets'] = [] + for target in config_dict['targets']: + if target.lower() == "cpu": + translated_config_dict['targets'].append(ProfilerTarget.CPU) + elif target.lower() == 'gpu': + translated_config_dict['targets'].append(ProfilerTarget.GPU) + except: + print('Set targets parameter error, use default parameter instead.') + translated_config_dict['targets'] = None + if "scheduler" in config_dict: + try: + if isinstance(config_dict['scheduler'], dict): + for key, value in config_dict['scheduler'].items(): + module_path = value['module'] + use_direct = value['use_direct'] + module = importlib.import_module(module_path) + method = getattr(module, key) + if not use_direct: + translated_config_dict['scheduler'] = method( + *value['args'], **value['kwargs']) + else: + translated_config_dict['scheduler'] = method + else: + translated_config_dict['scheduler'] = [ + config_dict['scheduler'][0], config_dict['scheduler'][1] + ] + + except: + print( + 'Set scheduler parameter error, use default parameter instead.') + translated_config_dict['scheduler'] = None + if "on_trace_ready" in config_dict: + try: + if isinstance(config_dict['on_trace_ready'], dict): + for key, value in config_dict['on_trace_ready'].items(): + module_path = value['module'] + use_direct = value['use_direct'] + module = importlib.import_module(module_path) + method = getattr(module, key) + if not use_direct: + translated_config_dict['on_trace_ready'] = method( + *value['args'], **value['kwargs']) + else: + translated_config_dict['on_trace_ready'] = method + except: + print( + 'Set on_trace_ready parameter error, use default parameter instead.' + ) + translated_config_dict['on_trace_ready'] = None + if "timer_only" in config_dict: + if isinstance(config_dict['timer_only'], bool): + translated_config_dict['timer_only'] = config_dict['timer_only'] + else: + print( + 'Set timer_only parameter error, use default parameter instead.') + + return Profiler(**translated_config_dict) diff --git a/python/paddle/profiler/profiler_statistic.py b/python/paddle/profiler/profiler_statistic.py index 3be6088a48..e4d4ff8c18 100755 --- a/python/paddle/profiler/profiler_statistic.py +++ b/python/paddle/profiler/profiler_statistic.py @@ -28,7 +28,7 @@ _AllTracerEventType = [ TracerEventType.PythonOp, TracerEventType.PythonUserDefined ] -_CommunicationOpName = ['reduce', 'broadcast', 'rpc'] +_CommunicationOpName = ['allreduce', 'broadcast', 'rpc'] class SortedKeys(Enum): @@ -74,8 +74,10 @@ class HostStatisticNode: self.runtime_node = [] self.cpu_time = 0 self.self_cpu_time = 0 - self.gpu_time = 0 + self.gpu_time = 0 # kernel time self.self_gpu_time = 0 + self.general_gpu_time = 0 # besides kernel, include time of gpu events like memcpy and memset + self.self_general_gpu_time = 0 def cal_statistic(self): for child in self.children_node: @@ -86,14 +88,20 @@ class HostStatisticNode: self.cpu_time = self.hostnode.end_ns - self.hostnode.start_ns for child in self.children_node: self.gpu_time += child.gpu_time + self.general_gpu_time += child.general_gpu_time self.self_cpu_time -= (child.end_ns - child.start_ns) for rt in self.runtime_node: self.self_cpu_time -= (rt.end_ns - rt.start_ns) self.gpu_time += rt.gpu_time self.self_gpu_time += rt.gpu_time + self.general_gpu_time += rt.general_gpu_time + self.self_general_gpu_time += rt.general_gpu_time for device in self.hostnode.device_node: - self.gpu_time += (device.end_ns - device.start_ns) - self.self_gpu_time += (device.end_ns - device.start_ns) + if device.type == TracerEventType.Kernel: + self.gpu_time += (device.end_ns - device.start_ns) + self.self_gpu_time += (device.end_ns - device.start_ns) + self.general_gpu_time += (device.end_ns - device.start_ns) + self.self_general_gpu_time += (device.end_ns - device.start_ns) @property def end_ns(self): @@ -258,6 +266,8 @@ class DistributedSummary: self.communication_range = [] self.computation_range = [] self.overlap_range = [] + self.cpu_calls = 0 + self.gpu_calls = 0 def parse(self, nodetrees): ''' @@ -300,6 +310,8 @@ class DistributedSummary: else: self.computation_range.append(( devicenode.start_ns, devicenode.end_ns)) + self.cpu_calls = len(set(self.cpu_communication_range)) + self.gpu_calls = len(set(self.gpu_communication_range)) self.cpu_communication_range = merge_self_ranges( self.cpu_communication_range, is_sorted=False) self.gpu_communication_range = merge_self_ranges( @@ -354,6 +366,9 @@ class EventSummary: self.min_gpu_time = float('inf') self.devices = {} self.operator_inners = {} + self.general_gpu_time = 0 + self.min_general_gpu_time = float('inf') + self.max_general_gpu_time = 0 @property def avg_cpu_time(self): @@ -363,6 +378,10 @@ class EventSummary: def avg_gpu_time(self): return self.gpu_time / self.call + @property + def avg_general_gpu_time(self): + return self.general_gpu_time / self.call + def add_cpu_time(self, time): if time > self.max_cpu_time: self.max_cpu_time = time @@ -377,6 +396,13 @@ class EventSummary: self.min_gpu_time = time self.gpu_time += time + def add_general_gpu_time(self, time): + if time > self.max_general_gpu_time: + self.max_general_gpu_time = time + if time < self.min_general_gpu_time: + self.min_general_gpu_time = time + self.general_gpu_time += time + def add_call(self): self.call += 1 @@ -384,6 +410,7 @@ class EventSummary: self.add_call() self.add_cpu_time(node.cpu_time) self.add_gpu_time(node.gpu_time) + self.add_general_gpu_time(node.general_gpu_time) for child in node.children_node: if child.name not in self.operator_inners: self.operator_inners[ @@ -407,6 +434,9 @@ class EventSummary: self.gpu_time = 0 self.max_gpu_time = 0 self.min_gpu_time = float('inf') + self.general_gpu_time = 0 + self.min_general_gpu_time = float('inf') + self.max_general_gpu_time = 0 @property def avg_cpu_time(self): @@ -416,6 +446,10 @@ class EventSummary: def avg_gpu_time(self): return self.gpu_time / self.call + @property + def avg_general_gpu_time(self): + return self.general_gpu_time / self.call + def add_cpu_time(self, time): if time > self.max_cpu_time: self.max_cpu_time = time @@ -430,6 +464,13 @@ class EventSummary: self.min_gpu_time = time self.gpu_time += time + def add_general_gpu_time(self, time): + if time > self.max_general_gpu_time: + self.max_general_gpu_time = time + if time < self.min_general_gpu_time: + self.min_general_gpu_time = time + self.general_gpu_time += time + def add_call(self): self.call += 1 @@ -437,6 +478,7 @@ class EventSummary: self.add_call() self.add_cpu_time(node.cpu_time) self.add_gpu_time(node.gpu_time) + self.add_general_gpu_time(node.general_gpu_time) def __init__(self): self.items = {} # for operator summary @@ -478,6 +520,8 @@ class EventSummary: self.add_model_perspective_item( child) #find first model perspective node else: + if child.type == TracerEventType.ProfileStep: + self.add_model_perspective_item(child) deque.append(child) def add_operator_item(self, operator_node): @@ -533,6 +577,8 @@ class EventSummary: name = 'Optimization' elif model_perspective_node.type == TracerEventType.Dataloader: name = 'Dataloader' + elif model_perspective_node.type == TracerEventType.ProfileStep: + name = 'ProfileStep' else: return if name not in self.model_perspective_items: @@ -626,7 +672,6 @@ def _build_table(statistic_data, # construct table string append(add_title(line_length, "Device Summary")) - append('Time unit: {}'.format(time_unit)) append(header_sep) append(row_format.format(*headers)) append(header_sep) @@ -661,7 +706,7 @@ def _build_table(statistic_data, return ''.join(result) ###### Print Overview Summary ###### - headers = ['Event Type', 'CPU Time', 'Ratio (%)'] + headers = ['Event Type', 'Calls', 'CPU Time', 'Ratio (%)'] row_format_list = [""] header_sep_list = [""] line_length_list = [-SPACING_SIZE] @@ -680,13 +725,13 @@ def _build_table(statistic_data, append(header_sep) append(row_format.format(*headers)) append(header_sep) - row_values = [ - 'Total Time', format_time( - total_time, unit=time_unit), format_ratio(1) - ] - append(row_format.format(*row_values)) cpu_type_time = collections.defaultdict(int) gpu_type_time = collections.defaultdict(int) + cpu_call_times = collections.defaultdict(int) + gpu_call_times = collections.defaultdict(int) + cpu_call_times.update(statistic_data.time_range_summary.call_times) + gpu_call_times.update(statistic_data.time_range_summary.call_times) + for event_type, value in statistic_data.time_range_summary.CPUTimeRangeSum.items( ): if event_type != TracerEventType.Communication: @@ -694,6 +739,19 @@ def _build_table(statistic_data, if statistic_data.distributed_summary.cpu_communication_range: cpu_type_time[TracerEventType.Communication] = sum_ranges( statistic_data.distributed_summary.cpu_communication_range) + cpu_call_times[ + TracerEventType. + Communication] = statistic_data.distributed_summary.cpu_calls + + for event_type in [ + TracerEventType.Dataloader, TracerEventType.Forward, + TracerEventType.Backward, TracerEventType.Optimization + ]: + event_type_name = str(event_type).split('.')[1] + if event_type in cpu_call_times and event_type_name in statistic_data.event_summary.model_perspective_items: + cpu_call_times[ + event_type] = statistic_data.event_summary.model_perspective_items[ + event_type_name].call gpu_time_range = collections.defaultdict(list) for device_id, device_time_ranges in statistic_data.time_range_summary.GPUTimeRange.items( @@ -706,22 +764,34 @@ def _build_table(statistic_data, if statistic_data.distributed_summary.gpu_communication_range: gpu_type_time[TracerEventType.Communication] = sum_ranges( statistic_data.distributed_summary.gpu_communication_range) + gpu_call_times[ + TracerEventType. + Communication] = statistic_data.distributed_summary.gpu_calls sorted_items = sorted( cpu_type_time.items(), key=lambda x: x[1], reverse=True) - for event_type, time in sorted_items: + event_type, time = sorted_items[0] + row_values = [ + '{}'.format(str(event_type).split('.')[1]), cpu_call_times[event_type], + format_time( + time, unit=time_unit), format_ratio(float(time) / total_time) + ] + append(row_format.format(*row_values)) + for event_type, time in sorted_items[1:]: row_values = [ - ' {}'.format(str(event_type).split('.')[1]), format_time( + ' {}'.format(str(event_type).split('.')[1]), + cpu_call_times[event_type], format_time( time, unit=time_unit), format_ratio(float(time) / total_time) ] append(row_format.format(*row_values)) append(header_sep) - headers = ['', 'GPU Time', 'Ratio (%)'] + headers = ['', 'Calls', 'GPU Time', 'Ratio (%)'] append(row_format.format(*headers)) append(header_sep) for event_type, time in gpu_type_time.items(): row_values = [ - ' {}'.format(str(event_type).split('.')[1]), format_time( + ' {}'.format(str(event_type).split('.')[1]), + gpu_call_times[event_type], format_time( time, unit=time_unit), format_ratio(float(time) / total_time) ] append(row_format.format(*row_values)) @@ -730,7 +800,7 @@ def _build_table(statistic_data, append( "Note:\nIn this table, We sum up all collected events in terms of event type.\n" "The time of events collected on host are presented as CPU Time, and as GPU Time if on device.\n" - "Ratio = CPU(GPU) Time / Total Time.\n" + "The time with ratio 100% is the base time for calculating ratio. \n" "Events with different types may overlap or inclusion, e.g. Operator includes OperatorInner, so the sum of ratios is not 100%.\n" "The time of events in the same type with overlap will not calculate twice, and all time is summed after merged.\n" "Example:\n" @@ -746,21 +816,21 @@ def _build_table(statistic_data, ###### Print Model Summary Report ###### model_perspective_items = statistic_data.event_summary.model_perspective_items - if model_perspective_items: + if len(model_perspective_items) > 1: all_row_values = [] - row_values = [ - 'Total Time', '-', '{} / - / - / - / {}'.format( - format_time( - total_time, unit=time_unit), format_ratio(1)), - '- / - / - / -/ -' - ] - all_row_values.append(row_values) accmulation_time = 0 - for name in ['Dataloader', 'Forward', 'Backward', 'Optimization']: + gpu_accmulation_time = 0 + gpu_total_time = 0 + for name in [ + 'ProfileStep', 'Dataloader', 'Forward', 'Backward', + 'Optimization' + ]: if name in model_perspective_items: item = model_perspective_items[name] + name = '{}'.format( + name) if 'ProfileStep' in name else ' {}'.format(name) row_values = [ - ' {}'.format(name), item.call, + '{}'.format(name), item.call, '{} / {} / {} / {} / {}'.format( format_time( item.cpu_time, unit=time_unit), @@ -783,15 +853,23 @@ def _build_table(statistic_data, format_ratio(float(item.gpu_time) / total_time)) ] all_row_values.append(row_values) - accmulation_time += item.cpu_time + if 'ProfileStep' not in name: + accmulation_time += item.cpu_time + gpu_accmulation_time += item.gpu_time + else: + gpu_total_time = item.gpu_time other_time = total_time - accmulation_time + other_gpu_time = gpu_total_time - gpu_accmulation_time row_values = [ ' Others', '-', '{} / - / - / - / {}'.format( format_time( other_time, unit=time_unit), format_ratio(float(other_time) / total_time)), - '- / - / - / - / -' + '{} / - / - / - / {}'.format( + format_time( + other_gpu_time, unit=time_unit), + format_ratio(float(other_gpu_time) / gpu_total_time)) ] all_row_values.append(row_values) # Calculate the column width @@ -835,6 +913,7 @@ def _build_table(statistic_data, append( "Note:\nIn this table, GPU time is the sum of all device(GPU) events called in the phase.\n" "Unlike overview summary, if two device(GPU) events execute on different streams with overlap time, we sum them directly here.\n" + "The time with ratio 100% is the base time for calculating ratio. \n" ) append('-' * line_length) append('') @@ -872,21 +951,27 @@ def _build_table(statistic_data, overlap_time = sum_ranges( statistic_data.distributed_summary.overlap_range) row_values = [ - 'Communication', format_time( + 'ProfileStep', format_time( + total_time, unit=time_unit), + format_ratio(float(total_time) / total_time) + ] + append(row_format.format(*row_values)) + row_values = [ + ' Communication', format_time( communication_time, unit=time_unit), format_ratio(float(communication_time) / total_time) ] append(row_format.format(*row_values)) row_values = [ - 'Computation', format_time( + ' Computation', format_time( computation_time, unit=time_unit), format_ratio(float(computation_time) / total_time) ] append(row_format.format(*row_values)) row_values = [ - 'Overlap', format_time( + ' Overlap', format_time( overlap_time, unit=time_unit), format_ratio(float(overlap_time) / total_time) ] @@ -896,6 +981,7 @@ def _build_table(statistic_data, "Note:\nCommunication time: Communication Event time, Communication Op time and its kernel time on gpu.\n" "Computation time: Kernel time, except kernels belong to communication(nccl kernels).\n" "Overlap time: Communication time intersects with computation time.\n" + "The time with ratio 100% is the base time for calculating ratio. \n" "Example:\n" "Communication:\n" " CPU: |_________________|\n" @@ -938,20 +1024,22 @@ def _build_table(statistic_data, items.items(), key=lambda x: x[1].min_cpu_time) elif sorted_by == SortedKeys.GPUTotal: sorted_items = sorted( - items.items(), key=lambda x: x[1].gpu_time, reverse=True) + items.items(), + key=lambda x: x[1].general_gpu_time, + reverse=True) elif sorted_by == SortedKeys.GPUAvg: sorted_items = sorted( items.items(), - key=lambda x: x[1].avg_gpu_time, + key=lambda x: x[1].avg_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMax: sorted_items = sorted( items.items(), - key=lambda x: x[1].max_gpu_time, + key=lambda x: x[1].max_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMin: sorted_items = sorted( - items.items(), key=lambda x: x[1].min_gpu_time) + items.items(), key=lambda x: x[1].min_general_gpu_time) for name, item in sorted_items: row_values = [ @@ -967,14 +1055,15 @@ def _build_table(statistic_data, format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - item.gpu_time, unit=time_unit), + item.general_gpu_time, unit=time_unit), format_time( - item.avg_gpu_time, unit=time_unit), + item.avg_general_gpu_time, unit=time_unit), format_time( - item.max_gpu_time, unit=time_unit), + item.max_general_gpu_time, unit=time_unit), format_time( - item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_time)) + item.min_general_gpu_time, unit=time_unit), + format_ratio( + float(item.general_gpu_time) / total_time)) ] all_row_values.append(row_values) if op_detail: @@ -998,18 +1087,23 @@ def _build_table(statistic_data, float(innerop_node.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - innerop_node.gpu_time, unit=time_unit), + innerop_node.general_gpu_time, + unit=time_unit), format_time( - innerop_node.avg_gpu_time, unit=time_unit), + innerop_node.avg_general_gpu_time, + unit=time_unit), format_time( - innerop_node.max_gpu_time, unit=time_unit), + innerop_node.max_general_gpu_time, + unit=time_unit), format_time( - innerop_node.min_gpu_time, unit=time_unit), + innerop_node.min_general_gpu_time, + unit=time_unit), format_ratio( - float(innerop_node.gpu_time) / total_time)) + float(innerop_node.general_gpu_time) / + total_time)) ] all_row_values.append(row_values) - for device_node_name, devicenode in innerop_node.devices.items( + for device_node_name, device_node in innerop_node.devices.items( ): if len(device_node_name) + 4 > name_column_width: device_node_name = device_node_name[: @@ -1018,21 +1112,21 @@ def _build_table(statistic_data, device_node_name += "..." row_values = [ ' {}'.format(device_node_name), - devicenode.call, '- / - / - / - / -', + device_node.call, '- / - / - / - / -', '{} / {} / {} / {} / {}'.format( format_time( - devicenode.gpu_time, unit=time_unit), + device_node.gpu_time, unit=time_unit), format_time( - devicenode.avg_gpu_time, + device_node.avg_gpu_time, unit=time_unit), format_time( - devicenode.max_gpu_time, + device_node.max_gpu_time, unit=time_unit), format_time( - devicenode.min_gpu_time, + device_node.min_gpu_time, unit=time_unit), format_ratio( - float(devicenode.gpu_time) / + float(device_node.gpu_time) / total_time)) ] all_row_values.append(row_values) @@ -1043,19 +1137,19 @@ def _build_table(statistic_data, - 5] device_node_name += "..." row_values = [ - ' {}'.format(device_node_name), devicenode.call, + ' {}'.format(device_node_name), device_node.call, '- / - / - / - / -', '{} / {} / {} / {} / {}'.format( format_time( - devicenode.gpu_time, unit=time_unit), + device_node.gpu_time, unit=time_unit), format_time( - devicenode.avg_gpu_time, unit=time_unit), + device_node.avg_gpu_time, unit=time_unit), format_time( - devicenode.max_gpu_time, unit=time_unit), + device_node.max_gpu_time, unit=time_unit), format_time( - devicenode.min_gpu_time, unit=time_unit), + device_node.min_gpu_time, unit=time_unit), format_ratio( - float(devicenode.gpu_time) / total_time)) + float(device_node.gpu_time) / total_time)) ] all_row_values.append(row_values) # Calculate the column width @@ -1123,14 +1217,14 @@ def _build_table(statistic_data, format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - item.gpu_time, unit=time_unit), + item.general_gpu_time, unit=time_unit), format_time( - item.avg_gpu_time, unit=time_unit), + item.avg_general_gpu_time, unit=time_unit), format_time( - item.max_gpu_time, unit=time_unit), + item.max_general_gpu_time, unit=time_unit), format_time( - item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_time)), + item.min_general_gpu_time, unit=time_unit), + format_ratio(float(item.general_gpu_time) / total_time)), ] all_row_values.append(row_values) @@ -1207,20 +1301,22 @@ def _build_table(statistic_data, items.items(), key=lambda x: x[1].min_cpu_time) elif sorted_by == SortedKeys.GPUTotal: sorted_items = sorted( - items.items(), key=lambda x: x[1].gpu_time, reverse=True) + items.items(), + key=lambda x: x[1].general_gpu_time, + reverse=True) elif sorted_by == SortedKeys.GPUAvg: sorted_items = sorted( items.items(), - key=lambda x: x[1].avg_gpu_time, + key=lambda x: x[1].avg_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMax: sorted_items = sorted( items.items(), - key=lambda x: x[1].max_gpu_time, + key=lambda x: x[1].max_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMin: sorted_items = sorted( - items.items(), key=lambda x: x[1].min_gpu_time) + items.items(), key=lambda x: x[1].min_general_gpu_time) for name, item in sorted_items: row_values = [ @@ -1238,14 +1334,15 @@ def _build_table(statistic_data, format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - item.gpu_time, unit=time_unit), + item.general_gpu_time, unit=time_unit), format_time( - item.avg_gpu_time, unit=time_unit), + item.avg_general_gpu_time, unit=time_unit), format_time( - item.max_gpu_time, unit=time_unit), + item.max_general_gpu_time, unit=time_unit), format_time( - item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_time)), + item.min_general_gpu_time, unit=time_unit), + format_ratio( + float(item.general_gpu_time) / total_time)), ] all_row_values.append(row_values) -- GitLab