提交 68cdabd2 编写于 作者: M Megvii Engine Team

feat(opr): indexing_multi_axis_vec support nd index

GitOrigin-RevId: 07b1248bdcaa8d12c91220eb482090ece16a0a10
上级 05ee6038
...@@ -1115,7 +1115,7 @@ public: ...@@ -1115,7 +1115,7 @@ public:
* access *data*; stride of layout on that axis would be zero, and * access *data*; stride of layout on that axis would be zero, and
* strides on other axes correspond to the strides in *data* * strides on other axes correspond to the strides in *data*
*/ */
static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout( static std::tuple<TensorLayout, size_t, TensorShape> get_value_iter_optimized_layout(
const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
size_t idx_axis); size_t idx_axis);
...@@ -1159,7 +1159,8 @@ public: ...@@ -1159,7 +1159,8 @@ public:
* \brief get workspace size based on output shape and indexing axes * \brief get workspace size based on output shape and indexing axes
*/ */
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorShape& dst, const size_t* axes, size_t nr_axes); const TensorShape& dst, const size_t* axes, size_t nr_axes,
size_t idx_ndim);
static void deduce_layout( static void deduce_layout(
const TensorLayout& data, const IndexDescLayoutOnly& index, const TensorLayout& data, const IndexDescLayoutOnly& index,
...@@ -1193,7 +1194,8 @@ public: ...@@ -1193,7 +1194,8 @@ public:
* axes * axes
*/ */
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorShape& value, const size_t* axes, size_t nr_axes); const TensorShape& value, const size_t* axes, size_t nr_axes,
size_t idx_ndim);
protected: protected:
ExecInfo check_exec( ExecInfo check_exec(
...@@ -1223,7 +1225,7 @@ public: ...@@ -1223,7 +1225,7 @@ public:
using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly; using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly; using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) { size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t, size_t) {
return 0; return 0;
} }
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
using namespace megdnn; using namespace megdnn;
namespace { namespace {
// we need a workspace to store offset base table, which has same size with index
size_t get_index_size_for_workspace( size_t get_index_size_for_workspace(
const TensorShape& shp, const size_t* axes, size_t nr_axes) { const TensorShape& shp, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
size_t idx_axis = axes[0]; size_t idx_axis = axes[0];
megdnn_assert(shp.ndim && nr_axes); megdnn_assert(shp.ndim && nr_axes);
for (size_t i = 1; i < nr_axes; ++i) { for (size_t i = 1; i < nr_axes; ++i) {
...@@ -29,7 +31,11 @@ size_t get_index_size_for_workspace( ...@@ -29,7 +31,11 @@ size_t get_index_size_for_workspace(
megdnn_assert( megdnn_assert(
shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis, shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis,
shp.to_string().c_str()); shp.to_string().c_str());
return shp.shape[idx_axis]; size_t idx_size = 1;
for (size_t i = 0; i < idx_ndim; ++i) {
idx_size *= shp.shape[idx_axis + i];
}
return idx_size;
} }
} // anonymous namespace } // anonymous namespace
...@@ -47,23 +53,17 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( ...@@ -47,23 +53,17 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) { const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) {
megdnn_assert(!index.empty()); megdnn_assert(!index.empty());
megdnn_assert(data.ndim >= index.size()); megdnn_assert(data.ndim >= index.size());
dst.ndim = data.ndim - index.size() + 1; dst.ndim = data.ndim - index.size();
dst.shape[0] = 1;
dst.dtype = data.dtype; dst.dtype = data.dtype;
TensorShapeArray index_shapes;
auto brdcast = [&](const TensorLayout& ly) { auto brdcast = [&](const TensorLayout& ly) {
if (ly.ndim != 1) megdnn_assert(ly.dtype == dtype::Int32{});
return false; index_shapes.push_back(ly);
if (dst.shape[0] == ly.shape[0])
return true;
if (dst.shape[0] == 1) {
dst.shape[0] = ly.shape[0];
return true;
}
return ly.shape[0] == 1;
}; };
size_t dst_axis = 1; size_t dst_axis = 0;
ptrdiff_t prev_axis = -1; ptrdiff_t prev_axis = -1;
for (size_t axis = 0; axis < index.size(); ++axis) { for (size_t axis = 0; axis < index.size(); ++axis) {
auto&& idx = index[axis]; auto&& idx = index[axis];
...@@ -73,10 +73,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( ...@@ -73,10 +73,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
megdnn_assert( megdnn_assert(
idx.axis<data.ndim&& static_cast<ptrdiff_t>(idx.axis)> prev_axis, idx.axis<data.ndim&& static_cast<ptrdiff_t>(idx.axis)> prev_axis,
"index %zu requests invalid axis %zu", axis, idx.axis); "index %zu requests invalid axis %zu", axis, idx.axis);
auto brd_succ = brdcast(idx.layout); brdcast(idx.layout);
megdnn_assert(
brd_succ, "invalid layout at index %zu: %s", axis,
idx.layout.to_string().c_str());
for (size_t i = prev_axis + 1; i < idx.axis; ++i) { for (size_t i = prev_axis + 1; i < idx.axis; ++i) {
dst.shape[dst_axis++] = data.shape[i]; dst.shape[dst_axis++] = data.shape[i];
...@@ -99,15 +96,18 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( ...@@ -99,15 +96,18 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
} }
} }
if (contig_idx) { if (contig_idx) {
auto shp0 = dst.shape[0];
idx_axis = index[0].axis; idx_axis = index[0].axis;
for (size_t i = 0; i < idx_axis; ++i) {
dst.shape[i] = dst.shape[i + 1];
}
dst.shape[idx_axis] = shp0;
} }
} }
TensorShape index_shape;
Elemwise::deduce_shape(index_shapes, index_shape);
for (size_t i = 0; i < index_shape.ndim; ++i) {
dst.add_axis_inplace(idx_axis + i, 1, 0);
dst.shape[idx_axis + i] = index_shape.shape[i];
}
dst.init_contiguous_stride(); dst.init_contiguous_stride();
return idx_axis; return idx_axis;
} }
...@@ -145,15 +145,26 @@ IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworksp ...@@ -145,15 +145,26 @@ IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworksp
return ret; return ret;
} }
std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase:: std::tuple<TensorLayout, size_t, TensorShape> IndexingMultiAxisVecBase::
get_value_iter_optimized_layout( get_value_iter_optimized_layout(
const TensorLayout& data, const TensorLayout& value, const TensorLayout& data, const TensorLayout& value,
const IndexDesc& index, size_t idx_axis) { const IndexDesc& index, size_t idx_axis) {
size_t data_axes[TensorLayout::MAX_NDIM], size_t data_axes[TensorLayout::MAX_NDIM],
nr_axes = get_nonindex_axes(data.ndim, index, data_axes); nr_axes = get_nonindex_axes(data.ndim, index, data_axes);
// broadcast index shapes
TensorLayout index_shape;
{
TensorShapeArray index_shapes;
for (auto& idx : index) {
megdnn_assert(idx.vec.layout.dtype == dtype::Int32{});
index_shapes.push_back(idx.vec.layout);
}
Elemwise::deduce_shape(index_shapes, index_shape);
}
megdnn_assert( megdnn_assert(
nr_axes == value.ndim - 1 && idx_axis < value.ndim && nr_axes == value.ndim - index_shape.ndim && idx_axis < value.ndim &&
nr_axes + index.size() == data.ndim); nr_axes + index.size() == data.ndim);
TensorLayout ret; TensorLayout ret;
...@@ -165,10 +176,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase:: ...@@ -165,10 +176,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
} }
ret = ret.collapse_contiguous(); ret = ret.collapse_contiguous();
} }
ret.shape[ret.ndim] = value.shape[idx_axis];
ret.stride[ret.ndim] = 0;
size_t ret_idx_axis = ret.ndim; size_t ret_idx_axis = ret.ndim;
++ret.ndim; for (size_t i = 0; i < index_shape.ndim; ++i) {
ret.shape[ret.ndim] = value.shape[idx_axis + i];
ret.stride[ret.ndim] = 0;
++ret.ndim;
}
if (idx_axis < nr_axes) { if (idx_axis < nr_axes) {
TensorLayout tail; TensorLayout tail;
...@@ -185,12 +199,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase:: ...@@ -185,12 +199,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
} }
} }
return {ret, ret_idx_axis}; return std::make_tuple(ret, ret_idx_axis, index_shape);
} }
size_t IndexingMultiAxisVec::get_workspace_in_bytes( size_t IndexingMultiAxisVec::get_workspace_in_bytes(
const TensorShape& dst, const size_t* axes, size_t nr_axes) { const TensorShape& dst, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
return get_workspace_in_bytes(get_index_size_for_workspace(dst, axes, nr_axes)); return get_workspace_in_bytes(
get_index_size_for_workspace(dst, axes, nr_axes, idx_ndim));
} }
IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
...@@ -205,8 +220,9 @@ IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( ...@@ -205,8 +220,9 @@ IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
} }
size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes( size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes(
const TensorShape& value, const size_t* axes, size_t nr_axes) { const TensorShape& value, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
return get_workspace_in_bytes(get_index_size_for_workspace(value, axes, nr_axes)); return get_workspace_in_bytes(
get_index_size_for_workspace(value, axes, nr_axes, idx_ndim));
} }
IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec( IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec(
......
...@@ -21,17 +21,24 @@ namespace cuda { ...@@ -21,17 +21,24 @@ namespace cuda {
namespace indexing_multi_axis_vec { namespace indexing_multi_axis_vec {
//! AxisIndexer equiv in kernel //! AxisIndexer equiv in kernel
template <int idx_ndim>
struct KAxisIndexer { struct KAxisIndexer {
int stride; int stride[idx_ndim];
#ifdef WIN32
Uint32Fastdiv shape[idx_ndim];
#else
// original shape[0] not storaged
Uint32Fastdiv shape[idx_ndim - 1];
#endif
const int* ptr; const int* ptr;
}; };
//! param for gen_offset_base //! param for gen_offset_base
template <int nidx> template <int nidx, int idx_ndim>
struct GenOffsetBaseParam { struct GenOffsetBaseParam {
uint32_t size; //!< number of outputs; also size of each index uint32_t size; //!< number of outputs; also size of each index
int* output; //!< output ptr int* output; //!< output ptr
KAxisIndexer indexer[nidx]; KAxisIndexer<idx_ndim> indexer[nidx];
uint32_t data_shape[nidx]; uint32_t data_shape[nidx];
int data_stride[nidx]; int data_stride[nidx];
...@@ -59,7 +66,12 @@ struct ApplyOprParam { ...@@ -59,7 +66,12 @@ struct ApplyOprParam {
const int* offset_base; const int* offset_base;
ctype *data, *value; ctype *data, *value;
// first idx axis
int idx_axis; int idx_axis;
// last idx axis + 1
int idx_axis_end;
// number of elements for idx shape
int idx_nelems;
int value_stride; int value_stride;
...@@ -68,8 +80,9 @@ struct ApplyOprParam { ...@@ -68,8 +80,9 @@ struct ApplyOprParam {
}; };
//! generate offset bases for first axis in the output //! generate offset bases for first axis in the output
template <int nidx> template <int nidx, int idx_ndim>
void gen_offset_base(const GenOffsetBaseParam<nidx>& param, cudaStream_t stream); void gen_offset_base(
const GenOffsetBaseParam<nidx, idx_ndim>& param, cudaStream_t stream);
struct OprAtomicIncr { struct OprAtomicIncr {
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
......
...@@ -29,11 +29,23 @@ namespace { ...@@ -29,11 +29,23 @@ namespace {
uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.tot_size) { if (oidx < param.tot_size) {
int offset = 0, coidx = oidx; int offset = 0, coidx = oidx;
int all_ax_idx[ndim]; // offset in index
int idx_flat = 0;
// for non-indexed axes get offset
#pragma unroll #pragma unroll
for (int i = ndim - 1; i >= 0; -- i) { for (int i = ndim - 1; i >= 0; -- i) {
int next_coidx, ax_idx; int next_coidx, ax_idx;
// [..., indexed_axes... |, ...]
if (i + 1 == param.idx_axis_end) {
idx_flat = coidx;
}
// [... |, indexed_axes..., ...]
if (i + 1 == param.idx_axis) {
idx_flat -= coidx * param.idx_nelems;
}
// shape[i] was storaged at shape[i-1]
if (i) { if (i) {
// fast divide
next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; next_coidx = coidx / param.value_ly_on_data.shape[i - 1];
ax_idx = ax_idx =
coidx - coidx -
...@@ -44,9 +56,9 @@ namespace { ...@@ -44,9 +56,9 @@ namespace {
ax_idx = coidx; ax_idx = coidx;
} }
offset += param.value_ly_on_data.stride[i] * ax_idx; offset += param.value_ly_on_data.stride[i] * ax_idx;
all_ax_idx[i] = ax_idx;
} }
offset += param.offset_base[all_ax_idx[param.idx_axis]]; // offset from index, which was generated before
offset += param.offset_base[idx_flat];
Opr::apply( Opr::apply(
param.data[offset], param.data[offset],
param.value[oidx * param.value_stride]); param.value[oidx * param.value_stride]);
......
...@@ -18,14 +18,29 @@ using namespace cuda; ...@@ -18,14 +18,29 @@ using namespace cuda;
using namespace indexing_multi_axis_vec; using namespace indexing_multi_axis_vec;
namespace { namespace {
template <int nidx> template <int nidx, int idx_ndim>
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) { __global__ void kgen_offset_base(GenOffsetBaseParam<nidx, idx_ndim> param) {
int oidx = threadIdx.x + blockDim.x * blockIdx.x; int oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.size) { if (oidx < param.size) {
int offset = 0; int offset = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
int data_idx = param.indexer[i].ptr[param.indexer[i].stride * oidx]; auto& indexer = param.indexer[i];
// index in index
int idx_flat = 0, coidx = oidx;
#pragma unroll
for (int j = idx_ndim - 1; j >= 0; --j) {
int ax_idx;
if (j) {
int next_coidx = coidx / indexer.shape[j - 1];
ax_idx = coidx - (next_coidx * indexer.shape[j - 1].divisor());
coidx = next_coidx;
} else {
ax_idx = coidx;
}
idx_flat += indexer.stride[j] * ax_idx;
}
int data_idx = indexer.ptr[idx_flat];
data_idx += (data_idx < 0 ? param.data_shape[i] : 0); data_idx += (data_idx < 0 ? param.data_shape[i] : 0);
if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) { if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) {
// cast to uint32 to handle both negative and overflow // cast to uint32 to handle both negative and overflow
...@@ -36,17 +51,19 @@ __global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) { ...@@ -36,17 +51,19 @@ __global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) {
i, data_idx, param.data_shape[i]); i, data_idx, param.data_shape[i]);
data_idx = 0; data_idx = 0;
} }
// calculate offset from current index
offset += data_idx * param.data_stride[i]; offset += data_idx * param.data_stride[i];
} }
// sum offsets and store at offset table
param.output[oidx] = offset; param.output[oidx] = offset;
} }
} }
} // namespace } // namespace
template <int nidx> template <int nidx, int idx_ndim>
void indexing_multi_axis_vec::gen_offset_base( void indexing_multi_axis_vec::gen_offset_base(
const GenOffsetBaseParam<nidx>& param, cudaStream_t stream) { const GenOffsetBaseParam<nidx, idx_ndim>& param, cudaStream_t stream) {
void (*kptr)(GenOffsetBaseParam<nidx>) = kgen_offset_base<nidx>; void (*kptr)(GenOffsetBaseParam<nidx, idx_ndim>) = kgen_offset_base<nidx, idx_ndim>;
int bsize = query_blocksize_for_kernel(kptr); int bsize = query_blocksize_for_kernel(kptr);
(*kptr)<<<DIVUP(param.size, bsize), bsize, 0, stream>>>(param); (*kptr)<<<DIVUP(param.size, bsize), bsize, 0, stream>>>(param);
} }
...@@ -55,9 +72,17 @@ namespace megdnn { ...@@ -55,9 +72,17 @@ namespace megdnn {
namespace cuda { namespace cuda {
namespace indexing_multi_axis_vec { namespace indexing_multi_axis_vec {
#define INST(_n) \ #define INST(_m, _n) \
template void gen_offset_base(const GenOffsetBaseParam<_n>&, cudaStream_t); template void gen_offset_base(const GenOffsetBaseParam<_m, _n>&, cudaStream_t);
MEGDNN_FOREACH_TENSOR_NDIM(INST)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 1)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 2)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 3)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 4)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 5)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 6)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 7)
#undef INST #undef INST
} // namespace indexing_multi_axis_vec } // namespace indexing_multi_axis_vec
......
...@@ -21,9 +21,10 @@ using namespace indexing_multi_axis_vec; ...@@ -21,9 +21,10 @@ using namespace indexing_multi_axis_vec;
namespace { namespace {
class ExecImplHelper { class ExecImplHelper {
template <int nidx, int idx_ndim>
void dispatch_gen_offset_base_nidx_ndim();
template <int nidx> template <int nidx>
void dispatch_gen_offset_base_nidx(); void dispatch_gen_offset_base_nidx();
void dispatch_gen_offset_base(); void dispatch_gen_offset_base();
protected: protected:
...@@ -38,6 +39,7 @@ protected: ...@@ -38,6 +39,7 @@ protected:
int* const m_offset_base; int* const m_offset_base;
TensorLayout m_value_layout_on_data; TensorLayout m_value_layout_on_data;
size_t m_idx_axis; size_t m_idx_axis;
TensorShape m_idx_shape;
int m_value_stride; int m_value_stride;
public: public:
...@@ -76,28 +78,30 @@ ExecImplHelper::ExecImplHelper( ...@@ -76,28 +78,30 @@ ExecImplHelper::ExecImplHelper(
m_exec_info{&exec_info}, m_exec_info{&exec_info},
m_offset_base{workspace.ptr<int>()} { m_offset_base{workspace.ptr<int>()} {
safe_size_in_kern(data.layout.total_nr_elems()); safe_size_in_kern(data.layout.total_nr_elems());
dispatch_gen_offset_base(); std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) =
std::tie(m_value_layout_on_data, m_idx_axis) =
IndexingMultiAxisVec::get_value_iter_optimized_layout( IndexingMultiAxisVec::get_value_iter_optimized_layout(
data.layout, value.layout, index, exec_info.idx_axis); data.layout, value.layout, index, exec_info.idx_axis);
dispatch_gen_offset_base();
m_value_stride = exec_info.value_stride; m_value_stride = exec_info.value_stride;
} }
template <int nidx> template <int nidx, int idx_ndim>
void ExecImplHelper::dispatch_gen_offset_base_nidx() { void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() {
GenOffsetBaseParam<nidx> param; GenOffsetBaseParam<nidx, idx_ndim> param;
param.size = m_value->layout.shape[m_exec_info->idx_axis]; param.size = m_idx_shape.total_nr_elems();
param.output = m_offset_base; param.output = m_offset_base;
param.error_tracker = m_exec_info->error_tracker; param.error_tracker = m_exec_info->error_tracker;
param.error_info = m_exec_info->error_info; param.error_info = m_exec_info->error_info;
megdnn_assert(m_idx_shape.ndim == idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
auto&& dst = param.indexer[i]; auto&& dst = param.indexer[i];
auto&& src = m_index->operator[](i); auto&& src = m_index->at(i);
megdnn_assert(src.vec.layout.ndim == 1); auto src_layout = src.vec.layout.broadcast(m_idx_shape);
dst.stride = src.vec.layout.stride[0]; for (size_t i = 0; i < idx_ndim; ++i) {
if (src.vec.layout.shape[0] == 1) { if (i) {
dst.stride = 0; dst.shape[i - 1] = src_layout.shape[i];
}
dst.stride[i] = src_layout.stride[i];
} }
dst.ptr = src.vec.ptr<int>(); dst.ptr = src.vec.ptr<int>();
param.data_shape[i] = m_data->layout.shape[src.axis]; param.data_shape[i] = m_data->layout.shape[src.axis];
...@@ -106,6 +110,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { ...@@ -106,6 +110,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
gen_offset_base(param, m_stream); gen_offset_base(param, m_stream);
} }
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
switch (m_idx_shape.ndim) {
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
megdnn_throw("bad index ndim");
}
void ExecImplHelper::dispatch_gen_offset_base() { void ExecImplHelper::dispatch_gen_offset_base() {
switch (m_index->size()) { switch (m_index->size()) {
#define cb(_n) \ #define cb(_n) \
...@@ -153,6 +169,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() { ...@@ -153,6 +169,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
param.data = m_data->ptr<ctype>(); param.data = m_data->ptr<ctype>();
param.value = m_value->ptr<ctype>(); param.value = m_value->ptr<ctype>();
param.idx_axis = m_idx_axis; param.idx_axis = m_idx_axis;
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim;
param.idx_nelems = m_idx_shape.total_nr_elems();
param.value_stride = m_value_stride; param.value_stride = m_value_stride;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i];
......
...@@ -33,37 +33,46 @@ void do_exec( ...@@ -33,37 +33,46 @@ void do_exec(
auto data_layout = data.layout; auto data_layout = data.layout;
auto data_ptr = data.ptr<data_type>(); auto data_ptr = data.ptr<data_type>();
std::tuple<size_t, const idx_type*, ptrdiff_t> index_raw[TensorLayout::MAX_NDIM]; std::tuple<size_t, const idx_type*, TensorLayout> index_raw[TensorLayout::MAX_NDIM];
size_t nr_index = index.size(); size_t nr_index = index.size();
TensorShape idx_shape;
{
TensorShapeArray idx_shapes;
for (size_t i = 0; i < nr_index; ++i) {
idx_shapes.push_back(index[i].vec.layout);
}
Elemwise::deduce_shape(idx_shapes, idx_shape);
}
for (size_t i = 0; i < nr_index; ++i) { for (size_t i = 0; i < nr_index; ++i) {
auto&& s = index[i]; auto&& s = index[i];
index_raw[i] = index_raw[i] = std::make_tuple(
std::make_tuple(s.axis, s.vec.ptr<idx_type>(), s.vec.layout.stride[0]); s.axis, s.vec.ptr<idx_type>(), s.vec.layout.broadcast(idx_shape));
if (s.vec.layout.shape[0] == 1)
std::get<2>(index_raw[i]) = 0;
} }
auto value_iter = tensor_iter<data_type>(value).begin(); auto value_iter = tensor_iter<data_type>(value).begin();
for (size_t _ = 0, _t = value.layout.total_nr_elems(); _ < _t; ++_) { for (size_t _ = 0, _t = value.layout.total_nr_elems(); _ < _t; ++_) {
ptrdiff_t offset = 0; ptrdiff_t offset = 0;
auto index_idx = value_iter.idx()[exec_info.idx_axis]; auto* index_idx = value_iter.idx() + exec_info.idx_axis;
for (size_t i = 0; i < nr_index; ++i) { for (size_t i = 0; i < nr_index; ++i) {
size_t axis = std::get<0>(index_raw[i]), size_t axis = std::get<0>(index_raw[i]),
data_shape = data_layout.shape[axis]; data_shape = data_layout.shape[axis];
ptrdiff_t data_stride = data_layout.stride[axis]; ptrdiff_t data_stride = data_layout.stride[axis];
idx_type data_idx = size_t index_offset = 0;
std::get<1>(index_raw[i])[std::get<2>(index_raw[i]) * index_idx]; TensorLayout& index_layout = std::get<2>(index_raw[i]);
for (size_t i = 0; i < index_layout.ndim; ++i) {
index_offset += index_idx[i] * index_layout.stride[i];
}
idx_type data_idx = std::get<1>(index_raw[i])[index_offset];
if (data_idx < 0) if (data_idx < 0)
data_idx += data_shape; data_idx += data_shape;
megdnn_assert( megdnn_assert(
data_idx >= 0 && static_cast<size_t>(data_idx) < data_shape, data_idx >= 0 && static_cast<size_t>(data_idx) < data_shape,
"bad index value for index %zu at output %zu", i, index_idx); "bad index value for index %zu at output %zu", i, *index_idx);
offset += data_stride * data_idx; offset += data_stride * data_idx;
} }
for (size_t i = 0; i < nr_nonidx_axes; ++i) { for (size_t i = 0; i < nr_nonidx_axes; ++i) {
auto stride = data_layout.stride[nonidx_axes[i]]; auto stride = data_layout.stride[nonidx_axes[i]];
auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis)]; auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis) * idx_shape.ndim];
offset += stride * idx; offset += stride * idx;
} }
Opr::apply(data_ptr[offset], *value_iter); Opr::apply(data_ptr[offset], *value_iter);
......
...@@ -21,17 +21,23 @@ namespace rocm { ...@@ -21,17 +21,23 @@ namespace rocm {
namespace indexing_multi_axis_vec { namespace indexing_multi_axis_vec {
//! AxisIndexer equiv in kernel //! AxisIndexer equiv in kernel
template <int idx_ndim>
struct KAxisIndexer { struct KAxisIndexer {
int stride; int stride[idx_ndim];
#ifdef WIN32
Uint32Fastdiv shape[idx_ndim];
#else
Uint32Fastdiv shape[idx_ndim - 1];
#endif
const int *ptr; const int *ptr;
}; };
//! param for gen_offset_base //! param for gen_offset_base
template<int nidx> template<int nidx, int idx_ndim>
struct GenOffsetBaseParam { struct GenOffsetBaseParam {
uint32_t size; //!< number of outputs; also size of each index uint32_t size; //!< number of outputs; also size of each index
int *output; //!< output ptr int *output; //!< output ptr
KAxisIndexer indexer[nidx]; KAxisIndexer<idx_ndim> indexer[nidx];
uint32_t data_shape[nidx]; uint32_t data_shape[nidx];
int data_stride[nidx]; int data_stride[nidx];
...@@ -60,6 +66,8 @@ namespace indexing_multi_axis_vec { ...@@ -60,6 +66,8 @@ namespace indexing_multi_axis_vec {
ctype *data, *value; ctype *data, *value;
int idx_axis; int idx_axis;
int idx_axis_end;
int idx_nelems;
int value_stride; int value_stride;
...@@ -68,8 +76,8 @@ namespace indexing_multi_axis_vec { ...@@ -68,8 +76,8 @@ namespace indexing_multi_axis_vec {
}; };
//! generate offset bases for first axis in the output //! generate offset bases for first axis in the output
template<int nidx> template<int nidx, int idx_ndim>
void gen_offset_base(const GenOffsetBaseParam<nidx> &param, void gen_offset_base(const GenOffsetBaseParam<nidx, idx_ndim> &param,
hipStream_t stream); hipStream_t stream);
struct OprAtomicIncr { struct OprAtomicIncr {
......
...@@ -30,10 +30,17 @@ namespace { ...@@ -30,10 +30,17 @@ namespace {
uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.tot_size) { if (oidx < param.tot_size) {
int offset = 0, coidx = oidx; int offset = 0, coidx = oidx;
int all_ax_idx[ndim]; int idx_flat = 0;
#pragma unroll #pragma unroll
for (int i = ndim - 1; i >= 0; -- i) { for (int i = ndim - 1; i >= 0; -- i) {
int next_coidx, ax_idx; int next_coidx, ax_idx;
if (i + 1 == param.idx_axis_end) {
idx_flat = coidx;
}
// may not trigger
if (i + 1 == param.idx_axis) {
idx_flat -= coidx * param.idx_nelems;
}
if (i) { if (i) {
next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; next_coidx = coidx / param.value_ly_on_data.shape[i - 1];
ax_idx = ax_idx =
...@@ -45,9 +52,8 @@ namespace { ...@@ -45,9 +52,8 @@ namespace {
ax_idx = coidx; ax_idx = coidx;
} }
offset += param.value_ly_on_data.stride[i] * ax_idx; offset += param.value_ly_on_data.stride[i] * ax_idx;
all_ax_idx[i] = ax_idx;
} }
offset += param.offset_base[all_ax_idx[param.idx_axis]]; offset += param.offset_base[idx_flat];
Opr::apply( Opr::apply(
param.data[offset], param.data[offset],
param.value[oidx * param.value_stride]); param.value[oidx * param.value_stride]);
......
...@@ -21,15 +21,28 @@ using namespace rocm; ...@@ -21,15 +21,28 @@ using namespace rocm;
using namespace indexing_multi_axis_vec; using namespace indexing_multi_axis_vec;
namespace { namespace {
template<int nidx> template<int nidx, int idx_ndim>
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) { __global__ void kgen_offset_base(GenOffsetBaseParam<nidx, idx_ndim> param) {
int oidx = threadIdx.x + blockDim.x * blockIdx.x; int oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.size) { if (oidx < param.size) {
int offset = 0; int offset = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < nidx; ++ i) { for (int i = 0; i < nidx; ++ i) {
int data_idx = param.indexer[i].ptr[ auto& indexer = param.indexer[i];
param.indexer[i].stride * oidx]; int offset2 = 0, coidx = oidx;
#pragma unroll
for (int j = idx_ndim-1; j >= 0; --j) {
int ax_idx;
if (j) {
int next_coidx = coidx / indexer.shape[j-1];
ax_idx = coidx - (next_coidx * indexer.shape[j-1].divisor());
coidx = next_coidx;
} else {
ax_idx = coidx;
}
offset2 += indexer.stride[j] * ax_idx;
}
int data_idx = indexer.ptr[offset2];
data_idx += (data_idx < 0 ? param.data_shape[i] : 0); data_idx += (data_idx < 0 ? param.data_shape[i] : 0);
if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) { if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) {
// cast to uint32 to handle both negative and overflow // cast to uint32 to handle both negative and overflow
...@@ -50,20 +63,28 @@ namespace megdnn { ...@@ -50,20 +63,28 @@ namespace megdnn {
namespace rocm { namespace rocm {
namespace indexing_multi_axis_vec { namespace indexing_multi_axis_vec {
#define INST(_n) \ #define INST(_m, _n) \
template void gen_offset_base( \ template void gen_offset_base( \
const GenOffsetBaseParam<_n> &, hipStream_t); const GenOffsetBaseParam<_m, _n> &, hipStream_t);
MEGDNN_FOREACH_TENSOR_NDIM(INST)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 1)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 2)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 3)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 4)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 5)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 6)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 7)
#undef INST #undef INST
} // namespace indexing_multi_axis_vec } // namespace indexing_multi_axis_vec
} // namespace rocm } // namespace rocm
} // namespace megdnn } // namespace megdnn
template<int nidx> template<int nidx, int idx_ndim>
void indexing_multi_axis_vec::gen_offset_base( void indexing_multi_axis_vec::gen_offset_base(
const GenOffsetBaseParam<nidx> &param, hipStream_t stream) { const GenOffsetBaseParam<nidx, idx_ndim> &param, hipStream_t stream) {
void (*kptr)(GenOffsetBaseParam<nidx>) = kgen_offset_base<nidx>; void (*kptr)(GenOffsetBaseParam<nidx, idx_ndim>) = kgen_offset_base<nidx, idx_ndim>;
int bsize = 256; int bsize = 256;
hipLaunchKernelGGL(kptr, hipLaunchKernelGGL(kptr,
DIVUP(param.size, bsize), bsize, 0, stream, DIVUP(param.size, bsize), bsize, 0, stream,
......
...@@ -22,9 +22,10 @@ using namespace indexing_multi_axis_vec; ...@@ -22,9 +22,10 @@ using namespace indexing_multi_axis_vec;
namespace { namespace {
class ExecImplHelper { class ExecImplHelper {
template <int nidx, int idx_ndim>
void dispatch_gen_offset_base_nidx_ndim();
template <int nidx> template <int nidx>
void dispatch_gen_offset_base_nidx(); void dispatch_gen_offset_base_nidx();
void dispatch_gen_offset_base(); void dispatch_gen_offset_base();
protected: protected:
...@@ -39,6 +40,7 @@ protected: ...@@ -39,6 +40,7 @@ protected:
int* const m_offset_base; int* const m_offset_base;
TensorLayout m_value_layout_on_data; TensorLayout m_value_layout_on_data;
size_t m_idx_axis; size_t m_idx_axis;
TensorShape m_idx_shape;
int m_value_stride; int m_value_stride;
public: public:
...@@ -77,18 +79,17 @@ ExecImplHelper::ExecImplHelper( ...@@ -77,18 +79,17 @@ ExecImplHelper::ExecImplHelper(
m_exec_info{&exec_info}, m_exec_info{&exec_info},
m_offset_base{workspace.ptr<int>()} { m_offset_base{workspace.ptr<int>()} {
safe_size_in_kern(data.layout.total_nr_elems()); safe_size_in_kern(data.layout.total_nr_elems());
dispatch_gen_offset_base(); std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) =
std::tie(m_value_layout_on_data, m_idx_axis) =
IndexingMultiAxisVec::get_value_iter_optimized_layout( IndexingMultiAxisVec::get_value_iter_optimized_layout(
data.layout, value.layout, index, exec_info.idx_axis); data.layout, value.layout, index, exec_info.idx_axis);
dispatch_gen_offset_base();
m_value_stride = exec_info.value_stride; m_value_stride = exec_info.value_stride;
} }
template <int nidx> template <int nidx, int idx_ndim>
void ExecImplHelper::dispatch_gen_offset_base_nidx() { void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() {
GenOffsetBaseParam<nidx> param; GenOffsetBaseParam<nidx, idx_ndim> param;
param.size = m_value->layout.shape[m_exec_info->idx_axis]; param.size = m_idx_shape.total_nr_elems();
param.output = m_offset_base; param.output = m_offset_base;
param.error_tracker = m_exec_info->error_tracker; param.error_tracker = m_exec_info->error_tracker;
param.error_info = m_exec_info->error_info; param.error_info = m_exec_info->error_info;
...@@ -96,9 +97,12 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { ...@@ -96,9 +97,12 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
auto&& dst = param.indexer[i]; auto&& dst = param.indexer[i];
auto&& src = m_index->operator[](i); auto&& src = m_index->operator[](i);
megdnn_assert(src.vec.layout.ndim == 1); megdnn_assert(src.vec.layout.ndim == 1);
dst.stride = src.vec.layout.stride[0]; auto src_layout = src.vec.layout.broadcast(m_idx_shape);
if (src.vec.layout.shape[0] == 1) { for (size_t i = 0; i < idx_ndim; ++i) {
dst.stride = 0; if (i) {
dst.shape[i - 1] = src_layout.shape[i];
}
dst.stride[i] = src_layout.stride[i];
} }
dst.ptr = src.vec.ptr<int>(); dst.ptr = src.vec.ptr<int>();
param.data_shape[i] = m_data->layout.shape[src.axis]; param.data_shape[i] = m_data->layout.shape[src.axis];
...@@ -107,6 +111,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { ...@@ -107,6 +111,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
gen_offset_base(param, m_stream); gen_offset_base(param, m_stream);
} }
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
switch (m_idx_shape.ndim) {
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
megdnn_throw("bad index ndim");
}
void ExecImplHelper::dispatch_gen_offset_base() { void ExecImplHelper::dispatch_gen_offset_base() {
switch (m_index->size()) { switch (m_index->size()) {
#define cb(_n) \ #define cb(_n) \
...@@ -154,6 +170,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() { ...@@ -154,6 +170,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
param.data = m_data->ptr<ctype>(); param.data = m_data->ptr<ctype>();
param.value = m_value->ptr<ctype>(); param.value = m_value->ptr<ctype>();
param.idx_axis = m_idx_axis; param.idx_axis = m_idx_axis;
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim;
param.idx_nelems = m_idx_shape.total_nr_elems();
param.value_stride = m_value_stride; param.value_stride = m_value_stride;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i];
......
...@@ -46,6 +46,15 @@ struct OprProxyIndexingMultiAxisVecHelper { ...@@ -46,6 +46,15 @@ struct OprProxyIndexingMultiAxisVecHelper {
return ret; return ret;
} }
size_t get_index_ndim(const TensorNDArray& tensors) const {
megdnn_assert(tensors.size() >= 3);
size_t ndim = 0;
for (size_t i = 2; i < tensors.size(); ++i) {
ndim = std::max(tensors[i].layout.ndim, ndim);
}
return ndim;
}
IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout( IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout(
const TensorLayoutArray& layouts) const { const TensorLayoutArray& layouts) const {
megdnn_assert(layouts.size() >= 3); megdnn_assert(layouts.size() >= 3);
...@@ -65,7 +74,8 @@ struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelpe ...@@ -65,7 +74,8 @@ struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelpe
void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const { void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const {
WorkspaceWrapper W( WorkspaceWrapper W(
opr->handle(), opr->get_workspace_in_bytes( opr->handle(), opr->get_workspace_in_bytes(
tensors[1].layout, axes, tensors.size() - 2)); tensors[1].layout, axes, tensors.size() - 2,
get_index_ndim(tensors)));
opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace());
} }
...@@ -81,7 +91,8 @@ struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecH ...@@ -81,7 +91,8 @@ struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecH
void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const { void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const {
WorkspaceWrapper W( WorkspaceWrapper W(
opr->handle(), opr->get_workspace_in_bytes( opr->handle(), opr->get_workspace_in_bytes(
tensors[1].layout, axes, tensors.size() - 2)); tensors[1].layout, axes, tensors.size() - 2,
get_index_ndim(tensors)));
opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
} }
...@@ -95,7 +106,8 @@ struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHe ...@@ -95,7 +106,8 @@ struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHe
void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const { void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const {
WorkspaceWrapper W( WorkspaceWrapper W(
opr->handle(), opr->get_workspace_in_bytes( opr->handle(), opr->get_workspace_in_bytes(
tensors[1].layout, axes, tensors.size() - 2)); tensors[1].layout, axes, tensors.size() - 2,
get_index_ndim(tensors)));
opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
} }
......
...@@ -27,7 +27,7 @@ namespace test { ...@@ -27,7 +27,7 @@ namespace test {
WorkspaceWrapper W( \ WorkspaceWrapper W( \
opr->handle(), \ opr->handle(), \
opr->get_workspace_in_bytes( \ opr->get_workspace_in_bytes( \
tensors[1].layout, axes, tensors.size() - 2)); \ tensors[1].layout, axes, tensors.size() - 2, 1)); \
opr->exec( \ opr->exec( \
tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \ tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
} \ } \
...@@ -46,7 +46,7 @@ namespace test { ...@@ -46,7 +46,7 @@ namespace test {
WorkspaceWrapper W( \ WorkspaceWrapper W( \
opr->handle(), \ opr->handle(), \
opr->get_workspace_in_bytes( \ opr->get_workspace_in_bytes( \
tensors[1].layout, axes, tensors.size() - 2)); \ tensors[1].layout, axes, tensors.size() - 2, 1)); \
opr->exec( \ opr->exec( \
tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \ tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
} \ } \
......
...@@ -132,6 +132,25 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) { ...@@ -132,6 +132,25 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}}); TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
} }
TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC_ND_INDEX) {
run_check<IndexingMultiAxisVec>(handle_cuda());
Checker<IndexingMultiAxisVec> checker(handle_cuda());
OrderedRNG rng;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32())
.set_dtype(4, dtype::Int32())
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_rng(4, &rng);
checker.set_proxy({{1, 2, 3}})
.execs({{5, 5, 6, 7, 3}, {5, 2, 3, 4, 3}, {3, 1}, {2, 1, 1}, {1, 4}});
}
TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
run_check<IndexingIncrMultiAxisVec>(handle_cuda()); run_check<IndexingIncrMultiAxisVec>(handle_cuda());
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda()); Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
......
...@@ -708,3 +708,19 @@ def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic): ...@@ -708,3 +708,19 @@ def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic):
run_test((10, 10, 0), test4) run_test((10, 10, 0), test4)
run_test((10, 10, 10), test3) run_test((10, 10, 10), test3)
run_test((10, 10, 10), test4) run_test((10, 10, 10), test4)
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_nd_int_indexing(symbolic):
inp = np.arange(11)
idx = np.random.randint(11, size=(5, 7))
def run_test(args, fn):
npy_out = fn(*args)
if symbolic:
fn = jit.trace(symbolic=symbolic)(fn)
for _ in range(3):
out = fn(*[Tensor(arg) for arg in args])
np.testing.assert_equal(out.numpy(), npy_out)
run_test([inp, idx], lambda inp, idx: inp[idx])
...@@ -197,9 +197,15 @@ Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr( ...@@ -197,9 +197,15 @@ Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
template <class Opr> template <class Opr>
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer( void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr, const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr,
VarNode* data, VarNode* value) { VarNode* data, VarNode* value, VarNodeArray idx_arr) {
using namespace cg::static_infer; using namespace cg::static_infer;
auto infer_shape = [this, &index_desc, &opr](TensorShape& dest, const InpVal& inp) { DepVal deps = {{data, DepType::SHAPE}, {value, DepType::SHAPE}};
for (auto&& idx : idx_arr) {
deps.push_back({idx, DepType::SHAPE});
}
auto infer_shape = [this, &index_desc, &opr, nr_idx = idx_arr.size()](
TensorShape& dest, const InpVal& inp) {
size_t axes[TensorShape::MAX_NDIM], nr_axes = 0; size_t axes[TensorShape::MAX_NDIM], nr_axes = 0;
auto ndim = inp.val[0].shape().ndim; auto ndim = inp.val[0].shape().ndim;
for (auto&& i : reverse_adaptor(index_desc)) { for (auto&& i : reverse_adaptor(index_desc)) {
...@@ -207,18 +213,22 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer( ...@@ -207,18 +213,22 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
axes[nr_axes++] = i.axis.get(ndim); axes[nr_axes++] = i.axis.get(ndim);
} }
} }
mgb_assert(nr_axes == nr_idx);
if (!nr_axes) { if (!nr_axes) {
dest = {0}; dest = {0};
} else { } else {
size_t idx_ndim = 0;
for (size_t i = 0; i < nr_idx; ++i) {
idx_ndim = std::max(idx_ndim, inp.val[2 + i].shape().ndim);
}
mgb_assert(idx_ndim > 0);
dest = {megdnn_opr(opr).get_workspace_in_bytes( dest = {megdnn_opr(opr).get_workspace_in_bytes(
inp.val[1].shape(), axes, nr_axes)}; inp.val[1].shape(), axes, nr_axes, idx_ndim)};
} }
return true; return true;
}; };
opr.owner_graph()->static_infer_manager().register_shape_infer( opr.owner_graph()->static_infer_manager().register_shape_infer(
opr.output(1), {SourceType::DEP, opr.output(1), {SourceType::DEP, deps, infer_shape});
{{data, DepType::SHAPE}, {value, DepType::SHAPE}},
infer_shape});
} }
template <class Opr> template <class Opr>
...@@ -342,8 +352,13 @@ void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() { ...@@ -342,8 +352,13 @@ void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() {
}; };
owner_graph()->static_infer_manager().register_shape_infer( owner_graph()->static_infer_manager().register_shape_infer(
output(0), {SourceType::DEP, deps, infer_shape}); output(0), {SourceType::DEP, deps, infer_shape});
VarNodeArray idx_arr;
this->register_workspace_infer(index_desc(), *this, input(0), output(0)); for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (m_input2idxonly_axis_indexer[i]) {
idx_arr.push_back(input(i));
}
}
this->register_workspace_infer(index_desc(), *this, input(0), output(0), idx_arr);
} }
template <class Opr> template <class Opr>
...@@ -401,7 +416,13 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc( ...@@ -401,7 +416,13 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc(
this->owner_graph()->static_infer_manager().register_shape_infer( this->owner_graph()->static_infer_manager().register_shape_infer(
this->output(0), ShapeInferDesc::make_identity(this->input(0))); this->output(0), ShapeInferDesc::make_identity(this->input(0)));
this->register_workspace_infer(index_desc(), *this, input(0), input(1)); VarNodeArray idx_arr;
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (m_input2idxonly_axis_indexer[i]) {
idx_arr.push_back(input(i));
}
}
this->register_workspace_infer(index_desc(), *this, input(0), input(1), idx_arr);
} }
template <class Opr> template <class Opr>
......
...@@ -96,7 +96,7 @@ protected: ...@@ -96,7 +96,7 @@ protected:
void register_workspace_infer( void register_workspace_infer(
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr, const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr,
VarNode* data, VarNode* value); VarNode* data, VarNode* value, VarNodeArray idx_arr);
void record_megdnn_opr(mgb::cg::GraphExecutable::ExecDependencyArray& deps); void record_megdnn_opr(mgb::cg::GraphExecutable::ExecDependencyArray& deps);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册