Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2775f458
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2775f458
编写于
9月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(subgraph): subgraph builder supports jit and custom grad
GitOrigin-RevId: e1a1ebdf1c1f8b3d7fd8b3795d618a8e71b0dcc4
上级
3c61e0e0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
143 addition
and
16 deletion
+143
-16
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+96
-9
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+3
-0
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+23
-7
imperative/src/impl/transformations/scalar.cpp
imperative/src/impl/transformations/scalar.cpp
+21
-0
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
2775f458
...
...
@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
itertools
from
typing
import
Iterable
,
Union
import
numpy
as
np
...
...
@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import (
)
from
.._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
.._wrap
import
as_device
from
..autodiff.grad
import
Function
from
..ops
import
builtin
from
..ops.special
import
Const
from
.amp
import
_high_prec_dtype
,
_low_prec_dtype
...
...
@@ -197,8 +199,15 @@ def _normalize_axis(
_opr_map
=
{
(
"-"
,
1
):
builtin
.
Elemwise
(
mode
=
"negate"
),
(
"abs"
,
1
):
builtin
.
Elemwise
(
mode
=
"abs"
),
(
"exp"
,
1
):
builtin
.
Elemwise
(
mode
=
"exp"
),
(
"log1p"
,
1
):
builtin
.
Elemwise
(
mode
=
"log1p"
),
(
"relu"
,
1
):
builtin
.
Elemwise
(
mode
=
"relu"
),
(
"cond_leq_mov"
,
3
):
builtin
.
Elemwise
(
mode
=
"cond_leq_mov"
),
(
"fma3"
,
3
):
builtin
.
Elemwise
(
mode
=
"FUSE_MUL_ADD3"
),
(
"fma4"
,
4
):
builtin
.
Elemwise
(
mode
=
"FUSE_MUL_ADD4"
),
(
"[?:]"
,
2
):
builtin
.
Subtensor
(
items
=
[(
0
,
True
,
False
,
False
,
False
)]),
(
"[:?]"
,
2
):
builtin
.
Subtensor
(
items
=
[(
0
,
False
,
True
,
False
,
False
)]),
}
for
name
,
mode
in
[
...
...
@@ -209,15 +218,21 @@ for name, mode in [
(
"//"
,
"floor_div"
),
(
"**"
,
"pow"
),
(
"max"
,
"max"
),
(
"min"
,
"min"
),
(
"additive"
,
"add"
),
(
"exp"
,
"EXP"
),
(
"switch_gt0"
,
"switch_gt0"
),
(
"abs_grad"
,
"abs_grad"
),
]:
_opr_map
[(
name
,
2
)]
=
builtin
.
Elemwise
(
mode
=
mode
)
def
subgraph
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
None
):
def
subgraph
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
None
,
jit_fusion
=
False
,
custom_grad
=
False
):
if
device
.
physical_name
.
startswith
(
"cpu"
):
gopt_level
=
None
# disable jit and compile
jit_fusion
=
False
def
as_op
(
op
,
nargs
):
if
isinstance
(
op
,
str
):
...
...
@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
return
builder
.
apply_const
(
value
,
dtype
,
device
)
def
build
(
builder
,
outputs
,
outputs_has_grad
):
builder
=
type
(
builder
)(
builder
)
builder
.
outputs
(
outputs
)
builder
.
outputs_has_grad
(
outputs_has_grad
)
if
jit_fusion
:
assert
gopt_level
is
None
op
=
lambda
:
builder
.
jit_fuse
()
elif
gopt_level
is
None
:
op
=
lambda
:
builder
.
get
()
else
:
op
=
lambda
:
builder
.
compile
(
gopt_level
)
return
op
inputs
=
[
builder
.
input
()
for
_
in
range
(
nr_inputs
)]
outputs
,
outputs_has_grad
=
func
(
inputs
,
apply_expr
,
apply_const
)
builder
.
outputs
(
outputs
)
builder
.
outputs_has_grad
(
outputs_has_grad
)
if
gopt_level
is
None
:
return
lambda
:
builder
.
get
()
if
not
custom_grad
:
outputs
,
outputs_has_grad
=
func
(
inputs
,
apply_expr
,
apply_const
)
return
build
(
builder
,
outputs
,
outputs_has_grad
)
else
:
return
lambda
:
builder
.
compile
(
gopt_level
)
gen
=
func
(
inputs
,
apply_expr
,
apply_const
)
outputs
=
gen
.
send
(
None
)
nr_outputs
=
len
(
outputs
)
forward_fn
=
build
(
builder
,
outputs
,
[
False
]
*
nr_outputs
)
output_grads
=
[
builder
.
input
()
for
_
in
range
(
nr_outputs
)]
input_grads
=
gen
.
send
(
output_grads
)
assert
len
(
input_grads
)
==
nr_inputs
input_grads_mask
=
[
input_grad
is
not
None
for
input_grad
in
input_grads
]
indices
=
[
i
-
1
if
mask
else
None
for
i
,
mask
in
zip
(
itertools
.
accumulate
(
input_grads_mask
),
input_grads_mask
)
]
encoded_input_grads
=
[
grad
for
grad
in
input_grads
if
grad
is
not
None
]
backward_fn
=
build
(
builder
,
encoded_input_grads
,
[
False
]
*
len
(
encoded_input_grads
)
)
class
SubgraphOp
(
Function
):
def
__init__
(
self
):
self
.
inputs
=
None
def
forward
(
self
,
*
inputs
):
self
.
inputs
=
inputs
return
apply
(
forward_fn
(),
*
inputs
)
def
backward
(
self
,
*
output_grads
):
inputs
=
self
.
inputs
self
.
inputs
=
None
encoded_input_grads
=
apply
(
backward_fn
(),
*
inputs
,
*
output_grads
)
input_grads
=
[
encoded_input_grads
[
i
]
if
i
is
not
None
else
None
for
i
in
indices
]
return
input_grads
gen
.
close
()
return
SubgraphOp
return
decorator
...
...
@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device):
return
Const
(
value
,
dtype
=
dtype
,
device
=
device
)()[
0
]
outputs
,
outputs_has_grad
=
func
(
args
,
apply_expr
,
apply_const
)
outputs
=
[
output
if
has_grad
else
output
.
detach
()
for
output
,
has_grad
in
zip
(
outputs
,
outputs_has_grad
)
]
return
outputs
return
decorated_func
def
subgraph_fn
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
None
,
interpret
=
False
):
def
subgraph_fn
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
None
,
jit_fusion
=
False
,
custom_grad
=
False
,
*
,
interpret
=
False
):
def
decorator
(
func
):
if
not
interpret
:
op
=
subgraph
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
gopt_level
)(
func
)
op
=
subgraph
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
gopt_level
,
jit_fusion
=
jit_fusion
,
custom_grad
=
custom_grad
,
)(
func
)
return
lambda
*
args
:
apply
(
op
(),
*
args
)
else
:
return
interpret_subgraph
(
func
,
dtype
,
device
)
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
2775f458
...
...
@@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import (
ExternOpr
,
RemoteRecv
,
RemoteSend
,
set_jit_enabled
,
)
from
..core._trace_option
import
set_symbolic_shape
from
..core.tensor
import
megbrain_graph
as
G
...
...
@@ -711,12 +712,14 @@ class trace:
graph
=
G
.
Graph
()
jit_enabled
=
set_jit_enabled
(
False
)
dest_vars
=
self
.
_trace
.
dump
(
graph
,
input_bindings
,
[
*
zip
(
self
.
_output_bindings
,
output_names
)],
prefer_input_names
,
)
set_jit_enabled
(
jit_enabled
)
# dest_vars = [i._node for i in dest_vars]
...
...
imperative/python/src/ops.cpp
浏览文件 @
2775f458
...
...
@@ -577,21 +577,26 @@ void init_ops(py::module m) {
struct
PySubgraphBuilder
{
explicit
PySubgraphBuilder
(
std
::
string
name
)
:
name
{
name
}
{}
std
::
string
name
;
std
::
shared_ptr
<
Subgraph
>
graph_storage
=
std
::
make_shared
<
Subgraph
>
();
std
::
shared_ptr
<
UniqueKey
>
graph_key
=
std
::
make_shared
<
UniqueKey
>
();
Subgraph
&
graph
=
*
graph_storage
;
Subgraph
graph
;
mgb
::
SmallVector
<
bool
>
output_grad_mask
;
Subgraph
::
var_t
next_var
=
1
;
std
::
shared_ptr
<
mgb
::
Hashable
>
key
=
nullptr
;
std
::
shared_ptr
<
OpDef
>
build
()
const
{
return
SubgraphOp
::
make
(
name
,
graph_storage
,
output_grad_mask
,
graph_key
);
std
::
shared_ptr
<
OpDef
>
build
()
{
if
(
key
==
nullptr
)
{
key
=
std
::
make_shared
<
UniqueKey
>
();
}
return
SubgraphOp
::
make
(
name
,
std
::
make_shared
<
Subgraph
>
(
graph
),
output_grad_mask
,
key
);
}
};
py
::
class_
<
PySubgraphBuilder
>
(
m
,
"SubgraphBuilder"
)
.
def
(
py
::
init
<
std
::
string
>
())
.
def
(
py
::
init
<
PySubgraphBuilder
>
())
.
def
(
"input"
,
[](
PySubgraphBuilder
&
self
)
{
mgb_assert
(
self
.
key
==
nullptr
);
auto
var
=
self
.
next_var
++
;
self
.
graph
.
inputs
.
push_back
(
var
);
return
var
;
...
...
@@ -599,6 +604,7 @@ void init_ops(py::module m) {
.
def
(
"apply"
,
[](
PySubgraphBuilder
&
self
,
std
::
shared_ptr
<
OpDef
>
op
,
Subgraph
::
vars_t
inputs
,
size_t
nr_outputs
)
{
mgb_assert
(
self
.
key
==
nullptr
);
Subgraph
::
vars_t
outputs
;
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
outputs
.
push_back
(
self
.
next_var
++
);
...
...
@@ -609,6 +615,7 @@ void init_ops(py::module m) {
.
def
(
"apply_const"
,
[](
PySubgraphBuilder
&
self
,
py
::
object
value
,
mgb
::
DType
dtype
,
mgb
::
CompNode
cn
)
{
mgb_assert
(
self
.
key
==
nullptr
);
auto
var
=
self
.
next_var
++
;
mgb
::
HostTensorND
hvalue
(
cn
);
npy
::
np2tensor
(
...
...
@@ -619,11 +626,13 @@ void init_ops(py::module m) {
})
.
def
(
"outputs"
,
[](
PySubgraphBuilder
&
self
,
Subgraph
::
vars_t
outputs
)
{
mgb_assert
(
self
.
key
==
nullptr
);
self
.
graph
.
outputs
=
outputs
;
self
.
output_grad_mask
.
resize
(
outputs
.
size
(),
true
);
})
.
def
(
"outputs_has_grad"
,
[](
PySubgraphBuilder
&
self
,
mgb
::
SmallVector
<
bool
>
outputs_has_grad
)
{
mgb_assert
(
self
.
key
==
nullptr
);
mgb_assert
(
self
.
graph
.
outputs
.
size
()
==
self
.
output_grad_mask
.
size
());
self
.
output_grad_mask
=
outputs_has_grad
;
...
...
@@ -632,11 +641,18 @@ void init_ops(py::module m) {
[](
PySubgraphBuilder
&
self
)
{
return
(
std
::
shared_ptr
<
OpDef
>
)
self
.
build
();
})
.
def
(
"compile"
,
[](
PySubgraphBuilder
&
self
,
int
gopt_level
)
{
.
def
(
"compile"
,
[](
PySubgraphBuilder
&
self
,
int
gopt_level
)
{
return
(
std
::
shared_ptr
<
OpDef
>
)
CompiledOp
::
make
(
self
.
build
(),
gopt_level
);
})
.
def
(
"jit_fuse"
,
[](
PySubgraphBuilder
&
self
)
{
return
(
std
::
shared_ptr
<
OpDef
>
)
CompiledOp
::
make
(
self
.
build
(),
gopt_level
);
JITFusionOp
::
make
(
self
.
build
())
);
});
m
.
def
(
"set_jit_enabled"
,
&
JITFusionOp
::
set_enabled
);
auto
custom
=
submodule
(
m
,
"_custom"
);
init_custom
(
custom
);
}
...
...
imperative/src/impl/transformations/scalar.cpp
浏览文件 @
2775f458
...
...
@@ -12,6 +12,7 @@
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -320,6 +321,24 @@ std::vector<ValueRef> inplace_add_rule(
}
}
template
<
typename
T
>
std
::
vector
<
ValueRef
>
subgraph_op_rule
(
const
T
&
op
,
Span
<
ValueRef
>
inputs
)
{
// TODO: add flag instead of assume
bool
all_scalar
=
true
;
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
<
ScalarValue
>
())
{
all_scalar
=
false
;
}
}
auto
outputs
=
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
if
(
all_scalar
)
{
for
(
auto
&
output
:
outputs
)
{
output
=
ScalarValue
::
make
(
output
);
}
}
return
outputs
;
}
struct
ScalarRuleRegistry
{
ScalarRuleRegistry
()
{
register_scalar_rule
(
elemwise_rule
);
...
...
@@ -339,6 +358,8 @@ struct ScalarRuleRegistry {
register_scalar_rule
(
broadcast_rule
);
register_scalar_rule
(
copy_rule
);
register_scalar_rule
(
inplace_add_rule
);
register_scalar_rule
(
subgraph_op_rule
<
SubgraphOp
>
);
register_scalar_rule
(
subgraph_op_rule
<
CompiledOp
>
);
}
}
_
;
}
// namespace
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录