Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
68cdabd2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
68cdabd2
编写于
10月 13, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr): indexing_multi_axis_vec support nd index
GitOrigin-RevId: 07b1248bdcaa8d12c91220eb482090ece16a0a10
上级
05ee6038
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
338 addition
and
122 deletion
+338
-122
dnn/include/megdnn/oprs/general.h
dnn/include/megdnn/oprs/general.h
+6
-4
dnn/src/common/indexing_multi_axis_vec.cpp
dnn/src/common/indexing_multi_axis_vec.cpp
+49
-33
dnn/src/cuda/indexing_multi_axis_vec/kern.cuh
dnn/src/cuda/indexing_multi_axis_vec/kern.cuh
+18
-5
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl
...rc/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl
+15
-3
dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu
dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu
+34
-9
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
+31
-13
dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp
dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp
+20
-11
dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip
dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip
+13
-5
dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl
...c/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl
+9
-3
dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip
...rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip
+31
-10
dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp
dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp
+29
-11
dnn/test/common/indexing_multi_axis_vec.h
dnn/test/common/indexing_multi_axis_vec.h
+15
-3
dnn/test/common/mesh_indexing.h
dnn/test/common/mesh_indexing.h
+2
-2
dnn/test/cuda/indexing_multi_axis_vec.cpp
dnn/test/cuda/indexing_multi_axis_vec.cpp
+19
-0
imperative/python/test/unit/core/test_indexing_op.py
imperative/python/test/unit/core/test_indexing_op.py
+16
-0
src/opr/impl/indexing.cpp
src/opr/impl/indexing.cpp
+30
-9
src/opr/include/megbrain/opr/indexing.h
src/opr/include/megbrain/opr/indexing.h
+1
-1
未找到文件。
dnn/include/megdnn/oprs/general.h
浏览文件 @
68cdabd2
...
...
@@ -1115,7 +1115,7 @@ public:
* access *data*; stride of layout on that axis would be zero, and
* strides on other axes correspond to the strides in *data*
*/
static
std
::
pair
<
TensorLayout
,
size_t
>
get_value_iter_optimized_layout
(
static
std
::
tuple
<
TensorLayout
,
size_t
,
TensorShape
>
get_value_iter_optimized_layout
(
const
TensorLayout
&
data
,
const
TensorLayout
&
value
,
const
IndexDesc
&
index
,
size_t
idx_axis
);
...
...
@@ -1159,7 +1159,8 @@ public:
* \brief get workspace size based on output shape and indexing axes
*/
size_t
get_workspace_in_bytes
(
const
TensorShape
&
dst
,
const
size_t
*
axes
,
size_t
nr_axes
);
const
TensorShape
&
dst
,
const
size_t
*
axes
,
size_t
nr_axes
,
size_t
idx_ndim
);
static
void
deduce_layout
(
const
TensorLayout
&
data
,
const
IndexDescLayoutOnly
&
index
,
...
...
@@ -1193,7 +1194,8 @@ public:
* axes
*/
size_t
get_workspace_in_bytes
(
const
TensorShape
&
value
,
const
size_t
*
axes
,
size_t
nr_axes
);
const
TensorShape
&
value
,
const
size_t
*
axes
,
size_t
nr_axes
,
size_t
idx_ndim
);
protected:
ExecInfo
check_exec
(
...
...
@@ -1223,7 +1225,7 @@ public:
using
AxisIndexerLayoutOnly
=
IndexingMultiAxisVecBase
::
AxisIndexerLayoutOnly
;
using
IndexDescLayoutOnly
=
IndexingMultiAxisVecBase
::
IndexDescLayoutOnly
;
size_t
get_workspace_in_bytes
(
const
TensorShape
&
,
const
size_t
*
,
size_t
)
{
size_t
get_workspace_in_bytes
(
const
TensorShape
&
,
const
size_t
*
,
size_t
,
size_t
)
{
return
0
;
}
...
...
dnn/src/common/indexing_multi_axis_vec.cpp
浏览文件 @
68cdabd2
...
...
@@ -15,8 +15,10 @@
using
namespace
megdnn
;
namespace
{
// we need a workspace to store offset base table, which has same size with index
size_t
get_index_size_for_workspace
(
const
TensorShape
&
shp
,
const
size_t
*
axes
,
size_t
nr_axes
)
{
const
TensorShape
&
shp
,
const
size_t
*
axes
,
size_t
nr_axes
,
size_t
idx_ndim
)
{
size_t
idx_axis
=
axes
[
0
];
megdnn_assert
(
shp
.
ndim
&&
nr_axes
);
for
(
size_t
i
=
1
;
i
<
nr_axes
;
++
i
)
{
...
...
@@ -29,7 +31,11 @@ size_t get_index_size_for_workspace(
megdnn_assert
(
shp
.
ndim
>
idx_axis
,
"index on the %zuth axis; but shape is %s"
,
idx_axis
,
shp
.
to_string
().
c_str
());
return
shp
.
shape
[
idx_axis
];
size_t
idx_size
=
1
;
for
(
size_t
i
=
0
;
i
<
idx_ndim
;
++
i
)
{
idx_size
*=
shp
.
shape
[
idx_axis
+
i
];
}
return
idx_size
;
}
}
// anonymous namespace
...
...
@@ -47,23 +53,17 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
const
TensorLayout
&
data
,
const
IndexDescLayoutOnly
&
index
,
TensorLayout
&
dst
)
{
megdnn_assert
(
!
index
.
empty
());
megdnn_assert
(
data
.
ndim
>=
index
.
size
());
dst
.
ndim
=
data
.
ndim
-
index
.
size
()
+
1
;
dst
.
shape
[
0
]
=
1
;
dst
.
ndim
=
data
.
ndim
-
index
.
size
();
dst
.
dtype
=
data
.
dtype
;
TensorShapeArray
index_shapes
;
auto
brdcast
=
[
&
](
const
TensorLayout
&
ly
)
{
if
(
ly
.
ndim
!=
1
)
return
false
;
if
(
dst
.
shape
[
0
]
==
ly
.
shape
[
0
])
return
true
;
if
(
dst
.
shape
[
0
]
==
1
)
{
dst
.
shape
[
0
]
=
ly
.
shape
[
0
];
return
true
;
}
return
ly
.
shape
[
0
]
==
1
;
megdnn_assert
(
ly
.
dtype
==
dtype
::
Int32
{});
index_shapes
.
push_back
(
ly
);
};
size_t
dst_axis
=
1
;
size_t
dst_axis
=
0
;
ptrdiff_t
prev_axis
=
-
1
;
for
(
size_t
axis
=
0
;
axis
<
index
.
size
();
++
axis
)
{
auto
&&
idx
=
index
[
axis
];
...
...
@@ -73,10 +73,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
megdnn_assert
(
idx
.
axis
<
data
.
ndim
&&
static_cast
<
ptrdiff_t
>
(
idx
.
axis
)
>
prev_axis
,
"index %zu requests invalid axis %zu"
,
axis
,
idx
.
axis
);
auto
brd_succ
=
brdcast
(
idx
.
layout
);
megdnn_assert
(
brd_succ
,
"invalid layout at index %zu: %s"
,
axis
,
idx
.
layout
.
to_string
().
c_str
());
brdcast
(
idx
.
layout
);
for
(
size_t
i
=
prev_axis
+
1
;
i
<
idx
.
axis
;
++
i
)
{
dst
.
shape
[
dst_axis
++
]
=
data
.
shape
[
i
];
...
...
@@ -99,13 +96,16 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
}
}
if
(
contig_idx
)
{
auto
shp0
=
dst
.
shape
[
0
];
idx_axis
=
index
[
0
].
axis
;
for
(
size_t
i
=
0
;
i
<
idx_axis
;
++
i
)
{
dst
.
shape
[
i
]
=
dst
.
shape
[
i
+
1
];
}
dst
.
shape
[
idx_axis
]
=
shp0
;
}
TensorShape
index_shape
;
Elemwise
::
deduce_shape
(
index_shapes
,
index_shape
);
for
(
size_t
i
=
0
;
i
<
index_shape
.
ndim
;
++
i
)
{
dst
.
add_axis_inplace
(
idx_axis
+
i
,
1
,
0
);
dst
.
shape
[
idx_axis
+
i
]
=
index_shape
.
shape
[
i
];
}
dst
.
init_contiguous_stride
();
...
...
@@ -145,15 +145,26 @@ IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworksp
return
ret
;
}
std
::
pair
<
TensorLayout
,
size_t
>
IndexingMultiAxisVecBase
::
std
::
tuple
<
TensorLayout
,
size_t
,
TensorShape
>
IndexingMultiAxisVecBase
::
get_value_iter_optimized_layout
(
const
TensorLayout
&
data
,
const
TensorLayout
&
value
,
const
IndexDesc
&
index
,
size_t
idx_axis
)
{
size_t
data_axes
[
TensorLayout
::
MAX_NDIM
],
nr_axes
=
get_nonindex_axes
(
data
.
ndim
,
index
,
data_axes
);
// broadcast index shapes
TensorLayout
index_shape
;
{
TensorShapeArray
index_shapes
;
for
(
auto
&
idx
:
index
)
{
megdnn_assert
(
idx
.
vec
.
layout
.
dtype
==
dtype
::
Int32
{});
index_shapes
.
push_back
(
idx
.
vec
.
layout
);
}
Elemwise
::
deduce_shape
(
index_shapes
,
index_shape
);
}
megdnn_assert
(
nr_axes
==
value
.
ndim
-
1
&&
idx_axis
<
value
.
ndim
&&
nr_axes
==
value
.
ndim
-
index_shape
.
ndim
&&
idx_axis
<
value
.
ndim
&&
nr_axes
+
index
.
size
()
==
data
.
ndim
);
TensorLayout
ret
;
...
...
@@ -165,10 +176,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
}
ret
=
ret
.
collapse_contiguous
();
}
ret
.
shape
[
ret
.
ndim
]
=
value
.
shape
[
idx_axis
];
ret
.
stride
[
ret
.
ndim
]
=
0
;
size_t
ret_idx_axis
=
ret
.
ndim
;
for
(
size_t
i
=
0
;
i
<
index_shape
.
ndim
;
++
i
)
{
ret
.
shape
[
ret
.
ndim
]
=
value
.
shape
[
idx_axis
+
i
];
ret
.
stride
[
ret
.
ndim
]
=
0
;
++
ret
.
ndim
;
}
if
(
idx_axis
<
nr_axes
)
{
TensorLayout
tail
;
...
...
@@ -185,12 +199,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
}
}
return
{
ret
,
ret_idx_axis
}
;
return
std
::
make_tuple
(
ret
,
ret_idx_axis
,
index_shape
)
;
}
size_t
IndexingMultiAxisVec
::
get_workspace_in_bytes
(
const
TensorShape
&
dst
,
const
size_t
*
axes
,
size_t
nr_axes
)
{
return
get_workspace_in_bytes
(
get_index_size_for_workspace
(
dst
,
axes
,
nr_axes
));
const
TensorShape
&
dst
,
const
size_t
*
axes
,
size_t
nr_axes
,
size_t
idx_ndim
)
{
return
get_workspace_in_bytes
(
get_index_size_for_workspace
(
dst
,
axes
,
nr_axes
,
idx_ndim
));
}
IndexingMultiAxisVec
::
ExecInfo
IndexingMultiAxisVec
::
check_exec
(
...
...
@@ -205,8 +220,9 @@ IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
}
size_t
IndexingModifyMultiAxisVecBase
::
get_workspace_in_bytes
(
const
TensorShape
&
value
,
const
size_t
*
axes
,
size_t
nr_axes
)
{
return
get_workspace_in_bytes
(
get_index_size_for_workspace
(
value
,
axes
,
nr_axes
));
const
TensorShape
&
value
,
const
size_t
*
axes
,
size_t
nr_axes
,
size_t
idx_ndim
)
{
return
get_workspace_in_bytes
(
get_index_size_for_workspace
(
value
,
axes
,
nr_axes
,
idx_ndim
));
}
IndexingModifyMultiAxisVecBase
::
ExecInfo
IndexingModifyMultiAxisVecBase
::
check_exec
(
...
...
dnn/src/cuda/indexing_multi_axis_vec/kern.cuh
浏览文件 @
68cdabd2
...
...
@@ -21,17 +21,24 @@ namespace cuda {
namespace
indexing_multi_axis_vec
{
//! AxisIndexer equiv in kernel
template
<
int
idx_ndim
>
struct
KAxisIndexer
{
int
stride
;
int
stride
[
idx_ndim
];
#ifdef WIN32
Uint32Fastdiv
shape
[
idx_ndim
];
#else
// original shape[0] not storaged
Uint32Fastdiv
shape
[
idx_ndim
-
1
];
#endif
const
int
*
ptr
;
};
//! param for gen_offset_base
template
<
int
nidx
>
template
<
int
nidx
,
int
idx_ndim
>
struct
GenOffsetBaseParam
{
uint32_t
size
;
//!< number of outputs; also size of each index
int
*
output
;
//!< output ptr
KAxisIndexer
indexer
[
nidx
];
KAxisIndexer
<
idx_ndim
>
indexer
[
nidx
];
uint32_t
data_shape
[
nidx
];
int
data_stride
[
nidx
];
...
...
@@ -59,7 +66,12 @@ struct ApplyOprParam {
const
int
*
offset_base
;
ctype
*
data
,
*
value
;
// first idx axis
int
idx_axis
;
// last idx axis + 1
int
idx_axis_end
;
// number of elements for idx shape
int
idx_nelems
;
int
value_stride
;
...
...
@@ -68,8 +80,9 @@ struct ApplyOprParam {
};
//! generate offset bases for first axis in the output
template
<
int
nidx
>
void
gen_offset_base
(
const
GenOffsetBaseParam
<
nidx
>&
param
,
cudaStream_t
stream
);
template
<
int
nidx
,
int
idx_ndim
>
void
gen_offset_base
(
const
GenOffsetBaseParam
<
nidx
,
idx_ndim
>&
param
,
cudaStream_t
stream
);
struct
OprAtomicIncr
{
#if MEGDNN_CC_CUDA
...
...
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl
浏览文件 @
68cdabd2
...
...
@@ -29,11 +29,23 @@ namespace {
uint32_t
oidx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
oidx
<
param
.
tot_size
)
{
int
offset
=
0
,
coidx
=
oidx
;
int
all_ax_idx
[
ndim
];
// offset in index
int
idx_flat
=
0
;
// for non-indexed axes get offset
#pragma unroll
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
int
next_coidx
,
ax_idx
;
// [..., indexed_axes... |, ...]
if
(
i
+
1
==
param
.
idx_axis_end
)
{
idx_flat
=
coidx
;
}
// [... |, indexed_axes..., ...]
if
(
i
+
1
==
param
.
idx_axis
)
{
idx_flat
-=
coidx
*
param
.
idx_nelems
;
}
// shape[i] was storaged at shape[i-1]
if
(
i
)
{
// fast divide
next_coidx
=
coidx
/
param
.
value_ly_on_data
.
shape
[
i
-
1
];
ax_idx
=
coidx
-
...
...
@@ -44,9 +56,9 @@ namespace {
ax_idx
=
coidx
;
}
offset
+=
param
.
value_ly_on_data
.
stride
[
i
]
*
ax_idx
;
all_ax_idx
[
i
]
=
ax_idx
;
}
offset
+=
param
.
offset_base
[
all_ax_idx
[
param
.
idx_axis
]];
// offset from index, which was generated before
offset
+=
param
.
offset_base
[
idx_flat
];
Opr
::
apply
(
param
.
data
[
offset
],
param
.
value
[
oidx
*
param
.
value_stride
]);
...
...
dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu
浏览文件 @
68cdabd2
...
...
@@ -18,14 +18,29 @@ using namespace cuda;
using
namespace
indexing_multi_axis_vec
;
namespace
{
template
<
int
nidx
>
__global__
void
kgen_offset_base
(
GenOffsetBaseParam
<
nidx
>
param
)
{
template
<
int
nidx
,
int
idx_ndim
>
__global__
void
kgen_offset_base
(
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
param
)
{
int
oidx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
oidx
<
param
.
size
)
{
int
offset
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
nidx
;
++
i
)
{
int
data_idx
=
param
.
indexer
[
i
].
ptr
[
param
.
indexer
[
i
].
stride
*
oidx
];
auto
&
indexer
=
param
.
indexer
[
i
];
// index in index
int
idx_flat
=
0
,
coidx
=
oidx
;
#pragma unroll
for
(
int
j
=
idx_ndim
-
1
;
j
>=
0
;
--
j
)
{
int
ax_idx
;
if
(
j
)
{
int
next_coidx
=
coidx
/
indexer
.
shape
[
j
-
1
];
ax_idx
=
coidx
-
(
next_coidx
*
indexer
.
shape
[
j
-
1
].
divisor
());
coidx
=
next_coidx
;
}
else
{
ax_idx
=
coidx
;
}
idx_flat
+=
indexer
.
stride
[
j
]
*
ax_idx
;
}
int
data_idx
=
indexer
.
ptr
[
idx_flat
];
data_idx
+=
(
data_idx
<
0
?
param
.
data_shape
[
i
]
:
0
);
if
(
static_cast
<
uint32_t
>
(
data_idx
)
>=
param
.
data_shape
[
i
])
{
// cast to uint32 to handle both negative and overflow
...
...
@@ -36,17 +51,19 @@ __global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) {
i
,
data_idx
,
param
.
data_shape
[
i
]);
data_idx
=
0
;
}
// calculate offset from current index
offset
+=
data_idx
*
param
.
data_stride
[
i
];
}
// sum offsets and store at offset table
param
.
output
[
oidx
]
=
offset
;
}
}
}
// namespace
template
<
int
nidx
>
template
<
int
nidx
,
int
idx_ndim
>
void
indexing_multi_axis_vec
::
gen_offset_base
(
const
GenOffsetBaseParam
<
nidx
>&
param
,
cudaStream_t
stream
)
{
void
(
*
kptr
)(
GenOffsetBaseParam
<
nidx
>
)
=
kgen_offset_base
<
nidx
>
;
const
GenOffsetBaseParam
<
nidx
,
idx_ndim
>&
param
,
cudaStream_t
stream
)
{
void
(
*
kptr
)(
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
)
=
kgen_offset_base
<
nidx
,
idx_ndim
>
;
int
bsize
=
query_blocksize_for_kernel
(
kptr
);
(
*
kptr
)
<<<
DIVUP
(
param
.
size
,
bsize
),
bsize
,
0
,
stream
>>>
(
param
);
}
...
...
@@ -55,9 +72,17 @@ namespace megdnn {
namespace
cuda
{
namespace
indexing_multi_axis_vec
{
#define INST(_n) \
template void gen_offset_base(const GenOffsetBaseParam<_n>&, cudaStream_t);
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
)
#define INST(_m, _n) \
template void gen_offset_base(const GenOffsetBaseParam<_m, _n>&, cudaStream_t);
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
1
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
2
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
3
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
4
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
5
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
6
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
7
)
#undef INST
}
// namespace indexing_multi_axis_vec
...
...
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
浏览文件 @
68cdabd2
...
...
@@ -21,9 +21,10 @@ using namespace indexing_multi_axis_vec;
namespace
{
class
ExecImplHelper
{
template
<
int
nidx
,
int
idx_ndim
>
void
dispatch_gen_offset_base_nidx_ndim
();
template
<
int
nidx
>
void
dispatch_gen_offset_base_nidx
();
void
dispatch_gen_offset_base
();
protected:
...
...
@@ -38,6 +39,7 @@ protected:
int
*
const
m_offset_base
;
TensorLayout
m_value_layout_on_data
;
size_t
m_idx_axis
;
TensorShape
m_idx_shape
;
int
m_value_stride
;
public:
...
...
@@ -76,28 +78,30 @@ ExecImplHelper::ExecImplHelper(
m_exec_info
{
&
exec_info
},
m_offset_base
{
workspace
.
ptr
<
int
>
()}
{
safe_size_in_kern
(
data
.
layout
.
total_nr_elems
());
dispatch_gen_offset_base
();
std
::
tie
(
m_value_layout_on_data
,
m_idx_axis
)
=
std
::
tie
(
m_value_layout_on_data
,
m_idx_axis
,
m_idx_shape
)
=
IndexingMultiAxisVec
::
get_value_iter_optimized_layout
(
data
.
layout
,
value
.
layout
,
index
,
exec_info
.
idx_axis
);
dispatch_gen_offset_base
();
m_value_stride
=
exec_info
.
value_stride
;
}
template
<
int
nidx
>
void
ExecImplHelper
::
dispatch_gen_offset_base_nidx
()
{
GenOffsetBaseParam
<
nidx
>
param
;
param
.
size
=
m_
value
->
layout
.
shape
[
m_exec_info
->
idx_axis
]
;
template
<
int
nidx
,
int
idx_ndim
>
void
ExecImplHelper
::
dispatch_gen_offset_base_nidx
_ndim
()
{
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
param
;
param
.
size
=
m_
idx_shape
.
total_nr_elems
()
;
param
.
output
=
m_offset_base
;
param
.
error_tracker
=
m_exec_info
->
error_tracker
;
param
.
error_info
=
m_exec_info
->
error_info
;
megdnn_assert
(
m_idx_shape
.
ndim
==
idx_ndim
);
for
(
int
i
=
0
;
i
<
nidx
;
++
i
)
{
auto
&&
dst
=
param
.
indexer
[
i
];
auto
&&
src
=
m_index
->
operator
[](
i
);
megdnn_assert
(
src
.
vec
.
layout
.
ndim
==
1
);
dst
.
stride
=
src
.
vec
.
layout
.
stride
[
0
];
if
(
src
.
vec
.
layout
.
shape
[
0
]
==
1
)
{
dst
.
stride
=
0
;
auto
&&
src
=
m_index
->
at
(
i
);
auto
src_layout
=
src
.
vec
.
layout
.
broadcast
(
m_idx_shape
);
for
(
size_t
i
=
0
;
i
<
idx_ndim
;
++
i
)
{
if
(
i
)
{
dst
.
shape
[
i
-
1
]
=
src_layout
.
shape
[
i
];
}
dst
.
stride
[
i
]
=
src_layout
.
stride
[
i
];
}
dst
.
ptr
=
src
.
vec
.
ptr
<
int
>
();
param
.
data_shape
[
i
]
=
m_data
->
layout
.
shape
[
src
.
axis
];
...
...
@@ -106,6 +110,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
gen_offset_base
(
param
,
m_stream
);
}
template
<
int
nidx
>
void
ExecImplHelper
::
dispatch_gen_offset_base_nidx
()
{
switch
(
m_idx_shape
.
ndim
)
{
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM
(
cb
)
#undef cb
}
megdnn_throw
(
"bad index ndim"
);
}
void
ExecImplHelper
::
dispatch_gen_offset_base
()
{
switch
(
m_index
->
size
())
{
#define cb(_n) \
...
...
@@ -153,6 +169,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
param
.
data
=
m_data
->
ptr
<
ctype
>
();
param
.
value
=
m_value
->
ptr
<
ctype
>
();
param
.
idx_axis
=
m_idx_axis
;
param
.
idx_axis_end
=
m_idx_axis
+
m_idx_shape
.
ndim
;
param
.
idx_nelems
=
m_idx_shape
.
total_nr_elems
();
param
.
value_stride
=
m_value_stride
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
param
.
value_ly_on_data
.
stride
[
i
]
=
m_value_layout_on_data
.
stride
[
i
];
...
...
dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp
浏览文件 @
68cdabd2
...
...
@@ -33,37 +33,46 @@ void do_exec(
auto
data_layout
=
data
.
layout
;
auto
data_ptr
=
data
.
ptr
<
data_type
>
();
std
::
tuple
<
size_t
,
const
idx_type
*
,
ptrdiff_
t
>
index_raw
[
TensorLayout
::
MAX_NDIM
];
std
::
tuple
<
size_t
,
const
idx_type
*
,
TensorLayou
t
>
index_raw
[
TensorLayout
::
MAX_NDIM
];
size_t
nr_index
=
index
.
size
();
TensorShape
idx_shape
;
{
TensorShapeArray
idx_shapes
;
for
(
size_t
i
=
0
;
i
<
nr_index
;
++
i
)
{
idx_shapes
.
push_back
(
index
[
i
].
vec
.
layout
);
}
Elemwise
::
deduce_shape
(
idx_shapes
,
idx_shape
);
}
for
(
size_t
i
=
0
;
i
<
nr_index
;
++
i
)
{
auto
&&
s
=
index
[
i
];
index_raw
[
i
]
=
std
::
make_tuple
(
s
.
axis
,
s
.
vec
.
ptr
<
idx_type
>
(),
s
.
vec
.
layout
.
stride
[
0
]);
if
(
s
.
vec
.
layout
.
shape
[
0
]
==
1
)
std
::
get
<
2
>
(
index_raw
[
i
])
=
0
;
index_raw
[
i
]
=
std
::
make_tuple
(
s
.
axis
,
s
.
vec
.
ptr
<
idx_type
>
(),
s
.
vec
.
layout
.
broadcast
(
idx_shape
));
}
auto
value_iter
=
tensor_iter
<
data_type
>
(
value
).
begin
();
for
(
size_t
_
=
0
,
_t
=
value
.
layout
.
total_nr_elems
();
_
<
_t
;
++
_
)
{
ptrdiff_t
offset
=
0
;
auto
index_idx
=
value_iter
.
idx
()[
exec_info
.
idx_axis
]
;
auto
*
index_idx
=
value_iter
.
idx
()
+
exec_info
.
idx_axis
;
for
(
size_t
i
=
0
;
i
<
nr_index
;
++
i
)
{
size_t
axis
=
std
::
get
<
0
>
(
index_raw
[
i
]),
data_shape
=
data_layout
.
shape
[
axis
];
ptrdiff_t
data_stride
=
data_layout
.
stride
[
axis
];
idx_type
data_idx
=
std
::
get
<
1
>
(
index_raw
[
i
])[
std
::
get
<
2
>
(
index_raw
[
i
])
*
index_idx
];
size_t
index_offset
=
0
;
TensorLayout
&
index_layout
=
std
::
get
<
2
>
(
index_raw
[
i
]);
for
(
size_t
i
=
0
;
i
<
index_layout
.
ndim
;
++
i
)
{
index_offset
+=
index_idx
[
i
]
*
index_layout
.
stride
[
i
];
}
idx_type
data_idx
=
std
::
get
<
1
>
(
index_raw
[
i
])[
index_offset
];
if
(
data_idx
<
0
)
data_idx
+=
data_shape
;
megdnn_assert
(
data_idx
>=
0
&&
static_cast
<
size_t
>
(
data_idx
)
<
data_shape
,
"bad index value for index %zu at output %zu"
,
i
,
index_idx
);
"bad index value for index %zu at output %zu"
,
i
,
*
index_idx
);
offset
+=
data_stride
*
data_idx
;
}
for
(
size_t
i
=
0
;
i
<
nr_nonidx_axes
;
++
i
)
{
auto
stride
=
data_layout
.
stride
[
nonidx_axes
[
i
]];
auto
idx
=
value_iter
.
idx
()[
i
+
(
i
>=
exec_info
.
idx_axis
)];
auto
idx
=
value_iter
.
idx
()[
i
+
(
i
>=
exec_info
.
idx_axis
)
*
idx_shape
.
ndim
];
offset
+=
stride
*
idx
;
}
Opr
::
apply
(
data_ptr
[
offset
],
*
value_iter
);
...
...
dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip
浏览文件 @
68cdabd2
...
...
@@ -21,17 +21,23 @@ namespace rocm {
namespace
indexing_multi_axis_vec
{
//! AxisIndexer equiv in kernel
template
<
int
idx_ndim
>
struct
KAxisIndexer
{
int
stride
;
int
stride
[
idx_ndim
];
#ifdef WIN32
Uint32Fastdiv
shape
[
idx_ndim
];
#else
Uint32Fastdiv
shape
[
idx_ndim
-
1
];
#endif
const
int
*
ptr
;
};
//! param for gen_offset_base
template
<
int
nidx
>
template
<
int
nidx
,
int
idx_ndim
>
struct
GenOffsetBaseParam
{
uint32_t
size
;
//!< number of outputs; also size of each index
int
*
output
;
//!< output ptr
KAxisIndexer
indexer
[
nidx
];
KAxisIndexer
<
idx_ndim
>
indexer
[
nidx
];
uint32_t
data_shape
[
nidx
];
int
data_stride
[
nidx
];
...
...
@@ -60,6 +66,8 @@ namespace indexing_multi_axis_vec {
ctype
*
data
,
*
value
;
int
idx_axis
;
int
idx_axis_end
;
int
idx_nelems
;
int
value_stride
;
...
...
@@ -68,8 +76,8 @@ namespace indexing_multi_axis_vec {
};
//! generate offset bases for first axis in the output
template
<
int
nidx
>
void
gen_offset_base
(
const
GenOffsetBaseParam
<
nidx
>
&
param
,
template
<
int
nidx
,
int
idx_ndim
>
void
gen_offset_base
(
const
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
&
param
,
hipStream_t
stream
);
struct
OprAtomicIncr
{
...
...
dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl
浏览文件 @
68cdabd2
...
...
@@ -30,10 +30,17 @@ namespace {
uint32_t
oidx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
oidx
<
param
.
tot_size
)
{
int
offset
=
0
,
coidx
=
oidx
;
int
all_ax_idx
[
ndim
]
;
int
idx_flat
=
0
;
#pragma unroll
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
int
next_coidx
,
ax_idx
;
if
(
i
+
1
==
param
.
idx_axis_end
)
{
idx_flat
=
coidx
;
}
// may not trigger
if
(
i
+
1
==
param
.
idx_axis
)
{
idx_flat
-=
coidx
*
param
.
idx_nelems
;
}
if
(
i
)
{
next_coidx
=
coidx
/
param
.
value_ly_on_data
.
shape
[
i
-
1
];
ax_idx
=
...
...
@@ -45,9 +52,8 @@ namespace {
ax_idx
=
coidx
;
}
offset
+=
param
.
value_ly_on_data
.
stride
[
i
]
*
ax_idx
;
all_ax_idx
[
i
]
=
ax_idx
;
}
offset
+=
param
.
offset_base
[
all_ax_idx
[
param
.
idx_axis
]
];
offset
+=
param
.
offset_base
[
idx_flat
];
Opr
::
apply
(
param
.
data
[
offset
],
param
.
value
[
oidx
*
param
.
value_stride
]);
...
...
dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip
浏览文件 @
68cdabd2
...
...
@@ -21,15 +21,28 @@ using namespace rocm;
using
namespace
indexing_multi_axis_vec
;
namespace
{
template
<
int
nidx
>
__global__
void
kgen_offset_base
(
GenOffsetBaseParam
<
nidx
>
param
)
{
template
<
int
nidx
,
int
idx_ndim
>
__global__
void
kgen_offset_base
(
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
param
)
{
int
oidx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
oidx
<
param
.
size
)
{
int
offset
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
nidx
;
++
i
)
{
int
data_idx
=
param
.
indexer
[
i
].
ptr
[
param
.
indexer
[
i
].
stride
*
oidx
];
auto
&
indexer
=
param
.
indexer
[
i
];
int
offset2
=
0
,
coidx
=
oidx
;
#pragma unroll
for
(
int
j
=
idx_ndim
-
1
;
j
>=
0
;
--
j
)
{
int
ax_idx
;
if
(
j
)
{
int
next_coidx
=
coidx
/
indexer
.
shape
[
j
-
1
];
ax_idx
=
coidx
-
(
next_coidx
*
indexer
.
shape
[
j
-
1
].
divisor
());
coidx
=
next_coidx
;
}
else
{
ax_idx
=
coidx
;
}
offset2
+=
indexer
.
stride
[
j
]
*
ax_idx
;
}
int
data_idx
=
indexer
.
ptr
[
offset2
];
data_idx
+=
(
data_idx
<
0
?
param
.
data_shape
[
i
]
:
0
);
if
(
static_cast
<
uint32_t
>
(
data_idx
)
>=
param
.
data_shape
[
i
])
{
// cast to uint32 to handle both negative and overflow
...
...
@@ -50,20 +63,28 @@ namespace megdnn {
namespace
rocm
{
namespace
indexing_multi_axis_vec
{
#define INST(_n) \
#define INST(_
m, _
n) \
template void gen_offset_base( \
const GenOffsetBaseParam<_n> &, hipStream_t);
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
)
const GenOffsetBaseParam<_m, _n> &, hipStream_t);
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
1
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
2
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
3
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
4
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
5
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
6
)
MEGDNN_FOREACH_TENSOR_NDIM
(
INST
,
7
)
#undef INST
}
// namespace indexing_multi_axis_vec
}
// namespace rocm
}
// namespace megdnn
template
<
int
nidx
>
template
<
int
nidx
,
int
idx_ndim
>
void
indexing_multi_axis_vec
::
gen_offset_base
(
const
GenOffsetBaseParam
<
nidx
>
&
param
,
hipStream_t
stream
)
{
void
(
*
kptr
)(
GenOffsetBaseParam
<
nidx
>
)
=
kgen_offset_base
<
nidx
>
;
const
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
&
param
,
hipStream_t
stream
)
{
void
(
*
kptr
)(
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
)
=
kgen_offset_base
<
nidx
,
idx_ndim
>
;
int
bsize
=
256
;
hipLaunchKernelGGL
(
kptr
,
DIVUP
(
param
.
size
,
bsize
),
bsize
,
0
,
stream
,
...
...
dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp
浏览文件 @
68cdabd2
...
...
@@ -22,9 +22,10 @@ using namespace indexing_multi_axis_vec;
namespace
{
class
ExecImplHelper
{
template
<
int
nidx
,
int
idx_ndim
>
void
dispatch_gen_offset_base_nidx_ndim
();
template
<
int
nidx
>
void
dispatch_gen_offset_base_nidx
();
void
dispatch_gen_offset_base
();
protected:
...
...
@@ -39,6 +40,7 @@ protected:
int
*
const
m_offset_base
;
TensorLayout
m_value_layout_on_data
;
size_t
m_idx_axis
;
TensorShape
m_idx_shape
;
int
m_value_stride
;
public:
...
...
@@ -77,18 +79,17 @@ ExecImplHelper::ExecImplHelper(
m_exec_info
{
&
exec_info
},
m_offset_base
{
workspace
.
ptr
<
int
>
()}
{
safe_size_in_kern
(
data
.
layout
.
total_nr_elems
());
dispatch_gen_offset_base
();
std
::
tie
(
m_value_layout_on_data
,
m_idx_axis
)
=
std
::
tie
(
m_value_layout_on_data
,
m_idx_axis
,
m_idx_shape
)
=
IndexingMultiAxisVec
::
get_value_iter_optimized_layout
(
data
.
layout
,
value
.
layout
,
index
,
exec_info
.
idx_axis
);
dispatch_gen_offset_base
();
m_value_stride
=
exec_info
.
value_stride
;
}
template
<
int
nidx
>
void
ExecImplHelper
::
dispatch_gen_offset_base_nidx
()
{
GenOffsetBaseParam
<
nidx
>
param
;
param
.
size
=
m_
value
->
layout
.
shape
[
m_exec_info
->
idx_axis
]
;
template
<
int
nidx
,
int
idx_ndim
>
void
ExecImplHelper
::
dispatch_gen_offset_base_nidx
_ndim
()
{
GenOffsetBaseParam
<
nidx
,
idx_ndim
>
param
;
param
.
size
=
m_
idx_shape
.
total_nr_elems
()
;
param
.
output
=
m_offset_base
;
param
.
error_tracker
=
m_exec_info
->
error_tracker
;
param
.
error_info
=
m_exec_info
->
error_info
;
...
...
@@ -96,9 +97,12 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
auto
&&
dst
=
param
.
indexer
[
i
];
auto
&&
src
=
m_index
->
operator
[](
i
);
megdnn_assert
(
src
.
vec
.
layout
.
ndim
==
1
);
dst
.
stride
=
src
.
vec
.
layout
.
stride
[
0
];
if
(
src
.
vec
.
layout
.
shape
[
0
]
==
1
)
{
dst
.
stride
=
0
;
auto
src_layout
=
src
.
vec
.
layout
.
broadcast
(
m_idx_shape
);
for
(
size_t
i
=
0
;
i
<
idx_ndim
;
++
i
)
{
if
(
i
)
{
dst
.
shape
[
i
-
1
]
=
src_layout
.
shape
[
i
];
}
dst
.
stride
[
i
]
=
src_layout
.
stride
[
i
];
}
dst
.
ptr
=
src
.
vec
.
ptr
<
int
>
();
param
.
data_shape
[
i
]
=
m_data
->
layout
.
shape
[
src
.
axis
];
...
...
@@ -107,6 +111,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
gen_offset_base
(
param
,
m_stream
);
}
template
<
int
nidx
>
void
ExecImplHelper
::
dispatch_gen_offset_base_nidx
()
{
switch
(
m_idx_shape
.
ndim
)
{
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM
(
cb
)
#undef cb
}
megdnn_throw
(
"bad index ndim"
);
}
void
ExecImplHelper
::
dispatch_gen_offset_base
()
{
switch
(
m_index
->
size
())
{
#define cb(_n) \
...
...
@@ -154,6 +170,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
param
.
data
=
m_data
->
ptr
<
ctype
>
();
param
.
value
=
m_value
->
ptr
<
ctype
>
();
param
.
idx_axis
=
m_idx_axis
;
param
.
idx_axis_end
=
m_idx_axis
+
m_idx_shape
.
ndim
;
param
.
idx_nelems
=
m_idx_shape
.
total_nr_elems
();
param
.
value_stride
=
m_value_stride
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
param
.
value_ly_on_data
.
stride
[
i
]
=
m_value_layout_on_data
.
stride
[
i
];
...
...
dnn/test/common/indexing_multi_axis_vec.h
浏览文件 @
68cdabd2
...
...
@@ -46,6 +46,15 @@ struct OprProxyIndexingMultiAxisVecHelper {
return
ret
;
}
size_t
get_index_ndim
(
const
TensorNDArray
&
tensors
)
const
{
megdnn_assert
(
tensors
.
size
()
>=
3
);
size_t
ndim
=
0
;
for
(
size_t
i
=
2
;
i
<
tensors
.
size
();
++
i
)
{
ndim
=
std
::
max
(
tensors
[
i
].
layout
.
ndim
,
ndim
);
}
return
ndim
;
}
IndexingMultiAxisVec
::
IndexDescLayoutOnly
make_index_layout
(
const
TensorLayoutArray
&
layouts
)
const
{
megdnn_assert
(
layouts
.
size
()
>=
3
);
...
...
@@ -65,7 +74,8 @@ struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelpe
void
exec
(
IndexingMultiAxisVec
*
opr
,
const
TensorNDArray
&
tensors
)
const
{
WorkspaceWrapper
W
(
opr
->
handle
(),
opr
->
get_workspace_in_bytes
(
tensors
[
1
].
layout
,
axes
,
tensors
.
size
()
-
2
));
tensors
[
1
].
layout
,
axes
,
tensors
.
size
()
-
2
,
get_index_ndim
(
tensors
)));
opr
->
exec
(
tensors
[
0
],
make_index_desc
(
tensors
),
tensors
[
1
],
W
.
workspace
());
}
...
...
@@ -81,7 +91,8 @@ struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecH
void
exec
(
IndexingIncrMultiAxisVec
*
opr
,
const
TensorNDArray
&
tensors
)
const
{
WorkspaceWrapper
W
(
opr
->
handle
(),
opr
->
get_workspace_in_bytes
(
tensors
[
1
].
layout
,
axes
,
tensors
.
size
()
-
2
));
tensors
[
1
].
layout
,
axes
,
tensors
.
size
()
-
2
,
get_index_ndim
(
tensors
)));
opr
->
exec
(
tensors
[
0
],
tensors
[
1
],
make_index_desc
(
tensors
),
W
.
workspace
());
}
...
...
@@ -95,7 +106,8 @@ struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHe
void
exec
(
IndexingSetMultiAxisVec
*
opr
,
const
TensorNDArray
&
tensors
)
const
{
WorkspaceWrapper
W
(
opr
->
handle
(),
opr
->
get_workspace_in_bytes
(
tensors
[
1
].
layout
,
axes
,
tensors
.
size
()
-
2
));
tensors
[
1
].
layout
,
axes
,
tensors
.
size
()
-
2
,
get_index_ndim
(
tensors
)));
opr
->
exec
(
tensors
[
0
],
tensors
[
1
],
make_index_desc
(
tensors
),
W
.
workspace
());
}
...
...
dnn/test/common/mesh_indexing.h
浏览文件 @
68cdabd2
...
...
@@ -27,7 +27,7 @@ namespace test {
WorkspaceWrapper W( \
opr->handle(), \
opr->get_workspace_in_bytes( \
tensors[1].layout, axes, tensors.size() - 2
));
\
tensors[1].layout, axes, tensors.size() - 2
, 1));
\
opr->exec( \
tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
} \
...
...
@@ -46,7 +46,7 @@ namespace test {
WorkspaceWrapper W( \
opr->handle(), \
opr->get_workspace_in_bytes( \
tensors[1].layout, axes, tensors.size() - 2
));
\
tensors[1].layout, axes, tensors.size() - 2
, 1));
\
opr->exec( \
tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
} \
...
...
dnn/test/cuda/indexing_multi_axis_vec.cpp
浏览文件 @
68cdabd2
...
...
@@ -132,6 +132,25 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
TensorLayout
{
TensorShape
{
9
},
{
-
1
},
dtype
::
Int32
()}});
}
TEST_F
(
CUDA
,
INDEXING_MULTI_AXIS_VEC_ND_INDEX
)
{
run_check
<
IndexingMultiAxisVec
>
(
handle_cuda
());
Checker
<
IndexingMultiAxisVec
>
checker
(
handle_cuda
());
OrderedRNG
rng
;
checker
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Int32
())
.
set_dtype
(
3
,
dtype
::
Int32
())
.
set_dtype
(
4
,
dtype
::
Int32
())
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
.
set_rng
(
3
,
&
rng
)
.
set_rng
(
4
,
&
rng
);
checker
.
set_proxy
({{
1
,
2
,
3
}})
.
execs
({{
5
,
5
,
6
,
7
,
3
},
{
5
,
2
,
3
,
4
,
3
},
{
3
,
1
},
{
2
,
1
,
1
},
{
1
,
4
}});
}
TEST_F
(
CUDA
,
INDEXING_INCR_MULTI_AXIS_VEC
)
{
run_check
<
IndexingIncrMultiAxisVec
>
(
handle_cuda
());
Checker
<
IndexingIncrMultiAxisVec
>
checker
(
handle_cuda
());
...
...
imperative/python/test/unit/core/test_indexing_op.py
浏览文件 @
68cdabd2
...
...
@@ -708,3 +708,19 @@ def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic):
run_test
((
10
,
10
,
0
),
test4
)
run_test
((
10
,
10
,
10
),
test3
)
run_test
((
10
,
10
,
10
),
test4
)
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
,
None
])
def
test_nd_int_indexing
(
symbolic
):
inp
=
np
.
arange
(
11
)
idx
=
np
.
random
.
randint
(
11
,
size
=
(
5
,
7
))
def
run_test
(
args
,
fn
):
npy_out
=
fn
(
*
args
)
if
symbolic
:
fn
=
jit
.
trace
(
symbolic
=
symbolic
)(
fn
)
for
_
in
range
(
3
):
out
=
fn
(
*
[
Tensor
(
arg
)
for
arg
in
args
])
np
.
testing
.
assert_equal
(
out
.
numpy
(),
npy_out
)
run_test
([
inp
,
idx
],
lambda
inp
,
idx
:
inp
[
idx
])
src/opr/impl/indexing.cpp
浏览文件 @
68cdabd2
...
...
@@ -197,9 +197,15 @@ Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
template
<
class
Opr
>
void
mixin
::
IndexingMultiAxisVecMegDNNOprHolder
<
Opr
>::
register_workspace_infer
(
const
indexing
::
IndexDesc
&
index_desc
,
cg
::
SingleCNOperatorNodeBase
&
opr
,
VarNode
*
data
,
VarNode
*
value
)
{
VarNode
*
data
,
VarNode
*
value
,
VarNodeArray
idx_arr
)
{
using
namespace
cg
::
static_infer
;
auto
infer_shape
=
[
this
,
&
index_desc
,
&
opr
](
TensorShape
&
dest
,
const
InpVal
&
inp
)
{
DepVal
deps
=
{{
data
,
DepType
::
SHAPE
},
{
value
,
DepType
::
SHAPE
}};
for
(
auto
&&
idx
:
idx_arr
)
{
deps
.
push_back
({
idx
,
DepType
::
SHAPE
});
}
auto
infer_shape
=
[
this
,
&
index_desc
,
&
opr
,
nr_idx
=
idx_arr
.
size
()](
TensorShape
&
dest
,
const
InpVal
&
inp
)
{
size_t
axes
[
TensorShape
::
MAX_NDIM
],
nr_axes
=
0
;
auto
ndim
=
inp
.
val
[
0
].
shape
().
ndim
;
for
(
auto
&&
i
:
reverse_adaptor
(
index_desc
))
{
...
...
@@ -207,18 +213,22 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
axes
[
nr_axes
++
]
=
i
.
axis
.
get
(
ndim
);
}
}
mgb_assert
(
nr_axes
==
nr_idx
);
if
(
!
nr_axes
)
{
dest
=
{
0
};
}
else
{
size_t
idx_ndim
=
0
;
for
(
size_t
i
=
0
;
i
<
nr_idx
;
++
i
)
{
idx_ndim
=
std
::
max
(
idx_ndim
,
inp
.
val
[
2
+
i
].
shape
().
ndim
);
}
mgb_assert
(
idx_ndim
>
0
);
dest
=
{
megdnn_opr
(
opr
).
get_workspace_in_bytes
(
inp
.
val
[
1
].
shape
(),
axes
,
nr_axes
)};
inp
.
val
[
1
].
shape
(),
axes
,
nr_axes
,
idx_ndim
)};
}
return
true
;
};
opr
.
owner_graph
()
->
static_infer_manager
().
register_shape_infer
(
opr
.
output
(
1
),
{
SourceType
::
DEP
,
{{
data
,
DepType
::
SHAPE
},
{
value
,
DepType
::
SHAPE
}},
infer_shape
});
opr
.
output
(
1
),
{
SourceType
::
DEP
,
deps
,
infer_shape
});
}
template
<
class
Opr
>
...
...
@@ -342,8 +352,13 @@ void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() {
};
owner_graph
()
->
static_infer_manager
().
register_shape_infer
(
output
(
0
),
{
SourceType
::
DEP
,
deps
,
infer_shape
});
this
->
register_workspace_infer
(
index_desc
(),
*
this
,
input
(
0
),
output
(
0
));
VarNodeArray
idx_arr
;
for
(
size_t
i
=
1
;
i
<
m_input2idxonly_axis_indexer
.
size
();
++
i
)
{
if
(
m_input2idxonly_axis_indexer
[
i
])
{
idx_arr
.
push_back
(
input
(
i
));
}
}
this
->
register_workspace_infer
(
index_desc
(),
*
this
,
input
(
0
),
output
(
0
),
idx_arr
);
}
template
<
class
Opr
>
...
...
@@ -401,7 +416,13 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc(
this
->
owner_graph
()
->
static_infer_manager
().
register_shape_infer
(
this
->
output
(
0
),
ShapeInferDesc
::
make_identity
(
this
->
input
(
0
)));
this
->
register_workspace_infer
(
index_desc
(),
*
this
,
input
(
0
),
input
(
1
));
VarNodeArray
idx_arr
;
for
(
size_t
i
=
1
;
i
<
m_input2idxonly_axis_indexer
.
size
();
++
i
)
{
if
(
m_input2idxonly_axis_indexer
[
i
])
{
idx_arr
.
push_back
(
input
(
i
));
}
}
this
->
register_workspace_infer
(
index_desc
(),
*
this
,
input
(
0
),
input
(
1
),
idx_arr
);
}
template
<
class
Opr
>
...
...
src/opr/include/megbrain/opr/indexing.h
浏览文件 @
68cdabd2
...
...
@@ -96,7 +96,7 @@ protected:
void
register_workspace_infer
(
const
indexing
::
IndexDesc
&
index_desc
,
cg
::
SingleCNOperatorNodeBase
&
opr
,
VarNode
*
data
,
VarNode
*
value
);
VarNode
*
data
,
VarNode
*
value
,
VarNodeArray
idx_arr
);
void
record_megdnn_opr
(
mgb
::
cg
::
GraphExecutable
::
ExecDependencyArray
&
deps
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录