未验证 提交 53619873 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement the new profiler api. (#24344)

上级 a851b97a
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import os import os
import tempfile import tempfile
import numpy as np import numpy as np
import paddle.utils as utils
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -31,16 +32,9 @@ class TestProfiler(unittest.TestCase): ...@@ -31,16 +32,9 @@ class TestProfiler(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
def net_profiler(self, def build_program(self, compile_program=True):
state,
option,
iter_range=None,
use_parallel_executor=False):
profile_path = os.path.join(tempfile.gettempdir(), "profile")
open(profile_path, "w").write("")
startup_program = fluid.Program() startup_program = fluid.Program()
main_program = fluid.Program() main_program = fluid.Program()
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
image = fluid.layers.data(name='x', shape=[784], dtype='float32') image = fluid.layers.data(name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=64, act='relu') hidden1 = fluid.layers.fc(input=image, size=64, act='relu')
...@@ -70,34 +64,19 @@ class TestProfiler(unittest.TestCase): ...@@ -70,34 +64,19 @@ class TestProfiler(unittest.TestCase):
optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
opts = optimizer.minimize(avg_cost, startup_program=startup_program) opts = optimizer.minimize(avg_cost, startup_program=startup_program)
place = fluid.CPUPlace() if state == 'CPU' else fluid.CUDAPlace(0) if compile_program:
exe = fluid.Executor(place) train_program = fluid.compiler.CompiledProgram(
exe.run(startup_program) main_program).with_data_parallel(loss_name=avg_cost.name)
if use_parallel_executor: else:
pe = fluid.ParallelExecutor( train_program = main_program
state != 'CPU', return train_program, startup_program, avg_cost, batch_size, batch_acc
loss_name=avg_cost.name,
main_program=main_program) def get_profile_path(self):
profile_path = os.path.join(tempfile.gettempdir(), "profile")
pass_acc_calculator = fluid.average.WeightedAverage() open(profile_path, "w").write("")
with profiler.profiler(state, 'total', profile_path, option) as prof: return profile_path
for iter in range(10):
if iter == 2: def check_profile_result(self, profile_path):
profiler.reset_profiler()
x = np.random.random((32, 784)).astype("float32")
y = np.random.randint(0, 10, (32, 1)).astype("int64")
if use_parallel_executor:
pe.run(feed={'x': x, 'y': y}, fetch_list=[avg_cost.name])
continue
outs = exe.run(main_program,
feed={'x': x,
'y': y},
fetch_list=[avg_cost, batch_acc, batch_size])
acc = np.array(outs[1])
b_size = np.array(outs[2])
pass_acc_calculator.add(value=acc, weight=b_size)
pass_acc = pass_acc_calculator.eval()
data = open(profile_path, 'rb').read() data = open(profile_path, 'rb').read()
if (len(data) > 0): if (len(data) > 0):
profile_pb = profiler_pb2.Profile() profile_pb = profiler_pb2.Profile()
...@@ -115,21 +94,114 @@ class TestProfiler(unittest.TestCase): ...@@ -115,21 +94,114 @@ class TestProfiler(unittest.TestCase):
event.name.startswith("Runtime API")): event.name.startswith("Runtime API")):
print("Warning: unregister", event.name) print("Warning: unregister", event.name)
def run_iter(self, exe, main_program, fetch_list, pass_acc_calculator):
x = np.random.random((32, 784)).astype("float32")
y = np.random.randint(0, 10, (32, 1)).astype("int64")
outs = exe.run(main_program,
feed={'x': x,
'y': y},
fetch_list=fetch_list)
acc = np.array(outs[1])
b_size = np.array(outs[2])
pass_acc_calculator.add(value=acc, weight=b_size)
pass_acc = pass_acc_calculator.eval()
def net_profiler(self,
exe,
state,
tracer_option,
batch_range=None,
use_parallel_executor=False,
use_new_api=False):
main_program, startup_program, avg_cost, batch_size, batch_acc = self.build_program(
compile_program=use_parallel_executor)
exe.run(startup_program)
profile_path = self.get_profile_path()
if not use_new_api:
with profiler.profiler(state, 'total', profile_path, tracer_option):
pass_acc_calculator = fluid.average.WeightedAverage()
for iter in range(10):
if iter == 2:
profiler.reset_profiler()
self.run_iter(exe, main_program,
[avg_cost, batch_acc, batch_size],
pass_acc_calculator)
else:
options = utils.ProfilerOptions(options={
'state': state,
'sorted_key': 'total',
'tracer_level': tracer_option,
'batch_range': [0, 10] if batch_range is None else batch_range,
'profile_path': profile_path
})
with utils.Profiler(enabled=True, options=options) as prof:
pass_acc_calculator = fluid.average.WeightedAverage()
for iter in range(10):
self.run_iter(exe, main_program,
[avg_cost, batch_acc, batch_size],
pass_acc_calculator)
utils.get_profiler().record_step()
if batch_range is None and iter == 2:
utils.get_profiler().reset()
self.check_profile_result(profile_path)
def test_cpu_profiler(self): def test_cpu_profiler(self):
self.net_profiler('CPU', "Default") exe = fluid.Executor(fluid.CPUPlace())
#self.net_profiler('CPU', "Default", use_parallel_executor=True) for use_new_api in [False, True]:
self.net_profiler(
exe,
'CPU',
"Default",
batch_range=[5, 10],
use_new_api=use_new_api)
#self.net_profiler('CPU', "Default", use_parallel_executor=True)
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU") "profiler is enabled only with GPU")
def test_cuda_profiler(self): def test_cuda_profiler(self):
self.net_profiler('GPU', "OpDetail") exe = fluid.Executor(fluid.CUDAPlace(0))
#self.net_profiler('GPU', "OpDetail", use_parallel_executor=True) for use_new_api in [False, True]:
self.net_profiler(
exe,
'GPU',
"OpDetail",
batch_range=[0, 100],
use_new_api=use_new_api)
#self.net_profiler('GPU', "OpDetail", use_parallel_executor=True)
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"profiler is enabled only with GPU") "profiler is enabled only with GPU")
def test_all_profiler(self): def test_all_profiler(self):
self.net_profiler('All', "AllOpDetail") exe = fluid.Executor(fluid.CUDAPlace(0))
#self.net_profiler('All', "AllOpDetail", use_parallel_executor=True) for use_new_api in [False, True]:
self.net_profiler(
exe,
'All',
"AllOpDetail",
batch_range=None,
use_new_api=use_new_api)
#self.net_profiler('All', "AllOpDetail", use_parallel_executor=True)
class TestProfilerAPIError(unittest.TestCase):
def test_errors(self):
options = utils.ProfilerOptions()
self.assertTrue(options['profile_path'] is None)
self.assertTrue(options['timeline_path'] is None)
options = options.with_state('All')
self.assertTrue(options['state'] == 'All')
try:
print(options['test'])
except ValueError:
pass
global_profiler = utils.get_profiler()
with utils.Profiler(enabled=True) as prof:
self.assertTrue(utils.get_profiler() == prof)
self.assertTrue(global_profiler != prof)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# limitations under the License. # limitations under the License.
from .plot import Ploter from .plot import Ploter
from .profiler import ProfilerOptions
from .profiler import Profiler
from .profiler import get_profiler
__all__ = ['dump_config', 'Ploter'] __all__ = ['dump_config', 'Ploter']
#TODO: define new api under this directory #TODO: define new api under this directory
# __all__ = ['profiler', # __all__ = ['unique_name',
# 'profiler.cuda_profiler',
# 'profiler.profiler',
# 'profiler.reset_profiler',
# 'profiler.start_profiler',
# 'profiler.stop_profiler',
# 'unique_name',
# 'load_op_library', # 'load_op_library',
# 'require_version'] # 'require_version']
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,9 +12,124 @@ ...@@ -12,9 +12,124 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#TODO: define new api of profiler from __future__ import print_function
# __all__ = ['cuda_profiler',
# 'profiler', import sys
# 'reset_profiler', import warnings
# 'start_profiler',
# 'stop_profiler'] from ..fluid import core
from ..fluid.profiler import *
__all__ = ['ProfilerOptions', 'Profiler', 'get_profiler']
class ProfilerOptions(object):
def __init__(self, options=None):
self.options = {
'state': 'All',
'sorted_key': 'default',
'tracer_level': 'Default',
'batch_range': [0, sys.maxsize],
'output_thread_detail': False,
'profile_path': 'none',
'timeline_path': 'none',
'op_summary_path': 'none'
}
if options is not None:
for key in self.options.keys():
if options.get(key, None) is not None:
self.options[key] = options[key]
# function to set one specified option
def with_state(self, state):
self.options['state'] = state
return self
def __getitem__(self, name):
if self.options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
else:
if isinstance(self.options[name],
str) and self.options[name] == 'none':
return None
else:
return self.options[name]
_current_profiler = None
class Profiler(object):
def __init__(self, enabled=True, options=None):
if options is not None:
self.profiler_options = options
else:
self.profiler_options = ProfilerOptions()
self.batch_id = 0
self.enabled = enabled
def __enter__(self):
# record current profiler
global _current_profiler
self.previous_profiler = _current_profiler
_current_profiler = self
if self.enabled:
if self.profiler_options['batch_range'][0] == 0:
self.start()
return self
def __exit__(self, exception_type, exception_value, traceback):
global _current_profiler
_current_profiler = self.previous_profiler
if self.enabled:
self.stop()
def start(self):
if self.enabled:
try:
start_profiler(
state=self.profiler_options['state'],
tracer_option=self.profiler_options['tracer_level'])
except Exception as e:
warnings.warn(
"Profiler is not enabled becuase following exception:\n{}".
format(e))
def stop(self):
if self.enabled:
try:
stop_profiler(
sorted_key=self.profiler_options['sorted_key'],
profile_path=self.profiler_options['profile_path'])
except Exception as e:
warnings.warn(
"Profiler is not disabled becuase following exception:\n{}".
format(e))
def reset(self):
if self.enabled and core.is_profiler_enabled():
reset_profiler()
def record_step(self, change_profiler_status=True):
if not self.enabled:
return
self.batch_id = self.batch_id + 1
if change_profiler_status:
if self.batch_id == self.profiler_options['batch_range'][0]:
if core.is_profiler_enabled():
self.reset()
else:
self.start()
if self.batch_id == self.profiler_options['batch_range'][1]:
self.stop()
def get_profiler():
global _current_profiler
if _current_profiler is None:
_current_profiler = Profiler()
return _current_profiler
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册