Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
13e6ea34
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
13e6ea34
编写于
2月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/opr): rebase rng refactoring to dev & add python module
GitOrigin-RevId: ee5984c52d3fa346d5f26d737bf40ec4ed43b2c7
上级
cded8ef1
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
371 addition
and
247 deletion
+371
-247
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+2
-1
imperative/python/megengine/distributed/group.py
imperative/python/megengine/distributed/group.py
+3
-0
imperative/python/megengine/random/__init__.py
imperative/python/megengine/random/__init__.py
+1
-1
imperative/python/megengine/random/distribution.py
imperative/python/megengine/random/distribution.py
+18
-24
imperative/python/megengine/random/rng.py
imperative/python/megengine/random/rng.py
+85
-4
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+14
-17
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+121
-0
imperative/src/impl/ops/rng.cpp
imperative/src/impl/ops/rng.cpp
+84
-84
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-28
imperative/src/include/megbrain/imperative/ops/rng.h
imperative/src/include/megbrain/imperative/ops/rng.h
+9
-78
imperative/src/test/rng.cpp
imperative/src/test/rng.cpp
+14
-6
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+20
-4
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
13e6ea34
...
...
@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False):
def
_remove_axis
(
inp
:
Tensor
,
axis
)
->
Tensor
:
def
get_axes
():
if
axis
is
None
:
return
[
i
for
i
,
s
in
enumerate
(
inp
.
shape
)
if
s
==
1
]
shp
=
inp
.
shape
return
[
i
for
i
,
s
in
enumerate
(
shp
)
if
s
==
1
]
try
:
return
[
int
(
axis
)]
except
(
TypeError
,
ValueError
):
...
...
imperative/python/megengine/distributed/group.py
浏览文件 @
13e6ea34
...
...
@@ -6,9 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
time
from
typing
import
List
,
Optional
,
Tuple
from
..device
import
set_default_device
,
what_is_xpu
from
..random
import
seed
from
.server
import
Client
,
Server
...
...
@@ -156,6 +158,7 @@ def init_process_group(
WORLD
.
reset
(
list
(
range
(
world_size
)))
set_default_device
(
"{}{}"
.
format
(
device_type
,
device
))
seed
(
int
(
time
.
time
())
+
rank
)
def
is_distributed
()
->
bool
:
...
...
imperative/python/megengine/random/__init__.py
浏览文件 @
13e6ea34
...
...
@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.distribution
import
normal
,
uniform
from
.rng
import
seed
from
.rng
import
RNG
,
seed
# pylint: disable=undefined-variable
del
distribution
,
rng
# type: ignore[name-defined]
imperative/python/megengine/random/distribution.py
浏览文件 @
13e6ea34
...
...
@@ -9,11 +9,8 @@
from
typing
import
Iterable
,
Optional
from
..
import
Tensor
from
..core._imperative_rt
import
invoke_op
from
..core._imperative_rt.core2
import
apply
from
..core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
..core.tensor
import
utils
from
.rng
import
_random_seed_generator
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
.rng
import
_normal
,
_uniform
__all__
=
[
"normal"
,
"uniform"
]
...
...
@@ -48,14 +45,14 @@ def normal(
[-1.4939808 -1.5824696 ]]
"""
if
size
is
None
:
size
=
(
1
,)
op
=
GaussianRNG
(
mean
,
std
)
_ref
=
Tensor
([],
dtype
=
"int32"
)
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
)
shape
=
Tensor
(
shape
,
dtype
=
"int32"
)
(
output
,)
=
apply
(
op
,
shape
)
return
output
return
_normal
(
mean
=
mean
,
std
=
std
,
size
=
size
,
seed
=
_get_global_rng_seed
(),
device
=
None
,
handle
=
0
,
)
def
uniform
(
...
...
@@ -88,14 +85,11 @@ def uniform(
[0.09365904 0.62957656]]
"""
assert
low
<
high
,
"Uniform is not defined when low >= high"
if
size
is
None
:
size
=
(
1
,)
op
=
UniformRNG
()
_ref
=
Tensor
([],
dtype
=
"int32"
)
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
)
shape
=
Tensor
(
shape
,
dtype
=
"int32"
)
(
output
,)
=
apply
(
op
,
shape
)
return
low
+
(
high
-
low
)
*
output
return
_uniform
(
low
=
low
,
high
=
high
,
size
=
size
,
seed
=
_get_global_rng_seed
(),
device
=
None
,
handle
=
0
,
)
imperative/python/megengine/random/rng.py
浏览文件 @
13e6ea34
...
...
@@ -7,17 +7,94 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
time
from
typing
import
Iterable
,
Optional
from
numpy.random
import
MT19937
from
..
import
Tensor
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.ops
import
delete_rng_handle
as
_delete_rng_handle
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
..core._imperative_rt.ops
import
new_rng_handle
as
_new_rng_handle
from
..core._imperative_rt.ops
import
set_global_rng_seed
as
_set_global_rng_seed
from
..core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
..core.tensor
import
utils
from
..device
import
get_default_device
_rng
=
None
def
_random_seed_generator
():
if
_rng
is
None
:
from
..distributed.group
import
get_rank
def
_normal
(
mean
:
float
,
std
:
float
,
size
:
Optional
[
Iterable
[
int
]],
seed
:
int
,
device
:
str
,
handle
:
int
,
)
->
Tensor
:
if
size
is
None
:
size
=
(
1
,)
op
=
GaussianRNG
(
seed
=
seed
,
mean
=
mean
,
std
=
std
,
handle
=
handle
)
_ref
=
Tensor
([],
dtype
=
"int32"
,
device
=
device
)
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
,
device
=
device
)
(
output
,)
=
apply
(
op
,
shape
)
return
output
def
_uniform
(
low
:
float
,
high
:
float
,
size
:
Optional
[
Iterable
[
int
]],
seed
:
int
,
device
:
str
,
handle
:
int
,
)
->
Tensor
:
assert
low
<
high
,
"Uniform is not defined when low >= high"
if
size
is
None
:
size
=
(
1
,)
op
=
UniformRNG
(
seed
=
seed
,
handle
=
handle
)
_ref
=
Tensor
([],
dtype
=
"int32"
,
device
=
device
)
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
,
device
=
device
)
(
output
,)
=
apply
(
op
,
shape
)
return
low
+
(
high
-
low
)
*
output
class
RNG
:
def
__init__
(
self
,
seed
=
0
,
device
=
None
):
self
.
seed
=
seed
self
.
device
=
device
if
device
else
get_default_device
()
self
.
handle
=
_new_rng_handle
(
self
.
device
,
self
.
seed
)
def
uniform
(
self
,
low
:
float
=
0
,
high
:
float
=
1
,
size
:
Optional
[
Iterable
[
int
]]
=
None
):
return
_uniform
(
low
=
low
,
high
=
high
,
size
=
size
,
seed
=
self
.
seed
,
device
=
self
.
device
,
handle
=
self
.
handle
,
)
seed
(
seed
=
int
(
time
.
time
())
+
get_rank
())
def
normal
(
self
,
mean
:
float
=
0
,
std
:
float
=
1
,
size
:
Optional
[
Iterable
[
int
]]
=
None
):
return
_normal
(
mean
=
mean
,
std
=
std
,
size
=
size
,
seed
=
self
.
seed
,
device
=
self
.
device
,
handle
=
self
.
handle
,
)
def
__del__
(
self
):
_delete_rng_handle
(
self
.
handle
)
def
_random_seed_generator
():
assert
_rng
while
True
:
yield
_rng
.
random_raw
()
...
...
@@ -25,3 +102,7 @@ def _random_seed_generator():
def
seed
(
seed
:
int
):
global
_rng
# pylint: disable=global-statement
_rng
=
MT19937
(
seed
=
seed
)
_set_global_rng_seed
(
seed
)
seed
(
int
(
time
.
time
()))
imperative/python/src/ops.cpp
浏览文件 @
13e6ea34
...
...
@@ -10,7 +10,10 @@
*/
#include "./ops.h"
#include "./helper.h"
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
...
...
@@ -491,21 +494,15 @@ void init_ops(py::module m) {
_init_py_op_base
(
m
);
INIT_ALL_OP
(
m
)
m
.
def
(
"new_rng_handle"
,
&
RNGMixin
::
new_handle
);
// FIXME: RNG op might execute after handle released due to async dispatch,
// which would cause memory leak or use-after-free
m
.
def
(
"delete_rng_handle"
,
&
RNGMixin
::
delete_handle
);
m
.
def
(
"set_rng_seed"
,
&
set_rng_seed
);
py
::
class_
<
UniformRNG
,
std
::
shared_ptr
<
UniformRNG
>
,
OpDef
>
(
m
,
"UniformRNG"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
mgb
::
CompNode
>
())
.
def
(
py
::
init
<
RNGMixin
::
Handle
>
());
py
::
class_
<
GaussianRNG
,
std
::
shared_ptr
<
GaussianRNG
>
,
OpDef
>
(
m
,
"GaussianRNG"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
mgb
::
CompNode
>
())
.
def
(
py
::
init
<
float
,
float
>
())
.
def
(
py
::
init
<
float
,
float
,
mgb
::
CompNode
>
())
.
def
(
py
::
init
<
float
,
float
,
RNGMixin
::
Handle
>
());
m
.
def
(
"new_rng_handle"
,
&
rng
::
new_handle
);
m
.
def
(
"delete_rng_handle"
,
[](
size_t
handle
){
// RNG op might execute after handle released due to async dispatch, so
// we need sync before delete a handle to avoid memory leak or use-after-free
python
::
interpreter_for_py
->
sync
();
mgb
::
CompNode
::
sync_all
();
py_task_q
.
wait_all_task_finish
();
rng
::
delete_handle
(
handle
);
},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"set_global_rng_seed"
,
&
rng
::
set_global_rng_seed
);
m
.
def
(
"get_global_rng_seed"
,
&
rng
::
get_global_rng_seed
);
}
imperative/python/test/unit/test_rng.py
→
imperative/python/test/unit/
random/
test_rng.py
浏览文件 @
13e6ea34
...
...
@@ -8,14 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
megengine
from
megengine
import
tensor
from
megengine.core._imperative_rt
import
CompNode
from
megengine.core._imperative_rt.ops
import
delete_rng_handle
,
new_rng_handle
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core._imperative_rt.ops
import
(
delete_rng_handle
,
get_global_rng_seed
,
new_rng_handle
,
)
from
megengine.core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
megengine.core.tensor.core
import
apply
from
megengine.random
import
RNG
from
megengine.random.rng
import
_normal
,
_uniform
def
test_gaussian_
rng
():
def
test_gaussian_
op
():
shape
=
(
8
,
9
,
...
...
@@ -23,23 +30,16 @@ def test_gaussian_rng():
12
,
)
shape
=
tensor
(
shape
,
dtype
=
"int32"
)
op
=
GaussianRNG
(
1.0
,
3.0
)
op
=
GaussianRNG
(
seed
=
get_global_rng_seed
(),
mean
=
1.0
,
std
=
3.0
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
1.0
)
<
1e-1
assert
np
.
sqrt
(
output
.
numpy
().
var
())
-
3.0
<
1e-1
assert
str
(
output
.
device
)
==
str
(
CompNode
(
"xpux"
))
cn
=
CompNode
(
"xpu1"
)
op
=
GaussianRNG
(
-
1.0
,
2.0
,
cn
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
(
-
1.0
))
<
1e-1
assert
np
.
sqrt
(
output
.
numpy
().
var
())
-
2.0
<
1e-1
assert
str
(
output
.
device
)
==
str
(
cn
)
cn
=
CompNode
(
"xpu2"
)
seed
=
233333
h
=
new_rng_handle
(
cn
,
seed
)
op
=
GaussianRNG
(
3.0
,
1.0
,
h
)
op
=
GaussianRNG
(
seed
=
seed
,
mean
=
3.0
,
std
=
1.0
,
handle
=
h
)
(
output
,)
=
apply
(
op
,
shape
)
delete_rng_handle
(
h
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
3.0
)
<
1e-1
...
...
@@ -47,7 +47,7 @@ def test_gaussian_rng():
assert
str
(
output
.
device
)
==
str
(
cn
)
def
test_uniform_
rng
():
def
test_uniform_
op
():
shape
=
(
8
,
9
,
...
...
@@ -55,22 +55,67 @@ def test_uniform_rng():
12
,
)
shape
=
tensor
(
shape
,
dtype
=
"int32"
)
op
=
UniformRNG
()
op
=
UniformRNG
(
seed
=
get_global_rng_seed
()
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
str
(
output
.
device
)
==
str
(
CompNode
(
"xpux"
))
cn
=
CompNode
(
"xpu1"
)
op
=
UniformRNG
(
cn
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
str
(
output
.
device
)
==
str
(
cn
)
cn
=
CompNode
(
"xpu2"
)
seed
=
233333
h
=
new_rng_handle
(
cn
,
seed
)
op
=
UniformRNG
(
h
)
op
=
UniformRNG
(
seed
=
seed
,
handle
=
h
)
(
output
,)
=
apply
(
op
,
shape
)
delete_rng_handle
(
h
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
str
(
output
.
device
)
==
str
(
cn
)
def
test_UniformRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
m2
=
RNG
(
seed
=
111
,
device
=
"xpu1"
)
m3
=
RNG
(
seed
=
222
,
device
=
"xpu0"
)
out1
=
m1
.
uniform
(
size
=
(
100
,))
out1_
=
m1
.
uniform
(
size
=
(
100
,))
out2
=
m2
.
uniform
(
size
=
(
100
,))
out3
=
m3
.
uniform
(
size
=
(
100
,))
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out2
.
numpy
())
assert
out1
.
device
==
"xpu0"
and
out2
.
device
==
"xpu1"
assert
not
(
out1
.
numpy
()
==
out3
.
numpy
()).
all
()
assert
not
(
out1
.
numpy
()
==
out1_
.
numpy
()).
all
()
low
=
-
234
high
=
123
out
=
m1
.
uniform
(
low
=
low
,
high
=
high
,
size
=
(
20
,
30
,
40
))
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
20
,
30
,
40
)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
20
,
30
,
40
]))
assert
np
.
abs
(
out
.
mean
().
numpy
()
-
((
low
+
high
)
/
2
))
/
(
high
-
low
)
<
0.1
def
test_NormalRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
m2
=
RNG
(
seed
=
111
,
device
=
"xpu1"
)
m3
=
RNG
(
seed
=
222
,
device
=
"xpu0"
)
out1
=
m1
.
normal
(
size
=
(
100
,))
out1_
=
m1
.
uniform
(
size
=
(
100
,))
out2
=
m2
.
normal
(
size
=
(
100
,))
out3
=
m3
.
normal
(
size
=
(
100
,))
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out2
.
numpy
())
assert
out1
.
device
==
"xpu0"
and
out2
.
device
==
"xpu1"
assert
not
(
out1
.
numpy
()
==
out3
.
numpy
()).
all
()
assert
not
(
out1
.
numpy
()
==
out1_
.
numpy
()).
all
()
mean
=
-
1
std
=
2
out
=
m1
.
normal
(
mean
=
mean
,
std
=
std
,
size
=
(
20
,
30
,
40
))
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
20
,
30
,
40
)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
20
,
30
,
40
]))
assert
np
.
abs
(
out
.
mean
().
numpy
()
-
mean
)
/
std
<
0.1
assert
np
.
abs
(
np
.
std
(
out
.
numpy
())
-
std
)
<
0.1
imperative/src/impl/ops/rng.cpp
浏览文件 @
13e6ea34
...
...
@@ -2,7 +2,7 @@
* \file imperative/src/impl/ops/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-202
0
Megvii Inc. All rights reserved.
* Copyright (c) 2014-202
1
Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
...
...
@@ -10,23 +10,23 @@
*/
#include "megbrain/imperative/ops/rng.h"
#include <bits/stdint-uintn.h>
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
//#include "megbrain/common.h"
#include "../op_trait.h"
#include "../dnn_op_helper.h"
namespace
mgb
{
namespace
imperative
{
namespace
mgb
::
imperative
::
rng
{
namespace
{
template
<
typename
HandleFactory
,
typename
THandle
>
class
DnnOpManagerT
:
public
CompNodeDepedentObject
,
public
NonCopyableObj
{
public:
using
DT
=
CompNode
::
DeviceType
;
using
Handle
=
THandle
;
using
OpTypeInfo
=
size_t
;
template
<
typename
...
Args
>
Handle
new_handle
(
Args
&&
...
args
)
{
...
...
@@ -38,27 +38,26 @@ public:
size_t
removed
=
0
;
if
(
!
is_finalized
())
{
MGB_LOCK_GUARD
(
m_mtx
);
removed
=
m_handle2op
.
erase
(
handle
);
removed
=
m_handle2op
s
.
erase
(
handle
);
}
static_cast
<
HandleFactory
*>
(
this
)
->
do_delete_handle
(
handle
);
return
removed
;
}
template
<
typename
DnnOp
>
auto
get_dnn_op
(
Handle
handle
,
CompNode
cn
)
{
auto
get_dnn_op
(
Handle
handle
,
OpTypeInfo
tpinfo
,
CompNode
cn
)
{
mgb_assert
(
!
is_finalized
());
DnnOpWithMutex
*
dnn_op_with_mtx
;
{
MGB_LOCK_GUARD
(
m_mtx
);
dnn_op_with_mtx
=
&
m_handle2op
[
handle
];
dnn_op_with_mtx
=
&
m_handle2op
s
[
handle
][
tpinfo
];
}
auto
dnn_handle
=
MegDNNHandle
::
get
(
CompNodeEnv
::
from_comp_node
(
cn
)).
handle
();
DnnOp
*
dnn_op
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
dnn_op_with_mtx
->
mtx
);
bool
initialized
=
false
;
if
((
dnn_op
=
dynamic_cast
<
DnnOp
*>
(
dnn_op_with_mtx
->
op
.
get
()))
!=
nullptr
)
{
DnnOp
*
dnn_op
=
static_cast
<
DnnOp
*>
(
dnn_op_with_mtx
->
op
.
get
());
if
(
dnn_op
!=
nullptr
)
{
mgb_assert
(
dnn_op
->
handle
()
==
dnn_handle
);
initialized
=
true
;
}
else
{
...
...
@@ -77,35 +76,30 @@ private:
struct
DnnOpWithMutex
{
std
::
mutex
mtx
;
std
::
unique_ptr
<
megdnn
::
OperatorBase
>
op
;
DnnOpWithMutex
()
:
op
{
nullptr
}
{}
};
std
::
shared_ptr
<
void
>
on_comp_node_finalize
()
override
{
MGB_LOCK_GUARD
(
m_mtx
);
m_handle2op
.
clear
();
m_handle2op
s
.
clear
();
return
{};
}
std
::
unordered_map
<
Handle
,
DnnOpWithMutex
>
m_handle2op
;
std
::
unordered_map
<
Handle
,
std
::
unordered_map
<
OpTypeInfo
,
DnnOpWithMutex
>
>
m_handle2ops
;
std
::
mutex
m_mtx
;
};
class
RNGDnnOpManager
final
:
public
DnnOpManagerT
<
RNGDnnOpManager
,
RNGMixin
::
Handle
>
{
:
public
DnnOpManagerT
<
RNGDnnOpManager
,
Handle
>
{
public:
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
MGB_LOCK_GUARD
(
sm_mtx
);
return
DnnOpManagerBase
::
new_handle
(
comp_node
,
seed
);
}
size_t
delete_handle
(
Handle
handle
)
{
size_t
ret
=
0
;
{
MGB_LOCK_GUARD
(
sm_mtx
);
auto
iter
=
sm_partial2full
.
find
(
handle
);
if
(
iter
!=
sm_partial2full
.
end
())
{
for
(
auto
&&
h
:
iter
->
second
)
{
ret
+=
DnnOpManagerBase
::
delete_handle
(
h
.
second
);
}
sm_partial2full
.
erase
(
iter
);
}
}
ret
+=
DnnOpManagerBase
::
delete_handle
(
handle
);
return
ret
;
MGB_LOCK_GUARD
(
sm_mtx
);
return
DnnOpManagerBase
::
delete_handle
(
handle
);
}
Handle
do_new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
...
...
@@ -118,32 +112,26 @@ public:
}
static
uint64_t
get_seed
(
Handle
handle
)
{
if
(
!
handle
)
{
return
glob_default_seed
;
}
return
reinterpret_cast
<
HandleData
*>
(
handle
)
->
seed
;
}
static
CompNode
get_comp_node
(
Handle
handle
)
{
mgb_assert
(
handle
,
"invalid handle"
);
return
reinterpret_cast
<
HandleData
*>
(
handle
)
->
comp_node
;
}
static
Handle
get_full_handle
(
Handle
handle
,
CompNode
comp_node
)
{
if
(
get_comp_node
(
handle
).
valid
())
{
return
handle
;
}
MGB_LOCK_GUARD
(
sm_mtx
);
auto
&&
full
=
sm_partial2full
[
handle
][
comp_node
];
if
(
!
full
)
{
full
=
inst
().
new_handle
(
comp_node
,
get_seed
(
handle
));
}
return
full
;
}
static
Handle
get_default_handle
(
CompNode
comp_node
)
{
static
Handle
glob_partial_handle
=
inst
().
new_handle
(
CompNode
{},
glob_default_seed
);
if
(
!
comp_node
.
valid
())
{
return
glob_partial_handle
;
mgb_assert
(
comp_node
.
valid
());
MGB_LOCK_GUARD
(
sm_mtx
);
auto
&&
glob_handle
=
glob_default_handles
[
comp_node
];
if
(
!
glob_handle
)
{
glob_handle
=
inst
().
do_new_handle
(
comp_node
,
glob_default_seed
);
}
else
if
(
get_seed
(
glob_handle
)
!=
glob_default_seed
)
{
inst
().
DnnOpManagerBase
::
delete_handle
(
glob_handle
);
glob_handle
=
inst
().
do_new_handle
(
comp_node
,
glob_default_seed
);
}
return
g
et_full_handle
(
glob_partial_handle
,
comp_node
)
;
return
g
lob_handle
;
}
static
RNGDnnOpManager
&
inst
()
{
...
...
@@ -152,9 +140,15 @@ public:
}
static
void
set_glob_default_seed
(
uint64_t
seed
)
{
MGB_LOCK_GUARD
(
sm_mtx
);
glob_default_seed
=
seed
;
}
static
uint64_t
get_glob_default_seed
()
{
MGB_LOCK_GUARD
(
sm_mtx
);
return
glob_default_seed
;
}
private:
struct
HandleData
{
CompNode
comp_node
;
...
...
@@ -165,16 +159,13 @@ private:
MemPool
<
HandleData
>
m_handle_pool
;
static
std
::
mutex
sm_mtx
;
static
std
::
unordered_map
<
Handle
,
CompNode
::
UnorderedMap
<
Handle
>>
sm_partial2full
;
static
CompNode
::
UnorderedMap
<
Handle
>
glob_default_handles
;
static
uint64_t
glob_default_seed
;
};
uint64_t
RNGDnnOpManager
::
glob_default_seed
=
0
;
std
::
mutex
RNGDnnOpManager
::
sm_mtx
;
std
::
unordered_map
<
RNGDnnOpManager
::
Handle
,
CompNode
::
UnorderedMap
<
RNGDnnOpManager
::
Handle
>>
RNGDnnOpManager
::
sm_partial2full
;
CompNode
::
UnorderedMap
<
Handle
>
RNGDnnOpManager
::
glob_default_handles
;
template
<
typename
Op
>
struct
OpMeth
;
...
...
@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> {
using
Param
=
DnnOp
::
Param
;
using
OpNode
=
mgb
::
opr
::
UniformRNG
;
static
Param
make_param
(
const
UniformRNG
&
rng
)
{
return
{
RNGDnnOpManager
::
get_seed
(
rng
.
handle
())};
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
rng
.
handle
);
mgb_assert
(
handle_seed
==
rng
.
seed
,
"inconsistent rng seed: rng op: %lu handle: %lu"
,
handle_seed
,
rng
.
seed
);
return
{
handle_seed
};
}
};
...
...
@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> {
using
Param
=
DnnOp
::
Param
;
using
OpNode
=
mgb
::
opr
::
GaussianRNG
;
static
Param
make_param
(
const
GaussianRNG
&
rng
)
{
return
{
RNGDnnOpManager
::
get_seed
(
rng
.
handle
()),
rng
.
mean
,
rng
.
std
};
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
rng
.
handle
);
mgb_assert
(
handle_seed
==
rng
.
seed
,
"inconsistent rng seed: rng op: %lu handle: %lu"
,
handle_seed
,
rng
.
seed
);
return
{
handle_seed
,
rng
.
mean
,
rng
.
std
};
}
};
...
...
@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto
dest
=
outputs
[
0
];
auto
cn
=
dest
->
comp_node
();
auto
handle
=
RNGDnnOpManager
::
get_full_handle
(
rng
.
handle
(),
cn
);
{
auto
handle_cn
=
RNGDnnOpManager
::
get_comp_node
(
handle
);
mgb_assert
(
cn
==
handle_cn
,
"inconsistent comp_node: handle: %s, output: %s"
,
cn
.
to_string
().
c_str
(),
handle_cn
.
to_string
().
c_str
());
auto
handle
=
rng
.
handle
;
if
(
!
handle
)
{
handle
=
RNGDnnOpManager
::
get_default_handle
(
cn
);
}
// retrieve dnn_op from glob cache
auto
dnn_op_thread_safe
=
RNGDnnOpManager
::
inst
()
.
get_dnn_op
<
typename
OpMeth
<
Op
>::
DnnOp
>
(
handle
,
cn
);
.
get_dnn_op
<
typename
OpMeth
<
Op
>::
DnnOp
>
(
handle
,
reinterpret_cast
<
size_t
>
(
op
.
dyn_typeinfo
()),
cn
);
auto
initialized
=
std
::
get
<
0
>
(
dnn_op_thread_safe
);
auto
dnn_op
=
std
::
get
<
1
>
(
dnn_op_thread_safe
);
if
(
initialized
)
{
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
handle
);
mgb_assert
(
dnn_op
->
param
().
seed
==
handle_seed
,
"inconsistent rng seed: handle: %
zu, dnn_op: %z
u"
,
"inconsistent rng seed: handle: %
lu, dnn_op: %l
u"
,
handle_seed
,
dnn_op
->
param
().
seed
);
}
dnn_op
->
param
()
=
OpMeth
<
Op
>::
make_param
(
rng
);
...
...
@@ -239,9 +237,12 @@ template <typename Op>
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs
(
const
OpDef
&
op
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
LogicalTensorDesc
dest
;
dest
.
comp_node
=
op
.
cast_final_safe
<
Op
>
().
comp_node
();
if
(
!
dest
.
comp_node
.
valid
())
auto
handle
=
op
.
cast_final_safe
<
Op
>
().
handle
;
if
(
handle
)
{
dest
.
comp_node
=
RNGDnnOpManager
::
get_comp_node
(
handle
);
}
else
{
dest
.
comp_node
=
inputs
[
0
]
->
comp_node
();
}
auto
hv
=
inputs
[
0
]
->
get_value
().
proxy_to_default_cpu
();
TensorShape
tshape
;
...
...
@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}
template
<
typename
Op
>
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
SymbolVar
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
1
,
"UniformRNG expects 1 inputs; got %lu actually"
,
nr_inp
);
auto
&&
rng
=
def
.
cast_final_safe
<
Op
>
();
mgb_assert
(
nr_inp
==
1
,
"%s expects 1 inputs; got %lu actually"
,
rng
.
dyn_typeinfo
()
->
name
,
nr_inp
);
auto
param
=
OpMeth
<
Op
>::
make_param
(
rng
);
return
OpMeth
<
Op
>::
OpNode
::
make
(
inputs
[
0
],
param
,
{
rng
.
comp_node
()}).
node
()
->
owner_opr
();
OperatorNodeConfig
config
;
if
(
rng
.
handle
)
{
config
=
{
rng
.
make_name
(),
RNGDnnOpManager
::
get_comp_node
(
rng
.
handle
)};
}
else
{
config
=
{
rng
.
make_name
()};
}
return
OpMeth
<
Op
>::
OpNode
::
make
(
inputs
[
0
],
param
,
config
);
}
template
<
typename
T
>
...
...
@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
// anonymous namespace
RNGMixin
::
RNGMixin
(
CompNode
cn
)
:
m_handle
(
RNGDnnOpManager
::
get_default_handle
(
cn
))
{}
uint64_t
RNGMixin
::
seed
()
const
{
return
RNGDnnOpManager
::
get_seed
(
m_handle
);
}
CompNode
RNGMixin
::
comp_node
()
const
{
return
RNGDnnOpManager
::
get_comp_node
(
m_handle
);
}
RNGMixin
::
Handle
RNGMixin
::
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
return
RNGDnnOpManager
::
inst
().
new_handle
(
comp_node
,
seed
);
}
size_t
RNGMixin
::
delete_handle
(
Handle
handle
)
{
size_t
delete_handle
(
Handle
handle
)
{
return
RNGDnnOpManager
::
inst
().
delete_handle
(
handle
);
}
void
set_rng_seed
(
uint64_t
seed
)
{
void
set_
global_
rng_seed
(
uint64_t
seed
)
{
RNGDnnOpManager
::
set_glob_default_seed
(
seed
);
}
uint64_t
get_global_rng_seed
()
{
return
RNGDnnOpManager
::
get_glob_default_seed
();
}
#define REG_RNG_OP(NAME)\
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
...
...
@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.fallback(); \
} \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME);
REG_RNG_OP
(
UniformRNG
)
REG_RNG_OP
(
GaussianRNG
)
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb::imperative::rng
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/impl/ops/specializations.cpp
浏览文件 @
13e6ea34
...
...
@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual)
.
fallback
();
}}
// assert_equal
namespace
{
namespace
uniform_rng
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
UniformRNG
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
UniformRNG
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
UniformRNG
,
UniformRNG
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// uniform_rng
namespace
{
namespace
gaussian_rng
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
GaussianRNG
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
GaussianRNG
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
GaussianRNG
,
GaussianRNG
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// gaussian_rng
namespace
{
namespace
roi_align
{
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
...
...
imperative/src/include/megbrain/imperative/ops/rng.h
浏览文件 @
13e6ea34
...
...
@@ -2,7 +2,7 @@
* \file imperative/src/include/megbrain/imperative/ops/rng.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-202
0
Megvii Inc. All rights reserved.
* Copyright (c) 2014-202
1
Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
...
...
@@ -12,84 +12,15 @@
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
::
imperative
{
namespace
mgb
::
imperative
::
rng
{
class
RNGMixin
{
public:
using
Handle
=
size_t
;
using
Handle
=
size_t
;
static
Handle
new_handle
(
CompNode
comp_node
=
{},
uint64_t
seed
=
0
);
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
);
size_t
delete_handle
(
Handle
handle
);
void
set_global_rng_seed
(
uint64_t
seed
);
uint64_t
get_global_rng_seed
();
static
size_t
delete_handle
(
Handle
handle
);
Handle
handle
()
const
{
return
m_handle
;
}
uint64_t
seed
()
const
;
CompNode
comp_node
()
const
;
protected:
RNGMixin
(
Handle
handle
)
:
m_handle
(
handle
)
{}
RNGMixin
(
CompNode
comp_node
);
private:
Handle
m_handle
;
};
class
GaussianRNG
:
public
OpDefImplBase
<
GaussianRNG
>
,
public
RNGMixin
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
float
mean
=
1.0
f
,
std
=
0.0
;
GaussianRNG
(
CompNode
comp_node_
)
:
RNGMixin
(
comp_node_
)
{}
GaussianRNG
(
float
mean_
=
1.0
,
float
std_
=
0.0
,
CompNode
comp_node_
=
{})
:
GaussianRNG
(
comp_node_
)
{
mean
=
mean_
;
std
=
std_
;
}
GaussianRNG
(
float
mean_
,
float
std_
,
Handle
handle
)
:
RNGMixin
(
handle
),
mean
(
mean_
),
std
(
std_
)
{}
size_t
hash
()
const
override
{
XXHash
xxhash
{};
auto
append
=
[
&
xxhash
](
auto
field
){
auto
hash_val
=
HashTrait
<
decltype
(
field
)
>::
eval
(
field
);
xxhash
.
update
(
reinterpret_cast
<
void
*>
(
&
hash_val
),
sizeof
(
hash_val
));
};
append
(
dyn_typeinfo
());
append
(
seed
());
append
(
mean
);
append
(
std
);
return
xxhash
.
digest
();
}
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
GaussianRNG
&>
(
rhs_
);
return
rhs
.
seed
()
==
seed
()
&&
rhs
.
mean
==
mean
&&
rhs
.
std
==
std
;
}
};
class
UniformRNG
:
public
OpDefImplBase
<
UniformRNG
>
,
public
RNGMixin
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
UniformRNG
(
CompNode
comp_node_
=
{})
:
RNGMixin
(
comp_node_
)
{}
UniformRNG
(
Handle
handle
)
:
RNGMixin
(
handle
)
{}
size_t
hash
()
const
override
{
return
hash_pair_combine
(
mgb
::
hash
(
seed
()),
reinterpret_cast
<
std
::
uintptr_t
>
(
dyn_typeinfo
()));
}
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
UniformRNG
&>
(
rhs_
);
return
rhs
.
dyn_typeinfo
()
==
dyn_typeinfo
()
&&
rhs
.
seed
()
==
seed
();
}
};
void
set_rng_seed
(
uint64_t
seed
);
}
// namespace mgb::imperative
}
// namespace mgb::imperative::rng
imperative/src/test/rng.cpp
浏览文件 @
13e6ea34
...
...
@@ -14,6 +14,7 @@
using
namespace
mgb
;
using
namespace
imperative
;
using
namespace
imperative
::
rng
;
template
<
typename
Op
,
typename
...
Args
>
void
check_rng_basic
(
Args
&&
...
args
)
{
...
...
@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) {
{
3
,
4
,
5
,
6
},
{
2333
}})
for
(
auto
&&
cn
:
{
CompNode
::
load
(
"
c
pu0"
),
CompNode
::
load
(
"xpu
0
"
)})
CompNode
::
load
(
"
x
pu0"
),
CompNode
::
load
(
"xpu
1
"
)})
{
auto
op
=
Op
::
make
(
std
::
forward
<
Args
>
(
args
)...,
cn
);
Handle
h
=
new_handle
(
cn
,
123
);
auto
op
=
Op
::
make
(
std
::
forward
<
Args
>
(
args
)...,
h
);
DeviceTensorND
tshape_dev
;
cg
::
copy_shape_to_tensor_value
(
tshape_dev
,
tshape
);
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
op
,
{
Tensor
::
make
(
tshape_dev
)});
SmallVector
<
TensorPtr
>
inputs
=
{
Tensor
::
make
(
tshape_dev
)};
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
op
,
inputs
);
ASSERT_TRUE
(
outputs
[
0
]
->
layout
().
eq_shape
(
tshape
));
ASSERT_TRUE
(
cn
==
outputs
[
0
]
->
comp_node
());
// sync before delete handle
for
(
auto
&&
p
:
outputs
)
{
p
->
get_value
();
}
delete_handle
(
h
);
}
}
TEST
(
TestImperative
,
UniformRNGBasic
)
{
check_rng_basic
<
UniformRNG
>
();
check_rng_basic
<
UniformRNG
>
(
123
);
}
TEST
(
TestImperative
,
GaussianRNGBasic
)
{
check_rng_basic
<
GaussianRNG
>
(
2.
f
,
3.
f
);
check_rng_basic
<
GaussianRNG
>
(
123
,
2.
f
,
3.
f
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/core/include/megbrain/ir/ops.td
浏览文件 @
13e6ea34
...
...
@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}];
let cmpFunction = [{return true;}];
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle));
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
}
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std)));
mgb::hash_pair_combine(
mgb::hash($_self.handle),
mgb::hash_pair_combine(
mgb::hash($_self.mean),
mgb::hash($_self.std))
)
);
}];
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}];
let cmpFunction = [{return $0.
handle == $1.handle && $0.
mean == $1.mean && $0.std == $1.std;}];
}
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录