From 8ea834002152812d9d1781b4243444510852ab00 Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 11 Aug 2022 14:55:24 +0800 Subject: [PATCH] Add input shape record for new dygraph operator (#44999) * fix * add control flag and input shapes for new dygraph * fix file mode * improve code coverage * fix a bug in statstic * fix according to review * optimize performance * fix --- paddle/fluid/imperative/basic_engine.cc | 2 +- paddle/fluid/platform/profiler.cc | 50 +++++++++++ paddle/fluid/platform/profiler.h | 6 ++ paddle/fluid/platform/profiler/common_event.h | 17 ++++ .../fluid/platform/profiler/event_tracing.h | 1 + paddle/fluid/platform/profiler/mem_tracing.h | 1 + .../platform/profiler/supplement_tracing.h | 12 +++ paddle/fluid/pybind/pybind.cc | 6 ++ paddle/phi/api/yaml/generator/api_base.py | 90 +++++++++++++++++-- paddle/phi/api/yaml/generator/api_gen.py | 1 + .../api/yaml/generator/backward_api_gen.py | 1 + .../yaml/generator/intermediate_api_gen.py | 1 + python/paddle/fluid/dygraph/layers.py | 2 +- .../fluid/tests/unittests/test_newprofiler.py | 4 +- python/paddle/profiler/profiler.py | 19 +++- python/paddle/profiler/profiler_statistic.py | 7 +- 16 files changed, 207 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index e7caf15ee77..c4b622f9850 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -400,7 +400,7 @@ static void PerformBackwardInplace(const std::string& op_type, void BasicEngine::Execute() { platform::RecordEvent backward_record_event( - "backward", platform::TracerEventType::Operator, 1); + "backward", platform::TracerEventType::UserDefined, 1); if (init_nodes_.empty()) { return; diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 59233568512..d6103698971 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -42,6 +42,10 @@ DEFINE_bool(enable_host_event_recorder_hook, false, "enable HostEventRecorder, hook Profiler"); +DEFINE_bool(enable_record_input_shape, false, "enable input shape recorder"); + +DEFINE_bool(enable_record_memory, false, "enable memory recorder"); + namespace paddle { namespace platform { @@ -258,6 +262,9 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( if (FLAGS_enable_host_event_recorder_hook == false) { return; } + if (IsEnabled() == false) { + return; + } std::map> input_shapes; std::map> dtypes; for (auto it = ctx.inputs.begin(); it != ctx.inputs.end(); it++) { @@ -285,6 +292,9 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( if (FLAGS_enable_host_event_recorder_hook == false) { return; } + if (IsEnabled() == false) { + return; + } std::map> input_shapes; std::map> dtypes; for (auto it = kernel_signature.input_names.begin(); @@ -308,6 +318,33 @@ RecordOpInfoSupplement::RecordOpInfoSupplement( PosixInNsec(), type, input_shapes, dtypes, callstack); } +RecordOpInfoSupplement::RecordOpInfoSupplement( + const std::string &type, + const std::vector>> + &input_shapes) { + if (FLAGS_enable_host_event_recorder_hook == false) { + return; + } + if (IsEnabled() == false) { + return; + } + std::map> dtypes; + std::vector callstack; + HostEventRecorder::GetInstance().RecordEvent( + PosixInNsec(), type, input_shapes, dtypes, callstack); +} + +bool RecordEvent::IsEnabled() { + return FLAGS_enable_host_event_recorder_hook || g_enable_nvprof_hook || + g_state != ProfilerState::kDisabled; +} + +bool RecordOpInfoSupplement::IsEnabled() { + return FLAGS_enable_record_input_shape; +} + +bool RecordMemEvent::IsEnabled() { return FLAGS_enable_record_memory; } + std::map>> RecordMemEvent::size_cache; @@ -322,6 +359,11 @@ RecordMemEvent::RecordMemEvent(const void *ptr, FLAGS_enable_host_event_recorder_hook == false) { return; } + + if (IsEnabled() == false) { + return; + } + if (type == TracerMemEventType::Allocate) { uint64_t current_allocated; uint64_t peak_allocated; @@ -1045,6 +1087,14 @@ void DisableHostEventRecorder() { FLAGS_enable_host_event_recorder_hook = false; } +void EnableInputShapeRecorder() { FLAGS_enable_record_input_shape = true; } + +void DisableInputShapeRecorder() { FLAGS_enable_record_input_shape = false; } + +void EnableMemoryRecorder() { FLAGS_enable_record_memory = true; } + +void DisableMemoryRecorder() { FLAGS_enable_record_memory = false; } + std::string PrintHostEvents() { std::ostringstream oss; auto host_evt_sec = diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 6046e54b6c8..80e74b20eeb 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -248,6 +248,12 @@ void NvprofDisableRecordEvent(); void EnableHostEventRecorder(); void DisableHostEventRecorder(); +void EnableMemoryRecorder(); +void DisableMemoryRecorder(); + +void EnableInputShapeRecorder(); +void DisableInputShapeRecorder(); + // Defined for UT std::string PrintHostEvents(); diff --git a/paddle/fluid/platform/profiler/common_event.h b/paddle/fluid/platform/profiler/common_event.h index 837abbec9ad..aeab169b27d 100644 --- a/paddle/fluid/platform/profiler/common_event.h +++ b/paddle/fluid/platform/profiler/common_event.h @@ -124,6 +124,23 @@ struct OperatorSupplementOriginEvent { strncpy(buf, type_name.c_str(), type_name.length() + 1); op_type = buf; } + OperatorSupplementOriginEvent( + std::function arena_allocator, + uint64_t timestamp_ns, + const std::string &type_name, + const std::vector>> + &shapes, + const std::map> + &dtypes, + const std::vector callstack) + : timestamp_ns(timestamp_ns), dtypes(dtypes), callstack(callstack) { + auto buf = static_cast(arena_allocator(type_name.length() + 1)); + strncpy(buf, type_name.c_str(), type_name.length() + 1); + op_type = buf; + for (auto it = shapes.begin(); it != shapes.end(); it++) { + input_shapes[std::string((*it).first)] = (*it).second; + } + } uint64_t timestamp_ns; const char *op_type = nullptr; // not owned, designed for performance // input shapes diff --git a/paddle/fluid/platform/profiler/event_tracing.h b/paddle/fluid/platform/profiler/event_tracing.h index 0219614bd1e..82d66068c33 100644 --- a/paddle/fluid/platform/profiler/event_tracing.h +++ b/paddle/fluid/platform/profiler/event_tracing.h @@ -48,6 +48,7 @@ struct RecordInstantEvent { // Chrome Trace Viewer Format: Duration Event/Complte Event class RecordEvent { public: + static bool IsEnabled(); /** * @param name: If your string argument has a longer lifetime (e.g.: string * literal, static variables, etc) than the event, use 'const char* name'. diff --git a/paddle/fluid/platform/profiler/mem_tracing.h b/paddle/fluid/platform/profiler/mem_tracing.h index 5b2a2391c2e..d180791e4d2 100644 --- a/paddle/fluid/platform/profiler/mem_tracing.h +++ b/paddle/fluid/platform/profiler/mem_tracing.h @@ -27,6 +27,7 @@ namespace platform { // The events can be used to draw memory variation curve. class RecordMemEvent { public: + static bool IsEnabled(); /** * @param ptr: Pointer address allocated or free. * @param place: Device for this memory event. diff --git a/paddle/fluid/platform/profiler/supplement_tracing.h b/paddle/fluid/platform/profiler/supplement_tracing.h index 270223d13b2..7f82155809d 100644 --- a/paddle/fluid/platform/profiler/supplement_tracing.h +++ b/paddle/fluid/platform/profiler/supplement_tracing.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once +#include #include +#include #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/type_defs.h" @@ -30,6 +32,8 @@ namespace platform { class RecordOpInfoSupplement { public: + static bool IsEnabled(); + /** * @param type: Operator type name. * @param attrs: Attribute map of op. @@ -50,6 +54,14 @@ class RecordOpInfoSupplement { const framework::AttributeMap& attrs, const framework::InferShapeContext& shape_ctx, const phi::KernelSignature& kernel_signature); + + /** + * + */ + explicit RecordOpInfoSupplement( + const std::string& type, + const std::vector>>& + input_shapes); }; } // namespace platform diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 5575d839a2f..1fe424eccd9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2133,6 +2133,12 @@ All parameter, weight, gradient are variables in Paddle. .value("PythonUserDefined", paddle::platform::TracerEventType::PythonUserDefined); m.def("load_profiler_result", &paddle::platform::LoadProfilerResult); + m.def("enable_memory_recorder", &paddle::platform::EnableMemoryRecorder); + m.def("disable_memory_recorder", &paddle::platform::DisableMemoryRecorder); + m.def("enable_input_shape_recorder", + &paddle::platform::EnableInputShapeRecorder); + m.def("disable_input_shape_recorder", + &paddle::platform::DisableInputShapeRecorder); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) m.def("set_cublas_switch", platform::SetAllowTF32Cublas); diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 88903763d80..7e32bcf3e5c 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -13,6 +13,7 @@ # limitations under the License. import re +import collections PREFIX_TENSOR_NAME = 'input_' PREFIX_META_TENSOR_NAME = 'meta_' @@ -569,6 +570,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d kernel_param = input_names + attr_names input_tensor_code = "" + input_name_tensor_map = collections.defaultdict(list) for i, input_name in enumerate(input_names): # set input code if input_name in kernel_param: @@ -582,17 +584,23 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d 'support_trans_dtype']: trans_flag = "{false, true}" if input_name in self.optional_vars: + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", False)) input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});""" else: if self.inputs['input_info'][ input_name] == "const Tensor&": + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", False)) input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});""" elif self.inputs['input_info'][ input_name] == "const std::vector&": + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)) input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag}); {code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size()); @@ -604,8 +612,11 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d # do nothing pass else: # input is selected_rows + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", False)) input_tensor_code = input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name}); +""" else: if input_name in self.infer_meta['param']: if input_name in self.optional_vars: @@ -621,7 +632,65 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d else: input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();""" - + input_tensor_code = input_tensor_code + f""" +{code_indent} if(platform::RecordOpInfoSupplement::IsEnabled()){{""" + single_tensor_names = [] + list_tensor_names = [] + for input_name, input_tensors in input_name_tensor_map.items(): + has_vector_tensor = False + for input_tensor, is_vector in input_tensors: + if is_vector is True: + has_vector_tensor = True + if has_vector_tensor is False: + single_tensor_names.append(input_name) + else: + list_tensor_names.append(input_name) + if not single_tensor_names: + input_tensor_code = input_tensor_code + f""" +{code_indent} std::vector>> input_shapes;""" + else: + input_tensor_code = input_tensor_code + f""" +{code_indent} std::vector>> input_shapes{{""" + for input_name in single_tensor_names[:-1]: + input_tensors = input_name_tensor_map[input_name] + input_tensor_code = input_tensor_code + f""" +{code_indent} {{"{input_name}", {{""" + for input_tensor, _ in input_tensors[:-1]: + input_tensor_code = input_tensor_code + f""" +{code_indent} (*{input_tensor}).dims(),""" + input_tensor_code = input_tensor_code + f""" +{code_indent} (*{input_tensors[-1][0]}).dims()}}}},""" + input_tensors = input_name_tensor_map[single_tensor_names[-1]] + input_tensor_code = input_tensor_code + f""" +{code_indent} {{"{single_tensor_names[-1]}", {{""" + for input_tensor, _ in input_tensors[:-1]: + input_tensor_code = input_tensor_code + f""" +{code_indent} (*{input_tensor}).dims(),""" + input_tensor_code = input_tensor_code + f""" +{code_indent} (*{input_tensors[-1][0]}).dims()}}}}}};""" + if list_tensor_names: + input_tensor_code = input_tensor_code + f""" +{code_indent} std::vector ddims_vec;""" + for input_name in list_tensor_names: + input_tensor_code = input_tensor_code + f""" +{code_indent} ddims_vec.clear();""" + for input_tensor, is_vector in input_name_tensor_map[input_name]: + if is_vector: + input_tensor_code = input_tensor_code + f""" +{code_indent} ddims_vec.reserve({input_tensor[:-4]}.size()); +{code_indent} for (size_t i = 0; i < {input_tensor[:-4]}.size(); ++i) {{ +{code_indent} ddims_vec.emplace_back((*{input_tensor[:-4]}[i]).dims()); +{code_indent} }}""" + else: + input_tensor_code = input_tensor_code + f""" + ddims_vec.emplace_back((*{input_tensor}).dims()); +{code_indent} """ + input_tensor_code = input_tensor_code + f""" +{code_indent} input_shapes.emplace_back("{input_name}", ddims_vec);""" + + input_tensor_code = input_tensor_code + f""" +{code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes); +{code_indent} }}""" kernel_args = ["*dev_ctx"] for param in kernel_param: if param in input_names: @@ -709,17 +778,26 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d {code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args}); {code_indent} const auto& kernel = kernel_result.kernel; {code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel; - {code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); {input_tensors} {output_create} +{code_indent} paddle::platform::RecordEvent *infer_shape_record_event = nullptr; +{code_indent} if(paddle::platform::RecordEvent::IsEnabled()){{ +{code_indent} infer_shape_record_event = new paddle::platform::RecordEvent(\"{self.api} infer_meta\", paddle::platform::TracerEventType::OperatorInner, 1); +{code_indent} }} {self.gene_infer_meta(kernel_output_names, code_indent)} - +{code_indent} if(infer_shape_record_event != nullptr){{ +{code_indent} delete infer_shape_record_event; +{code_indent} }} {code_indent} using kernel_signature = {kernel_signature}; {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn(); -{code_indent} {{ -{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{kernel_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1); +{code_indent} paddle::platform::RecordEvent* kernel_record_event = nullptr; +{code_indent} if(paddle::platform::RecordEvent::IsEnabled()){{ +{code_indent} kernel_record_event = new paddle::platform::RecordEvent(\"{self.api} compute\", paddle::platform::TracerEventType::OperatorInner, 1); +{code_indent} }} {code_indent} (*kernel_fn)({kernel_args}, {", ".join(outputs_args)}); +{code_indent} if(kernel_record_event != nullptr){{ +{code_indent} delete kernel_record_event; {code_indent} }} {code_indent} if (kernel_result.has_fallback_cpu) {{ {fallback_kernel_output_trans} diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index dc4581472f8..1eb030f7f9b 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -252,6 +252,7 @@ def source_include(header_file_path): #include "paddle/phi/infermeta/ternary.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/fluid/platform/profiler/supplement_tracing.h" DECLARE_bool(conv2d_disable_cudnn); """ diff --git a/paddle/phi/api/yaml/generator/backward_api_gen.py b/paddle/phi/api/yaml/generator/backward_api_gen.py index cb57b044597..187f8e8e4fa 100644 --- a/paddle/phi/api/yaml/generator/backward_api_gen.py +++ b/paddle/phi/api/yaml/generator/backward_api_gen.py @@ -220,6 +220,7 @@ def source_include(header_file_path): #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/fluid/platform/profiler/supplement_tracing.h" DECLARE_bool(conv2d_disable_cudnn); """ diff --git a/paddle/phi/api/yaml/generator/intermediate_api_gen.py b/paddle/phi/api/yaml/generator/intermediate_api_gen.py index c8ba88d054a..7834e5c230c 100644 --- a/paddle/phi/api/yaml/generator/intermediate_api_gen.py +++ b/paddle/phi/api/yaml/generator/intermediate_api_gen.py @@ -52,6 +52,7 @@ def source_include(header_file_path): #include "paddle/phi/infermeta/ternary.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/fluid/platform/profiler/supplement_tracing.h" """ diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 54c2cb29d92..cf2bd3a69bd 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -927,7 +927,7 @@ class Layer(object): self._built = True if in_profiler_mode(): - with profiler.RecordEvent(self.full_name(), + with profiler.RecordEvent(self.__class__.__name__, profiler.TracerEventType.Forward): outputs = self.forward(*inputs, **kwargs) else: diff --git a/python/paddle/fluid/tests/unittests/test_newprofiler.py b/python/paddle/fluid/tests/unittests/test_newprofiler.py index 99097aaf004..097911c2459 100755 --- a/python/paddle/fluid/tests/unittests/test_newprofiler.py +++ b/python/paddle/fluid/tests/unittests/test_newprofiler.py @@ -135,7 +135,9 @@ class TestProfiler(unittest.TestCase): record=2, repeat=1, skip_first=1), - on_trace_ready=my_trace_back) as prof: + on_trace_ready=my_trace_back, + profile_memory=True, + record_shapes=True) as prof: for i in range(5): y = x / 2.0 paddle.grad(outputs=y, inputs=[x], grad_outputs=ones_like_y) diff --git a/python/paddle/profiler/profiler.py b/python/paddle/profiler/profiler.py index 2df26020b8f..b2e95d24b0b 100644 --- a/python/paddle/profiler/profiler.py +++ b/python/paddle/profiler/profiler.py @@ -23,7 +23,10 @@ import json import paddle from paddle.fluid.core import (_Profiler, _ProfilerResult, ProfilerOptions, - TracerEventType) + TracerEventType, enable_memory_recorder, + enable_input_shape_recorder, + disable_memory_recorder, + disable_input_shape_recorder) from .utils import RecordEvent, wrap_optimizers from .profiler_statistic import StatisticData, _build_table, SortedKeys @@ -279,6 +282,8 @@ class Profiler: This callable object will be called when ``scheduler`` returns ``ProfilerState.RECORD_AND_RETURN``. The default value is :ref:`export_chrome_tracing ` (./profiler_log/). timer_only (bool, optional): If it is True, the cost of Dataloader and every step of the model will be count without profiling. Otherwise, the model will be timed and profiled. Default: False. + record_shapes (bool, optional): If it is True, collect op's input shape information. Default: False. + profile_memory (bool, optional): If it is True, collect tensor memory allocation and release information. Default: False. Examples: 1. profiling range [2, 5). @@ -396,6 +401,8 @@ class Profiler: scheduler: Union[Callable[[int], ProfilerState], tuple, None] = None, on_trace_ready: Optional[Callable[..., Any]] = None, + record_shapes: Optional[bool] = False, + profile_memory=False, timer_only: Optional[bool] = False): supported_targets = _get_supported_targets() if targets: @@ -447,6 +454,8 @@ class Profiler: self.record_event = None self.profiler_result = None self.timer_only = timer_only + self.record_shapes = record_shapes + self.profile_memory = profile_memory def __enter__(self): self.start() @@ -481,6 +490,10 @@ class Profiler: benchmark().begin() if self.timer_only: return + if self.record_shapes: + enable_input_shape_recorder() + if self.profile_memory: + enable_memory_recorder() # CLOSED -> self.current_state utils._is_profiler_used = True if self.current_state == ProfilerState.READY: @@ -520,6 +533,10 @@ class Profiler: benchmark().end() if self.timer_only: return + if self.record_shapes: + disable_input_shape_recorder() + if self.profile_memory: + disable_memory_recorder() # self.current_state -> CLOSED # In this situation, RECORD state is regarded as RECORD_AND_RETURN if self.record_event: diff --git a/python/paddle/profiler/profiler_statistic.py b/python/paddle/profiler/profiler_statistic.py index 63cdafbbf8d..5c8afd1b3b5 100755 --- a/python/paddle/profiler/profiler_statistic.py +++ b/python/paddle/profiler/profiler_statistic.py @@ -86,6 +86,7 @@ class HostStatisticNode: for rt in self.runtime_node: rt.cal_statistic() self.cpu_time = self.hostnode.end_ns - self.hostnode.start_ns + self.self_cpu_time = self.cpu_time for child in self.children_node: self.gpu_time += child.gpu_time self.general_gpu_time += child.general_gpu_time @@ -918,7 +919,7 @@ def _build_table(statistic_data, accmulation_time = 0 gpu_accmulation_time = 0 gpu_total_time = statistic_data.event_summary.model_perspective_items[ - 'ProfileStep'].general_gpu_time + 'ProfileStep'].gpu_time for name in [ 'ProfileStep', 'Dataloader', 'Forward', 'Backward', 'Optimization' @@ -928,7 +929,7 @@ def _build_table(statistic_data, if gpu_total_time == 0: gpu_ratio = 0 else: - gpu_ratio = float(item.general_gpu_time) / gpu_total_time + gpu_ratio = float(item.gpu_time) / gpu_total_time name = '{}'.format( name) if 'ProfileStep' in name else ' {}'.format(name) row_values = [ @@ -949,7 +950,7 @@ def _build_table(statistic_data, all_row_values.append(row_values) if 'ProfileStep' not in name: accmulation_time += item.cpu_time - gpu_accmulation_time += item.general_gpu_time + gpu_accmulation_time += item.gpu_time other_time = total_time - accmulation_time other_gpu_time = gpu_total_time - gpu_accmulation_time -- GitLab