Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3977b7aa
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3977b7aa
编写于
7月 21, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/shuffle): add shuffle opr
GitOrigin-RevId: 80490a6f848d524111bee097f11b591b5a3956c8
上级
17371e79
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
939 addition
and
115 deletion
+939
-115
dnn/include/megdnn/oprs/utils.h
dnn/include/megdnn/oprs/utils.h
+38
-1
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+3
-0
dnn/src/common/handle_impl.h
dnn/src/common/handle_impl.h
+2
-0
dnn/src/common/opr_trait.h
dnn/src/common/opr_trait.h
+2
-0
dnn/src/common/rng.cpp
dnn/src/common/rng.cpp
+41
-0
dnn/src/cuda/rng/kernel.cu
dnn/src/cuda/rng/kernel.cu
+44
-0
dnn/src/cuda/rng/kernel.cuh
dnn/src/cuda/rng/kernel.cuh
+11
-0
dnn/src/cuda/rng/opr_impl.cpp
dnn/src/cuda/rng/opr_impl.cpp
+73
-2
dnn/src/cuda/rng/opr_impl.h
dnn/src/cuda/rng/opr_impl.h
+41
-1
dnn/src/naive/rng/opr_impl.cpp
dnn/src/naive/rng/opr_impl.cpp
+77
-5
dnn/src/naive/rng/opr_impl.h
dnn/src/naive/rng/opr_impl.h
+29
-0
dnn/test/cuda/rng.cpp
dnn/test/cuda/rng.cpp
+78
-0
dnn/test/naive/rng.cpp
dnn/test/naive/rng.cpp
+81
-7
imperative/python/megengine/random/__init__.py
imperative/python/megengine/random/__init__.py
+2
-1
imperative/python/megengine/random/rng.py
imperative/python/megengine/random/rng.py
+49
-0
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+40
-0
imperative/src/impl/ops/rng.cpp
imperative/src/impl/ops/rng.cpp
+134
-70
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+13
-0
src/opr/impl/rand.cpp
src/opr/impl/rand.cpp
+81
-1
src/opr/impl/rand.sereg.h
src/opr/impl/rand.sereg.h
+23
-4
src/opr/include/megbrain/opr/rand.h
src/opr/include/megbrain/opr/rand.h
+44
-23
src/opr/test/rand.cpp
src/opr/test/rand.cpp
+32
-0
src/serialization/impl/schema.fbs
src/serialization/impl/schema.fbs
+1
-0
未找到文件。
dnn/include/megdnn/oprs/utils.h
浏览文件 @
3977b7aa
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
#include "megdnn/internal/opr_header_prologue.h"
#include "megdnn/internal/opr_header_prologue.h"
...
@@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase {
...
@@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase {
void
check_exec
(
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
);
void
check_exec
(
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
);
};
};
class
ShuffleRNGForward
:
public
OperatorBase
{
DEF_OPR_IMPL
(
ShuffleRNGForward
,
OperatorBase
,
1
,
2
);
DEF_OPR_PARAM
(
ShuffleRNG
);
public:
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
indices
,
_megdnn_workspace
workspace
)
=
0
;
void
deduce_layout
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
,
TensorLayout
&
indices
);
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
indices
)
=
0
;
protected:
void
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
indices
,
size_t
workspace_in_bytes
);
};
using
ShuffleRNG
=
ShuffleRNGForward
;
class
ShuffleRNGBackward
:
public
OperatorBase
{
DEF_OPR_IMPL
(
ShuffleRNGBackward
,
OperatorBase
,
2
,
1
);
DEF_OPR_PARAM
(
ShuffleRNG
);
public:
virtual
void
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
indices
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
=
0
;
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
diff
,
const
TensorLayout
&
indices
,
const
TensorLayout
&
grad
)
=
0
;
protected:
void
check_exec
(
const
TensorLayout
&
diff
,
const
TensorLayout
&
indices
,
const
TensorLayout
&
grad
,
size_t
workspace_in_bytes
);
};
/*!
/*!
* \brief sleep for specific time on the computing device; useful for testing
* \brief sleep for specific time on the computing device; useful for testing
* async problems
* async problems
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
3977b7aa
...
@@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
...
@@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'Float32 are supported.'
),
'Float32 are supported.'
),
'DTypeEnum::Int32'
))
'DTypeEnum::Int32'
))
(
pdef
(
'ShuffleRNG'
).
add_fields
(
'uint64'
,
'seed'
,
0
))
(
pdef
(
'Flip'
).
(
pdef
(
'Flip'
).
add_fields
(
'bool'
,
'vertical'
,
'false'
,
'horizontal'
,
'false'
))
add_fields
(
'bool'
,
'vertical'
,
'false'
,
'horizontal'
,
'false'
))
...
...
dnn/src/common/handle_impl.h
浏览文件 @
3977b7aa
...
@@ -165,6 +165,8 @@ private:
...
@@ -165,6 +165,8 @@ private:
cb(BetaRNG) \
cb(BetaRNG) \
cb(PoissonRNG) \
cb(PoissonRNG) \
cb(PermutationRNG) \
cb(PermutationRNG) \
cb(ShuffleRNGForward) \
cb(ShuffleRNGBackward) \
cb(SeparableConvForward) \
cb(SeparableConvForward) \
cb(SeparableFilterForward) \
cb(SeparableFilterForward) \
cb(BNForward) \
cb(BNForward) \
...
...
dnn/src/common/opr_trait.h
浏览文件 @
3977b7aa
...
@@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true);
...
@@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true);
DEF
(
BetaRNG
,
3
,
true
,
true
);
DEF
(
BetaRNG
,
3
,
true
,
true
);
DEF
(
PoissonRNG
,
2
,
true
,
true
);
DEF
(
PoissonRNG
,
2
,
true
,
true
);
DEF
(
PermutationRNG
,
1
,
true
,
true
);
DEF
(
PermutationRNG
,
1
,
true
,
true
);
DEF
(
ShuffleRNGForward
,
3
,
true
,
true
);
DEF
(
ShuffleRNGBackward
,
3
,
true
,
false
);
DEF
(
ChecksumForward
,
1
,
true
,
false
);
DEF
(
ChecksumForward
,
1
,
true
,
false
);
DEF
(
CheckHasInf
,
2
,
true
,
true
);
DEF
(
CheckHasInf
,
2
,
true
,
true
);
DEF
(
LSQForward
,
5
,
true
,
true
);
DEF
(
LSQForward
,
5
,
true
,
true
);
...
...
dnn/src/common/rng.cpp
浏览文件 @
3977b7aa
...
@@ -15,6 +15,47 @@
...
@@ -15,6 +15,47 @@
namespace
megdnn
{
namespace
megdnn
{
void
ShuffleRNGForward
::
deduce_layout
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
,
TensorLayout
&
indices
)
{
dst
=
src
;
indices
=
TensorLayout
(
TensorShape
({
src
.
shape
[
0
]}),
dtype
::
Int32
());
}
void
ShuffleRNGForward
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
indices
,
size_t
workspace_in_bytes
)
{
TensorLayout
dst_expected
,
indices_expected
;
megdnn_assert_contiguous
(
src
);
deduce_layout
(
src
,
dst_expected
,
indices_expected
);
megdnn_assert_eq_layout
(
dst_expected
,
dst
);
megdnn_assert_eq_layout
(
indices_expected
,
indices
);
megdnn_assert_contiguous
(
indices
);
megdnn_assert
(
src
.
dtype
==
dst
.
dtype
);
megdnn_assert
(
indices
.
dtype
==
dtype
::
Int32
());
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
src
,
dst
,
indices
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
}
void
ShuffleRNGBackward
::
check_exec
(
const
TensorLayout
&
diff
,
const
TensorLayout
&
indices
,
const
TensorLayout
&
grad
,
size_t
workspace_in_bytes
)
{
megdnn_assert
(
diff
.
shape
[
0
]
==
indices
.
shape
[
0
]
&&
diff
.
dtype
==
grad
.
dtype
&&
indices
.
dtype
==
dtype
::
Int32
{}
&&
diff
.
is_contiguous
()
&&
indices
.
is_contiguous
()
&&
grad
.
is_contiguous
(),
"invalid layouts: diff=%s indices=%s grad=%s"
,
diff
.
to_string
().
c_str
(),
indices
.
to_string
().
c_str
(),
grad
.
to_string
().
c_str
());
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
diff
,
indices
,
grad
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
}
void
PermutationRNG
::
check_exec
(
void
PermutationRNG
::
check_exec
(
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
megdnn_assert
((
dst
.
dtype
==
dtype
::
Float32
()
||
megdnn_assert
((
dst
.
dtype
==
dtype
::
Float32
()
||
...
...
dnn/src/cuda/rng/kernel.cu
浏览文件 @
3977b7aa
...
@@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs,
...
@@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs,
}
}
}
}
template
<
typename
T
>
__global__
void
shuffle_fwd_kernel
(
uint32_t
step
,
uint32_t
src_size
,
const
T
*
sptr
,
T
*
dptr
,
const
int
*
iptr
)
{
uint32_t
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
src_size
)
{
uint32_t
r
=
idx
/
step
;
dptr
[
idx
]
=
sptr
[
iptr
[
r
]
*
step
+
idx
%
step
];
}
}
template
<
typename
T
>
void
shuffle_forward
(
T
*
sptr
,
T
*
dptr
,
dt_int32
*
iptr
,
size_t
len
,
size_t
step
,
cudaStream_t
stream
)
{
uint32_t
src_size
=
len
*
step
;
shuffle_fwd_kernel
<<<
DIVUP
(
src_size
,
512
),
512
,
0
,
stream
>>>
(
step
,
src_size
,
sptr
,
dptr
,
iptr
);
after_kernel_launch
();
}
template
<
typename
T
>
__global__
void
shuffle_bwd_kernel
(
uint32_t
step
,
uint32_t
src_size
,
T
*
sptr
,
T
*
dptr
,
const
int
*
iptr
)
{
uint32_t
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
src_size
)
{
uint32_t
r
=
idx
/
step
;
sptr
[
iptr
[
r
]
*
step
+
idx
%
step
]
=
dptr
[
idx
];
}
}
template
<
typename
T
>
void
shuffle_backward
(
T
*
dptr
,
dt_int32
*
iptr
,
T
*
sptr
,
size_t
len
,
size_t
step
,
cudaStream_t
stream
)
{
uint32_t
src_size
=
len
*
step
;
shuffle_bwd_kernel
<<<
DIVUP
(
src_size
,
512
),
512
,
0
,
stream
>>>
(
step
,
src_size
,
sptr
,
dptr
,
iptr
);
after_kernel_launch
();
}
uint32_t
get_permutation_bits
(
size_t
N
)
{
uint32_t
get_permutation_bits
(
size_t
N
)
{
double
uniq_rand_num_prob
=
0.9
;
double
uniq_rand_num_prob
=
0.9
;
double
thresh
=
std
::
log
(
uniq_rand_num_prob
)
*
12
;
double
thresh
=
std
::
log
(
uniq_rand_num_prob
)
*
12
;
...
@@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16)
...
@@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16)
INST_PERMUTATION
(
dt_float32
)
INST_PERMUTATION
(
dt_float32
)
#undef INST_PERMUTATION
#undef INST_PERMUTATION
#define INST_SHUFFLE(T) \
template void shuffle_forward<T>(T* sptr, T* dptr, dt_int32* iptr,\
size_t len, size_t step, cudaStream_t stream);\
template void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,\
size_t len, size_t step, cudaStream_t stream);
ARGSORT_FOREACH_CTYPE
(
INST_SHUFFLE
)
#undef INST_SHUFFLE
}
// namespace random
}
// namespace random
#define INST(_dtype) \
#define INST(_dtype) \
...
...
dnn/src/cuda/rng/kernel.cuh
浏览文件 @
3977b7aa
...
@@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed
...
@@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed
size_t
get_permutation_workspace_in_bytes
(
size_t
N
);
size_t
get_permutation_workspace_in_bytes
(
size_t
N
);
template
<
typename
T
>
void
shuffle_forward
(
T
*
sptr
,
T
*
dptr
,
dt_int32
*
iptr
,
size_t
len
,
size_t
step
,
cudaStream_t
stream
);
template
<
typename
T
>
void
shuffle_backward
(
T
*
dptr
,
dt_int32
*
iptr
,
T
*
sptr
,
size_t
len
,
size_t
step
,
cudaStream_t
stream
);
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
}
// namespace random
}
// namespace random
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
dnn/src/cuda/rng/opr_impl.cpp
浏览文件 @
3977b7aa
...
@@ -9,11 +9,11 @@
...
@@ -9,11 +9,11 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "./opr_impl.h"
#include "./kernel.cuh"
#include "src/common/utils.h"
#include "src/common/utils.h"
#include "src/cuda/handle.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
#include "src/cuda/utils.h"
#include "./opr_impl.h"
#include "./kernel.cuh"
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cuda
;
...
@@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){
...
@@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){
return
random
::
get_permutation_workspace_in_bytes
(
size
);
return
random
::
get_permutation_workspace_in_bytes
(
size
);
}
}
ShuffleRNGForwardImpl
::
ShuffleRNGForwardImpl
(
Handle
*
handle
)
:
ShuffleRNGForward
(
handle
),
m_seed
(
0
),
m_offset
(
0
),
m_stream
(
cuda_stream
(
handle
))
{}
void
ShuffleRNGForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
indices
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
indices
.
layout
,
workspace
.
size
);
ensure_seed
(
m_param
.
seed
);
auto
wk
=
workspace
.
ptr
<
void
>
();
const
auto
len
=
indices
.
layout
[
0
];
random
::
permutation_forward
<
dt_int32
>
(
indices
.
ptr
<
dt_int32
>
(),
wk
,
len
,
m_seed
,
m_offset
,
m_stream
);
size_t
step
=
0
;
for
(
size_t
i
=
1
;
i
<
src
.
layout
.
ndim
;
++
i
)
{
step
+=
src
.
layout
[
i
];
}
if
(
step
<=
0
)
step
=
1
;
switch
(
src
.
layout
.
dtype
.
enumv
())
{
#define cb(DType) \
case DTypeTrait<DType>::enumv: \
random::shuffle_forward<DTypeTrait<DType>::ctype>( \
src.ptr<DTypeTrait<DType>::ctype>(), \
dst.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \
len, step, m_stream); \
break;
ARGSORT_FOREACH_CTYPE
(
cb
)
#undef cb
default
:
megdnn_throw
(
"bad dtype"
);
}
m_offset
+=
8
;
}
size_t
ShuffleRNGForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
indices
)
{
size_t
size
=
indices
.
total_nr_elems
();
return
random
::
get_permutation_workspace_in_bytes
(
size
);
}
ShuffleRNGBackwardImpl
::
ShuffleRNGBackwardImpl
(
Handle
*
handle
)
:
ShuffleRNGBackward
(
handle
),
m_stream
(
cuda_stream
(
handle
))
{}
void
ShuffleRNGBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
indices
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
const
auto
len
=
indices
.
layout
[
0
];
auto
step
=
0
;
for
(
size_t
i
=
1
;
i
<
diff
.
layout
.
ndim
;
++
i
)
{
step
+=
diff
.
layout
[
i
];
}
if
(
step
<=
0
)
step
=
1
;
switch
(
diff
.
layout
.
dtype
.
enumv
())
{
#define cb(DType) \
case DTypeTrait<DType>::enumv: \
random::shuffle_backward<DTypeTrait<DType>::ctype>( \
diff.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \
grad.ptr<DTypeTrait<DType>::ctype>(), len, step, m_stream); \
break;
ARGSORT_FOREACH_CTYPE
(
cb
)
#undef cb
default:
megdnn_throw
(
"bad dtype"
);
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/cuda/rng/opr_impl.h
浏览文件 @
3977b7aa
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -152,6 +153,45 @@ public:
...
@@ -152,6 +153,45 @@ public:
}
}
};
};
class
ShuffleRNGForwardImpl
:
public
ShuffleRNGForward
{
uint64_t
m_seed
,
m_offset
;
cudaStream_t
m_stream
;
public:
using
ShuffleRNGForward
::
ShuffleRNGForward
;
ShuffleRNGForwardImpl
(
Handle
*
handle
);
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
indices
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
indices
)
override
;
void
seed
(
uint64_t
seed
)
{
m_seed
=
seed
;
}
void
ensure_seed
(
uint64_t
seed
)
{
if
(
m_seed
!=
seed
)
{
this
->
seed
(
seed
);
}
}
};
class
ShuffleRNGBackwardImpl
:
public
ShuffleRNGBackward
{
cudaStream_t
m_stream
;
public:
using
ShuffleRNGBackward
::
ShuffleRNGBackward
;
ShuffleRNGBackwardImpl
(
Handle
*
handle
);
void
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
indices
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
};
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/naive/rng/opr_impl.cpp
浏览文件 @
3977b7aa
...
@@ -6,12 +6,13 @@
...
@@ -6,12 +6,13 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "./opr_impl.h"
#include "./opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cmath>
#include <cmath>
...
@@ -229,7 +230,29 @@ namespace {
...
@@ -229,7 +230,29 @@ namespace {
}
}
}
}
}
// anonymous namespace
template
<
typename
T
>
void
shuffle_fwd
(
const
T
*
__restrict
sptr
,
T
*
__restrict
dptr
,
const
dt_int32
*
iptr
,
const
size_t
len
,
const
size_t
step
)
MEGDNN_NOEXCEPT
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
dptr
[
i
*
step
+
j
]
=
sptr
[
iptr
[
i
]
*
step
+
j
];
}
}
}
template
<
typename
T
>
void
shuffle_bwd
(
T
*
__restrict
sptr
,
const
T
*
__restrict
dptr
,
const
dt_int32
*
iptr
,
const
size_t
len
,
const
size_t
step
)
MEGDNN_NOEXCEPT
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
sptr
[
iptr
[
i
]
*
step
+
j
]
=
dptr
[
i
*
step
+
j
];
}
}
}
}
// anonymous namespace
uint64_t
Splitmix64
::
operator
()
()
{
uint64_t
Splitmix64
::
operator
()
()
{
uint64_t
z
=
(
m_s
+=
UINT64_C
(
0x9E3779B97F4A7C15
));
uint64_t
z
=
(
m_s
+=
UINT64_C
(
0x9E3779B97F4A7C15
));
...
@@ -394,5 +417,54 @@ void PermutationRNGImpl::exec(
...
@@ -394,5 +417,54 @@ void PermutationRNGImpl::exec(
}
}
}
}
// vim: syntax=cpp.doxygen
void
ShuffleRNGForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
indices
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
indices
.
layout
,
workspace
.
size
);
const
auto
len
=
indices
.
layout
[
0
];
auto
iptr
=
indices
.
ptr
<
dt_int32
>
();
auto
prng
=
&
m_rng
.
ensure_seed
(
m_param
.
seed
);
fill_permutation
<
dt_int32
>
(
prng
,
iptr
,
len
);
auto
step
=
0
;
for
(
size_t
i
=
1
;
i
<
src
.
layout
.
ndim
;
++
i
)
{
step
+=
src
.
layout
[
i
];
}
if
(
step
<=
0
)
step
=
1
;
#define cb(DType) \
if (src.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
shuffle_fwd<T>(src.ptr<T>(), dst.ptr<T>(), iptr, len, step)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
#undef cb
}
void
ShuffleRNGBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
indices
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
check_exec
(
diff
.
layout
,
indices
.
layout
,
grad
.
layout
,
workspace
.
size
);
const
auto
len
=
indices
.
layout
[
0
];
auto
iptr
=
indices
.
ptr
<
dt_int32
>
();
auto
step
=
0
;
for
(
size_t
i
=
1
;
i
<
diff
.
layout
.
ndim
;
++
i
)
{
step
+=
diff
.
layout
[
i
];
}
if
(
step
<=
0
)
step
=
1
;
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(shuffle_bwd<T>( \
grad.ptr<T>(), diff.ptr<T>(), iptr, len, step)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
#undef cb
}
// vim: syntax=cpp.doxygen
dnn/src/naive/rng/opr_impl.h
浏览文件 @
3977b7aa
...
@@ -128,6 +128,35 @@ public:
...
@@ -128,6 +128,35 @@ public:
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
)
override
{
return
0
;
}
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
)
override
{
return
0
;
}
};
};
class
ShuffleRNGForwardImpl
:
public
ShuffleRNGForward
{
Xoroshiro128plus
m_rng
;
public:
using
ShuffleRNGForward
::
ShuffleRNGForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
indices
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
};
class
ShuffleRNGBackwardImpl
:
public
ShuffleRNGBackward
{
Xoroshiro128plus
m_rng
;
public:
using
ShuffleRNGBackward
::
ShuffleRNGBackward
;
void
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
indices
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
};
}
// namespace naive
}
// namespace naive
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/test/cuda/rng.cpp
浏览文件 @
3977b7aa
...
@@ -143,6 +143,60 @@ void run_permutation(Handle* handle) {
...
@@ -143,6 +143,60 @@ void run_permutation(Handle* handle) {
}
}
}
}
template
<
typename
T
>
void
run_shuffle
(
Handle
*
handle
,
bool
bwd_flag
)
{
using
ctype
=
typename
DTypeTrait
<
T
>::
ctype
;
auto
run
=
[
&
](
TensorShape
shape
)
{
auto
opr
=
handle
->
create_operator
<
ShuffleRNGForward
>
();
TensorLayout
srclay
{
shape
,
T
()};
TensorLayout
dstlay
{
shape
,
T
()};
TensorLayout
indexlay
{
TensorShape
{
shape
[
0
]},
dtype
::
Int32
()};
Tensor
<
dt_byte
>
workspace
(
handle
,
{
TensorShape
{
opr
->
get_workspace_in_bytes
(
srclay
,
dstlay
,
indexlay
)},
dtype
::
Byte
()});
SyncedTensor
<
ctype
>
src
(
handle
,
srclay
);
SyncedTensor
<
ctype
>
dst
(
handle
,
dstlay
);
SyncedTensor
<
DTypeTrait
<
dt_int32
>::
ctype
>
index
(
handle
,
indexlay
);
auto
sptr
=
src
.
ptr_mutable_host
();
size_t
size
=
src
.
layout
().
total_nr_elems
();
for
(
size_t
j
=
0
;
j
<
size
;
++
j
)
{
sptr
[
j
]
=
j
;
}
opr
->
exec
(
src
.
tensornd_dev
(),
dst
.
tensornd_dev
(),
index
.
tensornd_dev
(),
{
workspace
.
ptr
(),
workspace
.
layout
().
total_nr_elems
()});
auto
dptr
=
dst
.
ptr_mutable_host
();
auto
iptr
=
index
.
ptr_mutable_host
();
size_t
len
=
index
.
layout
().
total_nr_elems
();
size_t
step
=
size
/
len
;
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
ASSERT_EQ
(
dptr
[
i
*
step
+
j
],
sptr
[
iptr
[
i
]
*
step
+
j
]);
}
}
if
(
bwd_flag
)
{
for
(
size_t
j
=
0
;
j
<
size
;
++
j
)
{
sptr
[
j
]
=
0
;
}
auto
oprbwd
=
handle
->
create_operator
<
ShuffleRNGBackward
>
();
oprbwd
->
exec
(
dst
.
tensornd_dev
(),
index
.
tensornd_dev
(),
src
.
tensornd_dev
(),
{
workspace
.
ptr
(),
workspace
.
layout
().
total_nr_elems
()});
auto
sptr_bwd
=
src
.
ptr_mutable_host
();
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
ASSERT_EQ
(
dptr
[
i
*
step
+
j
],
sptr_bwd
[
iptr
[
i
]
*
step
+
j
]);
}
}
}
};
run
({
10
});
run
({
6
,
3
});
}
}
// anonymous namespace
}
// anonymous namespace
TEST_F
(
CUDA
,
UNIFORM_RNG_F32
)
{
TEST_F
(
CUDA
,
UNIFORM_RNG_F32
)
{
...
@@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) {
...
@@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) {
run_permutation
<
dtype
::
Int16
>
(
handle_cuda
());
run_permutation
<
dtype
::
Int16
>
(
handle_cuda
());
}
}
TEST_F
(
CUDA
,
SHUFFLE_RNG_F32
)
{
run_shuffle
<
dtype
::
Float32
>
(
handle_cuda
(),
false
);
}
TEST_F
(
CUDA
,
SHUFFLE_RNG_INT32
)
{
run_shuffle
<
dtype
::
Int32
>
(
handle_cuda
(),
false
);
}
TEST_F
(
CUDA
,
SHUFFLE_RNG_F16
)
{
run_shuffle
<
dtype
::
Float16
>
(
handle_cuda
(),
false
);
}
TEST_F
(
CUDA
,
SHUFFLE_RNG_BWD_F32
)
{
run_shuffle
<
dtype
::
Float32
>
(
handle_cuda
(),
true
);
}
TEST_F
(
CUDA
,
SHUFFLE_RNG_BWD_INT32
)
{
run_shuffle
<
dtype
::
Int32
>
(
handle_cuda
(),
true
);
}
TEST_F
(
CUDA
,
SHUFFLE_RNG_BWD_F16
)
{
run_shuffle
<
dtype
::
Float16
>
(
handle_cuda
(),
true
);
}
}
// namespace test
}
// namespace test
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/test/naive/rng.cpp
浏览文件 @
3977b7aa
...
@@ -6,12 +6,13 @@
...
@@ -6,12 +6,13 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "megdnn.h"
#include "test/naive/fixture.h"
#include "test/naive/rng.h"
#include "test/naive/rng.h"
#include "megdnn.h"
#include "test/common/tensor.h"
#include "test/common/tensor.h"
#include "test/naive/fixture.h"
namespace
megdnn
{
namespace
megdnn
{
...
@@ -181,7 +182,59 @@ namespace {
...
@@ -181,7 +182,59 @@ namespace {
ASSERT_LE
(
std
::
abs
(
res
[
i
]
-
ctype
(
i
)),
1e-8
);
ASSERT_LE
(
std
::
abs
(
res
[
i
]
-
ctype
(
i
)),
1e-8
);
}
}
}
}
}
template
<
typename
T
>
void
run_shuffle
(
Handle
*
handle
,
bool
bwd_flag
)
{
using
ctype
=
typename
DTypeTrait
<
T
>::
ctype
;
auto
run
=
[
&
](
TensorShape
shape
)
{
auto
opr
=
handle
->
create_operator
<
ShuffleRNGForward
>
();
TensorLayout
srclay
{
shape
,
T
()};
TensorLayout
dstlay
{
shape
,
T
()};
TensorLayout
indexlay
{
TensorShape
{
shape
[
0
]},
dtype
::
Int32
()};
Tensor
<
dt_byte
>
workspace
(
handle
,
{
TensorShape
{
opr
->
get_workspace_in_bytes
(
srclay
,
dstlay
,
indexlay
)},
dtype
::
Byte
()});
Tensor
<
ctype
>
src
(
handle
,
srclay
);
Tensor
<
ctype
>
dst
(
handle
,
dstlay
);
Tensor
<
DTypeTrait
<
dt_int32
>::
ctype
>
index
(
handle
,
indexlay
);
auto
sptr
=
src
.
ptr
();
size_t
size
=
src
.
layout
().
total_nr_elems
();
for
(
size_t
j
=
0
;
j
<
size
;
++
j
)
{
sptr
[
j
]
=
j
;
}
opr
->
exec
(
src
.
tensornd
(),
dst
.
tensornd
(),
index
.
tensornd
(),
{
workspace
.
ptr
(),
workspace
.
layout
().
total_nr_elems
()});
auto
dptr
=
dst
.
ptr
();
auto
iptr
=
index
.
ptr
();
size_t
len
=
index
.
layout
().
total_nr_elems
();
size_t
step
=
size
/
len
;
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
ASSERT_EQ
(
dptr
[
i
*
step
+
j
],
sptr
[
iptr
[
i
]
*
step
+
j
]);
}
}
if
(
bwd_flag
)
{
for
(
size_t
j
=
0
;
j
<
size
;
++
j
)
{
sptr
[
j
]
=
0
;
}
auto
oprbwd
=
handle
->
create_operator
<
ShuffleRNGBackward
>
();
oprbwd
->
exec
(
dst
.
tensornd
(),
index
.
tensornd
(),
src
.
tensornd
(),
{
workspace
.
ptr
(),
workspace
.
layout
().
total_nr_elems
()});
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
ASSERT_EQ
(
dptr
[
i
*
step
+
j
],
sptr
[
iptr
[
i
]
*
step
+
j
]);
}
}
}
};
run
({
10
});
run
({
6
,
3
});
}
}
// namespace
TEST_F
(
NAIVE
,
UNIFORM_RNG_F32
)
{
TEST_F
(
NAIVE
,
UNIFORM_RNG_F32
)
{
run_uniform
<
dtype
::
Float32
>
(
handle
());
run_uniform
<
dtype
::
Float32
>
(
handle
());
...
@@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) {
...
@@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) {
run_permutation
<
dtype
::
Int16
>
(
handle
());
run_permutation
<
dtype
::
Int16
>
(
handle
());
}
}
}
// namespace test
TEST_F
(
NAIVE
,
SHUFFLE_RNG_FWD_F32
)
{
}
// namespace megdnn
run_shuffle
<
dtype
::
Float32
>
(
handle
(),
false
);
}
// vim: syntax=cpp.doxygen
TEST_F
(
NAIVE
,
SHUFFLE_RNG_FWD_INT32
)
{
run_shuffle
<
dtype
::
Int32
>
(
handle
(),
false
);
}
TEST_F
(
NAIVE
,
SHUFFLE_RNG_FWD_F16
)
{
run_shuffle
<
dtype
::
Float16
>
(
handle
(),
false
);
}
TEST_F
(
NAIVE
,
SHUFFLE_RNG_BWD_F32
)
{
run_shuffle
<
dtype
::
Float32
>
(
handle
(),
true
);
}
TEST_F
(
NAIVE
,
SHUFFLE_RNG_BWD_INT32
)
{
run_shuffle
<
dtype
::
Int32
>
(
handle
(),
true
);
}
TEST_F
(
NAIVE
,
SHUFFLE_RNG_BWD_F16
)
{
run_shuffle
<
dtype
::
Float16
>
(
handle
(),
true
);
}
}
// namespace test
}
// namespace megdnn
// vim: syntax=cpp.doxygen
imperative/python/megengine/random/__init__.py
浏览文件 @
3977b7aa
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.rng
import
RNG
,
beta
,
gamma
,
normal
,
permutation
,
poisson
,
seed
,
uniform
from
.rng
import
RNG
,
beta
,
gamma
,
normal
,
permutation
,
poisson
,
seed
,
shuffle
,
uniform
__all__
=
[
__all__
=
[
"RNG"
,
"RNG"
,
...
@@ -17,6 +17,7 @@ __all__ = [
...
@@ -17,6 +17,7 @@ __all__ = [
"poisson"
,
"poisson"
,
"seed"
,
"seed"
,
"uniform"
,
"uniform"
,
"shuffle"
,
]
]
# pylint: disable=undefined-variable
# pylint: disable=undefined-variable
del
rng
# type: ignore[name-defined]
del
rng
# type: ignore[name-defined]
imperative/python/megengine/random/rng.py
浏览文件 @
3977b7aa
...
@@ -27,6 +27,7 @@ from ..core.ops.builtin import (
...
@@ -27,6 +27,7 @@ from ..core.ops.builtin import (
GaussianRNG
,
GaussianRNG
,
PermutationRNG
,
PermutationRNG
,
PoissonRNG
,
PoissonRNG
,
ShuffleRNG
,
UniformRNG
,
UniformRNG
,
)
)
from
..core.tensor
import
utils
from
..core.tensor
import
utils
...
@@ -41,6 +42,7 @@ __all__ = [
...
@@ -41,6 +42,7 @@ __all__ = [
"beta"
,
"beta"
,
"poisson"
,
"poisson"
,
"permutation"
,
"permutation"
,
"shuffle"
,
]
]
_rng
=
None
_rng
=
None
...
@@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten
...
@@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten
return
output
return
output
def
_shuffle
(
inp
:
Tensor
,
seed
:
int
,
handle
:
int
)
->
Tensor
:
assert
inp
.
size
>
0
,
"size needs to be greater than 0"
op
=
ShuffleRNG
(
seed
=
seed
,
handle
=
handle
)
output
,
_
=
apply
(
op
,
inp
)
inp
.
_reset
(
output
)
class
RNG
:
class
RNG
:
r
""":class:`RNG` exposes a number of methods for generating random numbers.
r
""":class:`RNG` exposes a number of methods for generating random numbers.
...
@@ -581,6 +590,45 @@ class RNG:
...
@@ -581,6 +590,45 @@ class RNG:
n
=
n
,
seed
=
_seed
,
device
=
self
.
_device
,
handle
=
self
.
_handle
,
dtype
=
dtype
n
=
n
,
seed
=
_seed
,
device
=
self
.
_device
,
handle
=
self
.
_handle
,
dtype
=
dtype
)
)
def
shuffle
(
self
,
inp
:
Tensor
):
r
"""Modify a sequence in-place by shuffling its contents.
This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
The order of sub-Tensors is changed but their contents remains the same.
Args:
inp: input tensor.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.random as rand
x = mge.tensor(np.arange(10))
rand.shuffle(x)
print(x.numpy())
y = mge.tensor(np.arange(18)).reshape(6,3)
rand.shuffle(y)
print(y.numpy())
Outputs:
.. testoutput::
:options: +SKIP
[7 9 3 0 8 2 4 5 6 1]
[[12. 13. 14.]
[ 3. 4. 5.]
[15. 16. 17.]
[ 0. 1. 2.]
[ 9. 10. 11.]
[ 6. 7. 8.]]
"""
_seed
=
self
.
_seed
()
if
callable
(
self
.
_seed
)
else
self
.
_seed
_shuffle
(
inp
=
inp
,
seed
=
_seed
,
handle
=
self
.
_handle
)
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
_handle
!=
0
:
if
self
.
_handle
!=
0
:
_delete_rng_handle
(
self
.
_handle
)
_delete_rng_handle
(
self
.
_handle
)
...
@@ -599,6 +647,7 @@ gamma = _default_handle.gamma
...
@@ -599,6 +647,7 @@ gamma = _default_handle.gamma
beta
=
_default_handle
.
beta
beta
=
_default_handle
.
beta
poisson
=
_default_handle
.
poisson
poisson
=
_default_handle
.
poisson
permutation
=
_default_handle
.
permutation
permutation
=
_default_handle
.
permutation
shuffle
=
_default_handle
.
shuffle
def
_random_seed_generator
():
def
_random_seed_generator
():
...
...
imperative/python/test/unit/random/test_rng.py
浏览文件 @
3977b7aa
...
@@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import (
...
@@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import (
get_global_rng_seed
,
get_global_rng_seed
,
new_rng_handle
,
new_rng_handle
,
)
)
from
megengine.core.autodiff.grad
import
Grad
from
megengine.core.ops.builtin
import
(
from
megengine.core.ops.builtin
import
(
BetaRNG
,
BetaRNG
,
GammaRNG
,
GammaRNG
,
...
@@ -397,6 +398,45 @@ def test_PermutationRNG():
...
@@ -397,6 +398,45 @@ def test_PermutationRNG():
assert
sum_result
(
out
,
np
.
sort
)
==
1000
assert
sum_result
(
out
,
np
.
sort
)
==
1000
@
pytest
.
mark
.
skipif
(
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_ShuffleRNG
():
g
=
[]
def
cb
(
grad
):
g
.
append
(
grad
)
n
,
m
=
6
,
3
arr
=
np
.
arange
(
n
*
m
)
out0
=
Tensor
(
arr
,
dtype
=
"float32"
)
grad
=
Grad
().
wrt
(
out0
,
callback
=
cb
)
random
.
shuffle
(
out0
)
grad
(
out0
,
F
.
ones_like
(
out0
))
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
m2
=
RNG
(
seed
=
111
,
device
=
"xpu1"
)
m3
=
RNG
(
seed
=
222
,
device
=
"xpu0"
)
out1
=
Tensor
(
arr
,
dtype
=
"float32"
,
device
=
"xpu0"
)
out2
=
Tensor
(
arr
,
dtype
=
"float32"
,
device
=
"xpu1"
)
out3
=
Tensor
(
arr
,
dtype
=
"float32"
,
device
=
"xpu0"
)
m1
.
shuffle
(
out1
)
m2
.
shuffle
(
out2
)
m3
.
shuffle
(
out3
)
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out2
.
numpy
())
assert
out1
.
device
==
"xpu0"
and
out2
.
device
==
"xpu1"
assert
not
(
out1
.
numpy
()
==
out3
.
numpy
()).
all
()
out
=
Tensor
(
arr
,
dtype
=
"float32"
).
reshape
(
n
,
m
)
m1
.
shuffle
(
out
)
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
n
,
m
)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
n
,
m
]))
def
test_seed
():
def
test_seed
():
set_global_seed
(
10
)
set_global_seed
(
10
)
out1
=
uniform
(
size
=
[
10
,
10
])
out1
=
uniform
(
size
=
[
10
,
10
])
...
...
imperative/src/impl/ops/rng.cpp
浏览文件 @
3977b7aa
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "megbrain/imperative/ops/rng.h"
#include "megbrain/imperative/ops/rng.h"
...
@@ -14,8 +15,8 @@
...
@@ -14,8 +15,8 @@
#include "megbrain/graph/helper.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/rand.h"
#include "../op_trait.h"
#include "../dnn_op_helper.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace
mgb
::
imperative
::
rng
{
namespace
mgb
::
imperative
::
rng
{
...
@@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> {
...
@@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> {
}
}
};
};
template
<
>
struct
OpMeth
<
ShuffleRNG
>
{
using
DnnOp
=
megdnn
::
ShuffleRNG
;
using
Param
=
DnnOp
::
Param
;
using
OpNode
=
mgb
::
opr
::
ShuffleRNG
;
static
Param
make_param
(
const
ShuffleRNG
&
rng
)
{
auto
handle_seed
=
RNGDnnOpManager
::
get_seed
(
rng
.
handle
);
mgb_assert
(
handle_seed
==
rng
.
seed
,
"inconsistent rng seed: rng op: %lu handle: %lu"
,
handle_seed
,
rng
.
seed
);
return
{
handle_seed
};
}
};
template
<
bool
>
template
<
bool
>
struct
_InferLayout
;
struct
_InferLayout
;
template
<
int
nr_in
>
template
<
int
nr_in
>
struct
_RNGOprMaker
;
struct
_RNGOprMaker
;
template
<
int
nr_in
>
template
<
int
nr_in
,
int
nr_out
>
struct
_RNGOprInvoker
;
struct
_RNGOprInvoker
;
template
<
>
template
<
>
...
@@ -316,50 +331,63 @@ struct _InferLayout<false>
...
@@ -316,50 +331,63 @@ struct _InferLayout<false>
return
inp
.
layout
;
return
inp
.
layout
;
}
}
};
};
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \
template<> \
struct _RNGOprInvoker<DNN_NR_INPUTS> { \
template<typename Opr> \
static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \
auto workspace = Blob::make(dest->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
dest->dev_tensor().as_megdnn(), dnn_wk); \
} \
};
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \
template <> \
struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \
template <typename Opr> \
static void exec(Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \
const SmallVector<TensorPtr>& outputs) { \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes( \
_FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
_FOR_EACH_OUT(->dev_tensor().as_megdnn()), \
dnn_wk); \
} \
};
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \
template<> \
template <> \
struct _RNGOprMaker<MGB_NR_INPUTS> { \
struct _RNGOprMaker<MGB_NR_INPUTS> { \
template<typename Op> \
template <typename Op> \
static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \
static auto make(const VarNodeArray& inputs, const Op& rng) { \
auto param = OpMeth<Op>::make_param(rng); \
auto param = OpMeth<Op>::make_param(rng); \
OperatorNodeConfig config; \
OperatorNodeConfig config; \
if (rng.handle) { \
if (rng.handle) { \
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
config = {rng.make_name(), \
} else { \
RNGDnnOpManager::get_comp_node(rng.handle)}; \
config = {rng.make_name()}; \
} else { \
} \
config = {rng.make_name()}; \
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
} \
} \
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
};
} \
};
#define _FOR_EACH_IN(subfix)
#define _FOR_EACH_IN(subfix)
_INST_RNG_INVOLKER
(
0
)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER
(
0
,
1
)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix,
#define _FOR_EACH_IN(subfix) inputs[0] subfix,
_INST_RNG_INVOLKER
(
1
)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER
(
1
,
1
)
#undef _FOR_EACH_OUT
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER
(
1
,
2
)
_INST_RNG_MAKER
(
1
)
_INST_RNG_MAKER
(
1
)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
_INST_RNG_INVOLKER
(
2
)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER
(
2
,
1
)
_INST_RNG_MAKER
(
2
)
_INST_RNG_MAKER
(
2
)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _FOR_EACH_IN
#undef _INST_RNG_INVOLKER
#undef _INST_RNG_INVOLKER
...
@@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
...
@@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
handle_seed
,
dnn_op
->
param
().
seed
);
handle_seed
,
dnn_op
->
param
().
seed
);
}
}
dnn_op
->
param
()
=
OpMeth
<
Op
>::
make_param
(
rng
);
dnn_op
->
param
()
=
OpMeth
<
Op
>::
make_param
(
rng
);
_RNGOprInvoker
<
OpMeth
<
Op
>::
DnnOp
::
NR_INPUTS
>::
exec
(
dnn_op
,
inputs
,
dest
);
_RNGOprInvoker
<
OpMeth
<
Op
>::
DnnOp
::
NR_INPUTS
,
OpMeth
<
Op
>::
DnnOp
::
NR_OUTPUTS
>::
exec
(
dnn_op
,
inputs
,
outputs
);
}
}
template
<
typename
Op
>
template
<
typename
Op
>
...
@@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
...
@@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
return
{
dest
};
return
{
dest
};
}
}
template
<
typename
Op
>
template
<
>
std
::
tuple
<
SmallVector
<
MemoryDesc
>
,
SmallVector
<
MemoryDesc
>>
infer_output_mem_desc
(
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs
<
ShuffleRNG
>
(
const
OpDef
&
def
,
const
OpDef
&
op
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
const
SmallVector
<
TensorPtr
>&
inputs_tensors
,
SmallVector
<
LogicalTensorDesc
>
dests
(
2
);
const
SmallVector
<
MemoryDesc
>&
inputs_mems
)
{
auto
&&
rng
=
op
.
cast_final_safe
<
ShuffleRNG
>
();
auto
&&
dest
=
infer_output_attrs
<
Op
>
(
def
,
inputs_tensors
);
auto
handle
=
rng
.
handle
;
SmallVector
<
MemoryDesc
>
outputs
=
{{
dest
[
0
].
layout
,
0
,
dest
[
0
].
comp_node
,
StorageIdentifier
::
make
(
1
)}};
if
(
handle
)
{
dests
[
0
].
comp_node
=
RNGDnnOpManager
::
get_comp_node
(
handle
);
return
{
outputs
,
{}};
dests
[
1
].
comp_node
=
RNGDnnOpManager
::
get_comp_node
(
handle
);
}
else
{
dests
[
0
].
comp_node
=
inputs
[
0
]
->
comp_node
();
dests
[
1
].
comp_node
=
inputs
[
0
]
->
comp_node
();
}
dests
[
0
].
layout
=
TensorLayout
(
inputs
[
0
]
->
layout
());
dests
[
0
].
layout
.
dtype
=
inputs
[
0
]
->
layout
().
dtype
;
dests
[
1
].
layout
=
TensorLayout
(
TensorShape
({
inputs
[
0
]
->
layout
()[
0
]}),
dtype
::
Int32
());
return
dests
;
}
}
template
<
typename
Op
>
std
::
tuple
<
SmallVector
<
MemoryDesc
>
,
SmallVector
<
MemoryDesc
>>
infer_output_mem_desc
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs_tensors
,
const
SmallVector
<
MemoryDesc
>&
inputs_mems
)
{
auto
&&
dests
=
infer_output_attrs
<
Op
>
(
def
,
inputs_tensors
);
SmallVector
<
MemoryDesc
>
outputs
;
for
(
size_t
i
=
0
;
i
<
dests
.
size
();
++
i
)
{
outputs
.
push_back
({
dests
[
i
].
layout
,
0
,
dests
[
i
].
comp_node
,
StorageIdentifier
::
make
(
i
+
1
)});
}
return
{
outputs
,
{}};
}
template
<
typename
Op
>
template
<
typename
Op
>
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
TensorPtr
>
outputs
;
SmallVector
<
TensorPtr
>
outputs
;
SmallVector
<
LogicalTensorDesc
>
desc
;
SmallVector
<
LogicalTensorDesc
>
desc
=
infer_output_attrs
<
Op
>
(
def
,
inputs
);
desc
=
infer_output_attrs
<
Op
>
(
def
,
inputs
);
for
(
auto
&&
i
:
desc
)
{
for
(
auto
&&
i
:
desc
)
{
outputs
.
push_back
(
Tensor
::
make
(
i
.
layout
,
i
.
comp_node
));
outputs
.
push_back
(
Tensor
::
make
(
i
.
layout
,
i
.
comp_node
));
}
}
...
@@ -454,10 +505,8 @@ void execute(
...
@@ -454,10 +505,8 @@ void execute(
exec
<
Op
>
(
def
,
inputs
,
outputs
,
{});
exec
<
Op
>
(
def
,
inputs
,
outputs
,
{});
}
}
template
<
typename
Op
>
template
<
typename
Op
,
typename
Output
>
SymbolVar
apply_on_var_node
(
Output
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
size_t
nr_inp
=
inputs
.
size
();
size_t
nr_inp
=
inputs
.
size
();
constexpr
size_t
dnn_nr_inp
=
OpMeth
<
Op
>::
DnnOp
::
NR_INPUTS
;
constexpr
size_t
dnn_nr_inp
=
OpMeth
<
Op
>::
DnnOp
::
NR_INPUTS
;
auto
&&
rng
=
def
.
cast_final_safe
<
Op
>
();
auto
&&
rng
=
def
.
cast_final_safe
<
Op
>
();
...
@@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
...
@@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{{
dest
},
true
};
return
{{
dest
},
true
};
}
}
}
// anonymous namespace
template
<
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
<
ShuffleRNG
>
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
SmallVector
<
LogicalTensorDesc
>
dests
(
2
);
dests
[
0
].
comp_node
=
inputs
[
0
].
comp_node
;
dests
[
0
].
layout
=
TensorLayout
(
inputs
[
0
].
layout
);
dests
[
0
].
layout
.
dtype
=
inputs
[
0
].
layout
.
dtype
;
dests
[
1
].
comp_node
=
inputs
[
0
].
comp_node
;
dests
[
1
].
layout
=
TensorLayout
(
TensorShape
({
inputs
[
0
].
layout
.
shape
[
0
]}),
dtype
::
Int32
());
return
{
dests
,
true
};
}
}
// anonymous namespace
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
Handle
new_handle
(
CompNode
comp_node
,
uint64_t
seed
)
{
return
RNGDnnOpManager
::
inst
().
new_handle
(
comp_node
,
seed
);
return
RNGDnnOpManager
::
inst
().
new_handle
(
comp_node
,
seed
);
...
@@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){
...
@@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){
return
RNGDnnOpManager
::
get_comp_node
(
handle
);
return
RNGDnnOpManager
::
get_comp_node
(
handle
);
}
}
#define REG_RNG_OP(NAME)\
#define REG_RNG_OP(NAME, Output) \
namespace { \
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME>) \
.apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \
.execute(execute<NAME>) \
.fallback(); \
.fallback(); \
} \
}
REG_RNG_OP
(
UniformRNG
)
REG_RNG_OP
(
UniformRNG
,
SymbolVar
)
REG_RNG_OP
(
GaussianRNG
)
REG_RNG_OP
(
GaussianRNG
,
SymbolVar
)
REG_RNG_OP
(
GammaRNG
)
REG_RNG_OP
(
GammaRNG
,
SymbolVar
)
REG_RNG_OP
(
PermutationRNG
)
REG_RNG_OP
(
PermutationRNG
,
SymbolVar
)
REG_RNG_OP
(
PoissonRNG
)
REG_RNG_OP
(
PoissonRNG
,
SymbolVar
)
REG_RNG_OP
(
BetaRNG
)
REG_RNG_OP
(
BetaRNG
,
SymbolVar
)
REG_RNG_OP
(
ShuffleRNG
,
SymbolVarArray
)
#undef REG_RNG_OP
#undef REG_RNG_OP
}
// namespace mgb::imperative::rng
}
// namespace mgb::imperative::rng
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
3977b7aa
...
@@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
...
@@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
}
}
def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle)
);
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
}
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
let extraArguments = (ins
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
MgbCompNodeAttr:$comp_node
...
...
src/opr/impl/rand.cpp
浏览文件 @
3977b7aa
...
@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>;
...
@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>;
template
class
RNGOprBase
<::
megdnn
::
PermutationRNG
>;
template
class
RNGOprBase
<::
megdnn
::
PermutationRNG
>;
template
class
RNGOprBase
<::
megdnn
::
BetaRNG
>;
template
class
RNGOprBase
<::
megdnn
::
BetaRNG
>;
template
class
RNGOprBase
<::
megdnn
::
PoissonRNG
>;
template
class
RNGOprBase
<::
megdnn
::
PoissonRNG
>;
template
class
RNGOprBase
<::
megdnn
::
ShuffleRNGForward
>;
template
class
RNGOprBase
<::
megdnn
::
ShuffleRNGBackward
>;
#if MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
IMPL
(
GaussianRNG
);
IMPL
(
GaussianRNG
);
IMPL
(
UniformRNG
);
IMPL
(
UniformRNG
);
...
@@ -200,9 +202,87 @@ IMPL(PoissonRNG);
...
@@ -200,9 +202,87 @@ IMPL(PoissonRNG);
IMPL
(
PermutationRNG
);
IMPL
(
PermutationRNG
);
IMPL
(
BetaRNG
);
IMPL
(
BetaRNG
);
#endif
#endif
}
}
// namespace intl
}
// namespace opr
}
// namespace mgb
/* ================= ShuffleRNGForward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ShuffleRNGForward
);
ShuffleRNGForward
::
ShuffleRNGForward
(
VarNode
*
data
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
:
Super
({
data
->
owner_graph
(),
config
,
"shuffle_rng"
,
{
data
}},
param
)
{
add_input
({
data
});
add_output
(
None
)
->
dtype
(
data
->
dtype
());
add_output
(
None
)
->
dtype
(
dtype
::
Int32
{});
cg
::
add_workspace_output
(
this
);
add_equivalence_component
<
ScalarHash
<
void
*>>
(
this
);
}
SymbolVarArray
ShuffleRNGForward
::
make
(
SymbolVar
in_tensor
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
{
auto
node
=
in_tensor
.
node
()
->
owner_graph
()
->
insert_opr
(
std
::
make_unique
<
ShuffleRNGForward
>
(
in_tensor
.
node
(),
param
,
config
));
mgb_assert
(
node
->
output
().
size
()
==
3
);
return
{
node
->
output
(
0
),
node
->
output
(
1
)};
}
}
void
ShuffleRNGForward
::
init_output_static_infer_desc
()
{
using
namespace
cg
::
static_infer
;
auto
&&
mgr
=
owner_graph
()
->
static_infer_manager
();
mgr
.
register_shape_infer
(
output
(
0
),
ShapeInferDesc
::
make_identity
(
input
(
0
)));
auto
infer_oshp1
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
iv
)
{
TensorLayout
o0
,
o1
;
m_dnn_opr
->
deduce_layout
({
iv
.
val
[
0
].
shape
(),
input
(
0
)
->
dtype
()},
o0
,
o1
);
dest
=
o1
;
return
true
;
};
mgr
.
register_shape_infer
(
output
(
1
),
{
SourceType
::
DEP
,
{{
input
(
0
),
DepType
::
SHAPE
}},
infer_oshp1
});
auto
infer_wk
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
inp
)
{
ensure_megdnn_opr
();
dest
.
ndim
=
1
;
dest
.
shape
[
0
]
=
m_dnn_opr
->
get_workspace_in_bytes
(
{
inp
.
val
[
0
].
shape
(),
input
(
0
)
->
dtype
()},
{
output
(
0
)
->
shape
(),
output
(
0
)
->
dtype
()},
{
output
(
1
)
->
shape
(),
output
(
1
)
->
dtype
()});
return
true
;
};
mgr
.
register_shape_infer
(
output
(
2
),
{
SourceType
::
DEP
,
{{
input
(
0
),
DepType
::
SHAPE
}},
infer_wk
});
}
}
void
ShuffleRNGForward
::
add_input_layout_constraint
()
{
input
(
0
)
->
add_layout_constraint_contiguous
();
};
void
ShuffleRNGForward
::
scn_do_execute
()
{
m_dnn_opr
->
exec
(
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
output
(
1
)
->
dev_tensor
().
as_megdnn
(),
get_megdnn_workspace_from_var
(
output
(
2
)));
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ShuffleRNGForward
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
wrt_idx
==
0
&&
!
out_grad
[
2
]);
if
(
!
out_grad
[
0
])
return
nullptr
;
return
ShuffleRNGBackward
::
make
(
out_grad
[
0
],
opr
.
output
(
1
),
opr
.
input
(
0
)).
node
();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ShuffleRNGBackward
);
MEGDNN_OPR_INIT3
(
ShuffleRNGBackward
,
"shuffle_rng_bwd"
,
2
,
true
)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/rand.sereg.h
浏览文件 @
3977b7aa
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "megbrain/opr/rand.h"
#include "megbrain/opr/rand.h"
...
@@ -14,6 +15,23 @@
...
@@ -14,6 +15,23 @@
namespace
mgb
{
namespace
mgb
{
namespace
serialization
{
template
<
>
struct
OprMaker
<
opr
::
ShuffleRNG
,
1
>
{
using
Opr
=
opr
::
ShuffleRNG
;
using
Param
=
Opr
::
Param
;
static
cg
::
OperatorNodeBase
*
make
(
const
Param
&
param
,
const
cg
::
VarNodeArray
&
inputs
,
ComputingGraph
&
graph
,
const
OperatorNodeConfig
&
config
)
{
MGB_MARK_USED_VAR
(
graph
);
auto
out
=
Opr
::
make
(
inputs
[
0
],
param
,
config
);
return
out
[
0
].
node
()
->
owner_opr
();
}
};
}
// namespace serialization
namespace
opr
{
namespace
opr
{
using
UniformRNGV1
=
opr
::
UniformRNG
;
using
UniformRNGV1
=
opr
::
UniformRNG
;
...
@@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2);
...
@@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR
(
PoissonRNG
,
1
);
MGB_SEREG_OPR
(
PoissonRNG
,
1
);
MGB_SEREG_OPR
(
PermutationRNG
,
1
);
MGB_SEREG_OPR
(
PermutationRNG
,
1
);
MGB_SEREG_OPR
(
BetaRNG
,
2
);
MGB_SEREG_OPR
(
BetaRNG
,
2
);
MGB_SEREG_OPR
(
ShuffleRNG
,
1
);
MGB_SEREG_OPR
(
ShuffleRNGBackward
,
3
);
}
// namespace opr
}
// namespace opr
}
// namespace mgb
}
// namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/include/megbrain/opr/rand.h
浏览文件 @
3977b7aa
...
@@ -6,14 +6,15 @@
...
@@ -6,14 +6,15 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
#include "megbrain/graph.h"
#include "megbrain/graph.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
namespace
mgb
{
namespace
mgb
{
...
@@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // {
...
@@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // {
};
};
/* ================= RNG with shape ================= */
/* ================= RNG with shape ================= */
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
public: \
\
RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \
public: \
static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \
const OperatorNodeConfig &config = {}); \
static SymbolVar make(SymbolVar shape, const Param& param = {}, \
static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \
const OperatorNodeConfig& config = {}); \
const OperatorNodeConfig &config, \
static SymbolVar make(ComputingGraph& graph, const TensorShape& shape, \
const Param ¶m = {}) { \
const OperatorNodeConfig& config, \
return make(var_from_tensor_shape(graph, config, "rng", shape), \
const Param& param = {}) { \
param, config); \
return make(var_from_tensor_shape(graph, config, "rng", shape), param, \
} \
config); \
void init_output_static_infer_desc() override; \
} \
void scn_do_execute() override; \
void init_output_static_infer_desc() override; \
};
void scn_do_execute() override; \
} \
;
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS
(
UniformRNG
)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS
(
UniformRNG
)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS
(
GaussianRNG
)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS
(
GaussianRNG
)
...
@@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
...
@@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
public: \
public: \
RNG(_INPUTS(VarNode*), const Param ¶m, \
RNG(_INPUTS(VarNode*), const Param ¶m, \
const OperatorNodeConfig &config); \
const OperatorNodeConfig &config); \
static
SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {},
\
static
_OUTPUTS make(_INPUTS(SymbolVar),const Param ¶m = {},
\
const OperatorNodeConfig &config = {}); \
const OperatorNodeConfig &config = {}); \
void init_output_static_infer_desc() override; \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
void scn_do_execute() override; \
...
@@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
...
@@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
/* ================= 1 input ================= */
/* ================= 1 input ================= */
#define _INPUTS(preifx) preifx i0
#define _INPUTS(preifx) preifx i0
#define _OUTPUTS SymbolVar
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
PoissonRNG
)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
PoissonRNG
)
#undef _OUTPUTS
#define _OUTPUTS SymbolVarArray
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
ShuffleRNGForward
)
#undef _OUTPUTS
#undef _INPUTS
#undef _INPUTS
/* ================= 2 input ================= */
/* ================= 2 input ================= */
#define _INPUTS(preifx) preifx i0, preifx i1
#define _INPUTS(preifx) preifx i0, preifx i1
#define _OUTPUTS SymbolVar
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
BetaRNG
)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
BetaRNG
)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
GammaRNG
)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS
(
GammaRNG
)
#undef _OUTPUTS
#undef _INPUTS
#undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
}
// intl
}
// intl
using
UniformRNG
=
intl
::
UniformRNG
;
using
UniformRNG
=
intl
::
UniformRNG
;
using
GaussianRNG
=
intl
::
GaussianRNG
;
using
GaussianRNG
=
intl
::
GaussianRNG
;
...
@@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG;
...
@@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG;
using
PermutationRNG
=
intl
::
PermutationRNG
;
using
PermutationRNG
=
intl
::
PermutationRNG
;
using
PoissonRNG
=
intl
::
PoissonRNG
;
using
PoissonRNG
=
intl
::
PoissonRNG
;
using
BetaRNG
=
intl
::
BetaRNG
;
using
BetaRNG
=
intl
::
BetaRNG
;
}
// namespace opr
using
ShuffleRNG
=
intl
::
ShuffleRNGForward
;
}
// namespace mgb
MGB_DEFINE_OPR_CLASS
(
ShuffleRNGBackward
,
intl
::
MegDNNOprWrapperBwd
<
megdnn
::
ShuffleRNGBackward
>
)
//{
public:
ShuffleRNGBackward
(
VarNode
*
out_diff
,
VarNode
*
indices
,
VarNode
*
result_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
static
SymbolVar
make
(
SymbolVar
out_diff
,
SymbolVar
indices
,
SymbolVar
result_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
};
}
// namespace opr
}
// namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/test/rand.cpp
浏览文件 @
3977b7aa
...
@@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) {
...
@@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) {
}
}
TEST
(
TestOprRand
,
ShuffleForward
)
{
auto
run
=
[
&
](
TensorShape
shape
)
{
std
::
shared_ptr
<
HostTensorND
>
src_host
(
new
HostTensorND
{
CompNode
::
load
(
"xpux"
),
shape
,
dtype
::
Float32
()});
auto
sptr
=
src_host
->
ptr
<
dt_float32
>
();
auto
size
=
shape
.
total_nr_elems
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
sptr
[
i
]
=
i
;
}
auto
graph
=
ComputingGraph
::
make
();
auto
src_sym
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
src_host
);
auto
rec
=
opr
::
ShuffleRNG
::
make
(
src_sym
,
{
10
});
HostTensorND
host_y
,
host_index
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
rec
[
0
],
host_y
),
make_callback_copy
(
rec
[
1
],
host_index
)});
func
->
execute
();
auto
dptr
=
host_y
.
ptr
<
dt_float32
>
();
auto
iptr
=
host_index
.
ptr
<
dt_int32
>
();
size_t
len
=
shape
[
0
];
size_t
step
=
size
/
len
;
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
step
;
++
j
)
{
assert
(
dptr
[
i
*
step
+
j
]
==
sptr
[
iptr
[
i
]
*
step
+
j
]);
}
}
};
run
({
10
});
run
({
6
,
3
});
run
({
1
,
1
});
}
TEST
(
TestOprRand
,
UniformReprod
)
{
TEST
(
TestOprRand
,
UniformReprod
)
{
static
constexpr
size_t
SIZE
=
123
;
static
constexpr
size_t
SIZE
=
123
;
auto
graph
=
ComputingGraph
::
make
();
auto
graph
=
ComputingGraph
::
make
();
...
...
src/serialization/impl/schema.fbs
浏览文件 @
3977b7aa
...
@@ -114,6 +114,7 @@ union OperatorParam {
...
@@ -114,6 +114,7 @@ union OperatorParam {
param.BetaRNG = 80,
param.BetaRNG = 80,
param.SlidingWindowTranspose = 81,
param.SlidingWindowTranspose = 81,
param.Padding = 82,
param.Padding = 82,
param.ShuffleRNG = 83,
}
}
table Operator {
table Operator {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录