Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2ed38f45
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2ed38f45
编写于
11月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(test): add some comment and test in imperative
GitOrigin-RevId: c2363abe29d230af584604bbfb2961deba074dd1
上级
0ebd4400
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
98 addition
and
0 deletion
+98
-0
imperative/src/impl/interpreter/interpreter_impl.h
imperative/src/impl/interpreter/interpreter_impl.h
+7
-0
imperative/src/include/megbrain/imperative/interpreter.h
imperative/src/include/megbrain/imperative/interpreter.h
+10
-0
imperative/src/test/interpreter_test.cpp
imperative/src/test/interpreter_test.cpp
+81
-0
未找到文件。
imperative/src/impl/interpreter/interpreter_impl.h
浏览文件 @
2ed38f45
...
...
@@ -27,6 +27,11 @@ struct InterpreterImpl : Interpreter {
std
::
unique_ptr
<
Channel
>
create_channel
()
override
;
};
/*!
* \brief implement Channel to execute the commands asynchronously,
* almost commands are executed by the worker threads, commands are sent
* by the interface
*/
struct
ChannelImpl
:
Interpreter
::
Channel
,
NonCopyableObj
,
NonMoveableObj
{
ChannelImpl
();
~
ChannelImpl
()
override
;
...
...
@@ -304,3 +309,5 @@ private:
};
}
// namespace mgb::imperative::interpreter::intl
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/include/megbrain/imperative/interpreter.h
浏览文件 @
2ed38f45
...
...
@@ -7,6 +7,7 @@
namespace
mgb
::
imperative
::
interpreter
{
//! nested throw the stored exceptions
struct
AsyncError
:
std
::
nested_exception
,
std
::
exception
{
const
char
*
what
()
const
noexcept
{
try
{
...
...
@@ -20,9 +21,16 @@ struct AsyncError : std::nested_exception, std::exception {
};
struct
Interpreter
{
/*!
* HandleImpl* just act as a key to represent TensorInfo
*/
struct
HandleImpl
{};
using
Handle
=
HandleImpl
*
;
/*!
* \brief the base command execution interface, Channel is similar to channel in
* golang, it executes the commands put into asynchronously.
*/
struct
Channel
{
virtual
~
Channel
()
=
default
;
...
...
@@ -66,3 +74,5 @@ protected:
};
}
// namespace mgb::imperative::interpreter
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/test/interpreter_test.cpp
0 → 100644
浏览文件 @
2ed38f45
#include "megbrain/imperative/interpreter.h"
#include "../impl/interpreter/tensor_info.h"
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
using
namespace
mgb
;
using
namespace
cg
;
using
namespace
imperative
;
using
namespace
interpreter
;
TEST
(
TestImperative
,
InterpreterPut
)
{
HostTensorGenerator
<>
gen
;
auto
h0
=
gen
({
3
});
auto
&&
channel
=
Interpreter
::
inst
().
create_channel
();
auto
tensor_handle
=
channel
->
put
(
*
h0
,
true
);
auto
tensor_info
=
reinterpret_cast
<
intl
::
TensorInfo
*>
(
tensor_handle
);
channel
->
sync
();
ASSERT_TRUE
(
tensor_info
->
status
==
intl
::
TensorInfo
::
Produced
);
//! because tensor elems is less than TensorShape::MAX_NDIM, stored it
//! directly
ASSERT_EQ
(
tensor_info
->
ptr
->
get_value
().
raw_ptr
(),
h0
->
raw_ptr
());
auto
shape
=
channel
->
get_shape
(
tensor_handle
);
ASSERT_TRUE
(
shape
.
ndim
==
1
);
ASSERT_TRUE
(
shape
.
total_nr_elems
()
==
3
);
auto
h2
=
gen
({
10
});
auto
tensor_handle2
=
channel
->
put
(
*
h2
,
false
);
auto
tensor_handle2_once
=
channel
->
put
(
*
h2
,
false
);
channel
->
sync
();
ASSERT_NE
(
tensor_handle2_once
,
tensor_handle2
);
auto
finded
=
MultiCNConstTensorCache
::
inst
().
lookup
(
*
h2
);
ASSERT_TRUE
(
finded
.
get
());
//! Device tensor ptr is not equal host tensor ptr
ASSERT_NE
(
finded
->
raw_ptr_not_for_readwrite
(),
h2
->
raw_ptr
());
channel
->
del
(
tensor_handle
);
channel
->
del
(
tensor_handle2
);
channel
->
del
(
tensor_handle2_once
);
}
TEST
(
TestImperative
,
InterpreterApplyOp
)
{
HostTensorGenerator
<>
gen
;
size_t
add
=
2
,
dim0
=
5
,
dim1
=
10
;
auto
h0
=
gen
({
1
});
h0
->
ptr
<
float
>
()[
0
]
=
add
;
auto
h1
=
gen
({
dim0
,
dim1
});
for
(
size_t
i
=
0
;
i
<
dim0
*
dim1
;
i
++
)
{
h1
->
ptr
<
float
>
()[
i
]
=
i
;
}
auto
&&
channel
=
Interpreter
::
inst
().
create_channel
();
auto
tensor_handle0
=
channel
->
put
(
*
h0
,
false
);
auto
tensor_handle1
=
channel
->
put
(
*
h1
,
false
);
SmallVector
<
Interpreter
::
Handle
>
inputs
{
tensor_handle0
,
tensor_handle1
};
auto
op
=
OprAttr
::
make
(
"Elemwise"
);
auto
&&
attr
=
op
->
cast_final_safe
<
OprAttr
>
();
using
Param
=
opr
::
Elemwise
::
Param
;
Param
param
{
Param
::
Mode
::
ADD
};
attr
.
param
.
write_pod
(
param
);
auto
outputs
=
channel
->
apply_op
(
op
,
inputs
);
channel
->
sync
();
auto
out_tensor
=
reinterpret_cast
<
intl
::
TensorInfo
*>
(
outputs
[
0
])
->
ptr
->
get_value
();
ASSERT_EQ
(
out_tensor
.
layout
().
ndim
,
2
);
ASSERT_EQ
(
out_tensor
.
shape
(
0
),
dim0
);
ASSERT_EQ
(
out_tensor
.
shape
(
1
),
dim1
);
float
*
output
=
out_tensor
.
ptr
<
float
>
();
for
(
size_t
i
=
0
;
i
<
dim0
*
dim1
;
i
++
)
{
ASSERT_EQ
(
output
[
i
],
i
+
add
);
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录