diff --git a/paddle/fluid/platform/profiler/cpu_utilization.cc b/paddle/fluid/platform/profiler/cpu_utilization.cc index ce2e49a1ccd39accb8830943759d361d15d12d9d..d507153d3f5b47ef072f9da0276073448127fb9c 100644 --- a/paddle/fluid/platform/profiler/cpu_utilization.cc +++ b/paddle/fluid/platform/profiler/cpu_utilization.cc @@ -118,8 +118,9 @@ float CpuUtilization::GetCpuUtilization() { float busy_time = (system_kernel_time_end - system_kernel_time_start) + (system_user_time_end - system_user_time_start); float idle_time = system_idle_time_end - system_idle_time_start; - cpu_utilization = busy_time / (busy_time + idle_time); - + if (busy_time + idle_time != 0) { + cpu_utilization = busy_time / (busy_time + idle_time); + } #elif defined(__linux__) float busy_time = (system_tms_end_.tms_utime - system_tms_start_.tms_utime) + (system_tms_end_.tms_stime - system_tms_start_.tms_stime) + @@ -127,7 +128,9 @@ float CpuUtilization::GetCpuUtilization() { (irq_end_ - irq_start_) + (softirq_end_ - softirq_start_) + (steal_end_ - steal_start_); float idle_time = (idle_end_ - idle_start_) + (iowait_end_ - iowait_start_); - cpu_utilization = busy_time / (busy_time + idle_time); + if (busy_time + idle_time != 0) { + cpu_utilization = busy_time / (busy_time + idle_time); + } #else LOG(WARNING) << "Current System is not supported to get system cpu utilization" @@ -148,13 +151,16 @@ float CpuUtilization::GetCpuCurProcessUtilization() { uint64_t end = FileTimeToUint64(end_); float busy_time = (process_kernel_time_end - process_kernel_time_start) + (process_user_time_end - process_user_time_start); - cpu_process_utilization = busy_time / (end - start); - LOG(INFO) << "Process Utilization = " << cpu_process_utilization << std::endl; + if (end - start != 0) { + cpu_process_utilization = busy_time / (end - start); + } #elif defined(__linux__) float busy_time = (process_tms_end_.tms_utime - process_tms_start_.tms_utime) + (process_tms_end_.tms_stime - process_tms_start_.tms_stime); - cpu_process_utilization = busy_time / (end_ - start_); + if (end_ - start_ != 0) { + cpu_process_utilization = busy_time / (end_ - start_); + } #else LOG(WARNING) << "Current System is not supported to get process cpu utilization" diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index 46cbb3358c6c4d6b2b17cfc1e549db6376931389..ac46fbed10a2022324438e4718261813a5c38b19 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -44,6 +44,14 @@ std::unique_ptr Profiler::Create(const ProfilerOptions& options) { return std::unique_ptr(new Profiler(options)); } +bool Profiler::IsCuptiSupported() { + bool supported = false; +#ifdef PADDLE_WITH_CUPTI + supported = true; +#endif + return supported; +} + Profiler::Profiler(const ProfilerOptions& options) { options_ = options; std::bitset<32> trace_switch(options_.trace_switch); diff --git a/paddle/fluid/platform/profiler/profiler.h b/paddle/fluid/platform/profiler/profiler.h index f9a8ece050492805226cccce001251c3cd2ad0c2..d24ee504bc6407230da875ab5e29251740d72822 100644 --- a/paddle/fluid/platform/profiler/profiler.h +++ b/paddle/fluid/platform/profiler/profiler.h @@ -43,6 +43,8 @@ class Profiler { public: static std::unique_ptr Create(const ProfilerOptions& options); + static bool IsCuptiSupported(); + void Prepare(); void Start(); diff --git a/paddle/fluid/platform/profiler/utils.cc b/paddle/fluid/platform/profiler/utils.cc index b43389866c7a8150846bef874f49bd72907f446f..de314d298c90ea7c70d9d244b78cbc46feae9a9c 100644 --- a/paddle/fluid/platform/profiler/utils.cc +++ b/paddle/fluid/platform/profiler/utils.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/dynload/cupti.h" namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/profiler/utils.h b/paddle/fluid/platform/profiler/utils.h index cd56d343842686abc31343effc93cf1a4887411c..b471d6b79833a17eca35fe44c9d4917684aa8bcc 100644 --- a/paddle/fluid/platform/profiler/utils.h +++ b/paddle/fluid/platform/profiler/utils.h @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/platform/dynload/cupti.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/os_info.h" diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index dcfad030a689c278b72a0061cfb170762d1a3156..f5c853fb4b8ee251edac8bc69cf64da87ac71189 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3322,6 +3322,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "_Profiler") .def("create", &paddle::platform::Profiler::Create, py::return_value_policy::take_ownership) + .def("is_cupti_supported", &paddle::platform::Profiler::IsCuptiSupported) .def("prepare", [](paddle::platform::Profiler *profiler) { platform::EnableHostEventRecorder(); diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 5385ac28b90f614fcd6003994b9a7000bc16702a..da66530f81b0a50ad432f72a10eeee354127c53a 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -30,6 +30,7 @@ from paddle.fluid.framework import _set_expected_place, _current_expected_place, import queue import paddle +import paddle.profiler as profiler from .. import core, layers from ..framework import in_dygraph_mode, _in_eager_mode from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL, CleanupFuncRegistrar @@ -250,6 +251,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): self._exit_thread_expectedly() def __next__(self): + trace_event = profiler.RecordEvent( + name="_DataLoaderIterSingleProcess", + event_type=profiler.TracerEventType.Dataloader) + trace_event.begin() try: if in_dygraph_mode(): if _in_eager_mode(): @@ -283,6 +288,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): self._reader.shutdown() self._try_shutdown_all() six.reraise(*sys.exc_info()) + finally: + trace_event.end() def _shutdown_thread(self): if self._thread: @@ -695,6 +702,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._try_shutdown_all(1) def __next__(self): + trace_event = profiler.RecordEvent( + name="_DataLoaderIterMultiProcess", + event_type=profiler.TracerEventType.Dataloader) + trace_event.begin() try: # _batches_outstanding here record the total batch data number # in 'from after _try_put_indices to beforeoutput data', this @@ -743,6 +754,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._reader.shutdown() self._try_shutdown_all() six.reraise(*sys.exc_info()) + finally: + trace_event.end() # python2 compatibility def next(self): diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index f4334085620f510e3d520f89332b754a93aa120a..37db9f8fce77a63773223888c8896822d56ba1e4 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -25,6 +25,7 @@ from copy import deepcopy import inspect import paddle +import paddle.profiler as profiler from . import parallel_helper from .. import unique_name @@ -905,7 +906,9 @@ class Layer(object): self._built = True - outputs = self.forward(*inputs, **kwargs) + with profiler.RecordEvent(self.full_name(), + profiler.TracerEventType.Forward): + outputs = self.forward(*inputs, **kwargs) for forward_post_hook in self._forward_post_hooks.values(): hook_result = forward_post_hook(self, inputs, outputs) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index af30b2b2444b44f1b27e8f277eb380557255517d..24284ca78c1ce98116945d7578e0f0cdf557e89b 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -28,6 +28,7 @@ from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE import paddle.utils.deprecated as deprecated +import paddle.profiler as profiler from paddle import _C_ops @@ -199,8 +200,8 @@ def monkey_patch_varbase(): You can clear gradient by ``Tensor.clear_grad()`` . Args: - grad_tensor(Tensor, optional): initial gradient values of the current Tensor. If `grad_tensor` is None, - the initial gradient values of the current Tensor would be Tensor filled with 1.0; + grad_tensor(Tensor, optional): initial gradient values of the current Tensor. If `grad_tensor` is None, + the initial gradient values of the current Tensor would be Tensor filled with 1.0; if `grad_tensor` is not None, it must have the same length as the current Tensor. Teh default value is None. @@ -243,6 +244,9 @@ def monkey_patch_varbase(): """ if framework.in_dygraph_mode(): + record_event = profiler.RecordEvent( + "Gradient Backward", profiler.TracerEventType.Backward) + record_event.begin() if grad_tensor is not None: if core._in_eager_mode(): assert isinstance( @@ -278,6 +282,7 @@ def monkey_patch_varbase(): core.dygraph_run_backward([self], [grad_tensor], retain_graph, framework._dygraph_tracer()) + record_event.end() else: raise ValueError( "Variable.backward() is only available in DyGraph mode") @@ -476,7 +481,7 @@ def monkey_patch_varbase(): def grad(self): """ .. warning:: - This API will return the tensor value of the gradient. If you want + This API will return the tensor value of the gradient. If you want to get the numpy value of the gradient, you can use :code:`x.grad.numpy()`. Get the Gradient of Current Tensor. @@ -515,7 +520,7 @@ def monkey_patch_varbase(): def item(self, *args): """ - Convert element at specific position in Tensor into Python scalars. If the position is not specified, the Tensor must be a + Convert element at specific position in Tensor into Python scalars. If the position is not specified, the Tensor must be a single-element Tensor. Args: @@ -526,7 +531,7 @@ def monkey_patch_varbase(): Raises: ValueError: If the Tensor has more than one element, there must be coordinates. - + Examples: .. code-block:: python @@ -588,7 +593,7 @@ def monkey_patch_varbase(): import paddle x = paddle.rand([2, 5]) print(x) - + # Tensor(shape=[2, 5], dtype=float32, place=CPUPlace, # [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436], # [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]]) @@ -611,7 +616,7 @@ def monkey_patch_varbase(): import copy x = paddle.to_tensor(2.) y = copy.deepcopy(x) - + print(x) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, # [2.]) @@ -655,7 +660,7 @@ def monkey_patch_varbase(): def __array__(self, dtype=None): """ Returns a numpy array shows the value of current Tensor. - + Returns: ndarray: The numpy value of current Tensor. diff --git a/python/paddle/fluid/profiler.py b/python/paddle/fluid/profiler.py index 183a00bd70bdff1ec37767f06a5a3944aa9882e8..4d39d38853063616bced2b76f86c3f8e9b66aa48 100644 --- a/python/paddle/fluid/profiler.py +++ b/python/paddle/fluid/profiler.py @@ -20,6 +20,8 @@ import os import six import sys +from paddle.utils.deprecated import deprecated + __all__ = [ 'cuda_profiler', 'reset_profiler', 'profiler', 'start_profiler', 'stop_profiler' @@ -36,10 +38,16 @@ NVPROF_CONFIG = [ ] +@deprecated( + since="2.3.0", + update_to="paddle.profiler.Profiler", + level=1, + reason="Please use new profiler tool, this profiler tool is no longer maintained." +) @signature_safe_contextmanager def cuda_profiler(output_file, output_mode=None, config=None): """ - API cuda_profiler has been abandoned. If you have relevant requirements, you can use `paddle.utils.profiler.start_profiler` and `paddle.utils.profiler.stop_profiler`. + API cuda_profiler has been abandoned. If you have relevant requirements, you can use `paddle.utils.profiler.start_profiler` and `paddle.utils.profiler.stop_profiler`. The relevant reference documents are as follows: @@ -54,18 +62,18 @@ def cuda_profiler(output_file, output_mode=None, config=None): def npu_profiler(output_file, config=None): """ The NPU profiler. - + This fuctions is used to profile NPU program by NPU runtime application programming interface. The profiling result will be written into - `output_file`. The users can set set the NPU profiling config by `config` argument. - - After getting the profiling result file, users can use - `tools provided by Ascend `_ + `output_file`. The users can set set the NPU profiling config by `config` argument. + + After getting the profiling result file, users can use + `tools provided by Ascend `_ to load this output file to visualize results. Args: output_file (str) : The output file name, the result will be - written into this file. It should be absolute path. + written into this file. It should be absolute path. config (list, optional) : NPU profile config. For more details, please refer to `User Guide `_ . @@ -109,6 +117,12 @@ def npu_profiler(output_file, config=None): core.npu_prof_finalize() +@deprecated( + since="2.3.0", + update_to="paddle.profiler.Profiler", + level=1, + reason="Please use new profiler tool, this profiler tool is no longer maintained." +) def reset_profiler(): """ Clear the previous time record. It works for @@ -131,31 +145,38 @@ def reset_profiler(): core.reset_profiler() +@deprecated( + since="2.3.0", + update_to="paddle.profiler.Profiler", + level=1, + reason="Please use new profiler tool, this profiler tool is no longer maintained." +) def start_profiler(state, tracer_option='Default'): """ Enable the profiler. Uers can use `fluid.profiler.start_profiler` and - `fluid.profiler.stop_profiler` to profile, which is equal to the usage + `fluid.profiler.stop_profiler` to profile, which is equal to the usage of `fluid.profiler.profiler` interface. Args: state (str) : The profiling state, which should be one of 'CPU', 'GPU' or 'All'. 'CPU' means only profiling CPU; 'GPU' means profiling - both CPU and GPU; 'All' means profiling both CPU and GPU, and + both CPU and GPU; 'All' means profiling both CPU and GPU, and generates timeline as well. tracer_option (str, optional) : tracer_option can be one of ['Default', 'OpDetail', 'AllOpDetail'], it - can control the profile level and print the different level profile result. `Default` option print - the different Op type profiling result and the `OpDetail` option print the detail profiling - result of different op types such as compute and data transform, `AllOpDetail` option + can control the profile level and print the different level profile result. `Default` option print + the different Op type profiling result and the `OpDetail` option print the detail profiling + result of different op types such as compute and data transform, `AllOpDetail` option print the detail profiling result of different op name same as `OpDetail`. Raises: - ValueError: If `state` is not in ['CPU', 'GPU', 'All'] or `tracer_option` + ValueError: If `state` is not in ['CPU', 'GPU', 'All'] or `tracer_option` is not in ['Default', 'OpDetail', 'AllOpDetail']. Examples: .. code-block:: python + # required: gpu import paddle.fluid as fluid import paddle.fluid.profiler as profiler @@ -165,7 +186,7 @@ def start_profiler(state, tracer_option='Default'): profiler.reset_profiler() # except each iteration profiler.stop_profiler('total', '/tmp/profile') - + profiler.start_profiler('GPU', "OpDetail") for iter in range(10): if iter == 2: @@ -198,14 +219,20 @@ def start_profiler(state, tracer_option='Default'): core.enable_profiler(prof_state) +@deprecated( + since="2.3.0", + update_to="paddle.profiler.Profiler", + level=1, + reason="Please use new profiler tool, this profiler tool is no longer maintained." +) def stop_profiler(sorted_key=None, profile_path='/tmp/profile'): """ Stop the profiler. Uers can use `fluid.profiler.start_profiler` and - `fluid.profiler.stop_profiler` to profile, which is equal to the usage + `fluid.profiler.stop_profiler` to profile, which is equal to the usage of `fluid.profiler.profiler` interface. Args: - sorted_key (str, optional) : The order of profiling results, which + sorted_key (str, optional) : The order of profiling results, which should be one of None, 'calls', 'total', 'max', 'min' or 'ave'. Default is None, means the profiling results will be printed in the order of first end time of events. @@ -214,7 +241,7 @@ def stop_profiler(sorted_key=None, profile_path='/tmp/profile'): The `max` means sorting by the maximum execution time. The `min` means sorting by the minimum execution time. The `ave` means sorting by the average execution time. - and write it into `profile_path`. The default profile_path is `/tmp/profile`. + and write it into `profile_path`. The default profile_path is `/tmp/profile`. profile_path (str, optional) : If state == 'All', it will generate timeline, Raises: @@ -225,6 +252,7 @@ def stop_profiler(sorted_key=None, profile_path='/tmp/profile'): .. code-block:: python + # required: gpu import paddle.fluid as fluid import paddle.fluid.profiler as profiler @@ -254,6 +282,12 @@ def stop_profiler(sorted_key=None, profile_path='/tmp/profile'): core.disable_profiler(key_map[sorted_key], profile_path) +@deprecated( + since="2.3.0", + update_to="paddle.profiler.Profiler", + level=1, + reason="Please use new profiler tool, this profiler tool is no longer maintained." +) @signature_safe_contextmanager def profiler(state, sorted_key=None, @@ -265,9 +299,9 @@ def profiler(state, Args: state (str) : The profiling state, which should be one of 'CPU', 'GPU' or 'All'. 'CPU' means only profiling CPU; 'GPU' means profiling - both CPU and GPU; 'All' means profiling both CPU and GPU, and + both CPU and GPU; 'All' means profiling both CPU and GPU, and generates timeline as well. - sorted_key (str, optional) : The order of profiling results, which + sorted_key (str, optional) : The order of profiling results, which should be one of None, 'calls', 'total', 'max', 'min' or 'ave'. Default is None, means the profiling results will be printed in the order of first end time of events. @@ -277,11 +311,11 @@ def profiler(state, The `min` means sorting by the minimum execution time. The `ave` means sorting by the average execution time. profile_path (str, optional) : If state == 'All', it will generate timeline, - and write it into `profile_path`. The default profile_path is `/tmp/profile`. + and write it into `profile_path`. The default profile_path is `/tmp/profile`. tracer_option (str, optional) : tracer_option can be one of ['Default', 'OpDetail', 'AllOpDetail'], it - can control the profile level and print the different level profile result. `Default` option print - the different Op type profiling result and the `OpDetail` option print the detail profiling - result of different op types such as compute and data transform, `AllOpDetail` option + can control the profile level and print the different level profile result. `Default` option print + the different Op type profiling result and the `OpDetail` option print the detail profiling + result of different op types such as compute and data transform, `AllOpDetail` option print the detail profiling result of different op name same as `OpDetail`. Raises: @@ -319,7 +353,7 @@ def profiler(state, #### Examples Results #### #### 1) sorted_key = 'total', 'calls', 'max', 'min', 'ave' #### - # The only difference in 5 sorted_key results is the following sentence: + # The only difference in 5 sorted_key results is the following sentence: # "Sorted by number of xxx in descending order in the same thread." # The reason is that in this example, above 5 columns are already sorted. -------------------------> Profiling Report <------------------------- @@ -339,7 +373,7 @@ def profiler(state, #### 2) sorted_key = None #### # Since the profiling results are printed in the order of first end time of Ops, - # the printed order is feed->conv2d->elementwise_add + # the printed order is feed->conv2d->elementwise_add -------------------------> Profiling Report <------------------------- Place: CPU @@ -366,7 +400,7 @@ def _nvprof_range(iter_id, start, end, exit_after_prof=True): Examples: .. code-block:: python - + model = Model() for i in range(max_iter): paddle.fluid.profiler._nvprof_range(i, 10, 20): diff --git a/python/paddle/fluid/tests/unittests/test_profiler_statistic.py b/python/paddle/fluid/tests/unittests/test_profiler_statistic.py index 838ccae37cfa5fb7dbdedcb5d39655cb62ad429f..73b501c9c7eade28e94281b6d07ce21140b72c53 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler_statistic.py +++ b/python/paddle/fluid/tests/unittests/test_profiler_statistic.py @@ -56,7 +56,15 @@ class TestProfilerStatistic(unittest.TestCase): mobilenet_node = HostPythonNode( 'MobileNet', profiler.TracerEventType.Forward, 20, 50, 1000, 1001) yolonet_node = HostPythonNode( - 'Yolov3Net', profiler.TracerEventType.Forward, 50, 100, 1000, 1001) + 'Yolov3Net', profiler.TracerEventType.Forward, 50, 110, 1000, 1001) + + userdefined_node = HostPythonNode('Communication Time', + profiler.TracerEventType.UserDefined, + 100, 110, 1000, 1001) + + communication_node = HostPythonNode( + 'Communication', profiler.TracerEventType.Communication, 105, 110, + 1000, 1001) backward_node = HostPythonNode('Gradient Backward', profiler.TracerEventType.Backward, 120, 200, 1000, 1001) @@ -114,7 +122,9 @@ class TestProfilerStatistic(unittest.TestCase): optimization_node ]) mobilenet_node.children_node.append(conv2d_node) - yolonet_node.children_node.append(sync_batch_norm_node) + yolonet_node.children_node.extend( + [sync_batch_norm_node, userdefined_node]) + userdefined_node.children_node.append(communication_node) conv2d_node.children_node.extend( [conv2d_infer_shape, conv2d_compute, conv2d_MemCpy]) conv2d_compute.runtime_node.append(conv2d_launchkernel) @@ -145,7 +155,7 @@ class TestProfilerStatistic(unittest.TestCase): profiler.TracerEventType.ProfileStep), 400) self.assertEqual( time_range_summary.get_cpu_range_sum( - profiler.TracerEventType.Forward), 90) + profiler.TracerEventType.Forward), 100) self.assertEqual( time_range_summary.get_cpu_range_sum( profiler.TracerEventType.Backward), 80) @@ -169,15 +179,18 @@ class TestProfilerStatistic(unittest.TestCase): 0, profiler.TracerEventType.Memcpy), 60) self.assertEqual( time_range_summary.get_cpu_range_sum( - profiler.TracerEventType.UserDefined), 15) + profiler.TracerEventType.UserDefined), 25) + self.assertEqual( + time_range_summary.get_cpu_range_sum( + profiler.TracerEventType.Communication), 5) self.assertEqual(len(event_summary.items), 2) - self.assertEqual(len(event_summary.userdefined_items), 0) + self.assertEqual(len(event_summary.userdefined_items), 1) self.assertEqual(len(event_summary.model_perspective_items), 3) self.assertEqual(len(event_summary.memory_manipulation_items), 1) self.assertEqual(event_summary.items['conv2d'].cpu_time, 15) self.assertEqual(event_summary.items['conv2d'].gpu_time, 25) self.assertEqual( - event_summary.model_perspective_items['Forward'].cpu_time, 90) + event_summary.model_perspective_items['Forward'].cpu_time, 100) self.assertEqual( event_summary.model_perspective_items['Forward'].gpu_time, 135) self.assertEqual( diff --git a/python/paddle/profiler/__init__.py b/python/paddle/profiler/__init__.py index 4999e703f2a5a31be2cd5c20b70bc7b9dfb7e60a..ae190b8a7846cd3c0d765f1831914df2ab98c77f 100644 --- a/python/paddle/profiler/__init__.py +++ b/python/paddle/profiler/__init__.py @@ -20,7 +20,7 @@ from .utils import RecordEvent, load_profiler_result from .profiler_statistic import SortedKeys __all__ = [ - 'ProfilerState', 'ProfilerTarget', 'TracerEventType', 'make_scheduler', + 'ProfilerState', 'ProfilerTarget', 'make_scheduler', 'export_chrome_tracing', 'export_protobuf', 'Profiler', 'RecordEvent', 'load_profiler_result', 'SortedKeys' ] diff --git a/python/paddle/profiler/profiler.py b/python/paddle/profiler/profiler.py index dc637bf983046b8025962257744b0e1bb4763b4b..efbe88583b776d623b757628998e583ac65f6179 100644 --- a/python/paddle/profiler/profiler.py +++ b/python/paddle/profiler/profiler.py @@ -1,11 +1,11 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,7 +24,7 @@ from paddle.fluid.core import (_Profiler, _ProfilerResult, ProfilerOptions, TracerEventType) from .utils import RecordEvent, wrap_optimizers -from .profiler_statistic import SortedKeys +from .profiler_statistic import StatisticData, _build_table, SortedKeys class ProfilerState(Enum): @@ -32,21 +32,28 @@ class ProfilerState(Enum): Profiler state that can be specified to control profiler action. CLOSED: The profilers are closed. + READY: The profilers are open, but the data will not be recorded. - This state is used for reducing overhead influence when profilers start. + This state is used for reducing overhead influence when profilers start. + RECORD: The profilers are open, and the data will be recorded. - RECORD_AND_RETURN: The profilers are open, and at the last batch of current profiler period, - the collected data will be returned. + + RECORD_AND_RETURN: The profilers are open, and at the last batch of current profiler period, + the collected data will be returned. """ CLOSED = 0 READY = 1 RECORD = 2 - RECORD_AND_RETURN = 3 # the last step of RECORD + RECORD_AND_RETURN = 3 # the last step of RECORD class ProfilerTarget(Enum): r""" Target device for profiling. + + CPU: Profile events on CPU. + + GPU: Profile events on GPU. """ CPU = 0 GPU = 1 @@ -62,17 +69,19 @@ def make_scheduler(*, Return a scheduler function, which scheduler the state according to the setting. The state transform confirms to: - (CLOSED) (CLOSED) (CLOSED) (READY) (RECORD,last RETURN) (CLOSED) - START -> skip_first -> closed -> ready -> record -> END - | | - | | (if has_repeated < repeat) - - - - - - - - - - - - - - Note that repeat <= 0 means the cycle will continue until the profiler exits. + .. code-block:: text + + (CLOSED) (CLOSED) (CLOSED) (READY) (RECORD,last RETURN) (CLOSED) + START -> skip_first -> closed -> ready -> record -> END + | | + | | (if has_repeated < repeat) + - - - - - - - - - - - - + Note that repeat <= 0 means the cycle will continue until the profiler exits. Parameters: closed(int): The number of steps in state ProfilerState.CLOSED. - ready(int): The number of steps in state ProfilerState.READY. - record(int): The number of steps in state ProfilerState.RECORD. + ready(int): The number of steps in state ProfilerState.READY. + record(int): The number of steps in state ProfilerState.RECORD. repeat(int): The number of cycles to repeat above state transform. skip_first(int): The number of first steps to drop, not participate in the state transform. @@ -81,13 +90,23 @@ def make_scheduler(*, Examples: 1. profiling range [2, 5] + batch 0: closed, batch 1: ready, batch [2, 5] record - .. code-block:: python - make_scheduler(closed=1, ready=1, record=4, repeat=1) + + .. code-block:: python + + import paddle.profiler as profiler + profiler.make_scheduler(closed=1, ready=1, record=4, repeat=1) + + 2. profiling range [3,6], [9,12], [15,18]... + batch 0: skiped, batch 1: closed, batch 2: ready, batch [3,6]: record, repeat - .. code-block:: python - make_scheduler(closed=1, ready=1, record=4, skip_first=1) + + .. code-block:: python + + import paddle.profiler as profiler + profiler.make_scheduler(closed=1, ready=1, record=4, skip_first=1) """ def getScheduleState(step: int) -> ProfilerState: @@ -138,15 +157,16 @@ def export_chrome_tracing(dir_name: str, Examples: .. code-block:: python - import paddle.profiler as profiler - with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU, - profiler.ProfilerTarget.GPU], - scheduler = (3, 10), - on_trace_ready = profiler.export_chrome_tracing('./log') - ) as p: - for iter in range(N): - train() - p.step() + + # required: gpu + import paddle.profiler as profiler + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (3, 10), + on_trace_ready=profiler.export_protobuf('./log')) as p: + for iter in range(10): + #train() + p.step() """ if not os.path.exists(dir_name): try: @@ -181,15 +201,16 @@ def export_protobuf(dir_name: str, worker_name: Optional[str]=None) -> Callable: Examples: .. code-block:: python - import paddle.profiler as profiler - with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU, - profiler.ProfilerTarget.GPU], - scheduler = (3, 10), - on_trace_ready = profiler.export_protobuf('./log') - ) as p: - for iter in range(N): - train() - p.step() + + # required: gpu + import paddle.profiler as profiler + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (3, 10), + on_trace_ready = profiler.export_protobuf('./log')) as p: + for iter in range(10): + #train() + p.step() """ if not os.path.exists(dir_name): try: @@ -216,7 +237,7 @@ def _get_supported_targets() -> Iterable[ProfilerTarget]: r""" Get the current supported profiler target in the system. """ - if paddle.device.is_compiled_with_cuda(): + if _Profiler.is_cupti_supported(): return [ProfilerTarget.CPU, ProfilerTarget.GPU] return [ProfilerTarget.CPU] @@ -226,48 +247,56 @@ class Profiler: Profiler context manager, user interface to manage profile process. Parameters: - targets (iterable): list of tracing targets, currently supported values: - ``paddle.profiler.ProfilerTarget.CPU``, - ``paddle.profiler.ProfilerTarget.GPU``. - scheduler (callable or tuple): If it is a callable object, it takes a step number as parameter and return the corresponding ``ProfilerState``. - If not provided, the default sheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch, + targets (iterable): list of tracing targets, currently supported values, ``ProfilerTarget.CPU``, ``ProfilerTarget.GPU`` . + scheduler (callable or tuple): If it is a callable object, it takes a step number as parameter and return the corresponding ``ProfilerState``. + If not provided, the default scheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch, which means profiling range [start_batch, end_batch). on_trace_ready (callable): callable object, takes the Profiler object as parameter, which provides a way for users to do post-processing. - This callable object will be called when ``sheduler`` returns ``ProfilerState.RECORD_AND_RETURN``. - + This callable object will be called when ``scheduler`` returns ``ProfilerState.RECORD_AND_RETURN``. + Examples: 1. profiling range [2, 5) - .. code-block:: python - import paddle.profiler as profiler - with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU, - profiler.ProfilerTarget.GPU], - scheduler = (2, 5), - on_trace_ready = profiler.export_chrome_tracing('./log') - ) as p: - for iter in range(N): - train() - p.step() + + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (2, 5), + on_trace_ready = profiler.export_chrome_tracing('./log')) as p: + for iter in range(10): + #train() + p.step() + 2. profiling range [2,4], [7, 9], [11,13] - .. code-block:: python - import paddle.profiler as profiler - with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU, - profiler.ProfilerTarget.GPU], - scheduler = profiler.make_scheduler(closed=1, ready=1, record=3, repeat=3), - on_trace_ready = profiler.export_chrome_tracing('./log') - ) as p: - for iter in range(N): - train() - p.step() + + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = profiler.make_scheduler(closed=1, ready=1, record=3, repeat=3), + on_trace_ready = profiler.export_chrome_tracing('./log')) as p: + for iter in range(10): + #train() + p.step() + 3. Use profiler without context manager, and use default parameters - .. code-block:: python - import paddle.profiler as profiler - p = profiler.Profiler() - p.start() - for iter in range(N): - train() - p.step() - p.stop() - p.summary() + + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + p = profiler.Profiler() + p.start() + for iter in range(10): + #train() + p.step() + p.stop() + p.summary() + """ def __init__( @@ -334,7 +363,22 @@ class Profiler: def start(self): r''' Start profiler and enter the first profiler step(0). - State transformed from CLOSED to self.current_state and trigger corresponding action. + State transformed from CLOSED to self.current_state and trigger corresponding action. + + Examples: + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (1, 9), + on_trace_ready = profiler.export_chrome_tracing('./log')) + prof.start() + for iter in range(10): + #train() + prof.step() + prof.stop() ''' # CLOSED -> self.current_state if self.current_state == ProfilerState.READY: @@ -354,6 +398,21 @@ class Profiler: r''' Stop profiler and State transformed from self.current_state to CLOSED. Trigger corresponding action and post-process profiler result using self.on_trace_ready if result exists. + + Examples: + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (1, 7), + on_trace_ready = profiler.export_chrome_tracing('./log')) + prof.start() + for iter in range(10): + #train() + prof.step() + prof.stop() ''' # self.current_state -> CLOSED # In this situation, RECORD state is regarded as RECORD_AND_RETURN @@ -375,6 +434,22 @@ class Profiler: r""" Signals the profiler that the next profiling step has started. Get the new ProfilerState and trigger corresponding action. + + Examples: + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (3, 7), + on_trace_ready = profiler.export_chrome_tracing('./log')) + + prof.start() + for iter in range(10): + #train() + prof.step() + prof.stop() """ if self.record_event: self.record_event.end() @@ -448,6 +523,21 @@ class Profiler: def export(self, path="", format="json"): r""" Exports the tracing data in Chrome tracing data format. + + Examples: + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (3, 7)) + prof.start() + for iter in range(10): + #train() + prof.step() + prof.stop() + prof.export(path="./profiler_data.json", format="json") """ if self.profiler_result: self.profiler_result.save(path, format) @@ -461,9 +551,35 @@ class Profiler: Print the Summary table. Parameters: - sorted_by: how to rank the op table items. - detail: expand each operator detail information. - thread_sep: print op table each thread. - time_unit: can be chosen form ['s', 'ms', 'us', 'ns'] + sorted_by(SortedKeys): how to rank the op table items. + op_detail(bool): expand each operator detail information. + thread_sep(bool): print op table each thread. + time_unit(str): can be chosen form ['s', 'ms', 'us', 'ns'] + + Examples: + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (3, 7), + on_trace_ready = profiler.export_chrome_tracing('./log')) + prof.start() + for iter in range(10): + #train() + prof.step() + prof.stop() + prof.summary(sorted_by=profiler.SortedKeys.CPUTotal, op_detail=True, thread_sep=False, time_unit='ms') """ - pass + if self.profiler_result: + statistic_data = StatisticData( + self.profiler_result.get_data(), + self.profiler_result.get_extra_info()) + print( + _build_table( + statistic_data, + sorted_by=sorted_by, + op_detail=op_detail, + thread_sep=thread_sep, + time_unit=time_unit)) diff --git a/python/paddle/profiler/profiler_statistic.py b/python/paddle/profiler/profiler_statistic.py index 7400f21e91365efeaef6a03d008691bdc837131b..a0bbd6b633ef017dc983c8458eb5551494425989 100755 --- a/python/paddle/profiler/profiler_statistic.py +++ b/python/paddle/profiler/profiler_statistic.py @@ -1,11 +1,11 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -34,6 +34,22 @@ _CommunicationOpName = ['reduce', 'broadcast', 'rpc'] class SortedKeys(Enum): r""" Sorted keys for printing summary table. + + CPUTotal: Sorted by CPU total time. + + CPUAvg: Sorted by CPU average time. + + CPUMax: Sorted by CPU max time. + + CPUMin: Sorted by CPU min time. + + GPUTotal: Sorted by GPU total time. + + GPUAvg: Sorted by GPU average time. + + GPUMax: Sorted by GPU max time. + + GPUMin: Sorted by GPU min time. """ CPUTotal = 0 CPUAvg = 1 @@ -642,6 +658,171 @@ def _build_table(statistic_data, append('') append('') + ###### Print Model Summary Report ###### + model_perspective_items = statistic_data.event_summary.model_perspective_items + if model_perspective_items: + headers = [ + 'Name', 'Calls', 'CPU Total / Avg / Max / Min / Ratio(%)', + 'GPU Total / Avg / Max / Min / Ratio(%)' + ] + row_format_list = [""] + header_sep_list = [""] + line_length_list = [-SPACING_SIZE] + name_column_width = 15 + add_column(name_column_width) + add_column(6) + add_column(40) + add_column(40) + + row_format = row_format_list[0] + header_sep = header_sep_list[0] + line_length = line_length_list[0] + + # construct table string + append(add_title(line_length, "Model Summary")) + append('Time unit: {}'.format(time_unit)) + append(header_sep) + append(row_format.format(*headers)) + append(header_sep) + accmulation_time = 0 + row_values = [ + 'Total Time', '-', '{} / - / - / - / {}'.format( + format_time( + total_time, unit=time_unit), format_ratio(1)), + '- / - / - / -/ -' + ] + append(row_format.format(*row_values)) + for name in ['Dataloader', 'Forward', 'Backward', 'Optimization']: + if name in model_perspective_items: + item = model_perspective_items[name] + row_values = [ + ' {}'.format(name), item.call, + '{} / {} / {} / {} / {}'.format( + format_time( + item.cpu_time, unit=time_unit), + format_time( + item.avg_cpu_time, unit=time_unit), + format_time( + item.max_cpu_time, unit=time_unit), + format_time( + item.min_cpu_time, unit=time_unit), + format_ratio(float(item.cpu_time) / total_time)), + '{} / {} / {} / {} / {}'.format( + format_time( + item.gpu_time, unit=time_unit), + format_time( + item.avg_gpu_time, unit=time_unit), + format_time( + item.max_gpu_time, unit=time_unit), + format_time( + item.min_gpu_time, unit=time_unit), + format_ratio(float(item.gpu_time) / total_time)) + ] + append(row_format.format(*row_values)) + accmulation_time += item.cpu_time + + other_time = total_time - accmulation_time + row_values = [ + ' Others', '-', '{} / - / - / - / {}'.format( + format_time( + other_time, unit=time_unit), + format_ratio(float(other_time) / total_time)), + '- / - / - / - / -' + ] + append(row_format.format(*row_values)) + append(header_sep) + append('') + append('') + + ###### Print Distribution Summary Report ###### + if TracerEventType.Communication in statistic_data.time_range_summary.CPUTimeRange: + headers = [ + 'Name', + 'Total Time', + 'Ratio (%)', + ] + row_format_list = [""] + header_sep_list = [""] + line_length_list = [-SPACING_SIZE] + + DEFAULT_COLUMN_WIDTH = 20 + for _ in headers: + add_column(DEFAULT_COLUMN_WIDTH) + + row_format = row_format_list[0] + header_sep = header_sep_list[0] + line_length = line_length_list[0] + + # construct table string + append(add_title(line_length, "Distribution Summary")) + append('Time unit: {}'.format(time_unit)) + append(header_sep) + append(row_format.format(*headers)) + append(header_sep) + cpu_communication_time_range = [] + gpu_communication_time_range = [] + cpu_communication_time_range = merge_ranges( + statistic_data.time_range_summary.CPUTimeRange[ + TracerEventType.Communication], cpu_communication_time_range) + kernel_time_range = [] + for device_id, device_time_ranges in statistic_data.time_range_summary.GPUTimeRange.items( + ): + kernel_time_range = merge_ranges( + device_time_ranges[TracerEventType.Kernel], + kernel_time_range, + is_sorted=True) + gpu_communication_time_range = merge_ranges( + device_time_ranges[TracerEventType.Communication], + gpu_communication_time_range, + is_sorted=True) + communication_time_range = merge_ranges( + cpu_communication_time_range, + gpu_communication_time_range, + is_sorted=True) + computation_time_range = subtract_ranges(kernel_time_range, + gpu_communication_time_range) + overlap_time_range = intersection_ranges(communication_time_range, + computation_time_range) + communication_time = sum_ranges(communication_time_range) + computation_time = sum_ranges(computation_time_range) + overlap_time = sum_ranges(overlap_time_range) + row_values = [ + 'Communication', format_time( + communication_time, unit=time_unit), + format_ratio(float(communication_time) / total_time) + ] + append(row_format.format(*row_values)) + + row_values = [ + 'Computation', format_time( + computation_time, unit=time_unit), + format_ratio(float(computation_time) / total_time) + ] + append(row_format.format(*row_values)) + + row_values = [ + 'Overlap', format_time( + overlap_time, unit=time_unit), + format_ratio(float(overlap_time) / total_time) + ] + append(row_format.format(*row_values)) + append(header_sep) + append( + "Note:\nCommunication time: Communication Op time and its kernel time on gpu.\n" + "Computation time: Kernel time, substract kernels belong to communication op.\n" + "Overlap time: Communication time intersect with computation time.\n" + "Example:\n" + "Communication:\n" + " CPU: |_________________|\n" + " GPU: |______________|\n" + " Total: |_________________| |______________|\n" + "Computation time(Kernel):\n" + " GPU: |________________|\n" + "Overlap time: |___________|\n") + append('-' * line_length) + append('') + append('') + ###### Print Operator Summary Report ###### if statistic_data.event_summary.items: headers = [ @@ -708,11 +889,6 @@ def _build_table(statistic_data, sorted_items = sorted( items.items(), key=lambda x: x[1].min_gpu_time) - total_cpu_time = 0 - total_gpu_time = 0 - for name, item in sorted_items: - total_cpu_time += item.cpu_time - total_gpu_time += item.gpu_time for name, item in sorted_items: row_values = [ name, item.call, '{} / {} / {} / {} / {}'.format( @@ -724,7 +900,7 @@ def _build_table(statistic_data, item.max_cpu_time, unit=time_unit), format_time( item.min_cpu_time, unit=time_unit), - format_ratio(float(item.cpu_time) / total_cpu_time)), + format_ratio(float(item.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( item.gpu_time, unit=time_unit), @@ -734,7 +910,7 @@ def _build_table(statistic_data, item.max_gpu_time, unit=time_unit), format_time( item.min_gpu_time, unit=time_unit), - format_ratio(float(item.gpu_time) / total_gpu_time)) + format_ratio(float(item.gpu_time) / total_time)) ] append(row_format.format(*row_values)) if op_detail: @@ -752,8 +928,7 @@ def _build_table(statistic_data, format_time( innerop_node.min_cpu_time, unit=time_unit), format_ratio( - float(innerop_node.cpu_time) / - total_cpu_time)), + float(innerop_node.cpu_time) / total_time)), '{} / {} / {} / {} / {}'.format( format_time( innerop_node.gpu_time, unit=time_unit), @@ -764,8 +939,7 @@ def _build_table(statistic_data, format_time( innerop_node.min_gpu_time, unit=time_unit), format_ratio( - float(innerop_node.gpu_time) / - total_gpu_time)) + float(innerop_node.gpu_time) / total_time)) ] append(row_format.format(*row_values)) for device_node_name, devicenode in innerop_node.devices.items( @@ -792,7 +966,7 @@ def _build_table(statistic_data, unit=time_unit), format_ratio( float(devicenode.gpu_time) / - total_gpu_time)) + total_time)) ] append(row_format.format(*row_values)) for device_node_name, device_node in item.devices.items(): @@ -814,11 +988,160 @@ def _build_table(statistic_data, format_time( devicenode.min_gpu_time, unit=time_unit), format_ratio( - float(devicenode.gpu_time) / - total_gpu_time)) + float(devicenode.gpu_time) / total_time)) ] append(row_format.format(*row_values)) append(header_sep) append('') append('') + + ###### Print Memory Manipulation Summary Report ###### + if statistic_data.event_summary.memory_manipulation_items: + headers = [ + 'Name', 'Calls', 'CPU Total / Avg / Max / Min / Ratio(%)', + 'GPU Total / Avg / Max / Min / Ratio(%)' + ] + row_format_list = [""] + header_sep_list = [""] + line_length_list = [-SPACING_SIZE] + name_column_width = 30 + add_column(name_column_width) + add_column(6) + add_column(40) + add_column(40) + + row_format = row_format_list[0] + header_sep = header_sep_list[0] + line_length = line_length_list[0] + + # construct table string + append(add_title(line_length, "Memory Manipulation Summary")) + append('Time unit: {}'.format(time_unit)) + append(header_sep) + append(row_format.format(*headers)) + append(header_sep) + memory_manipulation_items = statistic_data.event_summary.memory_manipulation_items + for name, item in memory_manipulation_items.items(): + row_values = [ + name, + item.call, + '{} / {} / {} / {} / {}'.format( + format_time( + item.cpu_time, unit=time_unit), + format_time( + item.avg_cpu_time, unit=time_unit), + format_time( + item.max_cpu_time, unit=time_unit), + format_time( + item.min_cpu_time, unit=time_unit), + format_ratio(float(item.cpu_time) / total_time)), + '{} / {} / {} / {} / {}'.format( + format_time( + item.gpu_time, unit=time_unit), + format_time( + item.avg_gpu_time, unit=time_unit), + format_time( + item.max_gpu_time, unit=time_unit), + format_time( + item.min_gpu_time, unit=time_unit), + format_ratio(float(item.gpu_time) / total_time)), + ] + append(row_format.format(*row_values)) + append(header_sep) + append('') + append('') + ###### Print UserDefined Summary Report ###### + if statistic_data.event_summary.userdefined_items: + headers = [ + 'Name', 'Calls', 'CPU Total / Avg / Max / Min / Ratio(%)', + 'GPU Total / Avg / Max / Min / Ratio(%)' + ] + row_format_list = [""] + header_sep_list = [""] + line_length_list = [-SPACING_SIZE] + name_column_width = 30 + add_column(name_column_width) + add_column(6) + add_column(40) + add_column(40) + + row_format = row_format_list[0] + header_sep = header_sep_list[0] + line_length = line_length_list[0] + + # construct table string + append(add_title(line_length, "UserDefined Summary")) + append('Time unit: {}'.format(time_unit)) + append(header_sep) + append(row_format.format(*headers)) + append(header_sep) + if thread_sep == True: + userdefined_thread_items = statistic_data.event_summary.userdefined_thread_items + else: + userdefined_thread_items = { + 'All threads merged': + statistic_data.event_summary.userdefined_items + } + for thread_id, items in userdefined_thread_items.items(): + append(add_title(line_length, "Thread: {}".format(thread_id))) + if sorted_by == SortedKeys.CPUTotal: + sorted_items = sorted( + items.items(), key=lambda x: x[1].cpu_time, reverse=True) + elif sorted_by == SortedKeys.CPUAvg: + sorted_items = sorted( + items.items(), + key=lambda x: x[1].avg_cpu_time, + reverse=True) + elif sorted_by == SortedKeys.CPUMax: + sorted_items = sorted( + items.items(), + key=lambda x: x[1].max_cpu_time, + reverse=True) + elif sorted_by == SortedKeys.CPUMin: + sorted_items = sorted( + items.items(), key=lambda x: x[1].min_cpu_time) + elif sorted_by == SortedKeys.GPUTotal: + sorted_items = sorted( + items.items(), key=lambda x: x[1].gpu_time, reverse=True) + elif sorted_by == SortedKeys.GPUAvg: + sorted_items = sorted( + items.items(), + key=lambda x: x[1].avg_gpu_time, + reverse=True) + elif sorted_by == SortedKeys.GPUMax: + sorted_items = sorted( + items.items(), + key=lambda x: x[1].max_gpu_time, + reverse=True) + elif sorted_by == SortedKeys.GPUMin: + sorted_items = sorted( + items.items(), key=lambda x: x[1].min_gpu_time) + + for name, item in sorted_items: + row_values = [ + name, + item.call, + '{} / {} / {} / {} / {}'.format( + format_time( + item.cpu_time, unit=time_unit), + format_time( + item.avg_cpu_time, unit=time_unit), + format_time( + item.max_cpu_time, unit=time_unit), + format_time( + item.min_cpu_time, unit=time_unit), + format_ratio(float(item.cpu_time) / total_time)), + '{} / {} / {} / {} / {}'.format( + format_time( + item.gpu_time, unit=time_unit), + format_time( + item.avg_gpu_time, unit=time_unit), + format_time( + item.max_gpu_time, unit=time_unit), + format_time( + item.min_gpu_time, unit=time_unit), + format_ratio(float(item.gpu_time) / total_time)), + ] + append(row_format.format(*row_values)) + append(header_sep) return ''.join(result) diff --git a/python/paddle/profiler/utils.py b/python/paddle/profiler/utils.py index 642001dfbfc5a307d5064860136034ba7b3bdbc5..7fa7a27bad7bf5ffbefdddb28d67e2d65e319e6d 100644 --- a/python/paddle/profiler/utils.py +++ b/python/paddle/profiler/utils.py @@ -1,24 +1,25 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from paddle.fluid.core import (_RecordEvent, TracerEventType, - load_profiler_result) from typing import Any from warnings import warn import functools from contextlib import ContextDecorator +from paddle.fluid.core import (_RecordEvent, TracerEventType) +import paddle.fluid.core as core + _AllowedEventTypeList = [ TracerEventType.Dataloader, TracerEventType.ProfileStep, TracerEventType.UserDefined, TracerEventType.Forward, @@ -32,14 +33,28 @@ class RecordEvent(ContextDecorator): Interface for recording a time range. Parameters: - name(str): Name of the record event - event_type(TracerEventType): Type of the record event, can be used for statistics. + name(str): Name of the record event Examples: .. code-block:: python - import paddle.profiler as profiler - with profiler.RecordEvent(name='op1', event_type=TracerEventType=TracerEventType.UserDefined): - op1() + + import paddle + import paddle.profiler as profiler + # method1: using context manager + with profiler.RecordEvent("record_add"): + data1 = paddle.randn(shape=[3]) + data2 = paddle.randn(shape=[3]) + result = data1 + data2 + # method2: call begin() and end() + record_event = profiler.RecordEvent("record_add") + record_event.begin() + data1 = paddle.randn(shape=[3]) + data2 = paddle.randn(shape=[3]) + result = data1 + data2 + record_event.end() + + Note: + RecordEvent will take effect only when profiler is on and at the state of RECORD. """ def __init__(self, @@ -57,6 +72,20 @@ class RecordEvent(ContextDecorator): self.end() def begin(self): + r""" + Record the time of begining. + + .. code-block:: python + + import paddle + import paddle.profiler as profiler + record_event = profiler.RecordEvent("record_sub") + record_event.begin() + data1 = paddle.randn(shape=[3]) + data2 = paddle.randn(shape=[3]) + result = data1 - data2 + record_event.end() + """ if self.event_type not in _AllowedEventTypeList: warn("Only TracerEvent Type in [{}, {}, {}, {}, {}, {},{}]\ can be recorded.".format(*_AllowedEventTypeList)) @@ -67,10 +96,51 @@ class RecordEvent(ContextDecorator): self.event = _RecordEvent(self.name, self.event_type) def end(self): + r''' + Record the time of ending. + + .. code-block:: python + + import paddle + import paddle.profiler as profiler + record_event = profiler.RecordEvent("record_mul") + record_event.begin() + data1 = paddle.randn(shape=[3]) + data2 = paddle.randn(shape=[3]) + result = data1 * data2 + record_event.end() + ''' if self.event: self.event.end() +def load_profiler_result(filename: str): + r""" + Load dumped profiler data back to memory. + + Parameters: + filename(str): Name of the exported protobuf file of profiler data. + + Returns: + ProfilerResult object. + + Examples: + .. code-block:: python + + # required: gpu + import paddle.profiler as profiler + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + scheduler = (3, 10)) as p: + for iter in range(10): + #train() + p.step() + p.export('test_export_protobuf.pb', format='pb') + profiler_result = profiler.load_profiler_result('test_export_protobuf.pb') + """ + return core.load_profiler_result(filename) + + def wrap_optimizers(): def optimizer_warpper(func): @functools.wraps(func)