未验证 提交 10325a82 编写于 作者: C chenjian 提交者: GitHub

add python profiler package (#40065)

* add python profiler package

* update according to review

* fix bug

* fix bug

* fix bug

* add unit test

* Revert "add unit test"

This reverts commit 4e69ff71b0645e069afe5dd8fea0d07717852c48.

* reduce for pr

* add unit test

* modify for pr

* fix unittest

* update for ci coverage

* modify according to review

* fix bug

* improve coverage
上级 b798fb07
......@@ -489,6 +489,10 @@ void NvprofDisableRecordEvent() { g_enable_nvprof_hook = false; }
void EnableHostEventRecorder() { FLAGS_enable_host_event_recorder_hook = true; }
void DisableHostEventRecorder() {
FLAGS_enable_host_event_recorder_hook = false;
}
std::string PrintHostEvents() {
std::ostringstream oss;
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents();
......
......@@ -216,6 +216,7 @@ void NvprofEnableRecordEvent();
void NvprofDisableRecordEvent();
void EnableHostEventRecorder();
void DisableHostEventRecorder();
// Defined for UT
std::string PrintHostEvents();
......
......@@ -2,7 +2,7 @@ set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model cuda_graph_with_memory_pool fleet_executor global_utils phi_utils tcp_store)
cost_model cuda_graph_with_memory_pool fleet_executor global_utils phi_utils tcp_store new_profiler)
if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......
......@@ -78,6 +78,9 @@ limitations under the License. */
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_python.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/fluid/pybind/eager.h"
......@@ -2913,6 +2916,88 @@ All parameter, weight, gradient are variables in Paddle.
});
m.def("size_of_dtype", framework::SizeOfType);
py::class_<paddle::platform::ProfilerResult>(m, "_ProfilerResult")
.def(py::init<>())
.def("get_data", &paddle::platform::ProfilerResult::GetData,
py::return_value_policy::automatic_reference)
.def("save", &paddle::platform::ProfilerResult::Save)
.def("get_extra_info", &paddle::platform::ProfilerResult::GetExtraInfo);
py::class_<paddle::platform::DevicePythonNode>(m, "DevicePythonNode")
.def(py::init<>())
.def_readwrite("name", &paddle::platform::DevicePythonNode::name)
.def_readwrite("type", &paddle::platform::DevicePythonNode::type)
.def_readwrite("start_ns", &paddle::platform::DevicePythonNode::start_ns)
.def_readwrite("end_ns", &paddle::platform::DevicePythonNode::end_ns)
.def_readwrite("device_id",
&paddle::platform::DevicePythonNode::device_id)
.def_readwrite("context_id",
&paddle::platform::DevicePythonNode::context_id)
.def_readwrite("stream_id",
&paddle::platform::DevicePythonNode::stream_id);
py::class_<paddle::platform::HostPythonNode>(m, "HostPythonNode")
.def(py::init<>())
.def_readwrite("name", &paddle::platform::HostPythonNode::name)
.def_readwrite("type", &paddle::platform::HostPythonNode::type)
.def_readwrite("start_ns", &paddle::platform::HostPythonNode::start_ns)
.def_readwrite("end_ns", &paddle::platform::HostPythonNode::end_ns)
.def_readwrite("process_id",
&paddle::platform::HostPythonNode::process_id)
.def_readwrite("thread_id", &paddle::platform::HostPythonNode::thread_id)
.def_readwrite("children_node",
&paddle::platform::HostPythonNode::children_node_ptrs)
.def_readwrite("runtime_node",
&paddle::platform::HostPythonNode::runtime_node_ptrs)
.def_readwrite("device_node",
&paddle::platform::HostPythonNode::device_node_ptrs);
py::class_<paddle::platform::Profiler>(m, "_Profiler")
.def("create", &paddle::platform::Profiler::Create,
py::return_value_policy::take_ownership)
.def("prepare",
[](paddle::platform::Profiler *profiler) {
platform::EnableHostEventRecorder();
profiler->Prepare();
})
.def("start", &paddle::platform::Profiler::Start)
.def("stop",
[](paddle::platform::Profiler *profiler) {
platform::DisableHostEventRecorder();
return profiler->Stop();
},
py::return_value_policy::automatic_reference);
py::class_<paddle::platform::ProfilerOptions>(m, "ProfilerOptions")
.def(py::init<>())
.def_readwrite("trace_switch",
&paddle::platform::ProfilerOptions::trace_switch);
py::class_<platform::RecordEvent>(m, "_RecordEvent")
.def(py::init([](std::string name, platform::TracerEventType type) {
return std::make_unique<platform::RecordEvent>(
name, type, 1, paddle::platform::EventRole::kOrdinary);
}))
.def("end", [](platform::RecordEvent *event) { event->End(); });
py::enum_<paddle::platform::TracerEventType>(m, "TracerEventType")
.value("Operator", paddle::platform::TracerEventType::Operator)
.value("Dataloader", paddle::platform::TracerEventType::Dataloader)
.value("ProfileStep", paddle::platform::TracerEventType::ProfileStep)
.value("CudaRuntime", paddle::platform::TracerEventType::CudaRuntime)
.value("Kernel", paddle::platform::TracerEventType::Kernel)
.value("Memcpy", paddle::platform::TracerEventType::Memcpy)
.value("Memset", paddle::platform::TracerEventType::Memset)
.value("UserDefined", paddle::platform::TracerEventType::UserDefined)
.value("OperatorInner", paddle::platform::TracerEventType::OperatorInner)
.value("Forward", paddle::platform::TracerEventType::Forward)
.value("Backward", paddle::platform::TracerEventType::Backward)
.value("Optimization", paddle::platform::TracerEventType::Optimization)
.value("Communication", paddle::platform::TracerEventType::Communication)
.value("PythonOp", paddle::platform::TracerEventType::PythonOp)
.value("PythonUserDefined",
paddle::platform::TracerEventType::PythonUserDefined);
m.def("load_profiler_result", &paddle::platform::LoadProfilerResult);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("set_cublas_switch", platform::SetAllowTF32Cublas);
......
......@@ -283,6 +283,7 @@ if avx_supported():
from .core_avx import _set_cached_executor_build_strategy
from .core_avx import _device_synchronize
from .core_avx import _get_current_stream
from .core_avx import _Profiler, _ProfilerResult, _RecordEvent
from .core_avx import _set_current_stream
if sys.platform != 'win32':
from .core_avx import _set_process_pids
......@@ -344,6 +345,7 @@ if load_noavx:
from .core_noavx import _device_synchronize
from .core_noavx import _get_current_stream
from .core_noavx import _set_current_stream
from .core_noavx import _Profiler, _ProfilerResult, _RecordEvent
if sys.platform != 'win32':
from .core_noavx import _set_process_pids
from .core_noavx import _erase_process_pids
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.profiler as profiler
class TestProfiler(unittest.TestCase):
def test_profiler(self):
def my_trace_back(prof):
profiler.export_chrome_tracing('./test_profiler_chrometracing/')(
prof)
profiler.export_protobuf('./test_profiler_pb/')(prof)
x_value = np.random.randn(2, 3, 3)
x = paddle.to_tensor(
x_value, stop_gradient=False, place=paddle.CPUPlace())
y = x / 2.0
ones_like_y = paddle.ones_like(y)
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU], ) as prof:
y = x / 2.0
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=(1, 2)) as prof:
with profiler.RecordEvent(name='test'):
y = x / 2.0
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=profiler.make_scheduler(
closed=0, ready=1, record=1, repeat=1),
on_trace_ready=my_trace_back) as prof:
y = x / 2.0
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=profiler.make_scheduler(
closed=0, ready=0, record=2, repeat=1),
on_trace_ready=my_trace_back) as prof:
for i in range(3):
y = x / 2.0
prof.step()
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=lambda x: profiler.ProfilerState.RECORD_AND_RETURN,
on_trace_ready=my_trace_back) as prof:
for i in range(2):
y = x / 2.0
prof.step()
def my_sheduler(num_step):
if num_step % 5 < 2:
return profiler.ProfilerState.RECORD_AND_RETURN
elif num_step % 5 < 3:
return profiler.ProfilerState.READY
elif num_step % 5 < 4:
return profiler.ProfilerState.RECORD
else:
return profiler.ProfilerState.CLOSED
def my_sheduler1(num_step):
if num_step % 5 < 2:
return profiler.ProfilerState.RECORD
elif num_step % 5 < 3:
return profiler.ProfilerState.READY
elif num_step % 5 < 4:
return profiler.ProfilerState.RECORD
else:
return profiler.ProfilerState.CLOSED
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=lambda x: profiler.ProfilerState.RECORD_AND_RETURN,
on_trace_ready=my_trace_back) as prof:
for i in range(2):
y = x / 2.0
prof.step()
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=my_sheduler,
on_trace_ready=my_trace_back) as prof:
for i in range(5):
y = x / 2.0
prof.step()
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=my_sheduler1) as prof:
for i in range(5):
y = x / 2.0
prof.step()
prof = None
with profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU],
scheduler=profiler.make_scheduler(
closed=1, ready=1, record=2, repeat=1, skip_first=1),
on_trace_ready=my_trace_back) as prof:
for i in range(5):
y = x / 2.0
paddle.grad(outputs=y, inputs=[x], grad_outputs=ones_like_y)
prof.step()
prof.export(path='./test_profiler_pb.pb', format='pb')
prof.summary()
result = profiler.utils.load_profiler_result('./test_profiler_pb.pb')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .profiler import ProfilerState, ProfilerTarget
from .profiler import make_scheduler, export_chrome_tracing, export_protobuf
from .profiler import Profiler
from .profiler import TracerEventType
from .utils import RecordEvent, load_profiler_result
from .profiler_statistic import SortedKeys
__all__ = [
'ProfilerState', 'ProfilerTarget', 'TracerEventType', 'make_scheduler',
'export_chrome_tracing', 'export_protobuf', 'Profiler', 'RecordEvent',
'load_profiler_result', 'SortedKeys'
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
import datetime
from enum import Enum
from typing import Any, Callable, Iterable, Optional, Union
from warnings import warn
import paddle
from paddle.fluid.core import (_Profiler, _ProfilerResult, ProfilerOptions,
TracerEventType)
from .utils import RecordEvent, wrap_optimizers
from .profiler_statistic import SortedKeys
class ProfilerState(Enum):
r"""
Profiler state that can be specified to control profiler action.
CLOSED: The profilers are closed.
READY: The profilers are open, but the data will not be recorded.
This state is used for reducing overhead influence when profilers start.
RECORD: The profilers are open, and the data will be recorded.
RECORD_AND_RETURN: The profilers are open, and at the last batch of current profiler period,
the collected data will be returned.
"""
CLOSED = 0
READY = 1
RECORD = 2
RECORD_AND_RETURN = 3 # the last step of RECORD
class ProfilerTarget(Enum):
r"""
Target device for profiling.
"""
CPU = 0
GPU = 1
def make_scheduler(*,
closed: int,
ready: int,
record: int,
repeat: int=0,
skip_first: int=0) -> Callable:
r"""
Return a scheduler function, which scheduler the state according to the setting.
The state transform confirms to:
(CLOSED) (CLOSED) (CLOSED) (READY) (RECORD,last RETURN) (CLOSED)
START -> skip_first -> closed -> ready -> record -> END
| |
| | (if has_repeated < repeat)
- - - - - - - - - - - -
Note that repeat <= 0 means the cycle will continue until the profiler exits.
Parameters:
closed(int): The number of steps in state ProfilerState.CLOSED.
ready(int): The number of steps in state ProfilerState.READY.
record(int): The number of steps in state ProfilerState.RECORD.
repeat(int): The number of cycles to repeat above state transform.
skip_first(int): The number of first steps to drop, not participate in the state transform.
Returns:
A scheduler function, conforms to above state transform setting.
Examples:
1. profiling range [2, 5]
batch 0: closed, batch 1: ready, batch [2, 5] record
.. code-block:: python
make_scheduler(closed=1, ready=1, record=4, repeat=1)
2. profiling range [3,6], [9,12], [15,18]...
batch 0: skiped, batch 1: closed, batch 2: ready, batch [3,6]: record, repeat
.. code-block:: python
make_scheduler(closed=1, ready=1, record=4, skip_first=1)
"""
def getScheduleState(step: int) -> ProfilerState:
assert step >= 0
if step < skip_first: # within skip_first, just skip
return ProfilerState.CLOSED
step = step - skip_first
period_steps = closed + ready + record
has_repeated = step // period_steps
if repeat > 0 and has_repeated >= repeat: # the period has repeated repeat times, return CLOSED state
return ProfilerState.CLOSED
mod_step = step % period_steps
if mod_step < closed:
return ProfilerState.CLOSED
elif mod_step >= closed and mod_step < closed + ready:
return ProfilerState.READY
else:
if mod_step < period_steps - 1:
return ProfilerState.RECORD
else:
return ProfilerState.RECORD_AND_RETURN
assert closed >= 0 and ready >= 0 and record > 0 and \
repeat >= 0 and skip_first >= 0, "Invalid profiler scheduler arguments"
if ready == 0:
warn("Profiler will record data after enabling profiler immediately, \
some data collected at the beginning of profiling may be 'noisy' because of overhead."
)
return getScheduleState
def _default_state_scheduler(step: int):
r"""
A default state scheduler, keep recording from the begining of the profiler until ending.
"""
return ProfilerState.RECORD
def export_chrome_tracing(dir_name: str,
worker_name: Optional[str]=None) -> Callable:
r"""
Return a callable, used for outputing tracing data to chrome tracing format file.
The output file will be saved in directory 'dir_name', and file name will be set as worker_name.
if worker_name is not set, the default name is [hostname]_[pid].
Parameters:
dir_name(str): Directory to save profiling data.
worker_name(Optional[str]): Prefix of the file name saved, default is [hostname]_[pid].
Examples:
.. code-block:: python
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
scheduler = (3, 10),
on_trace_ready = profiler.export_chrome_tracing('./log')
) as p:
for iter in range(N):
train()
p.step()
"""
if not os.path.exists(dir_name):
try:
os.makedirs(dir_name, exist_ok=True)
except Exception:
raise RuntimeError(
"Can not create directory '{}' for saving profiling results.".
format(dir_name))
def handle_fn(prof):
nonlocal worker_name
if not worker_name:
worker_name = "host_{}pid_{}".format(socket.gethostname(),
str(os.getpid()))
now = datetime.datetime.now()
filename = '{}_time_{}.paddle_trace.json'.format(
worker_name, now.strftime('%Y_%m_%d_%H_%M_%S_%f'))
prof.export(os.path.join(dir_name, filename), "json")
return handle_fn
def export_protobuf(dir_name: str, worker_name: Optional[str]=None) -> Callable:
r"""
Return a callable, used for outputing tracing data to protobuf file.
The output file will be saved in directory 'dir_name', and file name will be set as worker_name.
if worker_name is not set, the default name is [hostname]_[pid].
Parameters:
dir_name(str): Directory to save profiling data.
worker_name(Optional[str]): Prefix of the file name saved, default is [hostname]_[pid].
Examples:
.. code-block:: python
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
scheduler = (3, 10),
on_trace_ready = profiler.export_protobuf('./log')
) as p:
for iter in range(N):
train()
p.step()
"""
if not os.path.exists(dir_name):
try:
os.makedirs(dir_name, exist_ok=True)
except Exception:
raise RuntimeError(
"Can not create directory '{}' for saving profiling results.".
format(dir_name))
def handle_fn(prof):
nonlocal worker_name
if not worker_name:
worker_name = "host_{}pid_{}".format(socket.gethostname(),
str(os.getpid()))
now = datetime.datetime.now()
filename = '{}_time_{}.paddle_trace.pb'.format(
worker_name, now.strftime('%Y_%m_%d_%H_%M_%S_%f'))
prof.export(os.path.join(dir_name, filename), "pb")
return handle_fn
def _get_supported_targets() -> Iterable[ProfilerTarget]:
r"""
Get the current supported profiler target in the system.
"""
if paddle.device.is_compiled_with_cuda():
return [ProfilerTarget.CPU, ProfilerTarget.GPU]
return [ProfilerTarget.CPU]
class Profiler:
r"""
Profiler context manager, user interface to manage profile process.
Parameters:
targets (iterable): list of tracing targets, currently supported values:
``paddle.profiler.ProfilerTarget.CPU``,
``paddle.profiler.ProfilerTarget.GPU``.
scheduler (callable or tuple): If it is a callable object, it takes a step number as parameter and return the corresponding ``ProfilerState``.
If not provided, the default sheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch,
which means profiling range [start_batch, end_batch).
on_trace_ready (callable): callable object, takes the Profiler object as parameter, which provides a way for users to do post-processing.
This callable object will be called when ``sheduler`` returns ``ProfilerState.RECORD_AND_RETURN``.
Examples:
1. profiling range [2, 5)
.. code-block:: python
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
scheduler = (2, 5),
on_trace_ready = profiler.export_chrome_tracing('./log')
) as p:
for iter in range(N):
train()
p.step()
2. profiling range [2,4], [7, 9], [11,13]
.. code-block:: python
import paddle.profiler as profiler
with profiler.Profiler(targets=[profiler.ProfilerTarget.CPU,
profiler.ProfilerTarget.GPU],
scheduler = profiler.make_scheduler(closed=1, ready=1, record=3, repeat=3),
on_trace_ready = profiler.export_chrome_tracing('./log')
) as p:
for iter in range(N):
train()
p.step()
3. Use profiler without context manager, and use default parameters
.. code-block:: python
import paddle.profiler as profiler
p = profiler.Profiler()
p.start()
for iter in range(N):
train()
p.step()
p.stop()
p.summary()
"""
def __init__(
self,
*,
targets: Optional[Iterable[ProfilerTarget]]=None,
scheduler: Union[Callable[[int], ProfilerState], tuple, None]=None,
on_trace_ready: Optional[Callable[..., Any]]=None):
supported_targets = _get_supported_targets()
if targets:
self.targets = set(targets)
for target in targets:
if target not in supported_targets:
self.targets.remove(target)
warn("Profiling {} is not supported in current context.".
format(target))
else:
self.targets = supported_targets
profileoption = ProfilerOptions()
if ProfilerTarget.CPU in self.targets:
profileoption.trace_switch |= 1
if ProfilerTarget.GPU in self.targets:
profileoption.trace_switch |= (1 << 1)
wrap_optimizers()
self.profiler = _Profiler.create(profileoption)
if callable(scheduler):
self.scheduler = scheduler
elif isinstance(scheduler, (tuple, list)):
assert len(scheduler) == 2 and scheduler[1] > scheduler[0]
start_batch, end_batch = scheduler
start_batch = max(start_batch, 0)
if start_batch >= 1:
self.scheduler = make_scheduler(
closed=max(start_batch - 1, 0),
ready=1,
record=(end_batch - start_batch),
repeat=1)
else:
self.scheduler = make_scheduler(
closed=0,
ready=0,
record=(end_batch - start_batch),
repeat=1)
else:
self.scheduler = _default_state_scheduler
if on_trace_ready == None:
self.on_trace_ready = export_chrome_tracing('./profiler_log/')
else:
self.on_trace_ready = on_trace_ready
self.step_num = 0
self.previous_state = ProfilerState.CLOSED
self.current_state = self.scheduler(self.step_num)
self.record_event = None
self.profiler_result = None
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
r'''
Start profiler and enter the first profiler step(0).
State transformed from CLOSED to self.current_state and trigger corresponding action.
'''
# CLOSED -> self.current_state
if self.current_state == ProfilerState.READY:
self.profiler.prepare()
elif self.current_state == ProfilerState.RECORD:
self.profiler.prepare()
self.profiler.start()
elif self.current_state == ProfilerState.RECORD_AND_RETURN:
self.profiler.prepare()
self.profiler.start()
self.record_event = RecordEvent(
name="ProfileStep#{}".format(self.step_num),
event_type=TracerEventType.ProfileStep)
self.record_event.begin()
def stop(self):
r'''
Stop profiler and State transformed from self.current_state to CLOSED.
Trigger corresponding action and post-process profiler result using self.on_trace_ready if result exists.
'''
# self.current_state -> CLOSED
# In this situation, RECORD state is regarded as RECORD_AND_RETURN
if self.record_event:
self.record_event.end()
self.record_event = None
if self.current_state == ProfilerState.READY:
warn(
"Inproper Profiler state transform: READY->CLOSED, profiler will start and stop without saving data"
)
self.profiler.start()
self.profiler.stop()
if self.current_state == ProfilerState.RECORD or self.current_state == ProfilerState.RECORD_AND_RETURN:
self.profiler_result = self.profiler.stop()
if self.on_trace_ready:
self.on_trace_ready(self)
def step(self):
r"""
Signals the profiler that the next profiling step has started.
Get the new ProfilerState and trigger corresponding action.
"""
if self.record_event:
self.record_event.end()
self.record_event = None
self.previous_state = self.current_state
self.step_num += 1
self.current_state = self.scheduler(self.step_num)
self._trigger_action()
self.record_event = RecordEvent(
name="ProfileStep#{}".format(self.step_num),
event_type=TracerEventType.ProfileStep)
self.record_event.begin()
def _trigger_action(self):
if self.previous_state == ProfilerState.CLOSED:
if self.current_state == ProfilerState.READY: # CLOSED -> READY
self.profiler.prepare()
if self.current_state == ProfilerState.RECORD: # CLOSED -> RECORD
self.profiler.prepare()
self.profiler.start()
if self.current_state == ProfilerState.RECORD_AND_RETURN: # CLOSED -> RECORD_AND_RETURN
self.profiler.prepare()
self.profiler.start()
elif self.previous_state == ProfilerState.READY:
if self.current_state == ProfilerState.CLOSED: # READY -> CLOSED
warn(
"Improper schedule: READY->CLOSED, profiler will start and stop without saving data"
)
self.profiler.start()
self.profiler.stop()
if self.current_state == ProfilerState.RECORD: # READY -> RECORD
self.profiler.start()
if self.current_state == ProfilerState.RECORD_AND_RETURN: # READY -> RECORD_AND_RETURN
self.profiler.start()
elif self.previous_state == ProfilerState.RECORD:
if self.current_state == ProfilerState.CLOSED: # RECORD -> CLOSED
warn(
"Improper schedule: RECORD->CLOSED, profiler will not saving data"
)
self.profiler.stop()
if self.current_state == ProfilerState.READY: # RECORD -> READY
warn(
"Improper schedule: RECORD->READY, profiler will stop and re-prepare"
)
self.profiler.stop()
self.profiler.prepare()
if self.current_state == ProfilerState.RECORD_AND_RETURN: # RECORD -> RECORD_AND_RETURN
pass
else:
assert self.previous_state == ProfilerState.RECORD_AND_RETURN
if self.current_state == ProfilerState.CLOSED: # RECORD_AND_RETURN -> CLOSED
self.profiler_result = self.profiler.stop()
if self.current_state == ProfilerState.READY: # RECORD_AND_RETURN -> READY
self.profiler_result = self.profiler.stop()
self.profiler.prepare()
if self.current_state == ProfilerState.RECORD: # RECORD_AND_RETURN -> RECORD
self.profiler_result = self.profiler.stop()
self.profiler.prepare()
self.profiler.start()
if self.current_state == ProfilerState.RECORD_AND_RETURN: # RECORD_AND_RETURN -> RECORD_AND_RETURN
self.profiler_result = self.profiler.stop()
self.profiler.prepare()
self.profiler.start()
if self.on_trace_ready:
self.on_trace_ready(self)
def export(self, path="", format="json"):
r"""
Exports the tracing data in Chrome tracing data format.
"""
if self.profiler_result:
self.profiler_result.save(path, format)
def summary(self,
sorted_by=SortedKeys.CPUTotal,
op_detail=True,
thread_sep=False,
time_unit='ms'):
r"""
Print the Summary table.
Parameters:
sorted_by: how to rank the op table items.
detail: expand each operator detail information.
thread_sep: print op table each thread.
time_unit: can be chosen form ['s', 'ms', 'us', 'ns']
"""
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from enum import Enum
from paddle.fluid.core import TracerEventType
class SortedKeys(Enum):
r"""
Sorted keys for printing summary table.
"""
CPUTotal = 0
CPUAvg = 1
CPUMax = 2
CPUMin = 3
GPUTotal = 4
GPUAvg = 5
GPUMax = 6
GPUMin = 7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.core import (_RecordEvent, TracerEventType,
load_profiler_result)
from typing import Any
from warnings import warn
import functools
from contextlib import ContextDecorator
_AllowedEventTypeList = [
TracerEventType.Dataloader, TracerEventType.ProfileStep,
TracerEventType.UserDefined, TracerEventType.Forward,
TracerEventType.Backward, TracerEventType.Optimization,
TracerEventType.PythonOp, TracerEventType.PythonUserDefined
]
class RecordEvent(ContextDecorator):
r"""
Interface for recording a time range.
Parameters:
name(str): Name of the record event
event_type(TracerEventType): Type of the record event, can be used for statistics.
Examples:
.. code-block:: python
import paddle.profiler as profiler
with profiler.RecordEvent(name='op1', event_type=TracerEventType=TracerEventType.UserDefined):
op1()
"""
def __init__(self,
name: str,
event_type: TracerEventType=TracerEventType.UserDefined):
self.name = name
self.event_type = event_type
self.event = None
def __enter__(self):
self.begin()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
self.end()
def begin(self):
if self.event_type not in _AllowedEventTypeList:
warn("Only TracerEvent Type in [{}, {}, {}, {}, {}, {},{}]\
can be recorded.".format(*_AllowedEventTypeList))
self.event = None
else:
if self.event_type == TracerEventType.UserDefined:
self.event_type == TracerEventType.PythonUserDefined
self.event = _RecordEvent(self.name, self.event_type)
def end(self):
if self.event:
self.event.end()
def wrap_optimizers():
def optimizer_warpper(func):
@functools.wraps(func)
def warpper(*args, **kwargs):
with RecordEvent(
'Optimization Step',
event_type=TracerEventType.Optimization):
return func(*args, **kwargs)
return warpper
import paddle.optimizer as optimizer
for classname in optimizer.__all__:
if classname != 'Optimizer':
classobject = getattr(optimizer, classname)
if getattr(classobject, 'step', None) != None:
classobject.step = optimizer_warpper(classobject.step)
......@@ -372,6 +372,7 @@ packages=['paddle',
'paddle.device',
'paddle.device.cuda',
'paddle.version',
'paddle.profiler'
]
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册