Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
13e6ea34
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
13e6ea34
编写于
2月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/opr): rebase rng refactoring to dev & add python module
GitOrigin-RevId: ee5984c52d3fa346d5f26d737bf40ec4ed43b2c7
上级
cded8ef1
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
371 addition
and
247 deletion
+371
-247
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+2
-1
imperative/python/megengine/distributed/group.py
imperative/python/megengine/distributed/group.py
+3
-0
imperative/python/megengine/random/__init__.py
imperative/python/megengine/random/__init__.py
+1
-1
imperative/python/megengine/random/distribution.py
imperative/python/megengine/random/distribution.py
+18
-24
imperative/python/megengine/random/rng.py
imperative/python/megengine/random/rng.py
+85
-4
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+14
-17
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+121
-0
imperative/src/impl/ops/rng.cpp
imperative/src/impl/ops/rng.cpp
+84
-84
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-28
imperative/src/include/megbrain/imperative/ops/rng.h
imperative/src/include/megbrain/imperative/ops/rng.h
+9
-78
imperative/src/test/rng.cpp
imperative/src/test/rng.cpp
+14
-6
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+20
-4
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
13e6ea34
...
@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False):
...
@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False):
def
_remove_axis
(
inp
:
Tensor
,
axis
)
->
Tensor
:
def
_remove_axis
(
inp
:
Tensor
,
axis
)
->
Tensor
:
def
get_axes
():
def
get_axes
():
if
axis
is
None
:
if
axis
is
None
:
return
[
i
for
i
,
s
in
enumerate
(
inp
.
shape
)
if
s
==
1
]
shp
=
inp
.
shape
return
[
i
for
i
,
s
in
enumerate
(
shp
)
if
s
==
1
]
try
:
try
:
return
[
int
(
axis
)]
return
[
int
(
axis
)]
except
(
TypeError
,
ValueError
):
except
(
TypeError
,
ValueError
):
...
...
imperative/python/megengine/distributed/group.py
浏览文件 @
13e6ea34
...
@@ -6,9 +6,11 @@
...
@@ -6,9 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
time
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
..device
import
set_default_device
,
what_is_xpu
from
..device
import
set_default_device
,
what_is_xpu
from
..random
import
seed
from
.server
import
Client
,
Server
from
.server
import
Client
,
Server
...
@@ -156,6 +158,7 @@ def init_process_group(
...
@@ -156,6 +158,7 @@ def init_process_group(
WORLD
.
reset
(
list
(
range
(
world_size
)))
WORLD
.
reset
(
list
(
range
(
world_size
)))
set_default_device
(
"{}{}"
.
format
(
device_type
,
device
))
set_default_device
(
"{}{}"
.
format
(
device_type
,
device
))
seed
(
int
(
time
.
time
())
+
rank
)
def
is_distributed
()
->
bool
:
def
is_distributed
()
->
bool
:
...
...
imperative/python/megengine/random/__init__.py
浏览文件 @
13e6ea34
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.distribution
import
normal
,
uniform
from
.distribution
import
normal
,
uniform
from
.rng
import
seed
from
.rng
import
RNG
,
seed
# pylint: disable=undefined-variable
# pylint: disable=undefined-variable
del
distribution
,
rng
# type: ignore[name-defined]
del
distribution
,
rng
# type: ignore[name-defined]
imperative/python/megengine/random/distribution.py
浏览文件 @
13e6ea34
...
@@ -9,11 +9,8 @@
...
@@ -9,11 +9,8 @@
from
typing
import
Iterable
,
Optional
from
typing
import
Iterable
,
Optional
from
..
import
Tensor
from
..
import
Tensor
from
..core._imperative_rt
import
invoke_op
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
..core._imperative_rt.core2
import
apply
from
.rng
import
_normal
,
_uniform
from
..core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
..core.tensor
import
utils
from
.rng
import
_random_seed_generator
__all__
=
[
"normal"
,
"uniform"
]
__all__
=
[
"normal"
,
"uniform"
]
...
@@ -48,14 +45,14 @@ def normal(
...
@@ -48,14 +45,14 @@ def normal(
[-1.4939808 -1.5824696 ]]
[-1.4939808 -1.5824696 ]]
"""
"""
if
size
is
None
:
return
_normal
(
size
=
(
1
,)
mean
=
mean
,
op
=
GaussianRNG
(
mean
,
std
)
std
=
std
,
_ref
=
Tensor
([],
dtype
=
"int32"
)
size
=
size
,
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
)
seed
=
_get_global_rng_seed
(),
shape
=
Tensor
(
shape
,
dtype
=
"int32"
)
device
=
None
,
(
output
,)
=
apply
(
op
,
shape
)
handle
=
0
,
return
output
)
def
uniform
(
def
uniform
(
...
@@ -88,14 +85,11 @@ def uniform(
...
@@ -88,14 +85,11 @@ def uniform(
[0.09365904 0.62957656]]
[0.09365904 0.62957656]]
"""
"""
assert
low
<
high
,
"Uniform is not defined when low >= high"
return
_uniform
(
low
=
low
,
if
size
is
None
:
high
=
high
,
size
=
(
1
,)
size
=
size
,
op
=
UniformRNG
()
seed
=
_get_global_rng_seed
(),
_ref
=
Tensor
([],
dtype
=
"int32"
)
device
=
None
,
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
)
handle
=
0
,
shape
=
Tensor
(
shape
,
dtype
=
"int32"
)
)
(
output
,)
=
apply
(
op
,
shape
)
return
low
+
(
high
-
low
)
*
output
imperative/python/megengine/random/rng.py
浏览文件 @
13e6ea34
...
@@ -7,17 +7,94 @@
...
@@ -7,17 +7,94 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
time
import
time
from
typing
import
Iterable
,
Optional
from
numpy.random
import
MT19937
from
numpy.random
import
MT19937
from
..
import
Tensor
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.ops
import
delete_rng_handle
as
_delete_rng_handle
from
..core._imperative_rt.ops
import
get_global_rng_seed
as
_get_global_rng_seed
from
..core._imperative_rt.ops
import
new_rng_handle
as
_new_rng_handle
from
..core._imperative_rt.ops
import
set_global_rng_seed
as
_set_global_rng_seed
from
..core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
..core.tensor
import
utils
from
..device
import
get_default_device
_rng
=
None
_rng
=
None
def
_random_seed_generator
():
def
_normal
(
if
_rng
is
None
:
mean
:
float
,
from
..distributed.group
import
get_rank
std
:
float
,
size
:
Optional
[
Iterable
[
int
]],
seed
:
int
,
device
:
str
,
handle
:
int
,
)
->
Tensor
:
if
size
is
None
:
size
=
(
1
,)
op
=
GaussianRNG
(
seed
=
seed
,
mean
=
mean
,
std
=
std
,
handle
=
handle
)
_ref
=
Tensor
([],
dtype
=
"int32"
,
device
=
device
)
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
,
device
=
device
)
(
output
,)
=
apply
(
op
,
shape
)
return
output
def
_uniform
(
low
:
float
,
high
:
float
,
size
:
Optional
[
Iterable
[
int
]],
seed
:
int
,
device
:
str
,
handle
:
int
,
)
->
Tensor
:
assert
low
<
high
,
"Uniform is not defined when low >= high"
if
size
is
None
:
size
=
(
1
,)
op
=
UniformRNG
(
seed
=
seed
,
handle
=
handle
)
_ref
=
Tensor
([],
dtype
=
"int32"
,
device
=
device
)
shape
=
utils
.
astensor1d
(
size
,
_ref
,
dtype
=
"int32"
,
device
=
device
)
(
output
,)
=
apply
(
op
,
shape
)
return
low
+
(
high
-
low
)
*
output
class
RNG
:
def
__init__
(
self
,
seed
=
0
,
device
=
None
):
self
.
seed
=
seed
self
.
device
=
device
if
device
else
get_default_device
()
self
.
handle
=
_new_rng_handle
(
self
.
device
,
self
.
seed
)
def
uniform
(
self
,
low
:
float
=
0
,
high
:
float
=
1
,
size
:
Optional
[
Iterable
[
int
]]
=
None
):
return
_uniform
(
low
=
low
,
high
=
high
,
size
=
size
,
seed
=
self
.
seed
,
device
=
self
.
device
,
handle
=
self
.
handle
,
)
seed
(
seed
=
int
(
time
.
time
())
+
get_rank
())
def
normal
(
self
,
mean
:
float
=
0
,
std
:
float
=
1
,
size
:
Optional
[
Iterable
[
int
]]
=
None
):
return
_normal
(
mean
=
mean
,
std
=
std
,
size
=
size
,
seed
=
self
.
seed
,
device
=
self
.
device
,
handle
=
self
.
handle
,
)
def
__del__
(
self
):
_delete_rng_handle
(
self
.
handle
)
def
_random_seed_generator
():
assert
_rng
while
True
:
while
True
:
yield
_rng
.
random_raw
()
yield
_rng
.
random_raw
()
...
@@ -25,3 +102,7 @@ def _random_seed_generator():
...
@@ -25,3 +102,7 @@ def _random_seed_generator():
def
seed
(
seed
:
int
):
def
seed
(
seed
:
int
):
global
_rng
# pylint: disable=global-statement
global
_rng
# pylint: disable=global-statement
_rng
=
MT19937
(
seed
=
seed
)
_rng
=
MT19937
(
seed
=
seed
)
_set_global_rng_seed
(
seed
)
seed
(
int
(
time
.
time
()))
imperative/python/src/ops.cpp
浏览文件 @
13e6ea34
...
@@ -10,7 +10,10 @@
...
@@ -10,7 +10,10 @@
*/
*/
#include "./ops.h"
#include "./ops.h"
#include "./helper.h"
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/opr_attr.h"
...
@@ -491,21 +494,15 @@ void init_ops(py::module m) {
...
@@ -491,21 +494,15 @@ void init_ops(py::module m) {
_init_py_op_base
(
m
);
_init_py_op_base
(
m
);
INIT_ALL_OP
(
m
)
INIT_ALL_OP
(
m
)
m
.
def
(
"new_rng_handle"
,
&
RNGMixin
::
new_handle
);
m
.
def
(
"new_rng_handle"
,
&
rng
::
new_handle
);
// FIXME: RNG op might execute after handle released due to async dispatch,
m
.
def
(
"delete_rng_handle"
,
[](
size_t
handle
){
// which would cause memory leak or use-after-free
// RNG op might execute after handle released due to async dispatch, so
m
.
def
(
"delete_rng_handle"
,
&
RNGMixin
::
delete_handle
);
// we need sync before delete a handle to avoid memory leak or use-after-free
m
.
def
(
"set_rng_seed"
,
&
set_rng_seed
);
python
::
interpreter_for_py
->
sync
();
mgb
::
CompNode
::
sync_all
();
py
::
class_
<
UniformRNG
,
std
::
shared_ptr
<
UniformRNG
>
,
OpDef
>
(
m
,
"UniformRNG"
)
py_task_q
.
wait_all_task_finish
();
.
def
(
py
::
init
<>
())
rng
::
delete_handle
(
handle
);
.
def
(
py
::
init
<
mgb
::
CompNode
>
())
},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
.
def
(
py
::
init
<
RNGMixin
::
Handle
>
());
m
.
def
(
"set_global_rng_seed"
,
&
rng
::
set_global_rng_seed
);
m
.
def
(
"get_global_rng_seed"
,
&
rng
::
get_global_rng_seed
);
py
::
class_
<
GaussianRNG
,
std
::
shared_ptr
<
GaussianRNG
>
,
OpDef
>
(
m
,
"GaussianRNG"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
mgb
::
CompNode
>
())
.
def
(
py
::
init
<
float
,
float
>
())
.
def
(
py
::
init
<
float
,
float
,
mgb
::
CompNode
>
())
.
def
(
py
::
init
<
float
,
float
,
RNGMixin
::
Handle
>
());
}
}
imperative/python/test/unit/test_rng.py
→
imperative/python/test/unit/
random/
test_rng.py
浏览文件 @
13e6ea34
...
@@ -8,14 +8,21 @@
...
@@ -8,14 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
numpy
as
np
import
megengine
from
megengine
import
tensor
from
megengine
import
tensor
from
megengine.core._imperative_rt
import
CompNode
from
megengine.core._imperative_rt
import
CompNode
from
megengine.core._imperative_rt.ops
import
delete_rng_handle
,
new_rng_handle
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core._imperative_rt.ops
import
(
delete_rng_handle
,
get_global_rng_seed
,
new_rng_handle
,
)
from
megengine.core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
megengine.core.ops.builtin
import
GaussianRNG
,
UniformRNG
from
megengine.core.tensor.core
import
apply
from
megengine.random
import
RNG
from
megengine.random.rng
import
_normal
,
_uniform
def
test_gaussian_
rng
():
def
test_gaussian_
op
():
shape
=
(
shape
=
(
8
,
8
,
9
,
9
,
...
@@ -23,23 +30,16 @@ def test_gaussian_rng():
...
@@ -23,23 +30,16 @@ def test_gaussian_rng():
12
,
12
,
)
)
shape
=
tensor
(
shape
,
dtype
=
"int32"
)
shape
=
tensor
(
shape
,
dtype
=
"int32"
)
op
=
GaussianRNG
(
1.0
,
3.0
)
op
=
GaussianRNG
(
seed
=
get_global_rng_seed
(),
mean
=
1.0
,
std
=
3.0
)
(
output
,)
=
apply
(
op
,
shape
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
1.0
)
<
1e-1
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
1.0
)
<
1e-1
assert
np
.
sqrt
(
output
.
numpy
().
var
())
-
3.0
<
1e-1
assert
np
.
sqrt
(
output
.
numpy
().
var
())
-
3.0
<
1e-1
assert
str
(
output
.
device
)
==
str
(
CompNode
(
"xpux"
))
assert
str
(
output
.
device
)
==
str
(
CompNode
(
"xpux"
))
cn
=
CompNode
(
"xpu1"
)
op
=
GaussianRNG
(
-
1.0
,
2.0
,
cn
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
(
-
1.0
))
<
1e-1
assert
np
.
sqrt
(
output
.
numpy
().
var
())
-
2.0
<
1e-1
assert
str
(
output
.
device
)
==
str
(
cn
)
cn
=
CompNode
(
"xpu2"
)
cn
=
CompNode
(
"xpu2"
)
seed
=
233333
seed
=
233333
h
=
new_rng_handle
(
cn
,
seed
)
h
=
new_rng_handle
(
cn
,
seed
)
op
=
GaussianRNG
(
3.0
,
1.0
,
h
)
op
=
GaussianRNG
(
seed
=
seed
,
mean
=
3.0
,
std
=
1.0
,
handle
=
h
)
(
output
,)
=
apply
(
op
,
shape
)
(
output
,)
=
apply
(
op
,
shape
)
delete_rng_handle
(
h
)
delete_rng_handle
(
h
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
3.0
)
<
1e-1
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
3.0
)
<
1e-1
...
@@ -47,7 +47,7 @@ def test_gaussian_rng():
...
@@ -47,7 +47,7 @@ def test_gaussian_rng():
assert
str
(
output
.
device
)
==
str
(
cn
)
assert
str
(
output
.
device
)
==
str
(
cn
)
def
test_uniform_
rng
():
def
test_uniform_
op
():
shape
=
(
shape
=
(
8
,
8
,
9
,
9
,
...
@@ -55,22 +55,67 @@ def test_uniform_rng():
...
@@ -55,22 +55,67 @@ def test_uniform_rng():
12
,
12
,
)
)
shape
=
tensor
(
shape
,
dtype
=
"int32"
)
shape
=
tensor
(
shape
,
dtype
=
"int32"
)
op
=
UniformRNG
()
op
=
UniformRNG
(
seed
=
get_global_rng_seed
()
)
(
output
,)
=
apply
(
op
,
shape
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
str
(
output
.
device
)
==
str
(
CompNode
(
"xpux"
))
assert
str
(
output
.
device
)
==
str
(
CompNode
(
"xpux"
))
cn
=
CompNode
(
"xpu1"
)
op
=
UniformRNG
(
cn
)
(
output
,)
=
apply
(
op
,
shape
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
str
(
output
.
device
)
==
str
(
cn
)
cn
=
CompNode
(
"xpu2"
)
cn
=
CompNode
(
"xpu2"
)
seed
=
233333
seed
=
233333
h
=
new_rng_handle
(
cn
,
seed
)
h
=
new_rng_handle
(
cn
,
seed
)
op
=
UniformRNG
(
h
)
op
=
UniformRNG
(
seed
=
seed
,
handle
=
h
)
(
output
,)
=
apply
(
op
,
shape
)
(
output
,)
=
apply
(
op
,
shape
)
delete_rng_handle
(
h
)
delete_rng_handle
(
h
)
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
np
.
fabs
(
output
.
numpy
().
mean
()
-
0.5
)
<
1e-1
assert
str
(
output
.
device
)
==
str
(
cn
)
assert
str
(
output
.
device
)
==
str
(
cn
)
def
test_UniformRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
m2
=
RNG
(
seed
=
111
,
device
=
"xpu1"
)
m3
=
RNG
(
seed
=
222
,
device
=
"xpu0"
)
out1
=
m1
.
uniform
(
size
=
(
100
,))
out1_
=
m1
.
uniform
(
size
=
(
100
,))
out2
=
m2
.
uniform
(
size
=
(
100
,))
out3
=
m3
.
uniform
(
size
=
(
100
,))
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out2
.
numpy
())
assert
out1
.
device
==
"xpu0"
and
out2
.
device
==
"xpu1"
assert
not
(
out1
.
numpy
()
==
out3
.
numpy
()).
all
()
assert
not
(
out1
.
numpy
()
==
out1_
.
numpy
()).
all
()
low
=
-
234
high
=
123
out
=
m1
.
uniform
(
low
=
low
,
high
=
high
,
size
=
(
20
,
30
,
40
))
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
20
,
30
,
40
)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
20
,
30
,
40
]))
assert
np
.
abs
(
out
.
mean
().
numpy
()
-
((
low
+
high
)
/
2
))
/
(
high
-
low
)
<
0.1
def
test_NormalRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
m2
=
RNG
(
seed
=
111
,
device
=
"xpu1"
)
m3
=
RNG
(
seed
=
222
,
device
=
"xpu0"
)
out1
=
m1
.
normal
(
size
=
(
100
,))
out1_
=
m1
.
uniform
(
size
=
(
100
,))
out2
=
m2
.
normal
(
size
=
(
100
,))
out3
=
m3
.
normal
(
size
=
(
100
,))
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out2
.
numpy
())
assert
out1
.
device
==
"xpu0"
and
out2
.
device
==
"xpu1"
assert
not
(
out1
.
numpy
()
==
out3
.
numpy
()).
all
()
assert
not
(
out1
.
numpy
()
==
out1_
.
numpy
()).
all
()
mean
=
-
1
std
=
2
out
=
m1
.
normal
(
mean
=
mean
,
std
=
std
,
size
=
(
20
,
30
,
40
))
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
20
,
30
,
40
)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
20
,
30
,
40
]))
assert
np
.
abs
(
out
.
mean
().
numpy
()
-
mean
)
/
std
<
0.1
assert
np
.
abs
(
np
.
std
(
out
.
numpy
())
-
std
)
<
0.1
imperative/src/impl/ops/rng.cpp
浏览文件 @
13e6ea34
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
* \file imperative/src/impl/ops/rng.cpp
* \file imperative/src/impl/ops/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
*
* Copyright (c) 2014-202
0
Megvii Inc. All rights reserved.
* Copyright (c) 2014-202
1
Megvii Inc. All rights reserved.
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
...
@@ -10,23 +10,23 @@
...
@@ -10,23 +10,23 @@
*/
*/
#include "megbrain/imperative/ops/rng.h"
#include "megbrain/imperative/ops/rng.h"
#include <bits/stdint-uintn.h>
#include "megbrain/comp_node_env.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/rand.h"
//#include "megbrain/common.h"
#include "../op_trait.h"
#include "../op_trait.h"
#include "../dnn_op_helper.h"
namespace
mgb
{
namespace
mgb
::
imperative
::
rng
{
namespace
imperative
{
namespace
{
namespace
{
template
<
typename
HandleFactory
,
typename
THandle
>
template
<
typename
HandleFactory
,
typename
THandle
>
class
DnnOpManagerT
:
public
CompNodeDepedentObject
,
public
NonCopyableObj
{
class
DnnOpManagerT
:
public
CompNodeDepedentObject
,
public
NonCopyableObj
{
public:
public:
using
DT
=
CompNode
::
DeviceType
;
using
Handle
=
THandle
;
using
Handle
=
THandle
;
using
OpTypeInfo
=
size_t
;
template
<
typename
...
Args
>
template
<
typename
...
Args
>
Handle
new_handle
(
Args
&&
...
args
)
{
Handle
new_handle
(
Args
&&
...
args
)
{
...
@@ -38,27 +38,26 @@ public:
...
@@ -38,27 +38,26 @@ public:
size_t
removed
=
0
;
size_t
removed
=
0
;
if
(
!
is_finalized
())
{
if
(
!
is_finalized
())
{
MGB_LOCK_GUARD
(
m_mtx
);
MGB_LOCK_GUARD
(
m_mtx
);
removed
=
m_handle2op
.
erase
(
handle
);
removed
=
m_handle2op
s
.
erase
(
handle
);
}
}
static_cast
<
HandleFactory
*>
(
this
)
->
do_delete_handle
(
handle
);
static_cast
<
HandleFactory
*>
(
this
)
->
do_delete_handle
(
handle
);
return
removed
;
return
removed
;
}
}
template
<
typename
DnnOp
>
template
<
typename
DnnOp
>
auto
get_dnn_op
(
Handle
handle
,
CompNode
cn
)
{
auto
get_dnn_op
(
Handle
handle
,
OpTypeInfo
tpinfo
,
CompNode
cn
)
{
mgb_assert
(
!
is_finalized
());
mgb_assert
(
!
is_finalized
());
DnnOpWithMutex
*
dnn_op_with_mtx
;
DnnOpWithMutex
*
dnn_op_with_mtx
;
{
{
MGB_LOCK_GUARD
(
m_mtx
);
MGB_LOCK_GUARD
(
m_mtx
);
dnn_op_with_mtx
=
&
m_handle2op
[
handle
];
dnn_op_with_mtx
=
&
m_handle2op
s
[
handle
][
tpinfo
];
}
}
auto
dnn_handle
=
auto
dnn_handle
=
MegDNNHandle
::
get
(
CompNodeEnv
::
from_comp_node
(
cn
)).
handle
();
MegDNNHandle
::
get
(
CompNodeEnv
::
from_comp_node
(
cn
)).
handle
();
DnnOp
*
dnn_op
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
dnn_op_with_mtx
->
mtx
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
dnn_op_with_mtx
->
mtx
);
bool
initialized
=
false
;
bool
initialized
=
false
;
if
((
dnn_op
=
dynamic_cast
<
DnnOp
*>
(
dnn_op_with_mtx
->
op
.
get
()))
!=
DnnOp
*
dnn_op
=
static_cast
<
DnnOp
*>
(
dnn_op_with_mtx
->
op
.
get
());
nullptr
)
{
if
(
dnn_op
!=
nullptr
)
{
mgb_assert
(
dnn_op
->
handle
()
==
dnn_handle
);
mgb_assert
(
dnn_op
->
handle
()
==
dnn_handle
);
initialized
=
true
;
initialized
=
true
;
}
else
{
}
else
{
...
@@ -77,35 +76,30 @@ private:
...
@@ -77,35 +76,30 @@ private:
struct
DnnOpWithMutex
{
struct
DnnOpWithMutex
{
std
::
mutex
mtx
;
std
::
mutex
mtx
;
std
::
unique_ptr
<
megdnn
::
OperatorBase
>
op
;
std
::
unique_ptr
<
megdnn
::
OperatorBase
>
op
;
DnnOpWithMutex
()
:
op
{
nullptr
}
{}
};
};
std
::
shared_ptr
<
void
>
on_comp_node_finalize
()
override
{
std
::
shared_ptr
<
void
>
on_comp_node_finalize
()
override
{
MGB_LOCK_GUARD
(
m_mtx
);
MGB_LOCK_GUARD
(
m_mtx
);
m_handle2op
.
clear
();
m_handle2op
s
.
clear
();
return
{};
return
{};
}
}
std
::
unordered_map
<
Handle
,
DnnOpWithMutex
>
m_handle2op
;
std
::
unordered_map
<
Handle
,
std
::
unordered_map
<
OpTypeInfo
,
DnnOpWithMutex
>
>
m_handle2ops
;
std
::
mutex
m_mtx
;
std
::
mutex
m_mtx
;
};
};
class
RNGDnnOpManager
final
class
RNGDnnOpManager
final
:
public
DnnOpManagerT
<
RNGDnnOpManager
,
RNGMixin
::
Handle
>
{
:
public
DnnOpManagerT
<
RNGDnnOpManager
,
Handle
>
{
public:
public:
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
MGB_LOCK_GUARD
(
sm_mtx
);
return
DnnOpManagerBase
::
new_handle
(
comp_node
,
seed
);
}
size_t
delete_handle
(
Handle
handle
)
{
size_t
delete_handle
(
Handle
handle
)
{
size_t
ret
=
0
;
MGB_LOCK_GUARD
(
sm_mtx
);
{
return
DnnOpManagerBase
::
delete_handle
(
handle
);
MGB_LOCK_GUARD
(
sm_mtx
);
auto
iter
=
sm_partial2full
.
find
(
handle
);
if
(
iter
!=
sm_partial2full
.
end
())
{
for
(
auto
&&
h
:
iter
->
second
)
{
ret
+=
DnnOpManagerBase
::
delete_handle
(
h
.
second
);
}
sm_partial2full
.
erase
(
iter
);
}
}
ret
+=
DnnOpManagerBase
::
delete_handle
(
handle
);
return
ret
;
}
}
Handle
do_new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
Handle
do_new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
...
@@ -118,32 +112,26 @@ public:
...
@@ -118,32 +112,26 @@ public:
}
}
static
uint64_t
get_seed
(
Handle
handle
)
{
static
uint64_t
get_seed
(
Handle
handle
)
{
if
(
!
handle
)
{
return
glob_default_seed
;
}
return
reinterpret_cast
<
HandleData
*>
(
handle
)
->
seed
;
return
reinterpret_cast
<
HandleData
*>
(
handle
)
->
seed
;
}
}
static
CompNode
get_comp_node
(
Handle
handle
)
{
static
CompNode
get_comp_node
(
Handle
handle
)
{
mgb_assert
(
handle
,
"invalid handle"
);
return
reinterpret_cast
<
HandleData
*>
(
handle
)
->
comp_node
;
return
reinterpret_cast
<
HandleData
*>
(
handle
)
->
comp_node
;
}
}
static
Handle
get_full_handle
(
Handle
handle
,
CompNode
comp_node
)
{
if
(
get_comp_node
(
handle
).
valid
())
{
return
handle
;
}
MGB_LOCK_GUARD
(
sm_mtx
);
auto
&&
full
=
sm_partial2full
[
handle
][
comp_node
];
if
(
!
full
)
{
full
=
inst
().
new_handle
(
comp_node
,
get_seed
(
handle
));
}
return
full
;
}
static
Handle
get_default_handle
(
CompNode
comp_node
)
{
static
Handle
get_default_handle
(
CompNode
comp_node
)
{
static
Handle
glob_partial_handle
=
mgb_assert
(
comp_node
.
valid
());
inst
().
new_handle
(
CompNode
{},
glob_default_seed
);
MGB_LOCK_GUARD
(
sm_mtx
);
if
(
!
comp_node
.
valid
())
{
auto
&&
glob_handle
=
glob_default_handles
[
comp_node
];
return
glob_partial_handle
;
if
(
!
glob_handle
)
{
glob_handle
=
inst
().
do_new_handle
(
comp_node
,
glob_default_seed
);
}
else
if
(
get_seed
(
glob_handle
)
!=
glob_default_seed
)
{
inst
().
DnnOpManagerBase
::
delete_handle
(
glob_handle
);
glob_handle
=
inst
().
do_new_handle
(
comp_node
,
glob_default_seed
);
}
}
return
g
et_full_handle
(
glob_partial_handle
,
comp_node
)
;
return
g
lob_handle
;
}
}
static
RNGDnnOpManager
&
inst
()
{
static
RNGDnnOpManager
&
inst
()
{
...
@@ -152,9 +140,15 @@ public:
...
@@ -152,9 +140,15 @@ public:
}
}
static
void
set_glob_default_seed
(
uint64_t
seed
)
{
static
void
set_glob_default_seed
(
uint64_t
seed
)
{
MGB_LOCK_GUARD
(
sm_mtx
);
glob_default_seed
=
seed
;
glob_default_seed
=
seed
;
}
}
static
uint64_t
get_glob_default_seed
()
{
MGB_LOCK_GUARD
(
sm_mtx
);
return
glob_default_seed
;
}
private:
private:
struct
HandleData
{
struct
HandleData
{
CompNode
comp_node
;
CompNode
comp_node
;
...
@@ -165,16 +159,13 @@ private:
...
@@ -165,16 +159,13 @@ private:
MemPool
<
HandleData
>
m_handle_pool
;
MemPool
<
HandleData
>
m_handle_pool
;
static
std
::
mutex
sm_mtx
;
static
std
::
mutex
sm_mtx
;
static
std
::
unordered_map
<
Handle
,
CompNode
::
UnorderedMap
<
Handle
>>
static
CompNode
::
UnorderedMap
<
Handle
>
glob_default_handles
;
sm_partial2full
;
static
uint64_t
glob_default_seed
;
static
uint64_t
glob_default_seed
;
};
};
uint64_t
RNGDnnOpManager
::
glob_default_seed
=
0
;
uint64_t
RNGDnnOpManager
::
glob_default_seed
=
0
;
std
::
mutex
RNGDnnOpManager
::
sm_mtx
;
std
::
mutex
RNGDnnOpManager
::
sm_mtx
;
std
::
unordered_map
<
RNGDnnOpManager
::
Handle
,
CompNode
::
UnorderedMap
<
Handle
>
RNGDnnOpManager
::
glob_default_handles
;
CompNode
::
UnorderedMap
<
RNGDnnOpManager
::
Handle
>>
RNGDnnOpManager
::
sm_partial2full
;
template
<
typename
Op
>
template
<
typename
Op
>
struct
OpMeth
;
struct
OpMeth
;
...
@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> {
...
@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> {
using
Param
=
DnnOp
::
Param
;
using
Param
=
DnnOp
::
Param
;
using
OpNode
=
mgb
::
opr
::
UniformRNG
;
using
OpNode
=
mgb
::
opr
::
UniformRNG
;
static
Param
make_param
(
const
UniformRNG
&
rng
)
{
static
Param
make_param
(
const
UniformRNG
&
rng
)
{
return
{
RNGDnnOpManager
::
get_seed
(
rng
.
handle
())};
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
rng
.
handle
);
mgb_assert
(
handle_seed
==
rng
.
seed
,
"inconsistent rng seed: rng op: %lu handle: %lu"
,
handle_seed
,
rng
.
seed
);
return
{
handle_seed
};
}
}
};
};
...
@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> {
...
@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> {
using
Param
=
DnnOp
::
Param
;
using
Param
=
DnnOp
::
Param
;
using
OpNode
=
mgb
::
opr
::
GaussianRNG
;
using
OpNode
=
mgb
::
opr
::
GaussianRNG
;
static
Param
make_param
(
const
GaussianRNG
&
rng
)
{
static
Param
make_param
(
const
GaussianRNG
&
rng
)
{
return
{
RNGDnnOpManager
::
get_seed
(
rng
.
handle
()),
rng
.
mean
,
rng
.
std
};
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
rng
.
handle
);
mgb_assert
(
handle_seed
==
rng
.
seed
,
"inconsistent rng seed: rng op: %lu handle: %lu"
,
handle_seed
,
rng
.
seed
);
return
{
handle_seed
,
rng
.
mean
,
rng
.
std
};
}
}
};
};
...
@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
...
@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto
dest
=
outputs
[
0
];
auto
dest
=
outputs
[
0
];
auto
cn
=
dest
->
comp_node
();
auto
cn
=
dest
->
comp_node
();
auto
handle
=
RNGDnnOpManager
::
get_full_handle
(
rng
.
handle
(),
cn
);
auto
handle
=
rng
.
handle
;
{
if
(
!
handle
)
{
auto
handle_cn
=
RNGDnnOpManager
::
get_comp_node
(
handle
);
handle
=
RNGDnnOpManager
::
get_default_handle
(
cn
);
mgb_assert
(
cn
==
handle_cn
,
"inconsistent comp_node: handle: %s, output: %s"
,
cn
.
to_string
().
c_str
(),
handle_cn
.
to_string
().
c_str
());
}
}
// retrieve dnn_op from glob cache
// retrieve dnn_op from glob cache
auto
dnn_op_thread_safe
=
RNGDnnOpManager
::
inst
()
auto
dnn_op_thread_safe
=
RNGDnnOpManager
::
inst
()
.
get_dnn_op
<
typename
OpMeth
<
Op
>::
DnnOp
>
(
handle
,
cn
);
.
get_dnn_op
<
typename
OpMeth
<
Op
>::
DnnOp
>
(
handle
,
reinterpret_cast
<
size_t
>
(
op
.
dyn_typeinfo
()),
cn
);
auto
initialized
=
std
::
get
<
0
>
(
dnn_op_thread_safe
);
auto
initialized
=
std
::
get
<
0
>
(
dnn_op_thread_safe
);
auto
dnn_op
=
std
::
get
<
1
>
(
dnn_op_thread_safe
);
auto
dnn_op
=
std
::
get
<
1
>
(
dnn_op_thread_safe
);
if
(
initialized
)
{
if
(
initialized
)
{
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
handle
);
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
handle
);
mgb_assert
(
dnn_op
->
param
().
seed
==
handle_seed
,
mgb_assert
(
dnn_op
->
param
().
seed
==
handle_seed
,
"inconsistent rng seed: handle: %
zu, dnn_op: %z
u"
,
"inconsistent rng seed: handle: %
lu, dnn_op: %l
u"
,
handle_seed
,
dnn_op
->
param
().
seed
);
handle_seed
,
dnn_op
->
param
().
seed
);
}
}
dnn_op
->
param
()
=
OpMeth
<
Op
>::
make_param
(
rng
);
dnn_op
->
param
()
=
OpMeth
<
Op
>::
make_param
(
rng
);
...
@@ -239,9 +237,12 @@ template <typename Op>
...
@@ -239,9 +237,12 @@ template <typename Op>
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs
(
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs
(
const
OpDef
&
op
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
const
OpDef
&
op
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
LogicalTensorDesc
dest
;
LogicalTensorDesc
dest
;
dest
.
comp_node
=
op
.
cast_final_safe
<
Op
>
().
comp_node
();
auto
handle
=
op
.
cast_final_safe
<
Op
>
().
handle
;
if
(
!
dest
.
comp_node
.
valid
())
if
(
handle
)
{
dest
.
comp_node
=
RNGDnnOpManager
::
get_comp_node
(
handle
);
}
else
{
dest
.
comp_node
=
inputs
[
0
]
->
comp_node
();
dest
.
comp_node
=
inputs
[
0
]
->
comp_node
();
}
auto
hv
=
inputs
[
0
]
->
get_value
().
proxy_to_default_cpu
();
auto
hv
=
inputs
[
0
]
->
get_value
().
proxy_to_default_cpu
();
TensorShape
tshape
;
TensorShape
tshape
;
...
@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}
}
template
<
typename
Op
>
template
<
typename
Op
>
cg
::
OperatorNodeBase
*
apply_on_var_node
(
SymbolVar
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
size_t
nr_inp
=
inputs
.
size
();
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
1
,
"UniformRNG expects 1 inputs; got %lu actually"
,
nr_inp
);
auto
&&
rng
=
def
.
cast_final_safe
<
Op
>
();
auto
&&
rng
=
def
.
cast_final_safe
<
Op
>
();
mgb_assert
(
nr_inp
==
1
,
"%s expects 1 inputs; got %lu actually"
,
rng
.
dyn_typeinfo
()
->
name
,
nr_inp
);
auto
param
=
OpMeth
<
Op
>::
make_param
(
rng
);
auto
param
=
OpMeth
<
Op
>::
make_param
(
rng
);
return
OpMeth
<
Op
>::
OpNode
::
make
(
OperatorNodeConfig
config
;
inputs
[
0
],
param
,
{
rng
.
comp_node
()}).
node
()
->
owner_opr
();
if
(
rng
.
handle
)
{
config
=
{
rng
.
make_name
(),
RNGDnnOpManager
::
get_comp_node
(
rng
.
handle
)};
}
else
{
config
=
{
rng
.
make_name
()};
}
return
OpMeth
<
Op
>::
OpNode
::
make
(
inputs
[
0
],
param
,
config
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
...
@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
// anonymous namespace
}
// anonymous namespace
RNGMixin
::
RNGMixin
(
CompNode
cn
)
:
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
m_handle
(
RNGDnnOpManager
::
get_default_handle
(
cn
))
{}
uint64_t
RNGMixin
::
seed
()
const
{
return
RNGDnnOpManager
::
get_seed
(
m_handle
);
}
CompNode
RNGMixin
::
comp_node
()
const
{
return
RNGDnnOpManager
::
get_comp_node
(
m_handle
);
}
RNGMixin
::
Handle
RNGMixin
::
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
return
RNGDnnOpManager
::
inst
().
new_handle
(
comp_node
,
seed
);
return
RNGDnnOpManager
::
inst
().
new_handle
(
comp_node
,
seed
);
}
}
size_t
RNGMixin
::
delete_handle
(
Handle
handle
)
{
size_t
delete_handle
(
Handle
handle
)
{
return
RNGDnnOpManager
::
inst
().
delete_handle
(
handle
);
return
RNGDnnOpManager
::
inst
().
delete_handle
(
handle
);
}
}
void
set_rng_seed
(
uint64_t
seed
)
{
void
set_
global_
rng_seed
(
uint64_t
seed
)
{
RNGDnnOpManager
::
set_glob_default_seed
(
seed
);
RNGDnnOpManager
::
set_glob_default_seed
(
seed
);
}
}
uint64_t
get_global_rng_seed
()
{
return
RNGDnnOpManager
::
get_glob_default_seed
();
}
#define REG_RNG_OP(NAME)\
#define REG_RNG_OP(NAME)\
namespace { \
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
...
@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
...
@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.fallback(); \
.fallback(); \
} \
} \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME);
REG_RNG_OP
(
UniformRNG
)
REG_RNG_OP
(
UniformRNG
)
REG_RNG_OP
(
GaussianRNG
)
REG_RNG_OP
(
GaussianRNG
)
}
// namespace imperative
}
// namespace mgb::imperative::rng
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/impl/ops/specializations.cpp
浏览文件 @
13e6ea34
...
@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual)
...
@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual)
.
fallback
();
.
fallback
();
}}
// assert_equal
}}
// assert_equal
namespace
{
namespace
uniform_rng
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
UniformRNG
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
UniformRNG
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
UniformRNG
,
UniformRNG
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// uniform_rng
namespace
{
namespace
gaussian_rng
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
GaussianRNG
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
GaussianRNG
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
GaussianRNG
,
GaussianRNG
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// gaussian_rng
namespace
{
namespace
roi_align
{
namespace
{
namespace
roi_align
{
VarNodeArray
apply_on_var_node
(
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
OpDef
&
def
,
...
...
imperative/src/include/megbrain/imperative/ops/rng.h
浏览文件 @
13e6ea34
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
* \file imperative/src/include/megbrain/imperative/ops/rng.h
* \file imperative/src/include/megbrain/imperative/ops/rng.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
*
* Copyright (c) 2014-202
0
Megvii Inc. All rights reserved.
* Copyright (c) 2014-202
1
Megvii Inc. All rights reserved.
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
...
@@ -12,84 +12,15 @@
...
@@ -12,84 +12,15 @@
#pragma once
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
::
imperative
{
namespace
mgb
::
imperative
::
rng
{
class
RNGMixin
{
using
Handle
=
size_t
;
public:
using
Handle
=
size_t
;
static
Handle
new_handle
(
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
);
CompNode
comp_node
=
{},
uint64_t
seed
=
0
);
size_t
delete_handle
(
Handle
handle
);
void
set_global_rng_seed
(
uint64_t
seed
);
uint64_t
get_global_rng_seed
();
static
size_t
delete_handle
(
Handle
handle
);
}
// namespace mgb::imperative::rng
Handle
handle
()
const
{
return
m_handle
;
}
uint64_t
seed
()
const
;
CompNode
comp_node
()
const
;
protected:
RNGMixin
(
Handle
handle
)
:
m_handle
(
handle
)
{}
RNGMixin
(
CompNode
comp_node
);
private:
Handle
m_handle
;
};
class
GaussianRNG
:
public
OpDefImplBase
<
GaussianRNG
>
,
public
RNGMixin
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
float
mean
=
1.0
f
,
std
=
0.0
;
GaussianRNG
(
CompNode
comp_node_
)
:
RNGMixin
(
comp_node_
)
{}
GaussianRNG
(
float
mean_
=
1.0
,
float
std_
=
0.0
,
CompNode
comp_node_
=
{})
:
GaussianRNG
(
comp_node_
)
{
mean
=
mean_
;
std
=
std_
;
}
GaussianRNG
(
float
mean_
,
float
std_
,
Handle
handle
)
:
RNGMixin
(
handle
),
mean
(
mean_
),
std
(
std_
)
{}
size_t
hash
()
const
override
{
XXHash
xxhash
{};
auto
append
=
[
&
xxhash
](
auto
field
){
auto
hash_val
=
HashTrait
<
decltype
(
field
)
>::
eval
(
field
);
xxhash
.
update
(
reinterpret_cast
<
void
*>
(
&
hash_val
),
sizeof
(
hash_val
));
};
append
(
dyn_typeinfo
());
append
(
seed
());
append
(
mean
);
append
(
std
);
return
xxhash
.
digest
();
}
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
GaussianRNG
&>
(
rhs_
);
return
rhs
.
seed
()
==
seed
()
&&
rhs
.
mean
==
mean
&&
rhs
.
std
==
std
;
}
};
class
UniformRNG
:
public
OpDefImplBase
<
UniformRNG
>
,
public
RNGMixin
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
UniformRNG
(
CompNode
comp_node_
=
{})
:
RNGMixin
(
comp_node_
)
{}
UniformRNG
(
Handle
handle
)
:
RNGMixin
(
handle
)
{}
size_t
hash
()
const
override
{
return
hash_pair_combine
(
mgb
::
hash
(
seed
()),
reinterpret_cast
<
std
::
uintptr_t
>
(
dyn_typeinfo
()));
}
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
UniformRNG
&>
(
rhs_
);
return
rhs
.
dyn_typeinfo
()
==
dyn_typeinfo
()
&&
rhs
.
seed
()
==
seed
();
}
};
void
set_rng_seed
(
uint64_t
seed
);
}
// namespace mgb::imperative
imperative/src/test/rng.cpp
浏览文件 @
13e6ea34
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
using
namespace
mgb
;
using
namespace
mgb
;
using
namespace
imperative
;
using
namespace
imperative
;
using
namespace
imperative
::
rng
;
template
<
typename
Op
,
typename
...
Args
>
template
<
typename
Op
,
typename
...
Args
>
void
check_rng_basic
(
Args
&&
...
args
)
{
void
check_rng_basic
(
Args
&&
...
args
)
{
...
@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) {
...
@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) {
{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
},
{
2333
}})
{
2333
}})
for
(
auto
&&
cn
:
{
for
(
auto
&&
cn
:
{
CompNode
::
load
(
"
c
pu0"
),
CompNode
::
load
(
"
x
pu0"
),
CompNode
::
load
(
"xpu
0
"
)})
CompNode
::
load
(
"xpu
1
"
)})
{
{
auto
op
=
Op
::
make
(
std
::
forward
<
Args
>
(
args
)...,
cn
);
Handle
h
=
new_handle
(
cn
,
123
);
auto
op
=
Op
::
make
(
std
::
forward
<
Args
>
(
args
)...,
h
);
DeviceTensorND
tshape_dev
;
DeviceTensorND
tshape_dev
;
cg
::
copy_shape_to_tensor_value
(
tshape_dev
,
tshape
);
cg
::
copy_shape_to_tensor_value
(
tshape_dev
,
tshape
);
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
op
,
{
Tensor
::
make
(
tshape_dev
)});
SmallVector
<
TensorPtr
>
inputs
=
{
Tensor
::
make
(
tshape_dev
)};
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
op
,
inputs
);
ASSERT_TRUE
(
outputs
[
0
]
->
layout
().
eq_shape
(
tshape
));
ASSERT_TRUE
(
outputs
[
0
]
->
layout
().
eq_shape
(
tshape
));
ASSERT_TRUE
(
cn
==
outputs
[
0
]
->
comp_node
());
ASSERT_TRUE
(
cn
==
outputs
[
0
]
->
comp_node
());
// sync before delete handle
for
(
auto
&&
p
:
outputs
)
{
p
->
get_value
();
}
delete_handle
(
h
);
}
}
}
}
TEST
(
TestImperative
,
UniformRNGBasic
)
{
TEST
(
TestImperative
,
UniformRNGBasic
)
{
check_rng_basic
<
UniformRNG
>
();
check_rng_basic
<
UniformRNG
>
(
123
);
}
}
TEST
(
TestImperative
,
GaussianRNGBasic
)
{
TEST
(
TestImperative
,
GaussianRNGBasic
)
{
check_rng_basic
<
GaussianRNG
>
(
2.
f
,
3.
f
);
check_rng_basic
<
GaussianRNG
>
(
123
,
2.
f
,
3.
f
);
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/core/include/megbrain/ir/ops.td
浏览文件 @
13e6ea34
...
@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>;
...
@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}];
let extraArguments = (ins
let cmpFunction = [{return true;}];
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle));
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
}
}
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
let hashFunction = [{
return mgb::hash_pair_combine(
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std)));
mgb::hash_pair_combine(
mgb::hash($_self.handle),
mgb::hash_pair_combine(
mgb::hash($_self.mean),
mgb::hash($_self.std))
)
);
}];
}];
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}];
let cmpFunction = [{return $0.
handle == $1.handle && $0.
mean == $1.mean && $0.std == $1.std;}];
}
}
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录