Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8c47c1f1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8c47c1f1
编写于
8月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(syncbn): reimplement with subgraph
GitOrigin-RevId: 13e7e3d3c0d0e9cd8939ad5ddf62bc91a5dabde0
上级
53da5c79
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
257 addition
and
56 deletion
+257
-56
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+47
-0
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+158
-55
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+1
-1
imperative/python/src/common.cpp
imperative/python/src/common.cpp
+4
-0
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+47
-0
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
8c47c1f1
...
@@ -13,6 +13,7 @@ import numpy as np
...
@@ -13,6 +13,7 @@ import numpy as np
from
.._imperative_rt
import
make_const
from
.._imperative_rt
import
make_const
from
.._imperative_rt.core2
import
SymbolVar
,
Tensor
,
apply
,
dtype_promotion
,
get_device
from
.._imperative_rt.core2
import
SymbolVar
,
Tensor
,
apply
,
dtype_promotion
,
get_device
from
.._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
.._wrap
import
as_device
from
.._wrap
import
as_device
from
..ops
import
builtin
from
..ops
import
builtin
from
..ops.special
import
Const
from
..ops.special
import
Const
...
@@ -219,3 +220,49 @@ def _normalize_axis(
...
@@ -219,3 +220,49 @@ def _normalize_axis(
)
)
return
axis
return
axis
raise
raise
def
subgraph
(
name
,
dtype
,
device
,
nr_inputs
,
gopt_level
=
None
):
if
device
.
physical_name
.
startswith
(
"cpu"
):
gopt_level
=
None
# disable jit and compile
binary_ops
=
{
"+"
:
builtin
.
Elemwise
(
mode
=
"add"
),
"-"
:
builtin
.
Elemwise
(
mode
=
"sub"
),
"*"
:
builtin
.
Elemwise
(
mode
=
"mul"
),
"/"
:
builtin
.
Elemwise
(
mode
=
"true_div"
),
"//"
:
builtin
.
Elemwise
(
mode
=
"floor_div"
),
"**"
:
builtin
.
Elemwise
(
mode
=
"pow"
),
"√"
:
builtin
.
Elemwise
(
mode
=
"expm1"
),
"max"
:
builtin
.
Elemwise
(
mode
=
"max"
),
"additive"
:
builtin
.
Elemwise
(
mode
=
"add"
),
}
unary_ops
=
{
"-"
:
builtin
.
Elemwise
(
mode
=
"negate"
),
}
def
decorator
(
func
):
builder
=
_SubgraphBuilder
(
name
)
def
apply_expr
(
op
,
*
args
):
if
isinstance
(
op
,
str
):
if
len
(
args
)
==
2
:
op
=
binary_ops
[
op
]
elif
len
(
args
)
==
1
:
op
=
unary_ops
[
op
]
return
builder
.
apply
(
op
,
args
,
1
)[
0
]
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
return
builder
.
apply_const
(
value
,
dtype
,
device
)
inputs
=
[
builder
.
input
()
for
_
in
range
(
nr_inputs
)]
outputs
,
outputs_has_grad
=
func
(
inputs
,
apply_expr
,
apply_const
)
builder
.
outputs
(
outputs
)
builder
.
outputs_has_grad
(
outputs_has_grad
)
if
gopt_level
is
None
:
return
builder
.
get
()
else
:
return
builder
.
compile
(
gopt_level
)
return
decorator
imperative/python/megengine/functional/nn.py
浏览文件 @
8c47c1f1
...
@@ -7,11 +7,13 @@
...
@@ -7,11 +7,13 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
functools
import
lru_cache
from
typing
import
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..core.ops.builtin
import
BatchNorm
,
Elemwise
from
..core.ops.builtin
import
BatchNorm
,
Elemwise
,
GetVarShape
,
Reduce
,
TypeCvt
from
..core.ops.special
import
Const
from
..core.ops.special
import
Const
from
..core.tensor
import
amp
,
megbrain_graph
from
..core.tensor
import
amp
,
megbrain_graph
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.array_method
import
_elwise_apply
...
@@ -20,10 +22,13 @@ from ..core.tensor.utils import (
...
@@ -20,10 +22,13 @@ from ..core.tensor.utils import (
astype
,
astype
,
cast_tensors
,
cast_tensors
,
convert_single_value
,
convert_single_value
,
make_shape_tuple
,
setscalar
,
setscalar
,
subgraph
,
)
)
from
..device
import
get_default_device
from
..device
import
get_default_device
from
..distributed
import
WORLD
,
is_distributed
from
..distributed
import
WORLD
,
is_distributed
from
..jit
import
exclude_from_trace
from
..random
import
uniform
from
..random
import
uniform
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
..utils.deprecation
import
deprecated_func
from
..utils.deprecation
import
deprecated_func
...
@@ -1153,6 +1158,111 @@ def batch_norm(
...
@@ -1153,6 +1158,111 @@ def batch_norm(
return
inp
return
inp
@
lru_cache
(
maxsize
=
None
)
def
_get_sync_bn_ops
(
device
,
dtype
,
eps_mode
,
ndim
,
channels
):
# fmt: off
@
subgraph
(
"SyncBnStage0"
,
dtype
,
device
,
1
)
def
syncbn_stage0
(
inputs
,
f
,
c
):
input
=
inputs
[
0
]
reduce_shape
=
c
((
1
,
channels
)
+
(
1
,)
*
(
ndim
-
2
),
dtype
=
"int32"
,
device
=
device
)
input_shape
=
f
(
GetVarShape
(),
input
)
input_elems
=
f
(
Reduce
(
mode
=
"product"
,
axis
=
0
),
input_shape
)
reduce_elems
=
f
(
Reduce
(
mode
=
"product"
,
axis
=
0
),
reduce_shape
)
reduce_size
=
f
(
"//"
,
input_elems
,
reduce_elems
)
channel_x1s
=
f
(
Reduce
(
mode
=
"sum"
),
input
,
reduce_shape
)
channel_x2s
=
f
(
Reduce
(
mode
=
"sum_sqr"
),
input
,
reduce_shape
)
reduce_size_f
=
f
(
TypeCvt
(
dtype
=
dtype
),
reduce_size
)
return
(
reduce_shape
,
reduce_size_f
,
channel_x1s
,
channel_x2s
),
(
False
,
False
,
True
,
True
)
@
subgraph
(
"SyncBnStage1"
,
dtype
,
device
,
7
,
gopt_level
=
3
)
def
syncbn_stage1
(
inputs
,
f
,
c
):
input
,
reduce_size
,
channel_x1s
,
channel_x2s
,
eps
=
inputs
[
0
:
5
]
weight
,
bias
=
inputs
[
5
:
7
]
channel_mean
=
f
(
"/"
,
channel_x1s
,
reduce_size
)
channel_var
=
\
f
(
"+"
,
f
(
"/"
,
f
(
"**"
,
channel_x1s
,
c
(
2
)),
f
(
"-"
,
f
(
"*"
,
reduce_size
,
reduce_size
))),
f
(
"/"
,
channel_x2s
,
reduce_size
))
invsqrt_channel_var
=
f
(
"**"
,
f
(
eps_mode
,
channel_var
,
eps
),
c
(
-
0.5
))
inv_var_wt
=
f
(
"*"
,
invsqrt_channel_var
,
weight
)
neg_channel_mean
=
f
(
"-"
,
channel_mean
)
outvar
=
\
f
(
"+"
,
f
(
"*"
,
input
,
inv_var_wt
),
f
(
"+"
,
f
(
"*"
,
neg_channel_mean
,
inv_var_wt
),
bias
))
return
(
outvar
,
channel_mean
,
channel_var
,
inv_var_wt
),
(
True
,
False
,
False
,
False
)
@
subgraph
(
"SyncBnStage1Inference"
,
dtype
,
device
,
6
,
gopt_level
=
3
)
def
syncbn_stage1_inference
(
inputs
,
f
,
c
):
input
,
channel_mean
,
channel_var
,
eps
=
inputs
[
0
:
4
]
weight
,
bias
=
inputs
[
4
:
6
]
invsqrt_channel_var
=
f
(
"**"
,
f
(
eps_mode
,
channel_var
,
eps
),
c
(
-
0.5
))
inv_var_wt
=
f
(
"*"
,
invsqrt_channel_var
,
weight
)
neg_channel_mean
=
f
(
"-"
,
channel_mean
)
outvar
=
\
f
(
"+"
,
f
(
"*"
,
input
,
inv_var_wt
),
f
(
"+"
,
f
(
"*"
,
neg_channel_mean
,
inv_var_wt
),
bias
))
return
(
outvar
,),
(
True
,)
@
subgraph
(
"SyncBnStage2"
,
dtype
,
device
,
7
,
gopt_level
=
3
)
def
syncbn_stage2
(
inputs
,
f
,
c
):
running_mean
,
running_var
,
momentum
=
inputs
[
0
:
3
]
reduce_size
,
channel_x1s
,
channel_x2s
,
channel_mean
=
inputs
[
3
:
7
]
running_mean
=
f
(
"*"
,
running_mean
,
momentum
)
running_mean
=
\
f
(
"+"
,
running_mean
,
f
(
"*"
,
f
(
"-"
,
c
(
1
),
momentum
),
channel_mean
))
channel_variance_unbiased
=
\
f
(
"+"
,
f
(
"/"
,
f
(
"**"
,
channel_x1s
,
c
(
2
)),
f
(
"*"
,
f
(
"-"
,
reduce_size
),
f
(
"-"
,
reduce_size
,
c
(
1
)))),
f
(
"/"
,
channel_x2s
,
f
(
"-"
,
reduce_size
,
c
(
1
))))
running_var
=
f
(
"*"
,
running_var
,
momentum
)
running_var
=
\
f
(
"+"
,
running_var
,
f
(
"*"
,
f
(
"-"
,
c
(
1
),
momentum
),
channel_variance_unbiased
))
return
(
running_mean
,
running_var
),
(
True
,
True
)
@
subgraph
(
"SyncBnConcatStats"
,
dtype
,
device
,
3
,
gopt_level
=
3
)
def
syncbn_concat_stats
(
inputs
,
f
,
c
):
reduce_size
,
channel_x1s
,
channel_x2s
=
inputs
[
0
:
3
]
reduce_size
=
f
(
builtin
.
Broadcast
(),
reduce_size
,
c
([
1
]
*
ndim
,
dtype
=
"int32"
))
stats
=
f
(
builtin
.
Concat
(
axis
=
1
,
comp_node
=
device
),
reduce_size
,
channel_x1s
,
channel_x2s
)
return
(
stats
,),
(
True
,)
@
subgraph
(
"SyncBnSplitStats"
,
dtype
,
device
,
1
,
gopt_level
=
3
)
def
syncbn_split_stats
(
inputs
,
f
,
c
):
stats
=
inputs
[
0
]
c_1
=
c
(
1
,
dtype
=
"int32"
)
channel_x1s_end
=
c
(
channels
+
1
,
dtype
=
"int32"
)
def
_subtensor
(
src
,
axis
,
begin
,
end
):
items
=
(
axis
,
(
begin
is
not
None
),
(
end
is
not
None
),
False
,
False
),
args
=
()
if
begin
is
not
None
:
args
+=
begin
,
if
end
is
not
None
:
args
+=
end
,
return
f
(
builtin
.
Subtensor
(
items
=
items
),
src
,
*
args
)
reduce_size
=
_subtensor
(
stats
,
1
,
None
,
c_1
)
channel_x1s
=
_subtensor
(
stats
,
1
,
c_1
,
channel_x1s_end
)
channel_x2s
=
_subtensor
(
stats
,
1
,
channel_x1s_end
,
None
)
reduce_size
=
f
(
builtin
.
Reshape
(),
reduce_size
,
c_1
)
return
(
reduce_size
,
channel_x1s
,
channel_x2s
),
(
False
,
True
,
True
)
# fmt: on
return
(
syncbn_stage0
,
syncbn_stage1
,
syncbn_stage1_inference
,
syncbn_stage2
,
syncbn_concat_stats
,
syncbn_split_stats
,
)
def
sync_batch_norm
(
def
sync_batch_norm
(
inp
:
Tensor
,
inp
:
Tensor
,
running_mean
:
Tensor
,
running_mean
:
Tensor
,
...
@@ -1193,52 +1303,55 @@ def sync_batch_norm(
...
@@ -1193,52 +1303,55 @@ def sync_batch_norm(
assert
eps_mode
.
lower
()
in
{
"max"
,
"additive"
},
"unknown eps_mode: {}"
.
format
(
assert
eps_mode
.
lower
()
in
{
"max"
,
"additive"
},
"unknown eps_mode: {}"
.
format
(
eps_mode
eps_mode
)
)
_channels
=
inp
.
shape
[
1
]
# TODO: cudnnBn fastpath
_channels
=
make_shape_tuple
(
inp
.
shape
)[
1
]
_ndim
=
inp
.
ndim
_ndim
=
inp
.
ndim
_device
=
inp
.
device
_device
=
inp
.
device
_dtype
=
inp
.
dtype
_dtype
=
inp
.
dtype
_param_shape
=
(
1
,
_channels
)
+
(
1
,)
*
(
_ndim
-
2
)
_reduce_axis
=
[
0
]
+
[
i
for
i
in
range
(
2
,
_ndim
)]
if
training
:
def
_make_full_if_none
(
x
,
value
):
if
x
is
None
:
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
_device
)()
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
reduce_shape
)
return
result
elif
x
.
ndim
==
1
:
(
result
,)
=
apply
(
builtin
.
Reshape
(),
x
,
reduce_shape
)
return
result
return
x
def
_sum_on_channel
(
inp
):
(
return
inp
.
sum
(
axis
=
_reduce_axis
,
keepdims
=
True
)
syncbn_stage0
,
syncbn_stage1
,
syncbn_stage1_inference
,
syncbn_stage2
,
syncbn_concat_stats
,
syncbn_split_stats
,
)
=
_get_sync_bn_ops
(
_device
,
_dtype
,
eps_mode
,
_ndim
,
_channels
)
reduce_size
=
inp
.
shape
[
0
]
reduce_shape
,
reduce_size
,
channel_x1s
,
channel_x2s
=
apply
(
syncbn_stage0
,
inp
)
for
i
in
range
(
2
,
_ndim
):
reduce_size
=
reduce_size
*
inp
.
shape
[
i
]
channel_x1s
=
_sum_on_channel
(
inp
)
channel_x2s
=
_sum_on_channel
(
inp
**
2
)
eps
=
convert_single_value
(
eps
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
weight
=
_make_full_if_none
(
weight
,
1
)
bias
=
_make_full_if_none
(
bias
,
0
)
if
training
:
if
is_distributed
():
if
is_distributed
():
# reduce all nodes' data to calculate mean and variance
# reduce all nodes' data to calculate mean and variance
reduce_size
=
broadcast_to
(
(
stat
,)
=
apply
(
syncbn_concat_stats
,
reduce_size
,
channel_x1s
,
channel_x2s
)
Tensor
(
reduce_size
).
astype
(
dtype
=
_dtype
),
[
1
]
*
_ndim
)
stat
=
concat
([
reduce_size
,
channel_x1s
,
channel_x2s
],
axis
=
1
)
stat
=
all_reduce_sum
(
stat
,
group
)
stat
=
all_reduce_sum
(
stat
,
group
)
reduce_size
=
stat
[:,
:
1
].
reshape
(
1
)
reduce_size
,
channel_x1s
,
channel_x2s
=
apply
(
syncbn_split_stats
,
stat
)
channel_x1s
=
stat
[:,
1
:
1
+
_channels
]
channel_x2s
=
stat
[:,
1
+
_channels
:]
channel_mean
=
channel_x1s
/
reduce_size
outvar
,
channel_mean
,
*
_
=
apply
(
channel_variance
=
(
syncbn_stage1
,
inp
,
reduce_size
,
channel_x1s
,
channel_x2s
,
eps
,
weight
,
bias
channel_x1s
**
2
/
(
-
reduce_size
*
reduce_size
)
+
channel_x2s
/
reduce_size
)
)
else
:
else
:
assert
running_var
is
not
None
and
running_mean
is
not
None
assert
running_var
is
not
None
and
running_mean
is
not
None
channel_variance
=
running_var
.
reshape
(
*
_param_shape
)
channel_mean
=
running_mean
channel_mean
=
running_mean
.
reshape
(
*
_param_shape
)
channel_var
=
running_var
outvar
,
*
_
=
apply
(
invsqrt_channel_variance
=
(
syncbn_stage1_inference
,
inp
,
channel_mean
,
channel_var
,
eps
,
weight
,
bias
maximum
(
channel_variance
,
eps
)
if
eps_mode
==
"max"
else
channel_variance
+
eps
)
)
**
-
0.5
if
weight
is
not
None
:
weight
=
weight
.
reshape
(
*
_param_shape
)
if
bias
is
not
None
:
bias
=
bias
.
reshape
(
*
_param_shape
)
# outvar = output * weight + bias
# outvar = output * weight + bias
# where output = inp * invsqrt_channel_variance + (
# where output = inp * invsqrt_channel_variance + (
...
@@ -1246,28 +1359,18 @@ def sync_batch_norm(
...
@@ -1246,28 +1359,18 @@ def sync_batch_norm(
# )
# )
# Manually expand output for gopt
# Manually expand output for gopt
if
weight
is
not
None
:
inv_var_wt
=
invsqrt_channel_variance
*
weight
neg_channel_mean
=
-
channel_mean
if
bias
is
not
None
:
outvar
=
inp
*
inv_var_wt
+
(
neg_channel_mean
*
inv_var_wt
+
bias
)
else
:
outvar
=
inp
*
inv_var_wt
+
neg_channel_mean
*
inv_var_wt
else
:
outvar
=
inp
*
invsqrt_channel_variance
+
(
-
channel_mean
*
invsqrt_channel_variance
)
if
bias
is
not
None
:
outvar
=
outvar
+
bias
if
training
and
running_var
is
not
None
and
running_mean
is
not
None
:
if
training
and
running_var
is
not
None
and
running_mean
is
not
None
:
running_mean
*=
momentum
momentum
=
convert_single_value
(
momentum
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
running_mean
+=
(
1
-
momentum
)
*
channel_mean
running_mean
[...],
running_var
[...]
=
apply
(
channel_variance_unbiased
=
channel_x1s
**
2
/
(
syncbn_stage2
,
-
reduce_size
*
(
reduce_size
-
1
)
running_mean
,
)
+
channel_x2s
/
(
reduce_size
-
1
)
running_var
,
running_var
*=
momentum
momentum
,
running_var
+=
(
1
-
momentum
)
*
channel_variance_unbiased
reduce_size
,
channel_x1s
,
channel_x2s
,
channel_mean
,
)
return
outvar
return
outvar
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
8c47c1f1
...
@@ -66,7 +66,7 @@ def is_tracing():
...
@@ -66,7 +66,7 @@ def is_tracing():
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
exclude_from_trace
():
def
exclude_from_trace
():
global
skip_tracing
global
skip_tracing
if
skip_tracing
:
if
skip_tracing
or
(
active_trace
is
None
)
:
yield
yield
return
return
try
:
try
:
...
...
imperative/python/src/common.cpp
浏览文件 @
8c47c1f1
...
@@ -58,6 +58,9 @@ void init_common(py::module m) {
...
@@ -58,6 +58,9 @@ void init_common(py::module m) {
.
def_property_readonly
(
"logical_name"
,
[](
const
CompNode
&
cn
)
{
.
def_property_readonly
(
"logical_name"
,
[](
const
CompNode
&
cn
)
{
return
cn
.
to_string_logical
();
return
cn
.
to_string_logical
();
})
})
.
def_property_readonly
(
"physical_name"
,
[](
const
CompNode
&
cn
)
{
return
cn
.
to_string
();
})
.
def_property_readonly
(
"get_mem_status_bytes"
,
[](
const
CompNode
&
cn
)
{
.
def_property_readonly
(
"get_mem_status_bytes"
,
[](
const
CompNode
&
cn
)
{
return
cn
.
get_mem_status_bytes
();
return
cn
.
get_mem_status_bytes
();
})
})
...
@@ -70,6 +73,7 @@ void init_common(py::module m) {
...
@@ -70,6 +73,7 @@ void init_common(py::module m) {
cn
.
to_string_physical
().
c_str
(),
cn
.
to_string_physical
().
c_str
(),
cn
.
to_string_logical
().
c_str
());
cn
.
to_string_logical
().
c_str
());
})
})
.
def
(
"__hash__"
,
[](
CompNode
cn
){
return
mgb
::
hash
(
cn
);
})
.
def_static
(
"_sync_all"
,
&
CompNode
::
sync_all
)
.
def_static
(
"_sync_all"
,
&
CompNode
::
sync_all
)
.
def
(
py
::
self
==
py
::
self
)
.
def
(
py
::
self
==
py
::
self
)
.
def_static
(
"_get_device_count"
,
&
CompNode
::
get_device_count
,
.
def_static
(
"_get_device_count"
,
&
CompNode
::
get_device_count
,
...
...
imperative/python/src/ops.cpp
浏览文件 @
8c47c1f1
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "megbrain/common.h"
#include "megbrain/common.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/utility.h"
...
@@ -477,4 +478,50 @@ void init_ops(py::module m) {
...
@@ -477,4 +478,50 @@ void init_ops(py::module m) {
m
.
def
(
"set_global_rng_seed"
,
&
rng
::
set_global_rng_seed
);
m
.
def
(
"set_global_rng_seed"
,
&
rng
::
set_global_rng_seed
);
m
.
def
(
"get_global_rng_seed"
,
&
rng
::
get_global_rng_seed
);
m
.
def
(
"get_global_rng_seed"
,
&
rng
::
get_global_rng_seed
);
m
.
def
(
"get_rng_handle_compnode"
,
&
rng
::
get_rng_handle_compnode
);
m
.
def
(
"get_rng_handle_compnode"
,
&
rng
::
get_rng_handle_compnode
);
struct
PySubgraphBuilder
{
explicit
PySubgraphBuilder
(
std
::
string
name
)
:
name
{
name
}{}
std
::
string
name
;
Subgraph
graph
;
mgb
::
SmallVector
<
bool
>
output_grad_mask
;
Subgraph
::
var_t
next_var
=
1
;
};
py
::
class_
<
PySubgraphBuilder
>
(
m
,
"SubgraphBuilder"
)
.
def
(
py
::
init
<
std
::
string
>
())
.
def
(
"input"
,
[](
PySubgraphBuilder
&
self
){
auto
var
=
self
.
next_var
++
;
self
.
graph
.
inputs
.
push_back
(
var
);
return
var
;
})
.
def
(
"apply"
,
[](
PySubgraphBuilder
&
self
,
std
::
shared_ptr
<
OpDef
>
op
,
Subgraph
::
vars_t
inputs
,
size_t
nr_outputs
){
Subgraph
::
vars_t
outputs
;
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
outputs
.
push_back
(
self
.
next_var
++
);
}
self
.
graph
.
exprs
.
push_back
({
op
,
inputs
,
outputs
});
return
outputs
;
})
.
def
(
"apply_const"
,
[](
PySubgraphBuilder
&
self
,
py
::
object
value
,
mgb
::
DType
dtype
,
mgb
::
CompNode
cn
){
auto
var
=
self
.
next_var
++
;
mgb
::
HostTensorND
hvalue
(
cn
);
npy
::
np2tensor
(
value
.
cast
<
py
::
array
>
().
ptr
(),
npy
::
Meth
::
copy_into
(
&
hvalue
),
dtype
);
self
.
graph
.
constants
.
push_back
({
var
,
Tensor
::
make
(
hvalue
)});
return
var
;
})
.
def
(
"outputs"
,
[](
PySubgraphBuilder
&
self
,
Subgraph
::
vars_t
outputs
){
self
.
graph
.
outputs
=
outputs
;
self
.
output_grad_mask
.
resize
(
outputs
.
size
(),
true
);
})
.
def
(
"outputs_has_grad"
,
[](
PySubgraphBuilder
&
self
,
mgb
::
SmallVector
<
bool
>
outputs_has_grad
){
mgb_assert
(
self
.
graph
.
outputs
.
size
()
==
self
.
output_grad_mask
.
size
());
self
.
output_grad_mask
=
outputs_has_grad
;
})
.
def
(
"get"
,
[](
PySubgraphBuilder
&
self
){
return
(
std
::
shared_ptr
<
OpDef
>
)
SubgraphOp
::
make
(
self
.
name
,
self
.
graph
,
self
.
output_grad_mask
);
})
.
def
(
"compile"
,
[](
PySubgraphBuilder
&
self
,
int
gopt_level
){
auto
op
=
SubgraphOp
::
make
(
self
.
name
,
self
.
graph
,
self
.
output_grad_mask
);
return
(
std
::
shared_ptr
<
OpDef
>
)
CompiledOp
::
make
(
op
,
gopt_level
);
});
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录