Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
72ff5a09
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72ff5a09
编写于
3月 03, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
3月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix print bug of profile, test=develop (#22804)
上级
4e8bc024
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
63 addition
and
32 deletion
+63
-32
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+4
-4
paddle/fluid/platform/event.h
paddle/fluid/platform/event.h
+11
-1
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+7
-6
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+4
-9
paddle/fluid/platform/profiler_helper.h
paddle/fluid/platform/profiler_helper.h
+34
-10
paddle/fluid/platform/profiler_test.cc
paddle/fluid/platform/profiler_test.cc
+3
-2
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
72ff5a09
...
...
@@ -173,7 +173,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
platform
::
RecordEvent
op_type_record_event
(
Type
());
auto
op_name
=
platform
::
OpName
(
outputs_
,
Type
());
platform
::
RecordEvent
op_name_record_event
(
op_name
,
platform
::
Record
Role
::
kUniqueOp
);
op_name
,
platform
::
Event
Role
::
kUniqueOp
);
RunImpl
(
scope
,
place
);
}
...
...
@@ -957,7 +957,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
Scope
*
transfer_scope
=
nullptr
;
{
platform
::
RecordEvent
record_event
(
"prepare_data"
,
platform
::
Record
Role
::
kInnerOp
);
platform
::
Event
Role
::
kInnerOp
);
transfer_scope
=
PrepareData
(
scope
,
*
kernel_type_
,
&
transfered_inplace_vars
,
runtime_ctx
);
}
...
...
@@ -971,7 +971,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if
(
!
all_kernels_must_compute_runtime_shape_
)
{
platform
::
RecordEvent
record_event
(
"infer_shape"
,
platform
::
Record
Role
::
kInnerOp
);
platform
::
Event
Role
::
kInnerOp
);
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
*
runtime_ctx
);
this
->
InferShape
(
&
infer_shape_ctx
);
}
...
...
@@ -984,7 +984,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// not Scope. Imperative mode only pass inputs and get outputs.
{
platform
::
RecordEvent
record_event
(
"compute"
,
platform
::
Record
Role
::
kInnerOp
);
platform
::
Event
Role
::
kInnerOp
);
(
*
kernel_func_
)(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
,
*
runtime_ctx
,
kernel_configs
));
}
...
...
paddle/fluid/platform/event.h
浏览文件 @
72ff5a09
...
...
@@ -25,18 +25,27 @@ namespace platform {
enum
class
EventType
{
kMark
,
kPushRange
,
kPopRange
};
enum
class
EventRole
{
kOrdinary
,
// only record op time with op type key
kInnerOp
,
// record op detail time with op type key
kUniqueOp
,
// record op detail time with op unique name key
};
class
Event
{
public:
// The DeviceContext is used to get the cuda stream.
// If CPU profiling mode, can pass nullptr.
Event
(
EventType
type
,
std
::
string
name
,
uint32_t
thread_id
);
Event
(
EventType
type
,
std
::
string
name
,
uint32_t
thread_id
,
EventRole
role
=
EventRole
::
kOrdinary
);
const
EventType
&
type
()
const
;
Event
*
parent
()
const
{
return
parent_
;
}
void
set_parent
(
Event
*
parent
)
{
parent_
=
parent
;
}
std
::
string
name
()
const
{
return
name_
;
}
EventRole
role
()
const
{
return
role_
;
}
uint32_t
thread_id
()
const
{
return
thread_id_
;
}
void
set_name
(
std
::
string
name
)
{
name_
=
name
;
}
void
set_role
(
EventRole
role
)
{
role_
=
role
;
}
#ifdef PADDLE_WITH_CUDA
#ifndef PADDLE_WITH_CUPTI
...
...
@@ -53,6 +62,7 @@ class Event {
std
::
string
name_
{};
Event
*
parent_
{
nullptr
};
uint32_t
thread_id_
;
EventRole
role_
{};
int64_t
cpu_ns_
;
bool
visited_status_
{
false
};
#ifdef PADDLE_WITH_CUDA
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
72ff5a09
...
...
@@ -42,8 +42,9 @@ namespace platform {
MemEvenRecorder
MemEvenRecorder
::
recorder
;
Event
::
Event
(
EventType
type
,
std
::
string
name
,
uint32_t
thread_id
)
:
type_
(
type
),
name_
(
name
),
thread_id_
(
thread_id
)
{
Event
::
Event
(
EventType
type
,
std
::
string
name
,
uint32_t
thread_id
,
EventRole
role
)
:
type_
(
type
),
name_
(
name
),
thread_id_
(
thread_id
),
role_
(
role
)
{
cpu_ns_
=
GetTimeInNsec
();
}
...
...
@@ -62,12 +63,12 @@ double Event::CudaElapsedMs(const Event &e) const {
#endif
}
RecordEvent
::
RecordEvent
(
const
std
::
string
&
name
,
const
Record
Role
role
)
RecordEvent
::
RecordEvent
(
const
std
::
string
&
name
,
const
Event
Role
role
)
:
is_enabled_
(
false
),
start_ns_
(
PosixInNsec
()),
role_
(
role
)
{
if
(
g_state
==
ProfilerState
::
kDisabled
||
name
.
empty
())
return
;
// lock is not needed, the code below is thread-safe
is_enabled_
=
true
;
Event
*
e
=
PushEvent
(
name
);
Event
*
e
=
PushEvent
(
name
,
role
);
// Maybe need the same push/pop behavior.
SetCurAnnotation
(
e
);
name_
=
e
->
name
();
...
...
@@ -179,8 +180,8 @@ void Mark(const std::string &name) {
GetEventList
().
Record
(
EventType
::
kMark
,
name
,
g_thread_id
);
}
Event
*
PushEvent
(
const
std
::
string
&
name
)
{
return
GetEventList
().
Record
(
EventType
::
kPushRange
,
name
,
g_thread_id
);
Event
*
PushEvent
(
const
std
::
string
&
name
,
const
EventRole
role
)
{
return
GetEventList
().
Record
(
EventType
::
kPushRange
,
name
,
g_thread_id
,
role
);
}
void
PopEvent
(
const
std
::
string
&
name
)
{
...
...
paddle/fluid/platform/profiler.h
浏览文件 @
72ff5a09
...
...
@@ -43,12 +43,6 @@ enum class ProfilerState {
kAll
,
// Profile both CPU and GPU. (Currently experimental).
};
enum
class
RecordRole
{
kOrdinary
,
// only record op time with op type key
kInnerOp
,
// record op detail time with op type key
kUniqueOp
,
// record op detail time with op unique name key
};
// it is the flag to control to print the profiling result
enum
class
TracerOption
{
kDefault
,
// print the different op type profiling result
...
...
@@ -86,6 +80,7 @@ struct EventItem {
double
cpu_time
;
double
gpu_time
;
float
ratio
;
EventRole
role
;
};
struct
OverHead
{
...
...
@@ -128,7 +123,7 @@ struct MemEvenRecorder {
struct
RecordEvent
{
RecordEvent
(
const
std
::
string
&
name
,
const
RecordRole
role
=
Record
Role
::
kOrdinary
);
const
EventRole
role
=
Event
Role
::
kOrdinary
);
~
RecordEvent
();
...
...
@@ -139,7 +134,7 @@ struct RecordEvent {
// Need to distinguish name by op type, block_id, program_id and perhaps
// different kernel invocations within an op.
std
::
string
full_name_
;
RecordRole
role_
{
Record
Role
::
kOrdinary
};
EventRole
role_
{
Event
Role
::
kOrdinary
};
};
class
RecordRPCEvent
{
...
...
@@ -201,7 +196,7 @@ void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const
Place
&
place
,
const
std
::
string
&
annotation
);
void
PopMemEvent
(
uint64_t
start_ns
,
uint64_t
end_ns
,
size_t
bytes
,
const
Place
&
place
,
const
std
::
string
&
annotation
);
Event
*
PushEvent
(
const
std
::
string
&
name
);
Event
*
PushEvent
(
const
std
::
string
&
name
,
const
EventRole
role
);
void
PopEvent
(
const
std
::
string
&
name
);
// Return the event list of all threads. Assumed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
...
...
paddle/fluid/platform/profiler_helper.h
浏览文件 @
72ff5a09
...
...
@@ -304,9 +304,9 @@ void SetEvent(bool merge_thread, Event analyze_event, size_t *max_name_width,
if
(
event_idx
->
find
(
event_name
)
==
event_idx
->
end
())
{
event_idx
->
insert
({
event_name
,
event_items
->
size
()});
EventItem
event_item
=
{
event_name
,
1
,
event_time
,
event_time
,
event_time
,
event
_time
,
cpu_time
,
gpu_time
,
0.
};
EventItem
event_item
=
{
event_name
,
1
,
event_time
,
event_time
,
event_time
,
event_time
,
cpu_time
,
gpu
_time
,
0.
,
rit
->
role
()
};
event_items
->
push_back
(
event_item
);
}
else
{
int
index
=
event_idx
->
at
(
event_name
);
...
...
@@ -335,8 +335,10 @@ void SetEvent(bool merge_thread, Event analyze_event, size_t *max_name_width,
void
ComputeOverhead
(
const
std
::
multimap
<
std
::
string
,
EventItem
>
&
sub_child_map
,
OverHead
*
overhead
)
{
EventItem
memcpy_async
=
{
"GpuMemcpyAsync"
,
0
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.0
f
};
EventItem
memcpy_sync
=
{
"GpuMemcpySync"
,
0
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.0
f
};
EventItem
memcpy_async
=
{
"GpuMemcpyAsync"
,
0
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.0
f
,
EventRole
::
kOrdinary
};
EventItem
memcpy_sync
=
{
"GpuMemcpySync"
,
0
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.0
f
,
EventRole
::
kOrdinary
};
for
(
auto
it
=
sub_child_map
.
begin
();
it
!=
sub_child_map
.
end
();
it
++
)
{
if
(
it
->
second
.
name
.
find
(
"compute"
)
!=
std
::
string
::
npos
)
{
overhead
->
compute_ratio
+=
it
->
second
.
ratio
;
...
...
@@ -359,6 +361,29 @@ void ComputeOverhead(const std::multimap<std::string, EventItem> &sub_child_map,
overhead
->
sub_memcpy_items
=
{
memcpy_async
,
memcpy_sync
};
}
std
::
string
FindOrdinaryParent
(
const
std
::
multimap
<
std
::
string
,
EventItem
>
&
sub_child_map
,
std
::
string
name
)
{
bool
find_name
=
false
;
std
::
string
parent
=
name
;
EventRole
role
;
for
(
auto
it
=
sub_child_map
.
begin
();
it
!=
sub_child_map
.
end
();
it
++
)
{
if
(
it
->
second
.
name
==
name
)
{
role
=
it
->
second
.
role
;
parent
=
it
->
first
;
find_name
=
true
;
break
;
}
}
if
(
find_name
&&
role
==
EventRole
::
kOrdinary
)
{
return
name
;
}
else
if
(
find_name
&&
role
!=
EventRole
::
kOrdinary
)
{
return
FindOrdinaryParent
(
sub_child_map
,
parent
);
}
else
{
return
parent
;
}
}
// When TracerOption is KDefault, OpDetail will be recorded but only default
// profile result will be printed.
// GpuMemcpy should be printed in kDefault setting, however it offten occurs
...
...
@@ -376,11 +401,7 @@ void GetChildMap(const std::multimap<std::string, EventItem> &sub_child_map,
}
else
{
for
(
auto
it
=
sub_child_map
.
begin
();
it
!=
sub_child_map
.
end
();
it
++
)
{
if
(
it
->
second
.
name
.
find
(
"GpuMemcpy"
)
!=
std
::
string
::
npos
)
{
std
::
string
parent_name
=
it
->
first
;
auto
left_pos
=
it
->
first
.
find
(
"/"
);
if
(
left_pos
!=
std
::
string
::
npos
)
{
parent_name
=
it
->
first
.
substr
(
0
,
left_pos
);
}
std
::
string
parent_name
=
FindOrdinaryParent
(
sub_child_map
,
it
->
first
);
auto
item
=
it
->
second
;
auto
right_pos
=
item
.
name
.
rfind
(
"/"
);
if
(
right_pos
!=
std
::
string
::
npos
)
{
...
...
@@ -389,6 +410,9 @@ void GetChildMap(const std::multimap<std::string, EventItem> &sub_child_map,
item
.
name
=
parent_name
+
"/"
+
child_name
;
}
child_map
->
insert
(
std
::
pair
<
std
::
string
,
EventItem
>
(
parent_name
,
item
));
}
else
if
(
it
->
second
.
role
==
EventRole
::
kOrdinary
)
{
child_map
->
insert
(
std
::
pair
<
std
::
string
,
EventItem
>
(
it
->
first
,
it
->
second
));
}
}
}
...
...
paddle/fluid/platform/profiler_test.cc
浏览文件 @
72ff5a09
...
...
@@ -40,6 +40,7 @@ TEST(RecordEvent, RecordEvent) {
using
paddle
::
platform
::
PopEvent
;
using
paddle
::
platform
::
ProfilerState
;
using
paddle
::
platform
::
EventSortingKey
;
using
paddle
::
platform
::
EventRole
;
ProfilerState
state
=
ProfilerState
::
kCPU
;
EnableProfiler
(
state
);
...
...
@@ -55,7 +56,7 @@ TEST(RecordEvent, RecordEvent) {
for
(
int
loop
=
0
;
loop
<
3
;
++
loop
)
{
for
(
int
i
=
1
;
i
<
5
;
++
i
)
{
std
::
string
name
=
"op_"
+
std
::
to_string
(
i
);
PushEvent
(
name
);
PushEvent
(
name
,
EventRole
::
kOrdinary
);
int
counter
=
1
;
while
(
counter
!=
i
*
1000
)
counter
++
;
PopEvent
(
name
);
...
...
@@ -107,7 +108,7 @@ TEST(RecordEvent, RecordEvent) {
}
// Bad Usage:
PushEvent
(
"event_without_pop"
);
PushEvent
(
"event_without_pop"
,
EventRole
::
kOrdinary
);
PopEvent
(
"event_without_push"
);
std
::
vector
<
std
::
vector
<
Event
>>
events
=
paddle
::
platform
::
GetAllEvents
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录