Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2cc85487
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2cc85487
编写于
9月 04, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): fix hardcode of default device
GitOrigin-RevId: 722c4debfaf3c4a27029ea9f207e65c35dd16f21
上级
403a1e7b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
24 addition
and
7 deletion
+24
-7
imperative/python/megengine/device.py
imperative/python/megengine/device.py
+5
-5
imperative/python/src/common.cpp
imperative/python/src/common.cpp
+12
-0
imperative/python/src/common.h
imperative/python/src/common.h
+3
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+2
-1
imperative/python/src/imperative_rt.cpp
imperative/python/src/imperative_rt.cpp
+2
-1
未找到文件。
imperative/python/megengine/device.py
浏览文件 @
2cc85487
...
...
@@ -17,8 +17,6 @@ __all__ = [
"set_default_device"
,
]
_default_device
=
os
.
getenv
(
"MGE_DEFAULT_DEVICE"
,
"xpux"
)
def
_valid_device
(
inp
):
if
isinstance
(
inp
,
str
)
and
len
(
inp
)
==
4
:
...
...
@@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"):
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
"""
global
_default_device
# pylint: disable=global-statement
assert
_valid_device
(
device
),
"Invalid device name {}"
.
format
(
device
)
_default_device
=
device
CompNode
.
_set_default_device
(
device
)
def
get_default_device
()
->
str
:
...
...
@@ -86,4 +83,7 @@ def get_default_device() -> str:
It returns the value set by :func:`~.set_default_device`.
"""
return
_default_device
return
CompNode
.
_get_default_device
()
set_default_device
(
os
.
getenv
(
"MGE_DEFAULT_DEVICE"
,
"xpux"
))
imperative/python/src/common.cpp
浏览文件 @
2cc85487
...
...
@@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) {
&
XTensorND
::
template
copy_from_fixlayout
<
HostTensorStorage
>));
}
std
::
string
default_device
=
"xpux"
;
}
// namespace
void
set_default_device
(
const
std
::
string
&
device
)
{
default_device
=
device
;
}
std
::
string
get_default_device
()
{
return
default_device
;
}
void
init_common
(
py
::
module
m
)
{
auto
&&
PyCompNode
=
py
::
class_
<
CompNode
>
(
m
,
"CompNode"
)
.
def
(
py
::
init
())
.
def
(
py
::
init
(
py
::
overload_cast
<
const
std
::
string
&>
(
&
CompNode
::
load
)))
.
def
(
"create_event"
,
&
CompNode
::
create_event
,
py
::
arg
(
"flags"
)
=
0ul
)
.
def
(
"_set_default_device"
,
&
set_default_device
)
.
def
(
"_get_default_device"
,
&
get_default_device
)
.
def
(
"__str__"
,
&
CompNode
::
to_string_logical
)
.
def_static
(
"_sync_all"
,
&
CompNode
::
sync_all
)
.
def
(
py
::
self
==
py
::
self
)
...
...
imperative/python/src/common.h
浏览文件 @
2cc85487
...
...
@@ -14,3 +14,6 @@
#include "./helper.h"
void
init_common
(
pybind11
::
module
m
);
void
set_default_device
(
const
std
::
string
&
device
);
std
::
string
get_default_device
();
\ No newline at end of file
imperative/python/src/graph_rt.cpp
浏览文件 @
2cc85487
...
...
@@ -19,6 +19,7 @@
#include "megbrain/imperative.h"
#include "./helper.h"
#include "megbrain/plugin/profiler.h"
#include "./common.h"
namespace
py
=
pybind11
;
...
...
@@ -230,7 +231,7 @@ void init_graph_rt(py::module m) {
m
.
def
(
"make_const"
,
[](
cg
::
ComputingGraph
*
graph
,
py
::
array
data
,
CompNode
cn
,
DType
dtype
)
{
if
(
!
cn
.
valid
())
{
cn
=
CompNode
::
load
(
"xpux"
);
cn
=
CompNode
::
load
(
get_default_device
()
);
}
auto
hv
=
npy
::
np2tensor
(
data
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
);
return
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
OperatorNodeConfig
(
cn
)).
node
();
...
...
imperative/python/src/imperative_rt.cpp
浏览文件 @
2cc85487
...
...
@@ -21,6 +21,7 @@
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "./helper.h"
#include "./common.h"
namespace
py
=
pybind11
;
...
...
@@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) {
py
::
class_
<
Interpreter
::
Channel
>
(
m
,
"Interpreter"
)
.
def
(
"put"
,
[](
Interpreter
::
Channel
&
self
,
py
::
array
data
,
DType
dtype
,
CompNode
cn
)
{
if
(
!
cn
.
valid
())
{
cn
=
CompNode
::
load
(
"xpux"
);
cn
=
CompNode
::
load
(
get_default_device
()
);
}
constexpr
int
size_threshhold
=
TensorShape
::
MAX_NDIM
;
if
(
data
.
size
()
>
size_threshhold
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录