From f43af2759c9fc6e8aed797f3bb96c126f0624b87 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 8 Apr 2022 14:30:58 +0800 Subject: [PATCH] Refine statistic table (#41524) --- .../unittests/test_profiler_statistic.py | 88 +++---- python/paddle/profiler/profiler_statistic.py | 231 ++++++++++++------ 2 files changed, 205 insertions(+), 114 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_profiler_statistic.py b/python/paddle/fluid/tests/unittests/test_profiler_statistic.py index adc42d0447f..dc944e68c7f 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler_statistic.py +++ b/python/paddle/fluid/tests/unittests/test_profiler_statistic.py @@ -185,20 +185,22 @@ class TestProfilerStatistic(unittest.TestCase): profiler.TracerEventType.Communication), 5) self.assertEqual(len(event_summary.items), 2) self.assertEqual(len(event_summary.userdefined_items), 1) - self.assertEqual(len(event_summary.model_perspective_items), 3) + self.assertEqual(len(event_summary.model_perspective_items), 4) self.assertEqual(len(event_summary.memory_manipulation_items), 1) self.assertEqual(event_summary.items['conv2d'].cpu_time, 15) - self.assertEqual(event_summary.items['conv2d'].gpu_time, 25) + self.assertEqual(event_summary.items['conv2d'].general_gpu_time, 25) self.assertEqual( event_summary.model_perspective_items['Forward'].cpu_time, 100) self.assertEqual( - event_summary.model_perspective_items['Forward'].gpu_time, 135) + event_summary.model_perspective_items['Forward'].general_gpu_time, + 135) self.assertEqual( - event_summary.model_perspective_items['Backward'].gpu_time, 0) + event_summary.model_perspective_items['Backward'].general_gpu_time, + 0) self.assertEqual( event_summary.memory_manipulation_items['AsyncMemcpy'].cpu_time, 15) - self.assertEqual( - event_summary.memory_manipulation_items['AsyncMemcpy'].gpu_time, 60) + self.assertEqual(event_summary.memory_manipulation_items['AsyncMemcpy'] + .general_gpu_time, 60) print( profiler.profiler_statistic._build_table( statistic_data, @@ -226,31 +228,31 @@ class TestProfilerStatistic(unittest.TestCase): userdefined_node = HostPythonNode('Communication Time', profiler.TracerEventType.UserDefined, 100, 110, 1000, 1001) - reduce_all_launchkernel0 = HostPythonNode( + allreduce_launchkernel0 = HostPythonNode( 'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 102, 104, 1000, 1001) - nccl_reduce_all_kernel0 = DevicePythonNode( - 'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 105, 120, + nccl_allreduce_kernel0 = DevicePythonNode( + 'nccl_allreduce_kernel', profiler.TracerEventType.Kernel, 105, 120, 0, 0, 2) communication_node = HostPythonNode( 'Communication', profiler.TracerEventType.Communication, 105, 110, 1000, 1001) - reduce_all_op1 = HostPythonNode('reduce_all_op1', - profiler.TracerEventType.Operator, 105, - 108, 1000, 1001) - reduce_all_op1_infershape = HostPythonNode( - 'reduce_all_op1::infershape', - profiler.TracerEventType.OperatorInner, 105, 106, 1000, 1001) + allreduce_op1 = HostPythonNode('allreduce_op1', + profiler.TracerEventType.Operator, 105, + 108, 1000, 1001) + allreduce_op1_infershape = HostPythonNode( + 'allreduce_op1::infershape', profiler.TracerEventType.OperatorInner, + 105, 106, 1000, 1001) - reduce_all_launchkernel1 = HostPythonNode( + allreduce_launchkernel1 = HostPythonNode( 'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 106, 107, 1000, 1001) - nccl_reduce_all_kernel1 = DevicePythonNode( - 'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 130, 150, + nccl_allreduce_kernel1 = DevicePythonNode( + 'nccl_allreduce_kernel', profiler.TracerEventType.Kernel, 130, 150, 0, 0, 2) backward_node = HostPythonNode('Gradient Backward', @@ -305,19 +307,19 @@ class TestProfilerStatistic(unittest.TestCase): 'sync_batch_norm_memcpy', profiler.TracerEventType.Memcpy, 150, 200, 0, 0, 1) - reduce_all_node2 = HostPythonNode('reduce_all', - profiler.TracerEventType.Operator, - 230, 250, 1000, 1001) + allreduce_node2 = HostPythonNode('allreduce', + profiler.TracerEventType.Operator, 230, + 250, 1000, 1001) - reduce_all_node2_infershape = HostPythonNode( - 'reduce_all_node2::infershape', + allreduce_node2_infershape = HostPythonNode( + 'allreduce_node2::infershape', profiler.TracerEventType.OperatorInner, 231, 232, 1000, 1001) - reduce_all_launchkernel2 = HostPythonNode( + allreduce_launchkernel2 = HostPythonNode( 'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 235, 240, 1000, 1001) - nccl_reduce_all_kernel2 = DevicePythonNode( - 'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 250, 280, + nccl_allreduce_kernel2 = DevicePythonNode( + 'nccl_allreduce_kernel', profiler.TracerEventType.Kernel, 250, 280, 0, 0, 2) root_node.children_node.append(profilerstep_node) @@ -329,12 +331,12 @@ class TestProfilerStatistic(unittest.TestCase): yolonet_node.children_node.extend( [sync_batch_norm_node, userdefined_node]) userdefined_node.children_node.append(communication_node) - userdefined_node.runtime_node.append(reduce_all_launchkernel0) - reduce_all_launchkernel0.device_node.append(nccl_reduce_all_kernel0) - communication_node.children_node.append(reduce_all_op1) - reduce_all_op1.children_node.append(reduce_all_op1_infershape) - reduce_all_op1.runtime_node.append(reduce_all_launchkernel1) - reduce_all_launchkernel1.device_node.append(nccl_reduce_all_kernel1) + userdefined_node.runtime_node.append(allreduce_launchkernel0) + allreduce_launchkernel0.device_node.append(nccl_allreduce_kernel0) + communication_node.children_node.append(allreduce_op1) + allreduce_op1.children_node.append(allreduce_op1_infershape) + allreduce_op1.runtime_node.append(allreduce_launchkernel1) + allreduce_launchkernel1.device_node.append(nccl_allreduce_kernel1) conv2d_node.children_node.extend( [conv2d_infer_shape, conv2d_compute, conv2d_MemCpy]) conv2d_compute.runtime_node.append(conv2d_launchkernel) @@ -350,10 +352,10 @@ class TestProfilerStatistic(unittest.TestCase): sync_batch_norm_MemCpy.runtime_node.append(sync_batch_norm_cudaMemCpy) sync_batch_norm_launchkernel.device_node.append(sync_batch_norm_kernel) sync_batch_norm_cudaMemCpy.device_node.append(sync_batch_norm_memcpy) - optimization_node.children_node.append(reduce_all_node2) - reduce_all_node2.children_node.append(reduce_all_node2_infershape) - reduce_all_node2.runtime_node.append(reduce_all_launchkernel2) - reduce_all_launchkernel2.device_node.append(nccl_reduce_all_kernel2) + optimization_node.children_node.append(allreduce_node2) + allreduce_node2.children_node.append(allreduce_node2_infershape) + allreduce_node2.runtime_node.append(allreduce_launchkernel2) + allreduce_launchkernel2.device_node.append(nccl_allreduce_kernel2) thread_tree = {'thread1001': root_node} extra_info = { 'Process Cpu Utilization': '1.02', @@ -415,20 +417,22 @@ class TestProfilerStatistic(unittest.TestCase): distributed_summary.overlap_range), 85) self.assertEqual(len(event_summary.items), 4) self.assertEqual(len(event_summary.userdefined_items), 1) - self.assertEqual(len(event_summary.model_perspective_items), 3) + self.assertEqual(len(event_summary.model_perspective_items), 4) self.assertEqual(len(event_summary.memory_manipulation_items), 1) self.assertEqual(event_summary.items['conv2d'].cpu_time, 15) - self.assertEqual(event_summary.items['conv2d'].gpu_time, 25) + self.assertEqual(event_summary.items['conv2d'].general_gpu_time, 25) self.assertEqual( event_summary.model_perspective_items['Forward'].cpu_time, 100) self.assertEqual( - event_summary.model_perspective_items['Forward'].gpu_time, 315) + event_summary.model_perspective_items['Forward'].general_gpu_time, + 315) self.assertEqual( - event_summary.model_perspective_items['Backward'].gpu_time, 0) + event_summary.model_perspective_items['Backward'].general_gpu_time, + 0) self.assertEqual( event_summary.memory_manipulation_items['AsyncMemcpy'].cpu_time, 15) - self.assertEqual( - event_summary.memory_manipulation_items['AsyncMemcpy'].gpu_time, 60) + self.assertEqual(event_summary.memory_manipulation_items['AsyncMemcpy'] + .general_gpu_time, 60) print( profiler.profiler_statistic._build_table( statistic_data, diff --git a/python/paddle/profiler/profiler_statistic.py b/python/paddle/profiler/profiler_statistic.py index 3be6088a484..5fed5147613 100755 --- a/python/paddle/profiler/profiler_statistic.py +++ b/python/paddle/profiler/profiler_statistic.py @@ -28,7 +28,7 @@ _AllTracerEventType = [ TracerEventType.PythonOp, TracerEventType.PythonUserDefined ] -_CommunicationOpName = ['reduce', 'broadcast', 'rpc'] +_CommunicationOpName = ['allreduce', 'broadcast', 'rpc'] class SortedKeys(Enum): @@ -74,8 +74,10 @@ class HostStatisticNode: self.runtime_node = [] self.cpu_time = 0 self.self_cpu_time = 0 - self.gpu_time = 0 + self.gpu_time = 0 # kernel time self.self_gpu_time = 0 + self.general_gpu_time = 0 # besides kernel, include time of gpu events like memcpy and memset + self.self_general_gpu_time = 0 def cal_statistic(self): for child in self.children_node: @@ -86,14 +88,20 @@ class HostStatisticNode: self.cpu_time = self.hostnode.end_ns - self.hostnode.start_ns for child in self.children_node: self.gpu_time += child.gpu_time + self.general_gpu_time += child.general_gpu_time self.self_cpu_time -= (child.end_ns - child.start_ns) for rt in self.runtime_node: self.self_cpu_time -= (rt.end_ns - rt.start_ns) self.gpu_time += rt.gpu_time self.self_gpu_time += rt.gpu_time + self.general_gpu_time += rt.general_gpu_time + self.self_general_gpu_time += rt.general_gpu_time for device in self.hostnode.device_node: - self.gpu_time += (device.end_ns - device.start_ns) - self.self_gpu_time += (device.end_ns - device.start_ns) + if device.type == TracerEventType.Kernel: + self.gpu_time += (device.end_ns - device.start_ns) + self.self_gpu_time += (device.end_ns - device.start_ns) + self.general_gpu_time += (device.end_ns - device.start_ns) + self.self_general_gpu_time += (device.end_ns - device.start_ns) @property def end_ns(self): @@ -258,6 +266,8 @@ class DistributedSummary: self.communication_range = [] self.computation_range = [] self.overlap_range = [] + self.cpu_calls = 0 + self.gpu_calls = 0 def parse(self, nodetrees): ''' @@ -300,6 +310,8 @@ class DistributedSummary: else: self.computation_range.append(( devicenode.start_ns, devicenode.end_ns)) + self.cpu_calls = len(set(self.cpu_communication_range)) + self.gpu_calls = len(set(self.gpu_communication_range)) self.cpu_communication_range = merge_self_ranges( self.cpu_communication_range, is_sorted=False) self.gpu_communication_range = merge_self_ranges( @@ -354,6 +366,9 @@ class EventSummary: self.min_gpu_time = float('inf') self.devices = {} self.operator_inners = {} + self.general_gpu_time = 0 + self.min_general_gpu_time = float('inf') + self.max_general_gpu_time = 0 @property def avg_cpu_time(self): @@ -363,6 +378,10 @@ class EventSummary: def avg_gpu_time(self): return self.gpu_time / self.call + @property + def avg_general_gpu_time(self): + return self.general_gpu_time / self.call + def add_cpu_time(self, time): if time > self.max_cpu_time: self.max_cpu_time = time @@ -377,6 +396,13 @@ class EventSummary: self.min_gpu_time = time self.gpu_time += time + def add_general_gpu_time(self, time): + if time > self.max_general_gpu_time: + self.max_general_gpu_time = time + if time < self.min_general_gpu_time: + self.min_general_gpu_time = time + self.general_gpu_time += time + def add_call(self): self.call += 1 @@ -384,6 +410,7 @@ class EventSummary: self.add_call() self.add_cpu_time(node.cpu_time) self.add_gpu_time(node.gpu_time) + self.add_general_gpu_time(node.general_gpu_time) for child in node.children_node: if child.name not in self.operator_inners: self.operator_inners[ @@ -407,6 +434,9 @@ class EventSummary: self.gpu_time = 0 self.max_gpu_time = 0 self.min_gpu_time = float('inf') + self.general_gpu_time = 0 + self.min_general_gpu_time = float('inf') + self.max_general_gpu_time = 0 @property def avg_cpu_time(self): @@ -416,6 +446,10 @@ class EventSummary: def avg_gpu_time(self): return self.gpu_time / self.call + @property + def avg_general_gpu_time(self): + return self.general_gpu_time / self.call + def add_cpu_time(self, time): if time > self.max_cpu_time: self.max_cpu_time = time @@ -430,6 +464,13 @@ class EventSummary: self.min_gpu_time = time self.gpu_time += time + def add_general_gpu_time(self, time): + if time > self.max_general_gpu_time: + self.max_general_gpu_time = time + if time < self.min_general_gpu_time: + self.min_general_gpu_time = time + self.general_gpu_time += time + def add_call(self): self.call += 1 @@ -437,6 +478,7 @@ class EventSummary: self.add_call() self.add_cpu_time(node.cpu_time) self.add_gpu_time(node.gpu_time) + self.add_general_gpu_time(node.general_gpu_time) def __init__(self): self.items = {} # for operator summary @@ -478,6 +520,8 @@ class EventSummary: self.add_model_perspective_item( child) #find first model perspective node else: + if child.type == TracerEventType.ProfileStep: + self.add_model_perspective_item(child) deque.append(child) def add_operator_item(self, operator_node): @@ -533,6 +577,8 @@ class EventSummary: name = 'Optimization' elif model_perspective_node.type == TracerEventType.Dataloader: name = 'Dataloader' + elif model_perspective_node.type == TracerEventType.ProfileStep: + name = 'ProfileStep' else: return if name not in self.model_perspective_items: @@ -626,7 +672,6 @@ def _build_table(statistic_data, # construct table string append(add_title(line_length, "Device Summary")) - append('Time unit: {}'.format(time_unit)) append(header_sep) append(row_format.format(*headers)) append(header_sep) @@ -661,7 +706,7 @@ def _build_table(statistic_data, return ''.join(result) ###### Print Overview Summary ###### - headers = ['Event Type', 'CPU Time', 'Ratio (%)'] + headers = ['Event Type', 'Calls', 'CPU Time', 'Ratio (%)'] row_format_list = [""] header_sep_list = [""] line_length_list = [-SPACING_SIZE] @@ -680,13 +725,13 @@ def _build_table(statistic_data, append(header_sep) append(row_format.format(*headers)) append(header_sep) - row_values = [ - 'Total Time', format_time( - total_time, unit=time_unit), format_ratio(1) - ] - append(row_format.format(*row_values)) cpu_type_time = collections.defaultdict(int) gpu_type_time = collections.defaultdict(int) + cpu_call_times = collections.defaultdict(int) + gpu_call_times = collections.defaultdict(int) + cpu_call_times.update(statistic_data.time_range_summary.call_times) + gpu_call_times.update(statistic_data.time_range_summary.call_times) + for event_type, value in statistic_data.time_range_summary.CPUTimeRangeSum.items( ): if event_type != TracerEventType.Communication: @@ -694,6 +739,9 @@ def _build_table(statistic_data, if statistic_data.distributed_summary.cpu_communication_range: cpu_type_time[TracerEventType.Communication] = sum_ranges( statistic_data.distributed_summary.cpu_communication_range) + cpu_call_times[ + TracerEventType. + Communication] = statistic_data.distributed_summary.cpu_calls gpu_time_range = collections.defaultdict(list) for device_id, device_time_ranges in statistic_data.time_range_summary.GPUTimeRange.items( @@ -706,22 +754,34 @@ def _build_table(statistic_data, if statistic_data.distributed_summary.gpu_communication_range: gpu_type_time[TracerEventType.Communication] = sum_ranges( statistic_data.distributed_summary.gpu_communication_range) + gpu_call_times[ + TracerEventType. + Communication] = statistic_data.distributed_summary.gpu_calls sorted_items = sorted( cpu_type_time.items(), key=lambda x: x[1], reverse=True) - for event_type, time in sorted_items: + event_type, time = sorted_items[0] + row_values = [ + '{}'.format(str(event_type).split('.')[1]), cpu_call_times[event_type], + format_time( + time, unit=time_unit), format_ratio(float(time) / total_time) + ] + append(row_format.format(*row_values)) + for event_type, time in sorted_items[1:]: row_values = [ - ' {}'.format(str(event_type).split('.')[1]), format_time( + ' {}'.format(str(event_type).split('.')[1]), + cpu_call_times[event_type], format_time( time, unit=time_unit), format_ratio(float(time) / total_time) ] append(row_format.format(*row_values)) append(header_sep) - headers = ['', 'GPU Time', 'Ratio (%)'] + headers = ['', 'Calls', 'GPU Time', 'Ratio (%)'] append(row_format.format(*headers)) append(header_sep) for event_type, time in gpu_type_time.items(): row_values = [ - ' {}'.format(str(event_type).split('.')[1]), format_time( + ' {}'.format(str(event_type).split('.')[1]), + gpu_call_times[event_type], format_time( time, unit=time_unit), format_ratio(float(time) / total_time) ] append(row_format.format(*row_values)) @@ -730,7 +790,7 @@ def _build_table(statistic_data, append( "Note:\nIn this table, We sum up all collected events in terms of event type.\n" "The time of events collected on host are presented as CPU Time, and as GPU Time if on device.\n" - "Ratio = CPU(GPU) Time / Total Time.\n" + "The time with ratio 100% is the base time for calculating ratio. \n" "Events with different types may overlap or inclusion, e.g. Operator includes OperatorInner, so the sum of ratios is not 100%.\n" "The time of events in the same type with overlap will not calculate twice, and all time is summed after merged.\n" "Example:\n" @@ -746,21 +806,21 @@ def _build_table(statistic_data, ###### Print Model Summary Report ###### model_perspective_items = statistic_data.event_summary.model_perspective_items - if model_perspective_items: + if len(model_perspective_items) > 1: all_row_values = [] - row_values = [ - 'Total Time', '-', '{} / - / - / - / {}'.format( - format_time( - total_time, unit=time_unit), format_ratio(1)), - '- / - / - / -/ -' - ] - all_row_values.append(row_values) accmulation_time = 0 - for name in ['Dataloader', 'Forward', 'Backward', 'Optimization']: + gpu_accmulation_time = 0 + gpu_total_time = 0 + for name in [ + 'ProfileStep', 'Dataloader', 'Forward', 'Backward', + 'Optimization' + ]: if name in model_perspective_items: item = model_perspective_items[name] + name = '{}'.format( + name) if 'ProfileStep' in name else ' {}'.format(name) row_values = [ - ' {}'.format(name), item.call, + '{}'.format(name), item.call, '{} / {} / {} / {} / {}'.format( format_time( item.cpu_time, unit=time_unit), @@ -783,15 +843,23 @@ def _build_table(statistic_data, format_ratio(float(item.gpu_time) / total_time)) ] all_row_values.append(row_values) - accmulation_time += item.cpu_time + if 'ProfileStep' not in name: + accmulation_time += item.cpu_time + gpu_accmulation_time += item.gpu_time + else: + gpu_total_time = item.gpu_time other_time = total_time - accmulation_time + other_gpu_time = gpu_total_time - gpu_accmulation_time row_values = [ ' Others', '-', '{} / - / - / - / {}'.format( format_time( other_time, unit=time_unit), format_ratio(float(other_time) / total_time)), - '- / - / - / - / -' + '{} / - / - / - / {}'.format( + format_time( + other_gpu_time, unit=time_unit), + format_ratio(float(other_gpu_time) / gpu_total_time)) ] all_row_values.append(row_values) # Calculate the column width @@ -835,6 +903,7 @@ def _build_table(statistic_data, append( "Note:\nIn this table, GPU time is the sum of all device(GPU) events called in the phase.\n" "Unlike overview summary, if two device(GPU) events execute on different streams with overlap time, we sum them directly here.\n" + "The time with ratio 100% is the base time for calculating ratio. \n" ) append('-' * line_length) append('') @@ -872,21 +941,27 @@ def _build_table(statistic_data, overlap_time = sum_ranges( statistic_data.distributed_summary.overlap_range) row_values = [ - 'Communication', format_time( + 'ProfileStep', format_time( + total_time, unit=time_unit), + format_ratio(float(total_time) / total_time) + ] + append(row_format.format(*row_values)) + row_values = [ + ' Communication', format_time( communication_time, unit=time_unit), format_ratio(float(communication_time) / total_time) ] append(row_format.format(*row_values)) row_values = [ - 'Computation', format_time( + ' Computation', format_time( computation_time, unit=time_unit), format_ratio(float(computation_time) / total_time) ] append(row_format.format(*row_values)) row_values = [ - 'Overlap', format_time( + ' Overlap', format_time( overlap_time, unit=time_unit), format_ratio(float(overlap_time) / total_time) ] @@ -896,6 +971,7 @@ def _build_table(statistic_data, "Note:\nCommunication time: Communication Event time, Communication Op time and its kernel time on gpu.\n" "Computation time: Kernel time, except kernels belong to communication(nccl kernels).\n" "Overlap time: Communication time intersects with computation time.\n" + "The time with ratio 100% is the base time for calculating ratio. \n" "Example:\n" "Communication:\n" " CPU: |_________________|\n" @@ -938,20 +1014,22 @@ def _build_table(statistic_data, items.items(), key=lambda x: x[1].min_cpu_time) elif sorted_by == SortedKeys.GPUTotal: sorted_items = sorted( - items.items(), key=lambda x: x[1].gpu_time, reverse=True) + items.items(), + key=lambda x: x[1].general_gpu_time, + reverse=True) elif sorted_by == SortedKeys.GPUAvg: sorted_items = sorted( items.items(), - key=lambda x: x[1].avg_gpu_time, + key=lambda x: x[1].avg_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMax: sorted_items = sorted( items.items(), - key=lambda x: x[1].max_gpu_time, + key=lambda x: x[1].max_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMin: sorted_items = sorted( - items.items(), key=lambda x: x[1].min_gpu_time) + items.items(), key=lambda x: x[1].min_general_gpu_time) for name, item in sorted_items: row_values = [ @@ -967,14 +1045,15 @@ def _build_table(statistic_data, format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - item.gpu_time, unit=time_unit), + item.general_gpu_time, unit=time_unit), format_time( - item.avg_gpu_time, unit=time_unit), + item.avg_general_gpu_time, unit=time_unit), format_time( - item.max_gpu_time, unit=time_unit), + item.max_general_gpu_time, unit=time_unit), format_time( - item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_time)) + item.min_general_gpu_time, unit=time_unit), + format_ratio( + float(item.general_gpu_time) / total_time)) ] all_row_values.append(row_values) if op_detail: @@ -998,18 +1077,23 @@ def _build_table(statistic_data, float(innerop_node.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - innerop_node.gpu_time, unit=time_unit), + innerop_node.general_gpu_time, + unit=time_unit), format_time( - innerop_node.avg_gpu_time, unit=time_unit), + innerop_node.avg_general_gpu_time, + unit=time_unit), format_time( - innerop_node.max_gpu_time, unit=time_unit), + innerop_node.max_general_gpu_time, + unit=time_unit), format_time( - innerop_node.min_gpu_time, unit=time_unit), + innerop_node.min_general_gpu_time, + unit=time_unit), format_ratio( - float(innerop_node.gpu_time) / total_time)) + float(innerop_node.general_gpu_time) / + total_time)) ] all_row_values.append(row_values) - for device_node_name, devicenode in innerop_node.devices.items( + for device_node_name, device_node in innerop_node.devices.items( ): if len(device_node_name) + 4 > name_column_width: device_node_name = device_node_name[: @@ -1018,21 +1102,21 @@ def _build_table(statistic_data, device_node_name += "..." row_values = [ ' {}'.format(device_node_name), - devicenode.call, '- / - / - / - / -', + device_node.call, '- / - / - / - / -', '{} / {} / {} / {} / {}'.format( format_time( - devicenode.gpu_time, unit=time_unit), + device_node.gpu_time, unit=time_unit), format_time( - devicenode.avg_gpu_time, + device_node.avg_gpu_time, unit=time_unit), format_time( - devicenode.max_gpu_time, + device_node.max_gpu_time, unit=time_unit), format_time( - devicenode.min_gpu_time, + device_node.min_gpu_time, unit=time_unit), format_ratio( - float(devicenode.gpu_time) / + float(device_node.gpu_time) / total_time)) ] all_row_values.append(row_values) @@ -1043,19 +1127,19 @@ def _build_table(statistic_data, - 5] device_node_name += "..." row_values = [ - ' {}'.format(device_node_name), devicenode.call, + ' {}'.format(device_node_name), device_node.call, '- / - / - / - / -', '{} / {} / {} / {} / {}'.format( format_time( - devicenode.gpu_time, unit=time_unit), + device_node.gpu_time, unit=time_unit), format_time( - devicenode.avg_gpu_time, unit=time_unit), + device_node.avg_gpu_time, unit=time_unit), format_time( - devicenode.max_gpu_time, unit=time_unit), + device_node.max_gpu_time, unit=time_unit), format_time( - devicenode.min_gpu_time, unit=time_unit), + device_node.min_gpu_time, unit=time_unit), format_ratio( - float(devicenode.gpu_time) / total_time)) + float(device_node.gpu_time) / total_time)) ] all_row_values.append(row_values) # Calculate the column width @@ -1123,14 +1207,14 @@ def _build_table(statistic_data, format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - item.gpu_time, unit=time_unit), + item.general_gpu_time, unit=time_unit), format_time( - item.avg_gpu_time, unit=time_unit), + item.avg_general_gpu_time, unit=time_unit), format_time( - item.max_gpu_time, unit=time_unit), + item.max_general_gpu_time, unit=time_unit), format_time( - item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_time)), + item.min_general_gpu_time, unit=time_unit), + format_ratio(float(item.general_gpu_time) / total_time)), ] all_row_values.append(row_values) @@ -1207,20 +1291,22 @@ def _build_table(statistic_data, items.items(), key=lambda x: x[1].min_cpu_time) elif sorted_by == SortedKeys.GPUTotal: sorted_items = sorted( - items.items(), key=lambda x: x[1].gpu_time, reverse=True) + items.items(), + key=lambda x: x[1].general_gpu_time, + reverse=True) elif sorted_by == SortedKeys.GPUAvg: sorted_items = sorted( items.items(), - key=lambda x: x[1].avg_gpu_time, + key=lambda x: x[1].avg_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMax: sorted_items = sorted( items.items(), - key=lambda x: x[1].max_gpu_time, + key=lambda x: x[1].max_general_gpu_time, reverse=True) elif sorted_by == SortedKeys.GPUMin: sorted_items = sorted( - items.items(), key=lambda x: x[1].min_gpu_time) + items.items(), key=lambda x: x[1].min_general_gpu_time) for name, item in sorted_items: row_values = [ @@ -1238,14 +1324,15 @@ def _build_table(statistic_data, format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( - item.gpu_time, unit=time_unit), + item.general_gpu_time, unit=time_unit), format_time( - item.avg_gpu_time, unit=time_unit), + item.avg_general_gpu_time, unit=time_unit), format_time( - item.max_gpu_time, unit=time_unit), + item.max_general_gpu_time, unit=time_unit), format_time( - item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_time)), + item.min_general_gpu_time, unit=time_unit), + format_ratio( + float(item.general_gpu_time) / total_time)), ] all_row_values.append(row_values) -- GitLab