Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9f352b1c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9f352b1c
编写于
6月 05, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(megbrain/dnn): add indexing remap int32 for naive and cuda
GitOrigin-RevId: 5f66d51de4751d77fc05e2849388fc6dfbae4a53
上级
5dbf218d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
265 addition
and
196 deletion
+265
-196
dnn/src/common/tensor_remap.cpp
dnn/src/common/tensor_remap.cpp
+4
-2
dnn/src/cuda/tensor_remap/opr_impl.cpp
dnn/src/cuda/tensor_remap/opr_impl.cpp
+35
-16
dnn/src/cuda/tensor_remap/tensor_remap.cu
dnn/src/cuda/tensor_remap/tensor_remap.cu
+95
-78
dnn/src/cuda/tensor_remap/tensor_remap.cuh
dnn/src/cuda/tensor_remap/tensor_remap.cuh
+14
-16
dnn/src/naive/tensor_remap/opr_impl.cpp
dnn/src/naive/tensor_remap/opr_impl.cpp
+88
-56
dnn/test/cuda/tensor_remap.cpp
dnn/test/cuda/tensor_remap.cpp
+29
-28
未找到文件。
dnn/src/common/tensor_remap.cpp
浏览文件 @
9f352b1c
...
...
@@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src,
}
megdnn_assert
(
map
.
shape
[
dst
.
ndim
]
==
src
.
ndim
,
"%s"
,
errmsg_c
);
megdnn_assert
(
src
.
dtype
==
dtype
::
Float32
());
megdnn_assert
(
dst
.
dtype
==
src
.
dtype
);
megdnn_assert
(
src
.
dtype
==
dtype
::
Float32
()
||
src
.
dtype
==
dtype
::
Int32
(),
"indexing remap only support float32/int32, got %s"
,
src
.
dtype
.
name
());
megdnn_assert
(
map
.
dtype
==
dtype
::
Int32
());
megdnn_assert
(
dst
.
dtype
==
dtype
::
Float32
());
}
void
IndexingRemapForward
::
deduce_layout
(
const
TensorLayout
&
src
,
...
...
dnn/src/cuda/tensor_remap/opr_impl.cpp
浏览文件 @
9f352b1c
...
...
@@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src,
for
(
size_t
i
=
0
_z
;
i
<
dst
.
layout
.
ndim
;
++
i
)
{
dshape
.
data
[
i
]
=
dst
.
layout
.
shape
[
i
];
}
// Invoke kernel
tensor_remap
::
forward
(
src
.
ptr
<
dt_float32
>
(),
map
.
ptr
<
dt_int32
>
(),
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
.
ndim
,
dst
.
layout
.
ndim
,
sstride
,
dstride
,
dshape
,
cuda_stream
(
handle
()));
// Invoke kernel
#define cb(dt) \
if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
using ctype = DTypeTrait<dt>::ctype; \
tensor_remap::forward<ctype>(src.ptr<ctype>(), map.ptr<dt_int32>(), \
dst.ptr<ctype>(), src.layout.ndim, \
dst.layout.ndim, sstride, dstride, \
dshape, cuda_stream(handle())); \
return; \
}
cb
(
dtype
::
Float32
)
cb
(
dtype
::
Int32
)
#undef cb
megdnn_throw
(
ssprintf
(
"cuda indexing remap forward only support "
"float32/int32 dtype, got %s"
,
src
.
layout
.
dtype
.
name
()));
}
void
IndexingRemapBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
...
...
@@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
for
(
size_t
i
=
0
_z
;
i
<
diff
.
layout
.
ndim
;
++
i
)
{
dshape
.
data
[
i
]
=
diff
.
layout
.
shape
[
i
];
}
// Invoke kernel
tensor_remap
::
backward
(
diff
.
ptr
<
dt_float32
>
(),
map
.
ptr
<
dt_int32
>
(),
grad
.
ptr
<
dt_float32
>
(),
grad
.
layout
.
ndim
,
diff
.
layout
.
ndim
,
sstride
,
dstride
,
sshape
,
dshape
,
param
().
is_non_overlapping
,
cuda_stream
(
handle
()));
// Invoke kernel
#define cb(dt) \
if (diff.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
using ctype = DTypeTrait<dt>::ctype; \
tensor_remap::backward<ctype>( \
diff.ptr<ctype>(), map.ptr<dt_int32>(), grad.ptr<ctype>(), \
grad.layout.ndim, diff.layout.ndim, sstride, dstride, sshape, \
dshape, param().is_non_overlapping, cuda_stream(handle())); \
return; \
}
cb
(
dtype
::
Float32
)
cb
(
dtype
::
Int32
)
megdnn_throw
(
ssprintf
(
"cuda indexing remap forward only support "
"float32/int32 dtype, got %s"
,
diff
.
layout
.
dtype
.
name
()));
}
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/tensor_remap/tensor_remap.cu
浏览文件 @
9f352b1c
...
...
@@ -6,28 +6,29 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/tensor_remap/tensor_remap.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/tensor_remap/tensor_remap.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
tensor_remap
{
namespace
{
__global__
void
forward_kernel
(
const
float
*
src
,
const
int
*
map
,
float
*
dst
,
uint32_t
sdim
,
uint32_t
ddim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
sstride
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
d
stride
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
dshap
e
,
uint32_t
total
)
{
template
<
typename
ctype
>
__global__
void
forward_kernel
(
const
ctype
*
src
,
const
int
*
map
,
ctype
*
dst
,
uint32_t
sdim
,
uint32_t
ddim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
s
stride
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
dstrid
e
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
dshape
,
uint32_t
total
)
{
uint32_t
didx_cont
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
didx_cont
<
total
)
{
uint32_t
midx
=
didx_cont
*
sdim
;
uint32_t
didx
=
0u
;
for
(
uint32_t
j
=
ddim
;
j
>
0u
;
--
j
)
{
uint32_t
i
=
j
-
1u
;
uint32_t
i
=
j
-
1u
;
uint32_t
didx_cur
=
didx_cont
%
dshape
.
data
[
i
];
didx_cont
/=
dshape
.
data
[
i
];
didx
+=
didx_cur
*
dstride
.
data
[
i
];
...
...
@@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst,
}
}
void
forward
(
const
float
*
src
,
const
int
*
map
,
float
*
dst
,
uint32_t
sdim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
&
dshape
,
cudaStream_t
stream
)
{
uint32_t
total
=
1u
;
for
(
uint32_t
i
=
0u
;
i
<
ddim
;
++
i
)
total
*=
dshape
.
data
[
i
];
uint32_t
threads
=
query_blocksize_for_kernel
((
void
*
)
&
forward_kernel
);
uint32_t
blocks
=
DIVUP
(
total
,
threads
);
forward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
map
,
dst
,
sdim
,
ddim
,
sstride
,
dstride
,
dshape
,
total
);
after_kernel_launch
();
}
__global__
void
fill_zero_kernel
(
float
*
a
,
uint32_t
dim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
stride
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
shape
,
uint32_t
total
)
{
template
<
typename
ctype
>
__global__
void
fill_zero_kernel
(
ctype
*
a
,
uint32_t
dim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
stride
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
shape
,
uint32_t
total
)
{
uint32_t
idx_cont
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx_cont
<
total
)
{
uint32_t
idx
=
0u
;
for
(
uint32_t
j
=
dim
;
j
>
0u
;
--
j
)
{
uint32_t
i
=
j
-
1u
;
uint32_t
i
=
j
-
1u
;
uint32_t
idx_cur
=
idx_cont
%
shape
.
data
[
i
];
idx_cont
/=
shape
.
data
[
i
];
idx
+=
idx_cur
*
stride
.
data
[
i
];
...
...
@@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim,
}
}
__global__
void
backward_kernel
(
const
float
*
diff
,
const
int
*
map
,
float
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
sstride
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
d
stride
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
dshap
e
,
uint32_t
total
)
{
template
<
typename
ctype
>
__global__
void
backward_kernel
(
const
ctype
*
diff
,
const
int
*
map
,
ctype
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
s
stride
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
dstrid
e
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
dshape
,
uint32_t
total
)
{
uint32_t
didx_cont
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
didx_cont
<
total
)
{
uint32_t
midx
=
didx_cont
*
sdim
;
uint32_t
didx
=
0u
;
for
(
uint32_t
j
=
ddim
;
j
>
0u
;
--
j
)
{
uint32_t
i
=
j
-
1u
;
uint32_t
i
=
j
-
1u
;
uint32_t
didx_cur
=
didx_cont
%
dshape
.
data
[
i
];
didx_cont
/=
dshape
.
data
[
i
];
didx
+=
didx_cur
*
dstride
.
data
[
i
];
...
...
@@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad,
}
}
template
<
typename
ctype
>
__global__
void
backward_kernel_non_overlapping
(
const
float
*
diff
,
const
int
*
map
,
float
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
sstride
,
const
ctype
*
diff
,
const
int
*
map
,
ctype
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
sstride
,
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
dstride
,
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
dshape
,
uint32_t
total
)
{
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
dshape
,
uint32_t
total
)
{
uint32_t
didx_cont
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
didx_cont
<
total
)
{
uint32_t
midx
=
didx_cont
*
sdim
;
uint32_t
didx
=
0u
;
for
(
uint32_t
j
=
ddim
;
j
>
0u
;
--
j
)
{
uint32_t
i
=
j
-
1u
;
uint32_t
i
=
j
-
1u
;
uint32_t
didx_cur
=
didx_cont
%
dshape
.
data
[
i
];
didx_cont
/=
dshape
.
data
[
i
];
didx
+=
didx_cur
*
dstride
.
data
[
i
];
...
...
@@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping(
}
}
void
backward
(
const
float
*
diff
,
const
int
*
map
,
float
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
&
sshape
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
&
dshape
,
bool
is_non_overlapping
,
cudaStream_t
stream
)
{
}
// anonymous namespace
namespace
tensor_remap
{
template
<
typename
ctype
>
void
forward
(
const
ctype
*
src
,
const
int
*
map
,
ctype
*
dst
,
uint32_t
sdim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>&
dshape
,
cudaStream_t
stream
)
{
uint32_t
total
=
1u
;
for
(
uint32_t
i
=
0u
;
i
<
ddim
;
++
i
)
total
*=
dshape
.
data
[
i
];
uint32_t
threads
=
query_blocksize_for_kernel
((
void
*
)
&
forward_kernel
<
ctype
>
);
uint32_t
blocks
=
DIVUP
(
total
,
threads
);
forward_kernel
<
ctype
><<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
map
,
dst
,
sdim
,
ddim
,
sstride
,
dstride
,
dshape
,
total
);
after_kernel_launch
();
}
template
<
typename
ctype
>
void
backward
(
const
ctype
*
diff
,
const
int
*
map
,
ctype
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>&
sshape
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>&
dshape
,
bool
is_non_overlapping
,
cudaStream_t
stream
)
{
{
// Fill grad with zeros.
uint32_t
total
=
1u
;
for
(
uint32_t
i
=
0u
;
i
<
sdim
;
++
i
)
total
*=
sshape
.
data
[
i
];
uint32_t
threads
=
query_blocksize_for_kernel
((
void
*
)
&
fill_zero_kernel
);
for
(
uint32_t
i
=
0u
;
i
<
sdim
;
++
i
)
total
*=
sshape
.
data
[
i
];
uint32_t
threads
=
query_blocksize_for_kernel
((
void
*
)
&
fill_zero_kernel
<
ctype
>
);
uint32_t
blocks
=
DIVUP
(
total
,
threads
);
fill_zero_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
fill_zero_kernel
<
ctype
><
<<
blocks
,
threads
,
0
,
stream
>>>
(
grad
,
sdim
,
sstride
,
sshape
,
total
);
after_kernel_launch
();
}
{
// Update grad.
uint32_t
total
=
1u
;
for
(
uint32_t
i
=
0u
;
i
<
ddim
;
++
i
)
total
*=
dshape
.
data
[
i
];
for
(
uint32_t
i
=
0u
;
i
<
ddim
;
++
i
)
total
*=
dshape
.
data
[
i
];
if
(
is_non_overlapping
)
{
uint32_t
threads
=
query_blocksize_for_kernel
(
(
void
*
)
&
backward_kernel_non_overlapping
);
(
void
*
)
&
backward_kernel_non_overlapping
<
ctype
>
);
uint32_t
blocks
=
DIVUP
(
total
,
threads
);
backward_kernel_non_overlapping
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
map
,
grad
,
sdim
,
ddim
,
sstride
,
dstride
,
dshape
,
total
);
backward_kernel_non_overlapping
<
ctype
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
map
,
grad
,
sdim
,
ddim
,
sstride
,
dstride
,
dshape
,
total
);
}
else
{
uint32_t
threads
=
query_blocksize_for_kernel
(
(
void
*
)
&
backward_kernel
);
uint32_t
threads
=
query_blocksize_for_kernel
((
void
*
)
&
backward_kernel
<
ctype
>
);
uint32_t
blocks
=
DIVUP
(
total
,
threads
);
backward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
map
,
grad
,
sdim
,
ddim
,
sstride
,
dstride
,
dshape
,
backward_kernel
<
ctype
><<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
map
,
grad
,
sdim
,
ddim
,
sstride
,
dstride
,
dshape
,
total
);
}
after_kernel_launch
();
}
}
}
// namespace tensor_remap
}
// namespace cuda
}
// namespace megdnn
#define INST(T) \
template void forward<T>( \
const T* src, const int* map, T* dst, uint32_t sdim, \
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \
cudaStream_t stream); \
template void backward<T>( \
const T* diff, const int* map, T* grad, uint32_t sdim, \
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \
bool is_non_overlapping, cudaStream_t stream);
INST
(
dt_float32
)
INST
(
dt_int32
)
// vim: syntax=cpp.doxygen
#undef INST
}
// namespace tensor_remap
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/tensor_remap/tensor_remap.cuh
浏览文件 @
9f352b1c
...
...
@@ -17,25 +17,23 @@ namespace megdnn {
namespace
cuda
{
namespace
tensor_remap
{
void
forward
(
const
float
*
src
,
const
int
*
map
,
float
*
dst
,
uint32_t
sdim
,
uint32_t
d
dim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
&
dshape
,
cudaStream_t
stream
);
template
<
typename
ctype
>
void
forward
(
const
ctype
*
src
,
const
int
*
map
,
ctype
*
dst
,
uint32_t
s
dim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>&
dshape
,
cudaStream_t
stream
);
void
backward
(
const
float
*
diff
,
const
int
*
map
,
float
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>
&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
&
sshape
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>
&
dshape
,
bool
is_non_overlapping
,
cudaStream_t
stream
);
template
<
typename
ctype
>
void
backward
(
const
ctype
*
diff
,
const
int
*
map
,
ctype
*
grad
,
uint32_t
sdim
,
uint32_t
ddim
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
sstride
,
const
array_wrapper
<
int
,
MEGDNN_MAX_NDIM
>&
dstride
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>&
sshape
,
const
array_wrapper
<
uint32_t
,
MEGDNN_MAX_NDIM
>&
dshape
,
bool
is_non_overlapping
,
cudaStream_t
stream
);
}
// namespace tensor_remap
}
// namespace tensor_remap
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/naive/tensor_remap/opr_impl.cpp
浏览文件 @
9f352b1c
...
...
@@ -6,75 +6,107 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/tensor_remap/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace
megdnn
{
namespace
naive
{
using
namespace
megdnn
;
using
namespace
naive
;
namespace
{
template
<
typename
ctype
>
void
forward
(
const
TensorND
&
src
,
const
TensorND
&
map
,
const
TensorND
&
dst
)
{
auto
&&
sshape
=
src
.
layout
;
auto
&&
mshape
=
map
.
layout
;
auto
&&
dshape
=
dst
.
layout
;
// Last element is zero to facilitate maddr calculation.
std
::
vector
<
size_t
>
didx
(
dshape
.
ndim
+
1
,
0
_z
);
do
{
auto
maddr
=
get_linear_addr
(
didx
.
data
(),
mshape
.
shape
,
mshape
.
ndim
);
std
::
vector
<
size_t
>
sidx
(
sshape
.
ndim
);
for
(
size_t
i
=
0
_z
;
i
<
sshape
.
ndim
;
++
i
)
{
sidx
[
i
]
=
map
.
ptr
<
dt_int32
>
()[
maddr
+
i
];
}
auto
saddr
=
get_linear_addr_noncont
(
sidx
.
data
(),
src
.
layout
);
auto
daddr
=
get_linear_addr_noncont
(
didx
.
data
(),
dst
.
layout
);
dst
.
ptr
<
ctype
>
()[
daddr
]
=
src
.
ptr
<
ctype
>
()[
saddr
];
}
while
(
get_next_addr
(
didx
.
data
(),
dshape
.
shape
,
dshape
.
ndim
));
}
template
<
typename
ctype
>
void
backward
(
const
TensorND
&
diff
,
const
TensorND
&
map
,
const
TensorND
&
grad
)
{
auto
&&
sshape
=
grad
.
layout
;
auto
&&
mshape
=
map
.
layout
;
auto
&&
dshape
=
diff
.
layout
;
std
::
vector
<
size_t
>
sidx
(
sshape
.
ndim
,
0
_z
);
{
// Set grad to zero.
do
{
auto
saddr
=
get_linear_addr_noncont
(
sidx
.
data
(),
grad
.
layout
);
grad
.
ptr
<
ctype
>
()[
saddr
]
=
0.0
f
;
}
while
(
get_next_addr
(
sidx
.
data
(),
sshape
.
shape
,
sshape
.
ndim
));
}
std
::
vector
<
size_t
>
didx
(
dshape
.
ndim
+
1
,
0
_z
);
do
{
auto
maddr
=
get_linear_addr
(
didx
.
data
(),
mshape
.
shape
,
mshape
.
ndim
);
std
::
vector
<
size_t
>
sidx
(
sshape
.
ndim
);
for
(
size_t
i
=
0
_z
;
i
<
sshape
.
ndim
;
++
i
)
{
sidx
[
i
]
=
map
.
ptr
<
dt_int32
>
()[
maddr
+
i
];
}
auto
saddr
=
get_linear_addr_noncont
(
sidx
.
data
(),
grad
.
layout
);
auto
daddr
=
get_linear_addr_noncont
(
didx
.
data
(),
diff
.
layout
);
grad
.
ptr
<
ctype
>
()[
saddr
]
+=
diff
.
ptr
<
ctype
>
()[
daddr
];
}
while
(
get_next_addr
(
didx
.
data
(),
dshape
.
shape
,
dshape
.
ndim
));
}
}
// anonymous namespace
void
IndexingRemapForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
map
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_in
map
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
map
.
layout
,
dst
.
layout
,
workspace
.
size
);
auto
kern
=
[
=
]()
{
auto
&&
sshape
=
src
.
layout
;
auto
&&
mshape
=
map
.
layout
;
auto
&&
dshape
=
dst
.
layout
;
// Last element is zero to facilitate maddr calculation.
std
::
vector
<
size_t
>
didx
(
dshape
.
ndim
+
1
,
0
_z
);
do
{
auto
maddr
=
get_linear_addr
(
didx
.
data
(),
mshape
.
shape
,
mshape
.
ndim
);
std
::
vector
<
size_t
>
sidx
(
sshape
.
ndim
);
for
(
size_t
i
=
0
_z
;
i
<
sshape
.
ndim
;
++
i
)
{
sidx
[
i
]
=
map
.
ptr
<
dt_int32
>
()[
maddr
+
i
];
}
auto
saddr
=
get_linear_addr_noncont
(
sidx
.
data
(),
src
.
layout
);
auto
daddr
=
get_linear_addr_noncont
(
didx
.
data
(),
dst
.
layout
);
dst
.
ptr
<
dt_float32
>
()[
daddr
]
=
src
.
ptr
<
dt_float32
>
()[
saddr
];
}
while
(
get_next_addr
(
didx
.
data
(),
dshape
.
shape
,
dshape
.
ndim
));
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern
());
switch
(
src
.
layout
.
dtype
.
enumv
())
{
#define cb(dt) \
case DTypeTrait<dt>::enumv: \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
forward<DTypeTrait<dt>::ctype>(src, map, dst)); \
return;
cb
(
dtype
::
Float32
)
cb
(
dtype
::
Int32
)
#undef cb
default:
megdnn_throw
(
ssprintf
(
"unsupported dtype %s in indexing "
"remap forward naive
\n
"
,
src
.
layout
.
dtype
.
name
()));
}
}
void
IndexingRemapBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
map
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_in
map
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
check_exec
(
diff
.
layout
,
map
.
layout
,
grad
.
layout
,
workspace
.
size
);
auto
kern
=
[
=
]()
{
auto
&&
sshape
=
grad
.
layout
;
auto
&&
mshape
=
map
.
layout
;
auto
&&
dshape
=
diff
.
layout
;
std
::
vector
<
size_t
>
sidx
(
sshape
.
ndim
,
0
_z
);
{
// Set grad to zero.
do
{
auto
saddr
=
get_linear_addr_noncont
(
sidx
.
data
(),
grad
.
layout
);
grad
.
ptr
<
dt_float32
>
()[
saddr
]
=
0.0
f
;
}
while
(
get_next_addr
(
sidx
.
data
(),
sshape
.
shape
,
sshape
.
ndim
));
}
std
::
vector
<
size_t
>
didx
(
dshape
.
ndim
+
1
,
0
_z
);
do
{
auto
maddr
=
get_linear_addr
(
didx
.
data
(),
mshape
.
shape
,
mshape
.
ndim
);
std
::
vector
<
size_t
>
sidx
(
sshape
.
ndim
);
for
(
size_t
i
=
0
_z
;
i
<
sshape
.
ndim
;
++
i
)
{
sidx
[
i
]
=
map
.
ptr
<
dt_int32
>
()[
maddr
+
i
];
}
auto
saddr
=
get_linear_addr_noncont
(
sidx
.
data
(),
grad
.
layout
);
auto
daddr
=
get_linear_addr_noncont
(
didx
.
data
(),
diff
.
layout
);
grad
.
ptr
<
dt_float32
>
()[
saddr
]
+=
diff
.
ptr
<
dt_float32
>
()[
daddr
];
}
while
(
get_next_addr
(
didx
.
data
(),
dshape
.
shape
,
dshape
.
ndim
));
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern
());
switch
(
diff
.
layout
.
dtype
.
enumv
())
{
#define cb(dt) \
case DTypeTrait<dt>::enumv: \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
backward<DTypeTrait<dt>::ctype>(diff, map, grad)); \
return;
cb
(
dtype
::
Float32
)
cb
(
dtype
::
Int32
)
#undef cb
default:
megdnn_throw
(
ssprintf
(
"unsupported dtype %s in indexing remap backward naive
\n
"
,
diff
.
layout
.
dtype
.
name
()));
}
}
}
// namespace naive
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/test/cuda/tensor_remap.cpp
浏览文件 @
9f352b1c
...
...
@@ -16,39 +16,42 @@
namespace
megdnn
{
namespace
test
{
TEST_F
(
CUDA
,
TENSOR_REMAP_FORWARD
)
{
TEST_F
(
CUDA
,
TENSOR_REMAP_FORWARD
)
{
Checker
<
IndexingRemapForward
>
checker
(
handle_cuda
());
TensorShape
src
{
11
,
13
,
17
},
map
{
3
,
5
,
7
,
3
},
dst
{
3
,
5
,
7
};
checker
.
set_dtype
(
1
,
dtype
::
Int32
());
TensorShape
src
{
11
,
13
,
17
},
map
{
3
,
5
,
7
,
3
},
dst
{
3
,
5
,
7
};
using
namespace
tensor_remap
;
{
MapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
src
,
map
,
{}});
}
{
NonoverlappingMapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
src
,
map
,
{}});
for
(
auto
dt
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Int32
()})
{
checker
.
set_dtype
(
0
,
dt
);
checker
.
set_dtype
(
2
,
dt
);
using
namespace
tensor_remap
;
{
MapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
src
,
map
,
{}});
}
{
NonoverlappingMapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
src
,
map
,
{}});
}
}
}
TEST_F
(
CUDA
,
TENSOR_REMAP_BACKWARD
)
{
TEST_F
(
CUDA
,
TENSOR_REMAP_BACKWARD
)
{
Checker
<
IndexingRemapBackward
>
checker
(
handle_cuda
());
checker
.
set_dtype
(
1
,
dtype
::
Int32
());
TensorShape
src
{
11
,
13
,
17
},
map
{
3
,
5
,
7
,
3
},
dst
{
3
,
5
,
7
};
using
namespace
tensor_remap
;
{
MapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
dst
,
map
,
src
});
}
{
NonoverlappingMapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
dst
,
map
,
src
});
TensorShape
src
{
11
,
13
,
17
},
map
{
3
,
5
,
7
,
3
},
dst
{
3
,
5
,
7
};
checker
.
set_dtype
(
1
,
dtype
::
Int32
());
for
(
auto
dt
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Int32
()})
{
checker
.
set_dtype
(
0
,
dt
);
checker
.
set_dtype
(
2
,
dt
);
using
namespace
tensor_remap
;
{
MapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
dst
,
map
,
src
});
}
{
NonoverlappingMapRNG
rng
(
src
);
checker
.
set_rng
(
1
,
&
rng
).
execs
({
dst
,
map
,
src
});
}
}
}
...
...
@@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD)
}
// namespace megdnn
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录