Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c8697a70
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c8697a70
编写于
3月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/src): python wrapper for cambricon and atlas runtime opr
GitOrigin-RevId: bd969d1339463645d559cf1d0d016713d04191d9
上级
8cfed4a1
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
213 addition
and
9 deletion
+213
-9
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+5
-1
imperative/python/megengine/device.py
imperative/python/megengine/device.py
+22
-0
imperative/python/megengine/functional/external.py
imperative/python/megengine/functional/external.py
+27
-0
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+11
-5
imperative/python/megengine/module/external.py
imperative/python/megengine/module/external.py
+54
-1
imperative/python/megengine/utils/comp_graph_tools.py
imperative/python/megengine/utils/comp_graph_tools.py
+3
-2
imperative/python/src/common.cpp
imperative/python/src/common.cpp
+2
-0
imperative/src/impl/ops/atlas_runtime.cpp
imperative/src/impl/ops/atlas_runtime.cpp
+36
-0
imperative/src/impl/ops/cambricon_runtime.cpp
imperative/src/impl/ops/cambricon_runtime.cpp
+37
-0
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+16
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
c8697a70
...
...
@@ -529,7 +529,11 @@ class InputNode(OpNode):
@
property
def
device
(
self
):
return
self
.
outputs
[
0
].
device
var
=
self
.
outputs
[
0
]
if
isinstance
(
var
,
VarNode
):
return
var
.
device
else
:
return
var
.
comp_node
@
property
def
dtype
(
self
):
...
...
imperative/python/megengine/device.py
浏览文件 @
c8697a70
...
...
@@ -36,6 +36,10 @@ def _str2device_type(type_str: str, allow_unspec: bool = True):
return
DeviceType
.
CPU
elif
type_str
==
"GPU"
or
type_str
==
"CUDA"
:
return
DeviceType
.
CUDA
elif
type_str
==
"CAMBRICON"
:
return
DeviceType
.
CAMBRICON
elif
type_str
==
"ATLAS"
:
return
DeviceType
.
ATLAS
else
:
assert
allow_unspec
and
str
==
"XPU"
,
"device type can only be cpu, gpu or xpu"
return
DeviceType
.
UNSPEC
...
...
@@ -65,6 +69,24 @@ def is_cuda_available() -> bool:
return
CompNode
.
_get_device_count
(
t
,
False
)
>
0
def
is_cambricon_available
()
->
bool
:
"""
Returns whether cambricon device is available on this system.
"""
t
=
_str2device_type
(
"cambricon"
)
return
CompNode
.
_get_device_count
(
t
,
False
)
>
0
def
is_atlas_available
()
->
bool
:
"""
Returns whether atlas device is available on this system.
"""
t
=
_str2device_type
(
"atlas"
)
return
CompNode
.
_get_device_count
(
t
,
False
)
>
0
def
set_default_device
(
device
:
str
=
"xpux"
):
r
"""
Sets default computing node.
...
...
imperative/python/megengine/functional/external.py
浏览文件 @
c8697a70
...
...
@@ -20,3 +20,30 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None):
op
=
builtin
.
TensorRTRuntime
(
data
,
len
(
data
))
# return sequence of outputs
return
apply
(
op
,
*
inputs
)
def
cambricon_runtime_opr
(
inputs
,
data
,
symbol
,
tensor_dim_mutable
):
r
"""
Load a serialized Cambricon model as a runtime operator in MegEngine.
:param inputs: list of input tensors.
:param data: the serialized Cambricon model.
:param symbol: name of the function in Cambricon model.
:param tensor_dim_mutable: whether the input tensors' shapes are mutable
in ``cnrtModel_t``.
"""
op
=
builtin
.
CambriconRuntime
(
data
,
len
(
data
),
symbol
,
tensor_dim_mutable
)
return
apply
(
op
,
*
inputs
)
def
atlas_runtime_opr
(
inputs
,
data
):
r
"""
Load a serialized Atlas model as a runtime operator in MegEngine.
:param inputs: list of input tensors.
:param data: the serialized Atlas model.
"""
op
=
builtin
.
AtlasRuntime
(
data
,
len
(
data
))
return
apply
(
op
,
*
inputs
)
imperative/python/megengine/jit/tracing.py
浏览文件 @
c8697a70
...
...
@@ -786,7 +786,11 @@ class trace:
)
output_names
=
output_names
or
self
.
_output_names
dumped_device
=
as_device
(
"xpux"
)
def
dumped_device
(
info
):
device_name
=
info
.
device
.
logical_name
if
device_name
[:
3
]
in
(
"cpu"
,
"gpu"
,
"xpu"
):
return
as_device
(
"xpux"
)
return
info
.
device
h2v
=
{}
graph
=
G
.
Graph
()
...
...
@@ -794,19 +798,21 @@ class trace:
# apply graph_opt_level in dump
if
self
.
_graph_opt_level
is
not
None
:
graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
for
i
,
h
in
enumerate
(
self
.
_arg_bindings
):
info
=
self
.
_tinfo
[
h
]
h2v
[
h
]
=
graph
.
make_h2d
(
dtype
=
info
.
dtype
,
device
=
dumped_device
,
device
=
dumped_device
(
info
)
,
shape
=
info
.
shape
or
(
1
,),
name
=
arg_names
[
i
]
if
arg_names
else
None
,
)
for
k
,
h
in
self
.
_kwarg_bindings
.
items
():
info
=
self
.
_tinfo
[
h
]
h2v
[
h
]
=
graph
.
make_h2d
(
dtype
=
info
.
dtype
,
device
=
dumped_device
,
shape
=
info
.
shape
or
(
1
,),
name
=
k
dtype
=
info
.
dtype
,
device
=
dumped_device
(
info
),
shape
=
info
.
shape
or
(
1
,),
name
=
k
,
)
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
...
...
@@ -833,7 +839,7 @@ class trace:
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
dumped_device
,
device
=
dumped_device
(
info
)
,
name
=
info
.
name
,
)
ivars
.
append
(
h2v
[
h
])
...
...
imperative/python/megengine/module/external.py
浏览文件 @
c8697a70
...
...
@@ -9,7 +9,11 @@
# pylint: disable=redefined-builtin
import
numpy
as
np
from
..functional.external
import
tensorrt_runtime_opr
from
..functional.external
import
(
atlas_runtime_opr
,
cambricon_runtime_opr
,
tensorrt_runtime_opr
,
)
from
.module
import
Module
...
...
@@ -33,3 +37,52 @@ class TensorrtRuntimeSubgraph(Module):
def
forward
(
self
,
*
inputs
):
return
tensorrt_runtime_opr
(
inputs
,
data
=
self
.
_data
)
class
CambriconRuntimeSubgraph
(
Module
):
r
"""Load a serialized CambriconRuntime subgraph.
See :func:`~.cambricon_runtime_opr` for more details.
"""
def
__init__
(
self
,
data
,
symbol
,
tensor_dim_mutable
,
**
kwargs
):
super
(
CambriconRuntimeSubgraph
,
self
).
__init__
(
**
kwargs
)
self
.
_data
=
data
self
.
symbol
=
symbol
self
.
tensor_dim_mutable
=
tensor_dim_mutable
@
property
def
data
(
self
):
return
self
.
_data
@
data
.
setter
def
data
(
self
,
val
):
self
.
_data
=
np
.
frombuffer
(
val
,
dtype
=
np
.
uint8
)
def
forward
(
self
,
*
inputs
):
outputs
=
cambricon_runtime_opr
(
inputs
,
self
.
_data
,
self
.
symbol
,
self
.
tensor_dim_mutable
)
return
outputs
class
AtlasRuntimeSubgraph
(
Module
):
r
"""Load a serialized AtlasRuntime subgraph.
See :func:`~.atlas_runtime_opr` for more details.
"""
def
__init__
(
self
,
data
,
**
kwargs
):
super
(
AtlasRuntimeSubgraph
,
self
).
__init__
(
**
kwargs
)
self
.
_data
=
data
@
property
def
data
(
self
):
return
self
.
_data
@
data
.
setter
def
data
(
self
,
val
):
self
.
_data
=
np
.
frombuffer
(
val
,
dtype
=
np
.
uint8
)
def
forward
(
self
,
*
inputs
):
return
atlas_runtime_opr
(
inputs
,
data
=
self
.
_data
)
imperative/python/megengine/utils/comp_graph_tools.py
浏览文件 @
c8697a70
...
...
@@ -427,8 +427,9 @@ class GraphInference:
list
(
self
.
_inp_dict
.
keys
()),
list
(
inputs
.
keys
())
)
for
key
in
self
.
_inp_dict
:
self
.
_inp_dict
[
key
].
set_value
(
Tensor
(
inputs
[
key
]).
_dev_tensor
())
self
.
_inp_dict
[
key
].
set_value
(
Tensor
(
inputs
[
key
],
device
=
self
.
_inp_dict
[
key
].
device
).
_dev_tensor
()
)
self
.
_func
.
execute
()
self
.
_func
.
wait
()
...
...
imperative/python/src/common.cpp
浏览文件 @
c8697a70
...
...
@@ -171,6 +171,8 @@ void init_common(py::module m) {
.
value
(
"UNSPEC"
,
CompNode
::
DeviceType
::
UNSPEC
)
.
value
(
"CUDA"
,
CompNode
::
DeviceType
::
CUDA
)
.
value
(
"CPU"
,
CompNode
::
DeviceType
::
CPU
)
.
value
(
"CAMBRICON"
,
CompNode
::
DeviceType
::
CAMBRICON
)
.
value
(
"ATLAS"
,
CompNode
::
DeviceType
::
ATLAS
)
.
value
(
"MULTITHREAD"
,
CompNode
::
DeviceType
::
MULTITHREAD
)
.
value
(
"MAX_DEVICE_ID"
,
CompNode
::
DeviceType
::
MAX_DEVICE_ID
);
...
...
imperative/src/impl/ops/atlas_runtime.cpp
0 → 100644
浏览文件 @
c8697a70
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#if MGB_ATLAS
#include "megbrain/opr/atlas_runtime_op.h"
namespace
mgb
::
imperative
{
namespace
{
namespace
atlas_runtime
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
AtlasRuntime
&>
(
def
);
SymbolVarArray
symbol_var_inputs
(
inputs
.
begin
(),
inputs
.
end
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
AtlasRuntimeOpr
::
make
(
op
.
buf
.
c_str
(),
op
.
buf_size
,
symbol_var_inputs
,
config
);
}
OP_TRAIT_REG
(
AtlasRuntime
,
AtlasRuntime
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}
// namespace atlas_runtime
}
// namespace
}
// namespace mgb::imperative
#endif
imperative/src/impl/ops/cambricon_runtime.cpp
0 → 100644
浏览文件 @
c8697a70
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#if MGB_CAMBRICON
#include "megbrain/cambricon/cambricon_runtime_opr.h"
namespace
mgb
::
imperative
{
namespace
{
namespace
cambricon_runtime
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
CambriconRuntime
&>
(
def
);
SymbolVarArray
symbol_var_inputs
(
inputs
.
begin
(),
inputs
.
end
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
CambriconRuntimeOpr
::
make
(
op
.
buf
.
c_str
(),
op
.
buf_size
,
op
.
symbol
,
symbol_var_inputs
,
op
.
tensor_dim_mutable
,
config
);
}
OP_TRAIT_REG
(
CambriconRuntime
,
CambriconRuntime
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}
// namespace cambricon_runtime
}
// namespace
}
// namespace mgb::imperative
#endif
\ No newline at end of file
src/core/include/megbrain/ir/ops.td
浏览文件 @
c8697a70
...
...
@@ -266,6 +266,22 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
);
}
def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size
);
}
def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size,
MgbStringAttr:$symbol,
MgbBoolAttr:$tensor_dim_mutable
);
}
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
#endif // MGB_OPS
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录