未验证 提交 c15e3823 编写于 作者: C chenjian 提交者: GitHub

Add profiler features (#40357)

* add event record for model profiling

* fix format

* fix format

* fix code example bug

* no

* add profiler statistic

* add profiler feature

* fix bug

* fix bug

* fix bug

* fix bug

* required: gpu

* required: gpu

* fix bug

* required: gpu

* fix ci bug

* fix ci error

* fix ci error

* upgrade document

* fix doc

* fix ci bug

* add doc and fix bug

* nothing

* fix bug

* fix format bug

* modify format

* add deprecated description for old profiler

* fix bug

* fix bug

* fix

* add load_profiler_reuslt doc

* add load_profiler_reuslt doc

* add load_profiler_reuslt doc

* help fix old profiler sample code

* add api doc

* fix format

* fix api doc

* fix api doc format

* fix api doc format

* fix api doc c format

* fix api doc format
上级 58970995
......@@ -118,8 +118,9 @@ float CpuUtilization::GetCpuUtilization() {
float busy_time = (system_kernel_time_end - system_kernel_time_start) +
(system_user_time_end - system_user_time_start);
float idle_time = system_idle_time_end - system_idle_time_start;
if (busy_time + idle_time != 0) {
cpu_utilization = busy_time / (busy_time + idle_time);
}
#elif defined(__linux__)
float busy_time = (system_tms_end_.tms_utime - system_tms_start_.tms_utime) +
(system_tms_end_.tms_stime - system_tms_start_.tms_stime) +
......@@ -127,7 +128,9 @@ float CpuUtilization::GetCpuUtilization() {
(irq_end_ - irq_start_) + (softirq_end_ - softirq_start_) +
(steal_end_ - steal_start_);
float idle_time = (idle_end_ - idle_start_) + (iowait_end_ - iowait_start_);
if (busy_time + idle_time != 0) {
cpu_utilization = busy_time / (busy_time + idle_time);
}
#else
LOG(WARNING)
<< "Current System is not supported to get system cpu utilization"
......@@ -148,13 +151,16 @@ float CpuUtilization::GetCpuCurProcessUtilization() {
uint64_t end = FileTimeToUint64(end_);
float busy_time = (process_kernel_time_end - process_kernel_time_start) +
(process_user_time_end - process_user_time_start);
if (end - start != 0) {
cpu_process_utilization = busy_time / (end - start);
LOG(INFO) << "Process Utilization = " << cpu_process_utilization << std::endl;
}
#elif defined(__linux__)
float busy_time =
(process_tms_end_.tms_utime - process_tms_start_.tms_utime) +
(process_tms_end_.tms_stime - process_tms_start_.tms_stime);
if (end_ - start_ != 0) {
cpu_process_utilization = busy_time / (end_ - start_);
}
#else
LOG(WARNING)
<< "Current System is not supported to get process cpu utilization"
......
......@@ -44,6 +44,14 @@ std::unique_ptr<Profiler> Profiler::Create(const ProfilerOptions& options) {
return std::unique_ptr<Profiler>(new Profiler(options));
}
bool Profiler::IsCuptiSupported() {
bool supported = false;
#ifdef PADDLE_WITH_CUPTI
supported = true;
#endif
return supported;
}
Profiler::Profiler(const ProfilerOptions& options) {
options_ = options;
std::bitset<32> trace_switch(options_.trace_switch);
......
......@@ -43,6 +43,8 @@ class Profiler {
public:
static std::unique_ptr<Profiler> Create(const ProfilerOptions& options);
static bool IsCuptiSupported();
void Prepare();
void Start();
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/cupti.h"
namespace paddle {
namespace platform {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <ctime>
#include <string>
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/os_info.h"
......
......@@ -3322,6 +3322,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<paddle::platform::Profiler>(m, "_Profiler")
.def("create", &paddle::platform::Profiler::Create,
py::return_value_policy::take_ownership)
.def("is_cupti_supported", &paddle::platform::Profiler::IsCuptiSupported)
.def("prepare",
[](paddle::platform::Profiler *profiler) {
platform::EnableHostEventRecorder();
......
......@@ -30,6 +30,7 @@ from paddle.fluid.framework import _set_expected_place, _current_expected_place,
import queue
import paddle
import paddle.profiler as profiler
from .. import core, layers
from ..framework import in_dygraph_mode, _in_eager_mode
from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL, CleanupFuncRegistrar
......@@ -250,6 +251,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._exit_thread_expectedly()
def __next__(self):
trace_event = profiler.RecordEvent(
name="_DataLoaderIterSingleProcess",
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
try:
if in_dygraph_mode():
if _in_eager_mode():
......@@ -283,6 +288,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._reader.shutdown()
self._try_shutdown_all()
six.reraise(*sys.exc_info())
finally:
trace_event.end()
def _shutdown_thread(self):
if self._thread:
......@@ -695,6 +702,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._try_shutdown_all(1)
def __next__(self):
trace_event = profiler.RecordEvent(
name="_DataLoaderIterMultiProcess",
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
try:
# _batches_outstanding here record the total batch data number
# in 'from after _try_put_indices to beforeoutput data', this
......@@ -743,6 +754,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._reader.shutdown()
self._try_shutdown_all()
six.reraise(*sys.exc_info())
finally:
trace_event.end()
# python2 compatibility
def next(self):
......
......@@ -25,6 +25,7 @@ from copy import deepcopy
import inspect
import paddle
import paddle.profiler as profiler
from . import parallel_helper
from .. import unique_name
......@@ -905,6 +906,8 @@ class Layer(object):
self._built = True
with profiler.RecordEvent(self.full_name(),
profiler.TracerEventType.Forward):
outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
......
......@@ -28,6 +28,7 @@ from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
import paddle.utils.deprecated as deprecated
import paddle.profiler as profiler
from paddle import _C_ops
......@@ -243,6 +244,9 @@ def monkey_patch_varbase():
"""
if framework.in_dygraph_mode():
record_event = profiler.RecordEvent(
"Gradient Backward", profiler.TracerEventType.Backward)
record_event.begin()
if grad_tensor is not None:
if core._in_eager_mode():
assert isinstance(
......@@ -278,6 +282,7 @@ def monkey_patch_varbase():
core.dygraph_run_backward([self], [grad_tensor],
retain_graph,
framework._dygraph_tracer())
record_event.end()
else:
raise ValueError(
"Variable.backward() is only available in DyGraph mode")
......
......@@ -20,6 +20,8 @@ import os
import six
import sys
from paddle.utils.deprecated import deprecated
__all__ = [
'cuda_profiler', 'reset_profiler', 'profiler', 'start_profiler',
'stop_profiler'
......@@ -36,6 +38,12 @@ NVPROF_CONFIG = [
]
@deprecated(
since="2.3.0",
update_to="paddle.profiler.Profiler",
level=1,
reason="Please use new profiler tool, this profiler tool is no longer maintained."
)
@signature_safe_contextmanager
def cuda_profiler(output_file, output_mode=None, config=None):
"""
......@@ -109,6 +117,12 @@ def npu_profiler(output_file, config=None):
core.npu_prof_finalize()
@deprecated(
since="2.3.0",
update_to="paddle.profiler.Profiler",
level=1,
reason="Please use new profiler tool, this profiler tool is no longer maintained."
)
def reset_profiler():
"""
Clear the previous time record. It works for
......@@ -131,6 +145,12 @@ def reset_profiler():
core.reset_profiler()
@deprecated(
since="2.3.0",
update_to="paddle.profiler.Profiler",
level=1,
reason="Please use new profiler tool, this profiler tool is no longer maintained."
)
def start_profiler(state, tracer_option='Default'):
"""
Enable the profiler. Uers can use `fluid.profiler.start_profiler` and
......@@ -156,6 +176,7 @@ def start_profiler(state, tracer_option='Default'):
.. code-block:: python
# required: gpu
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
......@@ -198,6 +219,12 @@ def start_profiler(state, tracer_option='Default'):
core.enable_profiler(prof_state)
@deprecated(
since="2.3.0",
update_to="paddle.profiler.Profiler",
level=1,
reason="Please use new profiler tool, this profiler tool is no longer maintained."
)
def stop_profiler(sorted_key=None, profile_path='/tmp/profile'):
"""
Stop the profiler. Uers can use `fluid.profiler.start_profiler` and
......@@ -225,6 +252,7 @@ def stop_profiler(sorted_key=None, profile_path='/tmp/profile'):
.. code-block:: python
# required: gpu
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
......@@ -254,6 +282,12 @@ def stop_profiler(sorted_key=None, profile_path='/tmp/profile'):
core.disable_profiler(key_map[sorted_key], profile_path)
@deprecated(
since="2.3.0",
update_to="paddle.profiler.Profiler",
level=1,
reason="Please use new profiler tool, this profiler tool is no longer maintained."
)
@signature_safe_contextmanager
def profiler(state,
sorted_key=None,
......
......@@ -56,7 +56,15 @@ class TestProfilerStatistic(unittest.TestCase):
mobilenet_node = HostPythonNode(
'MobileNet', profiler.TracerEventType.Forward, 20, 50, 1000, 1001)
yolonet_node = HostPythonNode(
'Yolov3Net', profiler.TracerEventType.Forward, 50, 100, 1000, 1001)
'Yolov3Net', profiler.TracerEventType.Forward, 50, 110, 1000, 1001)
userdefined_node = HostPythonNode('Communication Time',
profiler.TracerEventType.UserDefined,
100, 110, 1000, 1001)
communication_node = HostPythonNode(
'Communication', profiler.TracerEventType.Communication, 105, 110,
1000, 1001)
backward_node = HostPythonNode('Gradient Backward',
profiler.TracerEventType.Backward, 120,
200, 1000, 1001)
......@@ -114,7 +122,9 @@ class TestProfilerStatistic(unittest.TestCase):
optimization_node
])
mobilenet_node.children_node.append(conv2d_node)
yolonet_node.children_node.append(sync_batch_norm_node)
yolonet_node.children_node.extend(
[sync_batch_norm_node, userdefined_node])
userdefined_node.children_node.append(communication_node)
conv2d_node.children_node.extend(
[conv2d_infer_shape, conv2d_compute, conv2d_MemCpy])
conv2d_compute.runtime_node.append(conv2d_launchkernel)
......@@ -145,7 +155,7 @@ class TestProfilerStatistic(unittest.TestCase):
profiler.TracerEventType.ProfileStep), 400)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Forward), 90)
profiler.TracerEventType.Forward), 100)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Backward), 80)
......@@ -169,15 +179,18 @@ class TestProfilerStatistic(unittest.TestCase):
0, profiler.TracerEventType.Memcpy), 60)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.UserDefined), 15)
profiler.TracerEventType.UserDefined), 25)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Communication), 5)
self.assertEqual(len(event_summary.items), 2)
self.assertEqual(len(event_summary.userdefined_items), 0)
self.assertEqual(len(event_summary.userdefined_items), 1)
self.assertEqual(len(event_summary.model_perspective_items), 3)
self.assertEqual(len(event_summary.memory_manipulation_items), 1)
self.assertEqual(event_summary.items['conv2d'].cpu_time, 15)
self.assertEqual(event_summary.items['conv2d'].gpu_time, 25)
self.assertEqual(
event_summary.model_perspective_items['Forward'].cpu_time, 90)
event_summary.model_perspective_items['Forward'].cpu_time, 100)
self.assertEqual(
event_summary.model_perspective_items['Forward'].gpu_time, 135)
self.assertEqual(
......
......@@ -20,7 +20,7 @@ from .utils import RecordEvent, load_profiler_result
from .profiler_statistic import SortedKeys
__all__ = [
'ProfilerState', 'ProfilerTarget', 'TracerEventType', 'make_scheduler',
'ProfilerState', 'ProfilerTarget', 'make_scheduler',
'export_chrome_tracing', 'export_protobuf', 'Profiler', 'RecordEvent',
'load_profiler_result', 'SortedKeys'
]
......@@ -24,7 +24,7 @@ from paddle.fluid.core import (_Profiler, _ProfilerResult, ProfilerOptions,
TracerEventType)
from .utils import RecordEvent, wrap_optimizers
from .profiler_statistic import SortedKeys
from .profiler_statistic import StatisticData, _build_table, SortedKeys
class ProfilerState(Enum):
......@@ -32,9 +32,12 @@ class ProfilerState(Enum):
Profiler state that can be specified to control profiler action.
CLOSED: The profilers are closed.
READY: The profilers are open, but the data will not be recorded.
This state is used for reducing overhead influence when profilers start.
RECORD: The profilers are open, and the data will be recorded.
RECORD_AND_RETURN: The profilers are open, and at the last batch of current profiler period,
the collected data will be returned.
"""
......@@ -47,6 +50,10 @@ class ProfilerState(Enum):
class ProfilerTarget(Enum):
r"""
Target device for profiling.
CPU: Profile events on CPU.
GPU: Profile events on GPU.
"""
CPU = 0
GPU = 1
......@@ -62,6 +69,8 @@ def make_scheduler(*,
Return a scheduler function, which scheduler the state according to the setting.
The state transform confirms to:
.. code-block:: text
(CLOSED) (CLOSED) (CLOSED) (READY) (RECORD,last RETURN) (CLOSED)
START -> skip_first -> closed -> ready -> record -> END
| |
......@@ -81,13 +90,23 @@ def make_scheduler(*,
Examples:
1. profiling range [2, 5]
batch 0: closed, batch 1: ready, batch [2, 5] record
.. code-block:: python
make_scheduler(closed=1, ready=1, record=4, repeat=1)
import paddle.profiler as profiler
profiler.make_scheduler(closed=1, ready=1, record=4, repeat=1)
2. profiling range [3,6], [9,12], [15,18]...
batch 0: skiped, batch 1: closed, batch 2: ready, batch [3,6]: record, repeat
.. code-block:: python
make_scheduler(closed=1, ready=1, record=4, skip_first=1)
import paddle.profiler as profiler
profiler.make_scheduler(closed=1, ready=1, record=4, skip_first=1)
"""
def getScheduleState(step: int) -> ProfilerState:
......@@ -138,14 +157,15 @@ def export_chrome_tracing(dir_name: str,
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (3, 10),
on_trace_ready = profiler.export_chrome_tracing('./log')
) as p:
for iter in range(N):
train()
on_trace_ready=profiler.export_protobuf('./log')) as p:
for iter in range(10):
#train()
p.step()
"""
if not os.path.exists(dir_name):
......@@ -181,14 +201,15 @@ def export_protobuf(dir_name: str, worker_name: Optional[str]=None) -> Callable:
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (3, 10),
on_trace_ready = profiler.export_protobuf('./log')
) as p:
for iter in range(N):
train()
on_trace_ready = profiler.export_protobuf('./log')) as p:
for iter in range(10):
#train()
p.step()
"""
if not os.path.exists(dir_name):
......@@ -216,7 +237,7 @@ def _get_supported_targets() -> Iterable[ProfilerTarget]:
r"""
Get the current supported profiler target in the system.
"""
if paddle.device.is_compiled_with_cuda():
if _Profiler.is_cupti_supported():
return [ProfilerTarget.CPU, ProfilerTarget.GPU]
return [ProfilerTarget.CPU]
......@@ -226,48 +247,56 @@ class Profiler:
Profiler context manager, user interface to manage profile process.
Parameters:
targets (iterable): list of tracing targets, currently supported values:
``paddle.profiler.ProfilerTarget.CPU``,
``paddle.profiler.ProfilerTarget.GPU``.
targets (iterable): list of tracing targets, currently supported values, ``ProfilerTarget.CPU``, ``ProfilerTarget.GPU`` .
scheduler (callable or tuple): If it is a callable object, it takes a step number as parameter and return the corresponding ``ProfilerState``.
If not provided, the default sheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch,
If not provided, the default scheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch,
which means profiling range [start_batch, end_batch).
on_trace_ready (callable): callable object, takes the Profiler object as parameter, which provides a way for users to do post-processing.
This callable object will be called when ``sheduler`` returns ``ProfilerState.RECORD_AND_RETURN``.
This callable object will be called when ``scheduler`` returns ``ProfilerState.RECORD_AND_RETURN``.
Examples:
1. profiling range [2, 5)
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (2, 5),
on_trace_ready = profiler.export_chrome_tracing('./log')
) as p:
for iter in range(N):
train()
on_trace_ready = profiler.export_chrome_tracing('./log')) as p:
for iter in range(10):
#train()
p.step()
2. profiling range [2,4], [7, 9], [11,13]
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = profiler.make_scheduler(closed=1, ready=1, record=3, repeat=3),
on_trace_ready = profiler.export_chrome_tracing('./log')
) as p:
for iter in range(N):
train()
on_trace_ready = profiler.export_chrome_tracing('./log')) as p:
for iter in range(10):
#train()
p.step()
3. Use profiler without context manager, and use default parameters
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
p = profiler.Profiler()
p.start()
for iter in range(N):
train()
for iter in range(10):
#train()
p.step()
p.stop()
p.summary()
"""
def __init__(
......@@ -335,6 +364,21 @@ class Profiler:
r'''
Start profiler and enter the first profiler step(0).
State transformed from CLOSED to self.current_state and trigger corresponding action.
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
prof = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (1, 9),
on_trace_ready = profiler.export_chrome_tracing('./log'))
prof.start()
for iter in range(10):
#train()
prof.step()
prof.stop()
'''
# CLOSED -> self.current_state
if self.current_state == ProfilerState.READY:
......@@ -354,6 +398,21 @@ class Profiler:
r'''
Stop profiler and State transformed from self.current_state to CLOSED.
Trigger corresponding action and post-process profiler result using self.on_trace_ready if result exists.
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
prof = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (1, 7),
on_trace_ready = profiler.export_chrome_tracing('./log'))
prof.start()
for iter in range(10):
#train()
prof.step()
prof.stop()
'''
# self.current_state -> CLOSED
# In this situation, RECORD state is regarded as RECORD_AND_RETURN
......@@ -375,6 +434,22 @@ class Profiler:
r"""
Signals the profiler that the next profiling step has started.
Get the new ProfilerState and trigger corresponding action.
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
prof = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (3, 7),
on_trace_ready = profiler.export_chrome_tracing('./log'))
prof.start()
for iter in range(10):
#train()
prof.step()
prof.stop()
"""
if self.record_event:
self.record_event.end()
......@@ -448,6 +523,21 @@ class Profiler:
def export(self, path="", format="json"):
r"""
Exports the tracing data in Chrome tracing data format.
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
prof = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (3, 7))
prof.start()
for iter in range(10):
#train()
prof.step()
prof.stop()
prof.export(path="./profiler_data.json", format="json")
"""
if self.profiler_result:
self.profiler_result.save(path, format)
......@@ -461,9 +551,35 @@ class Profiler:
Print the Summary table.
Parameters:
sorted_by: how to rank the op table items.
detail: expand each operator detail information.
thread_sep: print op table each thread.
time_unit: can be chosen form ['s', 'ms', 'us', 'ns']
sorted_by(SortedKeys): how to rank the op table items.
op_detail(bool): expand each operator detail information.
thread_sep(bool): print op table each thread.
time_unit(str): can be chosen form ['s', 'ms', 'us', 'ns']
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
prof = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (3, 7),
on_trace_ready = profiler.export_chrome_tracing('./log'))
prof.start()
for iter in range(10):
#train()
prof.step()
prof.stop()
prof.summary(sorted_by=profiler.SortedKeys.CPUTotal, op_detail=True, thread_sep=False, time_unit='ms')
"""
pass
if self.profiler_result:
statistic_data = StatisticData(
self.profiler_result.get_data(),
self.profiler_result.get_extra_info())
print(
_build_table(
statistic_data,
sorted_by=sorted_by,
op_detail=op_detail,
thread_sep=thread_sep,
time_unit=time_unit))
......@@ -34,6 +34,22 @@ _CommunicationOpName = ['reduce', 'broadcast', 'rpc']
class SortedKeys(Enum):
r"""
Sorted keys for printing summary table.
CPUTotal: Sorted by CPU total time.
CPUAvg: Sorted by CPU average time.
CPUMax: Sorted by CPU max time.
CPUMin: Sorted by CPU min time.
GPUTotal: Sorted by GPU total time.
GPUAvg: Sorted by GPU average time.
GPUMax: Sorted by GPU max time.
GPUMin: Sorted by GPU min time.
"""
CPUTotal = 0
CPUAvg = 1
......@@ -642,6 +658,171 @@ def _build_table(statistic_data,
append('')
append('')
###### Print Model Summary Report ######
model_perspective_items = statistic_data.event_summary.model_perspective_items
if model_perspective_items:
headers = [
'Name', 'Calls', 'CPU Total / Avg / Max / Min / Ratio(%)',
'GPU Total / Avg / Max / Min / Ratio(%)'
]
row_format_list = [""]
header_sep_list = [""]
line_length_list = [-SPACING_SIZE]
name_column_width = 15
add_column(name_column_width)
add_column(6)
add_column(40)
add_column(40)
row_format = row_format_list[0]
header_sep = header_sep_list[0]
line_length = line_length_list[0]
# construct table string
append(add_title(line_length, "Model Summary"))
append('Time unit: {}'.format(time_unit))
append(header_sep)
append(row_format.format(*headers))
append(header_sep)
accmulation_time = 0
row_values = [
'Total Time', '-', '{} / - / - / - / {}'.format(
format_time(
total_time, unit=time_unit), format_ratio(1)),
'- / - / - / -/ -'
]
append(row_format.format(*row_values))
for name in ['Dataloader', 'Forward', 'Backward', 'Optimization']:
if name in model_perspective_items:
item = model_perspective_items[name]
row_values = [
' {}'.format(name), item.call,
'{} / {} / {} / {} / {}'.format(
format_time(
item.cpu_time, unit=time_unit),
format_time(
item.avg_cpu_time, unit=time_unit),
format_time(
item.max_cpu_time, unit=time_unit),
format_time(
item.min_cpu_time, unit=time_unit),
format_ratio(float(item.cpu_time) / total_time)),
'{} / {} / {} / {} / {}'.format(
format_time(
item.gpu_time, unit=time_unit),
format_time(
item.avg_gpu_time, unit=time_unit),
format_time(
item.max_gpu_time, unit=time_unit),
format_time(
item.min_gpu_time, unit=time_unit),
format_ratio(float(item.gpu_time) / total_time))
]
append(row_format.format(*row_values))
accmulation_time += item.cpu_time
other_time = total_time - accmulation_time
row_values = [
' Others', '-', '{} / - / - / - / {}'.format(
format_time(
other_time, unit=time_unit),
format_ratio(float(other_time) / total_time)),
'- / - / - / - / -'
]
append(row_format.format(*row_values))
append(header_sep)
append('')
append('')
###### Print Distribution Summary Report ######
if TracerEventType.Communication in statistic_data.time_range_summary.CPUTimeRange:
headers = [
'Name',
'Total Time',
'Ratio (%)',
]
row_format_list = [""]
header_sep_list = [""]
line_length_list = [-SPACING_SIZE]
DEFAULT_COLUMN_WIDTH = 20
for _ in headers:
add_column(DEFAULT_COLUMN_WIDTH)
row_format = row_format_list[0]
header_sep = header_sep_list[0]
line_length = line_length_list[0]
# construct table string
append(add_title(line_length, "Distribution Summary"))
append('Time unit: {}'.format(time_unit))
append(header_sep)
append(row_format.format(*headers))
append(header_sep)
cpu_communication_time_range = []
gpu_communication_time_range = []
cpu_communication_time_range = merge_ranges(
statistic_data.time_range_summary.CPUTimeRange[
TracerEventType.Communication], cpu_communication_time_range)
kernel_time_range = []
for device_id, device_time_ranges in statistic_data.time_range_summary.GPUTimeRange.items(
):
kernel_time_range = merge_ranges(
device_time_ranges[TracerEventType.Kernel],
kernel_time_range,
is_sorted=True)
gpu_communication_time_range = merge_ranges(
device_time_ranges[TracerEventType.Communication],
gpu_communication_time_range,
is_sorted=True)
communication_time_range = merge_ranges(
cpu_communication_time_range,
gpu_communication_time_range,
is_sorted=True)
computation_time_range = subtract_ranges(kernel_time_range,
gpu_communication_time_range)
overlap_time_range = intersection_ranges(communication_time_range,
computation_time_range)
communication_time = sum_ranges(communication_time_range)
computation_time = sum_ranges(computation_time_range)
overlap_time = sum_ranges(overlap_time_range)
row_values = [
'Communication', format_time(
communication_time, unit=time_unit),
format_ratio(float(communication_time) / total_time)
]
append(row_format.format(*row_values))
row_values = [
'Computation', format_time(
computation_time, unit=time_unit),
format_ratio(float(computation_time) / total_time)
]
append(row_format.format(*row_values))
row_values = [
'Overlap', format_time(
overlap_time, unit=time_unit),
format_ratio(float(overlap_time) / total_time)
]
append(row_format.format(*row_values))
append(header_sep)
append(
"Note:\nCommunication time: Communication Op time and its kernel time on gpu.\n"
"Computation time: Kernel time, substract kernels belong to communication op.\n"
"Overlap time: Communication time intersect with computation time.\n"
"Example:\n"
"Communication:\n"
" CPU: |_________________|\n"
" GPU: |______________|\n"
" Total: |_________________| |______________|\n"
"Computation time(Kernel):\n"
" GPU: |________________|\n"
"Overlap time: |___________|\n")
append('-' * line_length)
append('')
append('')
###### Print Operator Summary Report ######
if statistic_data.event_summary.items:
headers = [
......@@ -708,11 +889,6 @@ def _build_table(statistic_data,
sorted_items = sorted(
items.items(), key=lambda x: x[1].min_gpu_time)
total_cpu_time = 0
total_gpu_time = 0
for name, item in sorted_items:
total_cpu_time += item.cpu_time
total_gpu_time += item.gpu_time
for name, item in sorted_items:
row_values = [
name, item.call, '{} / {} / {} / {} / {}'.format(
......@@ -724,7 +900,7 @@ def _build_table(statistic_data,
item.max_cpu_time, unit=time_unit),
format_time(
item.min_cpu_time, unit=time_unit),
format_ratio(float(item.cpu_time) / total_cpu_time)),
format_ratio(float(item.cpu_time) / total_time)),
'{} / {} / {} / {} / {}'.format(
format_time(
item.gpu_time, unit=time_unit),
......@@ -734,7 +910,7 @@ def _build_table(statistic_data,
item.max_gpu_time, unit=time_unit),
format_time(
item.min_gpu_time, unit=time_unit),
format_ratio(float(item.gpu_time) / total_gpu_time))
format_ratio(float(item.gpu_time) / total_time))
]
append(row_format.format(*row_values))
if op_detail:
......@@ -752,8 +928,7 @@ def _build_table(statistic_data,
format_time(
innerop_node.min_cpu_time, unit=time_unit),
format_ratio(
float(innerop_node.cpu_time) /
total_cpu_time)),
float(innerop_node.cpu_time) / total_time)),
'{} / {} / {} / {} / {}'.format(
format_time(
innerop_node.gpu_time, unit=time_unit),
......@@ -764,8 +939,7 @@ def _build_table(statistic_data,
format_time(
innerop_node.min_gpu_time, unit=time_unit),
format_ratio(
float(innerop_node.gpu_time) /
total_gpu_time))
float(innerop_node.gpu_time) / total_time))
]
append(row_format.format(*row_values))
for device_node_name, devicenode in innerop_node.devices.items(
......@@ -792,7 +966,7 @@ def _build_table(statistic_data,
unit=time_unit),
format_ratio(
float(devicenode.gpu_time) /
total_gpu_time))
total_time))
]
append(row_format.format(*row_values))
for device_node_name, device_node in item.devices.items():
......@@ -814,11 +988,160 @@ def _build_table(statistic_data,
format_time(
devicenode.min_gpu_time, unit=time_unit),
format_ratio(
float(devicenode.gpu_time) /
total_gpu_time))
float(devicenode.gpu_time) / total_time))
]
append(row_format.format(*row_values))
append(header_sep)
append('')
append('')
###### Print Memory Manipulation Summary Report ######
if statistic_data.event_summary.memory_manipulation_items:
headers = [
'Name', 'Calls', 'CPU Total / Avg / Max / Min / Ratio(%)',
'GPU Total / Avg / Max / Min / Ratio(%)'
]
row_format_list = [""]
header_sep_list = [""]
line_length_list = [-SPACING_SIZE]
name_column_width = 30
add_column(name_column_width)
add_column(6)
add_column(40)
add_column(40)
row_format = row_format_list[0]
header_sep = header_sep_list[0]
line_length = line_length_list[0]
# construct table string
append(add_title(line_length, "Memory Manipulation Summary"))
append('Time unit: {}'.format(time_unit))
append(header_sep)
append(row_format.format(*headers))
append(header_sep)
memory_manipulation_items = statistic_data.event_summary.memory_manipulation_items
for name, item in memory_manipulation_items.items():
row_values = [
name,
item.call,
'{} / {} / {} / {} / {}'.format(
format_time(
item.cpu_time, unit=time_unit),
format_time(
item.avg_cpu_time, unit=time_unit),
format_time(
item.max_cpu_time, unit=time_unit),
format_time(
item.min_cpu_time, unit=time_unit),
format_ratio(float(item.cpu_time) / total_time)),
'{} / {} / {} / {} / {}'.format(
format_time(
item.gpu_time, unit=time_unit),
format_time(
item.avg_gpu_time, unit=time_unit),
format_time(
item.max_gpu_time, unit=time_unit),
format_time(
item.min_gpu_time, unit=time_unit),
format_ratio(float(item.gpu_time) / total_time)),
]
append(row_format.format(*row_values))
append(header_sep)
append('')
append('')
###### Print UserDefined Summary Report ######
if statistic_data.event_summary.userdefined_items:
headers = [
'Name', 'Calls', 'CPU Total / Avg / Max / Min / Ratio(%)',
'GPU Total / Avg / Max / Min / Ratio(%)'
]
row_format_list = [""]
header_sep_list = [""]
line_length_list = [-SPACING_SIZE]
name_column_width = 30
add_column(name_column_width)
add_column(6)
add_column(40)
add_column(40)
row_format = row_format_list[0]
header_sep = header_sep_list[0]
line_length = line_length_list[0]
# construct table string
append(add_title(line_length, "UserDefined Summary"))
append('Time unit: {}'.format(time_unit))
append(header_sep)
append(row_format.format(*headers))
append(header_sep)
if thread_sep == True:
userdefined_thread_items = statistic_data.event_summary.userdefined_thread_items
else:
userdefined_thread_items = {
'All threads merged':
statistic_data.event_summary.userdefined_items
}
for thread_id, items in userdefined_thread_items.items():
append(add_title(line_length, "Thread: {}".format(thread_id)))
if sorted_by == SortedKeys.CPUTotal:
sorted_items = sorted(
items.items(), key=lambda x: x[1].cpu_time, reverse=True)
elif sorted_by == SortedKeys.CPUAvg:
sorted_items = sorted(
items.items(),
key=lambda x: x[1].avg_cpu_time,
reverse=True)
elif sorted_by == SortedKeys.CPUMax:
sorted_items = sorted(
items.items(),
key=lambda x: x[1].max_cpu_time,
reverse=True)
elif sorted_by == SortedKeys.CPUMin:
sorted_items = sorted(
items.items(), key=lambda x: x[1].min_cpu_time)
elif sorted_by == SortedKeys.GPUTotal:
sorted_items = sorted(
items.items(), key=lambda x: x[1].gpu_time, reverse=True)
elif sorted_by == SortedKeys.GPUAvg:
sorted_items = sorted(
items.items(),
key=lambda x: x[1].avg_gpu_time,
reverse=True)
elif sorted_by == SortedKeys.GPUMax:
sorted_items = sorted(
items.items(),
key=lambda x: x[1].max_gpu_time,
reverse=True)
elif sorted_by == SortedKeys.GPUMin:
sorted_items = sorted(
items.items(), key=lambda x: x[1].min_gpu_time)
for name, item in sorted_items:
row_values = [
name,
item.call,
'{} / {} / {} / {} / {}'.format(
format_time(
item.cpu_time, unit=time_unit),
format_time(
item.avg_cpu_time, unit=time_unit),
format_time(
item.max_cpu_time, unit=time_unit),
format_time(
item.min_cpu_time, unit=time_unit),
format_ratio(float(item.cpu_time) / total_time)),
'{} / {} / {} / {} / {}'.format(
format_time(
item.gpu_time, unit=time_unit),
format_time(
item.avg_gpu_time, unit=time_unit),
format_time(
item.max_gpu_time, unit=time_unit),
format_time(
item.min_gpu_time, unit=time_unit),
format_ratio(float(item.gpu_time) / total_time)),
]
append(row_format.format(*row_values))
append(header_sep)
return ''.join(result)
......@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.core import (_RecordEvent, TracerEventType,
load_profiler_result)
from typing import Any
from warnings import warn
import functools
from contextlib import ContextDecorator
from paddle.fluid.core import (_RecordEvent, TracerEventType)
import paddle.fluid.core as core
_AllowedEventTypeList = [
TracerEventType.Dataloader, TracerEventType.ProfileStep,
TracerEventType.UserDefined, TracerEventType.Forward,
......@@ -33,13 +34,27 @@ class RecordEvent(ContextDecorator):
Parameters:
name(str): Name of the record event
event_type(TracerEventType): Type of the record event, can be used for statistics.
Examples:
.. code-block:: python
import paddle
import paddle.profiler as profiler
with profiler.RecordEvent(name='op1', event_type=TracerEventType=TracerEventType.UserDefined):
op1()
# method1: using context manager
with profiler.RecordEvent("record_add"):
data1 = paddle.randn(shape=[3])
data2 = paddle.randn(shape=[3])
result = data1 + data2
# method2: call begin() and end()
record_event = profiler.RecordEvent("record_add")
record_event.begin()
data1 = paddle.randn(shape=[3])
data2 = paddle.randn(shape=[3])
result = data1 + data2
record_event.end()
Note:
RecordEvent will take effect only when profiler is on and at the state of RECORD.
"""
def __init__(self,
......@@ -57,6 +72,20 @@ class RecordEvent(ContextDecorator):
self.end()
def begin(self):
r"""
Record the time of begining.
.. code-block:: python
import paddle
import paddle.profiler as profiler
record_event = profiler.RecordEvent("record_sub")
record_event.begin()
data1 = paddle.randn(shape=[3])
data2 = paddle.randn(shape=[3])
result = data1 - data2
record_event.end()
"""
if self.event_type not in _AllowedEventTypeList:
warn("Only TracerEvent Type in [{}, {}, {}, {}, {}, {},{}]\
can be recorded.".format(*_AllowedEventTypeList))
......@@ -67,10 +96,51 @@ class RecordEvent(ContextDecorator):
self.event = _RecordEvent(self.name, self.event_type)
def end(self):
r'''
Record the time of ending.
.. code-block:: python
import paddle
import paddle.profiler as profiler
record_event = profiler.RecordEvent("record_mul")
record_event.begin()
data1 = paddle.randn(shape=[3])
data2 = paddle.randn(shape=[3])
result = data1 * data2
record_event.end()
'''
if self.event:
self.event.end()
def load_profiler_result(filename: str):
r"""
Load dumped profiler data back to memory.
Parameters:
filename(str): Name of the exported protobuf file of profiler data.
Returns:
ProfilerResult object.
Examples:
.. code-block:: python
# required: gpu
import paddle.profiler as profiler
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
scheduler = (3, 10)) as p:
for iter in range(10):
#train()
p.step()
p.export('test_export_protobuf.pb', format='pb')
profiler_result = profiler.load_profiler_result('test_export_protobuf.pb')
"""
return core.load_profiler_result(filename)
def wrap_optimizers():
def optimizer_warpper(func):
@functools.wraps(func)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册