Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5c79dbb2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c79dbb2
编写于
5月 26, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
5月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Marker op for profiling (#33034)
上级
c711e913
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
199 addition
and
14 deletion
+199
-14
paddle/fluid/operators/marker_op.cc
paddle/fluid/operators/marker_op.cc
+76
-0
paddle/fluid/operators/marker_op.cu
paddle/fluid/operators/marker_op.cu
+61
-0
paddle/fluid/platform/device_tracer.cc
paddle/fluid/platform/device_tracer.cc
+1
-1
paddle/fluid/platform/event.h
paddle/fluid/platform/event.h
+3
-2
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+15
-8
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+6
-3
python/paddle/fluid/tests/unittests/test_marker_op.py
python/paddle/fluid/tests/unittests/test_marker_op.py
+36
-0
tools/static_mode_white_list.py
tools/static_mode_white_list.py
+1
-0
未找到文件。
paddle/fluid/operators/marker_op.cc
0 → 100644
浏览文件 @
5c79dbb2
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
class
MarkerOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
std
::
string
marker_role
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"marker_role"
);
std
::
string
marker_pos
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"marker_pos"
);
VLOG
(
3
)
<<
"The role is:"
<<
marker_role
<<
";"
<<
"The position is:"
<<
marker_pos
<<
"."
;
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
FP32
,
ctx
.
GetPlace
());
}
};
class
MarkerOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddAttr
<
std
::
string
>
(
"marker_role"
,
"(string, default forward)forward or backward,"
" mark different stages of porcess."
)
.
SetDefault
(
"forward"
);
AddAttr
<
std
::
string
>
(
"marker_pos"
,
"(string, default B)the posititon where the marker is placed, "
"B stands for begin of duration,"
" E stands for end of duration."
)
.
SetDefault
(
"B"
);
AddComment
(
R"DOC(Marker Operator - Add marker at the beginning/end of a forward/backward process.)DOC"
);
}
};
template
<
typename
T
>
class
MarkerOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
marker_role
=
ctx
.
Attr
<
std
::
string
>
(
"marker_role"
);
auto
marker_pos
=
ctx
.
Attr
<
std
::
string
>
(
"marker_pos"
);
platform
::
RecordEvent
record_event
(
"MarkerCPU"
,
platform
::
EventRole
::
kInnerOp
,
"marker_"
+
marker_role
+
"_"
+
marker_pos
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
marker
,
ops
::
MarkerOp
,
ops
::
MarkerOpMaker
);
REGISTER_OP_CPU_KERNEL
(
marker
,
ops
::
MarkerOpCPUKernel
<
float
>
);
paddle/fluid/operators/marker_op.cu
0 → 100644
浏览文件 @
5c79dbb2
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
SimpleMarkerKernel
(
T
*
in
,
T
*
out
,
int
ndim
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(;
idx
<
ndim
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
out
[
idx
]
=
in
[
idx
];
}
}
template
<
typename
T
>
class
MarkerOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
marker_role
=
ctx
.
Attr
<
std
::
string
>
(
"marker_role"
);
auto
marker_pos
=
ctx
.
Attr
<
std
::
string
>
(
"marker_pos"
);
VLOG
(
3
)
<<
"marker role: "
<<
marker_role
<<
" marker position: "
<<
marker_pos
;
framework
::
Tensor
A
;
framework
::
Tensor
B
;
auto
*
in_temp
=
A
.
mutable_data
<
T
>
({
32
,
1
},
ctx
.
GetPlace
());
auto
*
out_temp
=
B
.
mutable_data
<
T
>
({
32
,
1
},
ctx
.
GetPlace
());
platform
::
RecordEvent
record_event
(
"MarkerCUDA"
,
platform
::
EventRole
::
kInnerOp
,
"marker_"
+
marker_role
+
"_"
+
marker_pos
);
SimpleMarkerKernel
<
T
><<<
1
,
32
,
0
,
dev_ctx
.
stream
()
>>>
(
in_temp
,
out_temp
,
32
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
marker
,
ops
::
MarkerOpCUDAKernel
<
float
>
);
paddle/fluid/platform/device_tracer.cc
浏览文件 @
5c79dbb2
...
...
@@ -511,7 +511,7 @@ class DeviceTracerImpl : public DeviceTracer {
auto
c
=
correlations_
.
find
(
r
.
correlation_id
);
if
(
c
!=
correlations_
.
end
()
&&
c
->
second
!=
nullptr
)
{
event
->
set_name
(
c
->
second
->
name
());
event
->
set_detail_info
(
r
.
name
);
event
->
set_detail_info
(
c
->
second
->
attr
()
);
find
++
;
}
else
{
VLOG
(
10
)
<<
"Missing Kernel Event: "
+
r
.
name
;
...
...
paddle/fluid/platform/event.h
浏览文件 @
5c79dbb2
...
...
@@ -40,7 +40,7 @@ class Event {
// The DeviceContext is used to get the cuda stream.
// If CPU profiling mode, can pass nullptr.
Event
(
EventType
type
,
std
::
string
name
,
uint32_t
thread_id
,
EventRole
role
=
EventRole
::
kOrdinary
);
EventRole
role
=
EventRole
::
kOrdinary
,
std
::
string
attr
=
"none"
);
const
EventType
&
type
()
const
;
Event
*
parent
()
const
{
return
parent_
;
}
...
...
@@ -50,7 +50,7 @@ class Event {
uint32_t
thread_id
()
const
{
return
thread_id_
;
}
void
set_name
(
std
::
string
name
)
{
name_
=
name
;
}
void
set_role
(
EventRole
role
)
{
role_
=
role
;
}
std
::
string
attr
()
const
{
return
attr_
;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef PADDLE_WITH_CUPTI
gpuEvent_t
event
()
const
{
return
event_
;
}
...
...
@@ -69,6 +69,7 @@ class Event {
EventRole
role_
{};
int64_t
cpu_ns_
;
bool
visited_status_
{
false
};
std
::
string
attr_
;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUPTI
int64_t
gpu_ns_
=
0
;
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
5c79dbb2
...
...
@@ -32,8 +32,12 @@ namespace platform {
MemEvenRecorder
MemEvenRecorder
::
recorder
;
Event
::
Event
(
EventType
type
,
std
::
string
name
,
uint32_t
thread_id
,
EventRole
role
)
:
type_
(
type
),
name_
(
name
),
thread_id_
(
thread_id
),
role_
(
role
)
{
EventRole
role
,
std
::
string
attr
)
:
type_
(
type
),
name_
(
name
),
thread_id_
(
thread_id
),
role_
(
role
),
attr_
(
attr
)
{
cpu_ns_
=
GetTimeInNsec
();
}
...
...
@@ -52,7 +56,8 @@ double Event::CudaElapsedMs(const Event &e) const {
#endif
}
RecordEvent
::
RecordEvent
(
const
std
::
string
&
name
,
const
EventRole
role
)
{
RecordEvent
::
RecordEvent
(
const
std
::
string
&
name
,
const
EventRole
role
,
const
std
::
string
attr
)
{
#ifndef _WIN32
#ifdef PADDLE_WITH_CUDA
if
(
g_enable_nvprof_hook
)
{
...
...
@@ -69,7 +74,7 @@ RecordEvent::RecordEvent(const std::string &name, const EventRole role) {
is_enabled_
=
true
;
// lock is not needed, the code below is thread-safe
// Maybe need the same push/pop behavior.
Event
*
e
=
PushEvent
(
name
,
role
);
Event
*
e
=
PushEvent
(
name
,
role
,
attr
);
SetCurAnnotation
(
e
);
name_
=
e
->
name
();
}
...
...
@@ -186,12 +191,14 @@ void Mark(const std::string &name) {
GetEventList
().
Record
(
EventType
::
kMark
,
name
,
g_thread_id
);
}
Event
*
PushEvent
(
const
std
::
string
&
name
,
const
EventRole
role
)
{
return
GetEventList
().
Record
(
EventType
::
kPushRange
,
name
,
g_thread_id
,
role
);
Event
*
PushEvent
(
const
std
::
string
&
name
,
const
EventRole
role
,
std
::
string
attr
)
{
return
GetEventList
().
Record
(
EventType
::
kPushRange
,
name
,
g_thread_id
,
role
,
attr
);
}
void
PopEvent
(
const
std
::
string
&
name
,
const
EventRole
role
)
{
GetEventList
().
Record
(
EventType
::
kPopRange
,
name
,
g_thread_id
,
role
);
void
PopEvent
(
const
std
::
string
&
name
,
const
EventRole
role
,
std
::
string
attr
)
{
GetEventList
().
Record
(
EventType
::
kPopRange
,
name
,
g_thread_id
,
role
,
attr
);
}
void
EnableProfiler
(
ProfilerState
state
)
{
PADDLE_ENFORCE_NE
(
state
,
ProfilerState
::
kDisabled
,
...
...
paddle/fluid/platform/profiler.h
浏览文件 @
5c79dbb2
...
...
@@ -126,7 +126,8 @@ struct MemEvenRecorder {
struct
RecordEvent
{
RecordEvent
(
const
std
::
string
&
name
,
const
EventRole
role
=
EventRole
::
kOrdinary
);
const
EventRole
role
=
EventRole
::
kOrdinary
,
const
std
::
string
attr
=
"none"
);
~
RecordEvent
();
...
...
@@ -200,8 +201,10 @@ void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const
Place
&
place
,
const
std
::
string
&
annotation
);
void
PopMemEvent
(
uint64_t
start_ns
,
uint64_t
end_ns
,
size_t
bytes
,
const
Place
&
place
,
const
std
::
string
&
annotation
);
Event
*
PushEvent
(
const
std
::
string
&
name
,
const
EventRole
role
);
void
PopEvent
(
const
std
::
string
&
name
,
const
EventRole
role
);
Event
*
PushEvent
(
const
std
::
string
&
name
,
const
EventRole
role
,
const
std
::
string
attr
=
"none"
);
void
PopEvent
(
const
std
::
string
&
name
,
const
EventRole
role
,
const
std
::
string
attr
=
"none"
);
// Return the event list of all threads. Assumed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
std
::
vector
<
std
::
vector
<
Event
>>
GetAllEvents
();
...
...
python/paddle/fluid/tests/unittests/test_marker_op.py
0 → 100644
浏览文件 @
5c79dbb2
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
class
TestMarkerOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"marker"
self
.
inputs
=
{}
self
.
attrs
=
{
'marker_role'
:
'forward'
,
'marker_pos'
:
'B'
,
'op_role'
:
OpRole
.
Forward
}
self
.
outputs
=
{}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tools/static_mode_white_list.py
浏览文件 @
5c79dbb2
...
...
@@ -710,4 +710,5 @@ STATIC_MODE_TESTING_LIST = [
'test_lamb_op_xpu'
,
'test_model_cast_to_bf16'
,
'test_sgd_op_bf16'
,
'test_marker_op'
,
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录