Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
Docs
提交
cf198a43
D
Docs
项目概览
MegEngine 天元
/
Docs
通知
3
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
Docs
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cf198a43
编写于
4月 01, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the bug of memory increasing during evaluation
GitOrigin-RevId: 5d7ab3be4077849e1c186ef0700a63c471700a83
上级
d89667ad
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
7 addition
and
14 deletion
+7
-14
source/basic/dynamic_and_static_graph.rst
source/basic/dynamic_and_static_graph.rst
+4
-9
source/basic/train_and_evaluation.rst
source/basic/train_and_evaluation.rst
+3
-5
未找到文件。
source/basic/dynamic_and_static_graph.rst
浏览文件 @
cf198a43
...
...
@@ -174,17 +174,12 @@ MegEngine 提供了很方便的动静态图转换的方法,几乎无需代码
trace.enabled = True # 开启trace,使用静态图模式
le_net.eval() # 将网络设为测试模式
data = mge.tensor()
label = mge.tensor(dtype="int32")
correct = 0
total = 0
for idx, (batch_data, batch_label) in enumerate(dataloader_test):
data.set_value(batch_data)
label.set_value(batch_label)
logits = eval_func(data, net=le_net) # 测试函数
logits = eval_func(batch_data, net=le_net) # 测试函数
predicted =
F.argmax(logits,
axis=1)
correct += (predicted==
label).sum().numpy().ite
m()
total += label.shape[0]
predicted =
logits.numpy().argmax(
axis=1)
correct += (predicted==
batch_label).su
m()
total +=
batch_
label.shape[0]
print("correct: {}, total: {}, accuracy: {}".format(correct, total, float(correct)/total))
source/basic/train_and_evaluation.rst
浏览文件 @
cf198a43
...
...
@@ -261,16 +261,14 @@ MegEngine 在GPU和CPU同时存在时默认使用GPU进行训练。用户可以
le_net.eval() # 设置为测试模式
data = mge.tensor()
label = mge.tensor(dtype="int32")
correct = 0
total = 0
for idx, (batch_data, batch_label) in enumerate(dataloader_test):
data.set_value(batch_data)
label.set_value(batch_label)
logits = le_net(data)
predicted =
F.argmax(logits,
axis=1)
correct += (predicted==
label).sum().numpy().ite
m()
total += label.shape[0]
predicted =
logits.numpy().argmax(
axis=1)
correct += (predicted==
batch_label).su
m()
total +=
batch_
label.shape[0]
print("correct: {}, total: {}, accuracy: {}".format(correct, total, float(correct)/total))
测试输出如下,可以看到经过训练的 ``LeNet`` 在 MNIST 测试数据集上的准确率已经达到98.84%:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录