Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
VisualDL
提交
59c86e52
V
VisualDL
项目概览
PaddlePaddle
/
VisualDL
大约 1 年 前同步成功
通知
88
Star
4655
Fork
642
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
2
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
V
VisualDL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
2
合并请求
2
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
59c86e52
编写于
11月 28, 2022
作者:
C
chenjian
提交者:
GitHub
11月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add profiler_test demo code (#1143)
上级
a27a4f25
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
131 addition
and
0 deletion
+131
-0
demo/components/profiler_test.py
demo/components/profiler_test.py
+131
-0
未找到文件。
demo/components/profiler_test.py
0 → 100644
浏览文件 @
59c86e52
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
import
paddle.profiler
as
profiler
from
paddle.vision.transforms
import
ToTensor
transform
=
ToTensor
()
cifar10_train
=
paddle
.
vision
.
datasets
.
Cifar10
(
mode
=
'train'
,
transform
=
transform
)
cifar10_test
=
paddle
.
vision
.
datasets
.
Cifar10
(
mode
=
'test'
,
transform
=
transform
)
class
MyNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
1
):
super
(
MyNet
,
self
).
__init__
()
self
.
conv1
=
paddle
.
nn
.
Conv2D
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
(
3
,
3
))
self
.
pool1
=
paddle
.
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
)
self
.
conv2
=
paddle
.
nn
.
Conv2D
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
(
3
,
3
))
self
.
pool2
=
paddle
.
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
)
self
.
conv3
=
paddle
.
nn
.
Conv2D
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
(
3
,
3
))
self
.
flatten
=
paddle
.
nn
.
Flatten
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
in_features
=
1024
,
out_features
=
64
)
self
.
linear2
=
paddle
.
nn
.
Linear
(
in_features
=
64
,
out_features
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
pool1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
pool2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
linear1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
linear2
(
x
)
return
x
epoch_num
=
10
batch_size
=
32
learning_rate
=
0.001
val_acc_history
=
[]
val_loss_history
=
[]
def
train
(
model
):
print
(
'start training ... '
)
# turn into training mode
model
.
train
()
opt
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
parameters
=
model
.
parameters
())
train_loader
=
paddle
.
io
.
DataLoader
(
cifar10_train
,
shuffle
=
True
,
batch_size
=
batch_size
,
num_workers
=
1
)
valid_loader
=
paddle
.
io
.
DataLoader
(
cifar10_test
,
batch_size
=
batch_size
)
# 创建性能分析器相关的代码
def
my_on_trace_ready
(
prof
):
# 定义回调函数,性能分析器结束采集数据时会被调用
callback
=
profiler
.
export_chrome_tracing
(
'./profiler_demo'
)
# 创建导出性能数据到 profiler_demo 文件夹的回调函数
callback
(
prof
)
# 执行该导出函数
prof
.
summary
(
sorted_by
=
profiler
.
SortedKeys
.
GPUTotal
)
# 打印表单,按 GPUTotal 排序表单项
p
=
profiler
.
Profiler
(
scheduler
=
[
3
,
14
],
on_trace_ready
=
my_on_trace_ready
,
timer_only
=
False
)
# 初始化 Profiler 对象
p
.
start
()
# 性能分析器进入第 0 个 step
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
x_data
=
data
[
0
]
y_data
=
paddle
.
to_tensor
(
data
[
1
])
y_data
=
paddle
.
unsqueeze
(
y_data
,
1
)
logits
=
model
(
x_data
)
loss
=
F
.
cross_entropy
(
logits
,
y_data
)
if
batch_id
%
1000
==
0
:
print
(
"epoch: {}, batch_id: {}, loss is: {}"
.
format
(
epoch
,
batch_id
,
loss
.
numpy
()))
loss
.
backward
()
opt
.
step
()
opt
.
clear_grad
()
p
.
step
()
# 指示性能分析器进入下一个 step
if
batch_id
==
19
:
p
.
stop
()
# 关闭性能分析器
exit
()
# 做性能分析时,可以将程序提前退出
# evaluate model after one epoch
model
.
eval
()
accuracies
=
[]
losses
=
[]
for
batch_id
,
data
in
enumerate
(
valid_loader
()):
x_data
=
data
[
0
]
y_data
=
paddle
.
to_tensor
(
data
[
1
])
y_data
=
paddle
.
unsqueeze
(
y_data
,
1
)
logits
=
model
(
x_data
)
loss
=
F
.
cross_entropy
(
logits
,
y_data
)
acc
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
)
accuracies
.
append
(
acc
.
numpy
())
losses
.
append
(
loss
.
numpy
())
avg_acc
,
avg_loss
=
np
.
mean
(
accuracies
),
np
.
mean
(
losses
)
print
(
"[validation] accuracy/loss: {}/{}"
.
format
(
avg_acc
,
avg_loss
))
val_acc_history
.
append
(
avg_acc
)
val_loss_history
.
append
(
avg_loss
)
model
.
train
()
model
=
MyNet
(
num_classes
=
10
)
train
(
model
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录