diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 95dddd9108d5c41b8f301736db56e577d03836a6..8de736e43d60f2ae4324fcd57ad42da3939a440f 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1115,7 +1115,7 @@ public: * access *data*; stride of layout on that axis would be zero, and * strides on other axes correspond to the strides in *data* */ - static std::pair get_value_iter_optimized_layout( + static std::tuple get_value_iter_optimized_layout( const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, size_t idx_axis); @@ -1159,7 +1159,8 @@ public: * \brief get workspace size based on output shape and indexing axes */ size_t get_workspace_in_bytes( - const TensorShape& dst, const size_t* axes, size_t nr_axes); + const TensorShape& dst, const size_t* axes, size_t nr_axes, + size_t idx_ndim); static void deduce_layout( const TensorLayout& data, const IndexDescLayoutOnly& index, @@ -1193,7 +1194,8 @@ public: * axes */ size_t get_workspace_in_bytes( - const TensorShape& value, const size_t* axes, size_t nr_axes); + const TensorShape& value, const size_t* axes, size_t nr_axes, + size_t idx_ndim); protected: ExecInfo check_exec( @@ -1223,7 +1225,7 @@ public: using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly; using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly; - size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) { + size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t, size_t) { return 0; } diff --git a/dnn/src/common/indexing_multi_axis_vec.cpp b/dnn/src/common/indexing_multi_axis_vec.cpp index a4cc8b0bf89e71b7aefe29adfc040fe2f0cb7135..e306993a4db12afc65bef1f3152a622e2d2318de 100644 --- a/dnn/src/common/indexing_multi_axis_vec.cpp +++ b/dnn/src/common/indexing_multi_axis_vec.cpp @@ -15,8 +15,10 @@ using namespace megdnn; namespace { + +// we need a workspace to store offset base table, which has same size with index size_t get_index_size_for_workspace( - const TensorShape& shp, const size_t* axes, size_t nr_axes) { + const TensorShape& shp, const size_t* axes, size_t nr_axes, size_t idx_ndim) { size_t idx_axis = axes[0]; megdnn_assert(shp.ndim && nr_axes); for (size_t i = 1; i < nr_axes; ++i) { @@ -29,7 +31,11 @@ size_t get_index_size_for_workspace( megdnn_assert( shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis, shp.to_string().c_str()); - return shp.shape[idx_axis]; + size_t idx_size = 1; + for (size_t i = 0; i < idx_ndim; ++i) { + idx_size *= shp.shape[idx_axis + i]; + } + return idx_size; } } // anonymous namespace @@ -47,23 +53,17 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) { megdnn_assert(!index.empty()); megdnn_assert(data.ndim >= index.size()); - dst.ndim = data.ndim - index.size() + 1; - dst.shape[0] = 1; + dst.ndim = data.ndim - index.size(); dst.dtype = data.dtype; + TensorShapeArray index_shapes; + auto brdcast = [&](const TensorLayout& ly) { - if (ly.ndim != 1) - return false; - if (dst.shape[0] == ly.shape[0]) - return true; - if (dst.shape[0] == 1) { - dst.shape[0] = ly.shape[0]; - return true; - } - return ly.shape[0] == 1; + megdnn_assert(ly.dtype == dtype::Int32{}); + index_shapes.push_back(ly); }; - size_t dst_axis = 1; + size_t dst_axis = 0; ptrdiff_t prev_axis = -1; for (size_t axis = 0; axis < index.size(); ++axis) { auto&& idx = index[axis]; @@ -73,10 +73,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( megdnn_assert( idx.axis(idx.axis)> prev_axis, "index %zu requests invalid axis %zu", axis, idx.axis); - auto brd_succ = brdcast(idx.layout); - megdnn_assert( - brd_succ, "invalid layout at index %zu: %s", axis, - idx.layout.to_string().c_str()); + brdcast(idx.layout); for (size_t i = prev_axis + 1; i < idx.axis; ++i) { dst.shape[dst_axis++] = data.shape[i]; @@ -99,15 +96,18 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( } } if (contig_idx) { - auto shp0 = dst.shape[0]; idx_axis = index[0].axis; - for (size_t i = 0; i < idx_axis; ++i) { - dst.shape[i] = dst.shape[i + 1]; - } - dst.shape[idx_axis] = shp0; } } + TensorShape index_shape; + Elemwise::deduce_shape(index_shapes, index_shape); + + for (size_t i = 0; i < index_shape.ndim; ++i) { + dst.add_axis_inplace(idx_axis + i, 1, 0); + dst.shape[idx_axis + i] = index_shape.shape[i]; + } + dst.init_contiguous_stride(); return idx_axis; } @@ -145,15 +145,26 @@ IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworksp return ret; } -std::pair IndexingMultiAxisVecBase:: +std::tuple IndexingMultiAxisVecBase:: get_value_iter_optimized_layout( const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, size_t idx_axis) { size_t data_axes[TensorLayout::MAX_NDIM], nr_axes = get_nonindex_axes(data.ndim, index, data_axes); + // broadcast index shapes + TensorLayout index_shape; + { + TensorShapeArray index_shapes; + for (auto& idx : index) { + megdnn_assert(idx.vec.layout.dtype == dtype::Int32{}); + index_shapes.push_back(idx.vec.layout); + } + Elemwise::deduce_shape(index_shapes, index_shape); + } + megdnn_assert( - nr_axes == value.ndim - 1 && idx_axis < value.ndim && + nr_axes == value.ndim - index_shape.ndim && idx_axis < value.ndim && nr_axes + index.size() == data.ndim); TensorLayout ret; @@ -165,10 +176,13 @@ std::pair IndexingMultiAxisVecBase:: } ret = ret.collapse_contiguous(); } - ret.shape[ret.ndim] = value.shape[idx_axis]; - ret.stride[ret.ndim] = 0; + size_t ret_idx_axis = ret.ndim; - ++ret.ndim; + for (size_t i = 0; i < index_shape.ndim; ++i) { + ret.shape[ret.ndim] = value.shape[idx_axis + i]; + ret.stride[ret.ndim] = 0; + ++ret.ndim; + } if (idx_axis < nr_axes) { TensorLayout tail; @@ -185,12 +199,13 @@ std::pair IndexingMultiAxisVecBase:: } } - return {ret, ret_idx_axis}; + return std::make_tuple(ret, ret_idx_axis, index_shape); } size_t IndexingMultiAxisVec::get_workspace_in_bytes( - const TensorShape& dst, const size_t* axes, size_t nr_axes) { - return get_workspace_in_bytes(get_index_size_for_workspace(dst, axes, nr_axes)); + const TensorShape& dst, const size_t* axes, size_t nr_axes, size_t idx_ndim) { + return get_workspace_in_bytes( + get_index_size_for_workspace(dst, axes, nr_axes, idx_ndim)); } IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( @@ -205,8 +220,9 @@ IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( } size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes( - const TensorShape& value, const size_t* axes, size_t nr_axes) { - return get_workspace_in_bytes(get_index_size_for_workspace(value, axes, nr_axes)); + const TensorShape& value, const size_t* axes, size_t nr_axes, size_t idx_ndim) { + return get_workspace_in_bytes( + get_index_size_for_workspace(value, axes, nr_axes, idx_ndim)); } IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec( diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh b/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh index be26b47ecab6e8eb8297d9dbcb9cd2589bce8403..950c5311a096911d81a55557d33e3f5bf70a8b03 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh @@ -21,17 +21,24 @@ namespace cuda { namespace indexing_multi_axis_vec { //! AxisIndexer equiv in kernel +template struct KAxisIndexer { - int stride; + int stride[idx_ndim]; +#ifdef WIN32 + Uint32Fastdiv shape[idx_ndim]; +#else + // original shape[0] not storaged + Uint32Fastdiv shape[idx_ndim - 1]; +#endif const int* ptr; }; //! param for gen_offset_base -template +template struct GenOffsetBaseParam { uint32_t size; //!< number of outputs; also size of each index int* output; //!< output ptr - KAxisIndexer indexer[nidx]; + KAxisIndexer indexer[nidx]; uint32_t data_shape[nidx]; int data_stride[nidx]; @@ -59,7 +66,12 @@ struct ApplyOprParam { const int* offset_base; ctype *data, *value; + // first idx axis int idx_axis; + // last idx axis + 1 + int idx_axis_end; + // number of elements for idx shape + int idx_nelems; int value_stride; @@ -68,8 +80,9 @@ struct ApplyOprParam { }; //! generate offset bases for first axis in the output -template -void gen_offset_base(const GenOffsetBaseParam& param, cudaStream_t stream); +template +void gen_offset_base( + const GenOffsetBaseParam& param, cudaStream_t stream); struct OprAtomicIncr { #if MEGDNN_CC_CUDA diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl index 6d9e1032a3c5e8627de40388aa9284cb7930c87f..e146a6334d9e2e55e90c4cf121052a14a0d54eb8 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl @@ -29,11 +29,23 @@ namespace { uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; if (oidx < param.tot_size) { int offset = 0, coidx = oidx; - int all_ax_idx[ndim]; + // offset in index + int idx_flat = 0; + // for non-indexed axes get offset #pragma unroll for (int i = ndim - 1; i >= 0; -- i) { int next_coidx, ax_idx; + // [..., indexed_axes... |, ...] + if (i + 1 == param.idx_axis_end) { + idx_flat = coidx; + } + // [... |, indexed_axes..., ...] + if (i + 1 == param.idx_axis) { + idx_flat -= coidx * param.idx_nelems; + } + // shape[i] was storaged at shape[i-1] if (i) { + // fast divide next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; ax_idx = coidx - @@ -44,9 +56,9 @@ namespace { ax_idx = coidx; } offset += param.value_ly_on_data.stride[i] * ax_idx; - all_ax_idx[i] = ax_idx; } - offset += param.offset_base[all_ax_idx[param.idx_axis]]; + // offset from index, which was generated before + offset += param.offset_base[idx_flat]; Opr::apply( param.data[offset], param.value[oidx * param.value_stride]); diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu index 1dc1bfc8e746c7388b4737568a9047bbe3f98779..ebab2b52a75de6071cf66b2015300cd3a18d6f49 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu @@ -18,14 +18,29 @@ using namespace cuda; using namespace indexing_multi_axis_vec; namespace { -template -__global__ void kgen_offset_base(GenOffsetBaseParam param) { +template +__global__ void kgen_offset_base(GenOffsetBaseParam param) { int oidx = threadIdx.x + blockDim.x * blockIdx.x; if (oidx < param.size) { int offset = 0; #pragma unroll for (int i = 0; i < nidx; ++i) { - int data_idx = param.indexer[i].ptr[param.indexer[i].stride * oidx]; + auto& indexer = param.indexer[i]; + // index in index + int idx_flat = 0, coidx = oidx; +#pragma unroll + for (int j = idx_ndim - 1; j >= 0; --j) { + int ax_idx; + if (j) { + int next_coidx = coidx / indexer.shape[j - 1]; + ax_idx = coidx - (next_coidx * indexer.shape[j - 1].divisor()); + coidx = next_coidx; + } else { + ax_idx = coidx; + } + idx_flat += indexer.stride[j] * ax_idx; + } + int data_idx = indexer.ptr[idx_flat]; data_idx += (data_idx < 0 ? param.data_shape[i] : 0); if (static_cast(data_idx) >= param.data_shape[i]) { // cast to uint32 to handle both negative and overflow @@ -36,17 +51,19 @@ __global__ void kgen_offset_base(GenOffsetBaseParam param) { i, data_idx, param.data_shape[i]); data_idx = 0; } + // calculate offset from current index offset += data_idx * param.data_stride[i]; } + // sum offsets and store at offset table param.output[oidx] = offset; } } } // namespace -template +template void indexing_multi_axis_vec::gen_offset_base( - const GenOffsetBaseParam& param, cudaStream_t stream) { - void (*kptr)(GenOffsetBaseParam) = kgen_offset_base; + const GenOffsetBaseParam& param, cudaStream_t stream) { + void (*kptr)(GenOffsetBaseParam) = kgen_offset_base; int bsize = query_blocksize_for_kernel(kptr); (*kptr)<<>>(param); } @@ -55,9 +72,17 @@ namespace megdnn { namespace cuda { namespace indexing_multi_axis_vec { -#define INST(_n) \ - template void gen_offset_base(const GenOffsetBaseParam<_n>&, cudaStream_t); -MEGDNN_FOREACH_TENSOR_NDIM(INST) +#define INST(_m, _n) \ + template void gen_offset_base(const GenOffsetBaseParam<_m, _n>&, cudaStream_t); + +MEGDNN_FOREACH_TENSOR_NDIM(INST, 1) +MEGDNN_FOREACH_TENSOR_NDIM(INST, 2) +MEGDNN_FOREACH_TENSOR_NDIM(INST, 3) +MEGDNN_FOREACH_TENSOR_NDIM(INST, 4) +MEGDNN_FOREACH_TENSOR_NDIM(INST, 5) +MEGDNN_FOREACH_TENSOR_NDIM(INST, 6) +MEGDNN_FOREACH_TENSOR_NDIM(INST, 7) + #undef INST } // namespace indexing_multi_axis_vec diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp index 88833ef91aec812ad110c914e4330fa044917494..997a84b2727f971ba08dc16caac14c2959414f2a 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp @@ -21,9 +21,10 @@ using namespace indexing_multi_axis_vec; namespace { class ExecImplHelper { + template + void dispatch_gen_offset_base_nidx_ndim(); template void dispatch_gen_offset_base_nidx(); - void dispatch_gen_offset_base(); protected: @@ -38,6 +39,7 @@ protected: int* const m_offset_base; TensorLayout m_value_layout_on_data; size_t m_idx_axis; + TensorShape m_idx_shape; int m_value_stride; public: @@ -76,28 +78,30 @@ ExecImplHelper::ExecImplHelper( m_exec_info{&exec_info}, m_offset_base{workspace.ptr()} { safe_size_in_kern(data.layout.total_nr_elems()); - dispatch_gen_offset_base(); - - std::tie(m_value_layout_on_data, m_idx_axis) = + std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) = IndexingMultiAxisVec::get_value_iter_optimized_layout( data.layout, value.layout, index, exec_info.idx_axis); + dispatch_gen_offset_base(); m_value_stride = exec_info.value_stride; } -template -void ExecImplHelper::dispatch_gen_offset_base_nidx() { - GenOffsetBaseParam param; - param.size = m_value->layout.shape[m_exec_info->idx_axis]; +template +void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() { + GenOffsetBaseParam param; + param.size = m_idx_shape.total_nr_elems(); param.output = m_offset_base; param.error_tracker = m_exec_info->error_tracker; param.error_info = m_exec_info->error_info; + megdnn_assert(m_idx_shape.ndim == idx_ndim); for (int i = 0; i < nidx; ++i) { auto&& dst = param.indexer[i]; - auto&& src = m_index->operator[](i); - megdnn_assert(src.vec.layout.ndim == 1); - dst.stride = src.vec.layout.stride[0]; - if (src.vec.layout.shape[0] == 1) { - dst.stride = 0; + auto&& src = m_index->at(i); + auto src_layout = src.vec.layout.broadcast(m_idx_shape); + for (size_t i = 0; i < idx_ndim; ++i) { + if (i) { + dst.shape[i - 1] = src_layout.shape[i]; + } + dst.stride[i] = src_layout.stride[i]; } dst.ptr = src.vec.ptr(); param.data_shape[i] = m_data->layout.shape[src.axis]; @@ -106,6 +110,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { gen_offset_base(param, m_stream); } +template +void ExecImplHelper::dispatch_gen_offset_base_nidx() { + switch (m_idx_shape.ndim) { +#define cb(_n) \ + case _n: \ + return dispatch_gen_offset_base_nidx_ndim(); + MEGDNN_FOREACH_TENSOR_NDIM(cb) +#undef cb + } + megdnn_throw("bad index ndim"); +} + void ExecImplHelper::dispatch_gen_offset_base() { switch (m_index->size()) { #define cb(_n) \ @@ -153,6 +169,8 @@ void ExecImpl::dispatch_exec_ctype_ndim() { param.data = m_data->ptr(); param.value = m_value->ptr(); param.idx_axis = m_idx_axis; + param.idx_axis_end = m_idx_axis + m_idx_shape.ndim; + param.idx_nelems = m_idx_shape.total_nr_elems(); param.value_stride = m_value_stride; for (int i = 0; i < ndim; ++i) { param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; diff --git a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp index d8f17c67c9b10c6fa3380609000a4ad3300548b0..d30597575fa1ed7821e76b38f74b57b75cf02c8d 100644 --- a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp @@ -33,37 +33,46 @@ void do_exec( auto data_layout = data.layout; auto data_ptr = data.ptr(); - std::tuple index_raw[TensorLayout::MAX_NDIM]; + std::tuple index_raw[TensorLayout::MAX_NDIM]; size_t nr_index = index.size(); + TensorShape idx_shape; + { + TensorShapeArray idx_shapes; + for (size_t i = 0; i < nr_index; ++i) { + idx_shapes.push_back(index[i].vec.layout); + } + Elemwise::deduce_shape(idx_shapes, idx_shape); + } for (size_t i = 0; i < nr_index; ++i) { auto&& s = index[i]; - index_raw[i] = - std::make_tuple(s.axis, s.vec.ptr(), s.vec.layout.stride[0]); - - if (s.vec.layout.shape[0] == 1) - std::get<2>(index_raw[i]) = 0; + index_raw[i] = std::make_tuple( + s.axis, s.vec.ptr(), s.vec.layout.broadcast(idx_shape)); } auto value_iter = tensor_iter(value).begin(); for (size_t _ = 0, _t = value.layout.total_nr_elems(); _ < _t; ++_) { ptrdiff_t offset = 0; - auto index_idx = value_iter.idx()[exec_info.idx_axis]; + auto* index_idx = value_iter.idx() + exec_info.idx_axis; for (size_t i = 0; i < nr_index; ++i) { size_t axis = std::get<0>(index_raw[i]), data_shape = data_layout.shape[axis]; ptrdiff_t data_stride = data_layout.stride[axis]; - idx_type data_idx = - std::get<1>(index_raw[i])[std::get<2>(index_raw[i]) * index_idx]; + size_t index_offset = 0; + TensorLayout& index_layout = std::get<2>(index_raw[i]); + for (size_t i = 0; i < index_layout.ndim; ++i) { + index_offset += index_idx[i] * index_layout.stride[i]; + } + idx_type data_idx = std::get<1>(index_raw[i])[index_offset]; if (data_idx < 0) data_idx += data_shape; megdnn_assert( data_idx >= 0 && static_cast(data_idx) < data_shape, - "bad index value for index %zu at output %zu", i, index_idx); + "bad index value for index %zu at output %zu", i, *index_idx); offset += data_stride * data_idx; } for (size_t i = 0; i < nr_nonidx_axes; ++i) { auto stride = data_layout.stride[nonidx_axes[i]]; - auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis)]; + auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis) * idx_shape.ndim]; offset += stride * idx; } Opr::apply(data_ptr[offset], *value_iter); diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip index 8d35248bdd910741e2592634e20d1e0164e4ff12..4da6291691aa77d03dd9b78db4214c8437c9614b 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip @@ -21,17 +21,23 @@ namespace rocm { namespace indexing_multi_axis_vec { //! AxisIndexer equiv in kernel + template struct KAxisIndexer { - int stride; + int stride[idx_ndim]; +#ifdef WIN32 + Uint32Fastdiv shape[idx_ndim]; +#else + Uint32Fastdiv shape[idx_ndim - 1]; +#endif const int *ptr; }; //! param for gen_offset_base - template + template struct GenOffsetBaseParam { uint32_t size; //!< number of outputs; also size of each index int *output; //!< output ptr - KAxisIndexer indexer[nidx]; + KAxisIndexer indexer[nidx]; uint32_t data_shape[nidx]; int data_stride[nidx]; @@ -60,6 +66,8 @@ namespace indexing_multi_axis_vec { ctype *data, *value; int idx_axis; + int idx_axis_end; + int idx_nelems; int value_stride; @@ -68,8 +76,8 @@ namespace indexing_multi_axis_vec { }; //! generate offset bases for first axis in the output - template - void gen_offset_base(const GenOffsetBaseParam ¶m, + template + void gen_offset_base(const GenOffsetBaseParam ¶m, hipStream_t stream); struct OprAtomicIncr { diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl index b5291e668a79a83822abd9c524990fac78f5d180..631dc1a13024d2aee4f8964f3ed79bedcac5b1e8 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl @@ -30,10 +30,17 @@ namespace { uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; if (oidx < param.tot_size) { int offset = 0, coidx = oidx; - int all_ax_idx[ndim]; + int idx_flat = 0; #pragma unroll for (int i = ndim - 1; i >= 0; -- i) { int next_coidx, ax_idx; + if (i + 1 == param.idx_axis_end) { + idx_flat = coidx; + } + // may not trigger + if (i + 1 == param.idx_axis) { + idx_flat -= coidx * param.idx_nelems; + } if (i) { next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; ax_idx = @@ -45,9 +52,8 @@ namespace { ax_idx = coidx; } offset += param.value_ly_on_data.stride[i] * ax_idx; - all_ax_idx[i] = ax_idx; } - offset += param.offset_base[all_ax_idx[param.idx_axis]]; + offset += param.offset_base[idx_flat]; Opr::apply( param.data[offset], param.value[oidx * param.value_stride]); diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip index 88c0b7dc0f48529f1c17ffcb37386f60320352c8..3eb03726436ab41aea92d03580b57d51ef6b584a 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip @@ -21,15 +21,28 @@ using namespace rocm; using namespace indexing_multi_axis_vec; namespace { - template - __global__ void kgen_offset_base(GenOffsetBaseParam param) { + template + __global__ void kgen_offset_base(GenOffsetBaseParam param) { int oidx = threadIdx.x + blockDim.x * blockIdx.x; if (oidx < param.size) { int offset = 0; #pragma unroll for (int i = 0; i < nidx; ++ i) { - int data_idx = param.indexer[i].ptr[ - param.indexer[i].stride * oidx]; + auto& indexer = param.indexer[i]; + int offset2 = 0, coidx = oidx; +#pragma unroll + for (int j = idx_ndim-1; j >= 0; --j) { + int ax_idx; + if (j) { + int next_coidx = coidx / indexer.shape[j-1]; + ax_idx = coidx - (next_coidx * indexer.shape[j-1].divisor()); + coidx = next_coidx; + } else { + ax_idx = coidx; + } + offset2 += indexer.stride[j] * ax_idx; + } + int data_idx = indexer.ptr[offset2]; data_idx += (data_idx < 0 ? param.data_shape[i] : 0); if (static_cast(data_idx) >= param.data_shape[i]) { // cast to uint32 to handle both negative and overflow @@ -50,20 +63,28 @@ namespace megdnn { namespace rocm { namespace indexing_multi_axis_vec { -#define INST(_n) \ +#define INST(_m, _n) \ template void gen_offset_base( \ - const GenOffsetBaseParam<_n> &, hipStream_t); - MEGDNN_FOREACH_TENSOR_NDIM(INST) + const GenOffsetBaseParam<_m, _n> &, hipStream_t); + + MEGDNN_FOREACH_TENSOR_NDIM(INST, 1) + MEGDNN_FOREACH_TENSOR_NDIM(INST, 2) + MEGDNN_FOREACH_TENSOR_NDIM(INST, 3) + MEGDNN_FOREACH_TENSOR_NDIM(INST, 4) + MEGDNN_FOREACH_TENSOR_NDIM(INST, 5) + MEGDNN_FOREACH_TENSOR_NDIM(INST, 6) + MEGDNN_FOREACH_TENSOR_NDIM(INST, 7) + #undef INST } // namespace indexing_multi_axis_vec } // namespace rocm } // namespace megdnn -template +template void indexing_multi_axis_vec::gen_offset_base( - const GenOffsetBaseParam ¶m, hipStream_t stream) { - void (*kptr)(GenOffsetBaseParam) = kgen_offset_base; + const GenOffsetBaseParam ¶m, hipStream_t stream) { + void (*kptr)(GenOffsetBaseParam) = kgen_offset_base; int bsize = 256; hipLaunchKernelGGL(kptr, DIVUP(param.size, bsize), bsize, 0, stream, diff --git a/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp index b6d413888e12cf7ea09457ca4029ba4880a1cacd..a788da76ee2dd129ca3c634f7fcec65a31db0fc5 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp @@ -22,9 +22,10 @@ using namespace indexing_multi_axis_vec; namespace { class ExecImplHelper { + template + void dispatch_gen_offset_base_nidx_ndim(); template void dispatch_gen_offset_base_nidx(); - void dispatch_gen_offset_base(); protected: @@ -39,6 +40,7 @@ protected: int* const m_offset_base; TensorLayout m_value_layout_on_data; size_t m_idx_axis; + TensorShape m_idx_shape; int m_value_stride; public: @@ -77,18 +79,17 @@ ExecImplHelper::ExecImplHelper( m_exec_info{&exec_info}, m_offset_base{workspace.ptr()} { safe_size_in_kern(data.layout.total_nr_elems()); - dispatch_gen_offset_base(); - - std::tie(m_value_layout_on_data, m_idx_axis) = + std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) = IndexingMultiAxisVec::get_value_iter_optimized_layout( data.layout, value.layout, index, exec_info.idx_axis); + dispatch_gen_offset_base(); m_value_stride = exec_info.value_stride; } -template -void ExecImplHelper::dispatch_gen_offset_base_nidx() { - GenOffsetBaseParam param; - param.size = m_value->layout.shape[m_exec_info->idx_axis]; +template +void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() { + GenOffsetBaseParam param; + param.size = m_idx_shape.total_nr_elems(); param.output = m_offset_base; param.error_tracker = m_exec_info->error_tracker; param.error_info = m_exec_info->error_info; @@ -96,9 +97,12 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { auto&& dst = param.indexer[i]; auto&& src = m_index->operator[](i); megdnn_assert(src.vec.layout.ndim == 1); - dst.stride = src.vec.layout.stride[0]; - if (src.vec.layout.shape[0] == 1) { - dst.stride = 0; + auto src_layout = src.vec.layout.broadcast(m_idx_shape); + for (size_t i = 0; i < idx_ndim; ++i) { + if (i) { + dst.shape[i - 1] = src_layout.shape[i]; + } + dst.stride[i] = src_layout.stride[i]; } dst.ptr = src.vec.ptr(); param.data_shape[i] = m_data->layout.shape[src.axis]; @@ -107,6 +111,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { gen_offset_base(param, m_stream); } +template +void ExecImplHelper::dispatch_gen_offset_base_nidx() { + switch (m_idx_shape.ndim) { +#define cb(_n) \ + case _n: \ + return dispatch_gen_offset_base_nidx_ndim(); + MEGDNN_FOREACH_TENSOR_NDIM(cb) +#undef cb + } + megdnn_throw("bad index ndim"); +} + void ExecImplHelper::dispatch_gen_offset_base() { switch (m_index->size()) { #define cb(_n) \ @@ -154,6 +170,8 @@ void ExecImpl::dispatch_exec_ctype_ndim() { param.data = m_data->ptr(); param.value = m_value->ptr(); param.idx_axis = m_idx_axis; + param.idx_axis_end = m_idx_axis + m_idx_shape.ndim; + param.idx_nelems = m_idx_shape.total_nr_elems(); param.value_stride = m_value_stride; for (int i = 0; i < ndim; ++i) { param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; diff --git a/dnn/test/common/indexing_multi_axis_vec.h b/dnn/test/common/indexing_multi_axis_vec.h index 2186bc44a7cc3c129a48282bee0db2e9cc74906f..4f1db7e8e642e000723f2bb662b7b5e055aa07b2 100644 --- a/dnn/test/common/indexing_multi_axis_vec.h +++ b/dnn/test/common/indexing_multi_axis_vec.h @@ -46,6 +46,15 @@ struct OprProxyIndexingMultiAxisVecHelper { return ret; } + size_t get_index_ndim(const TensorNDArray& tensors) const { + megdnn_assert(tensors.size() >= 3); + size_t ndim = 0; + for (size_t i = 2; i < tensors.size(); ++i) { + ndim = std::max(tensors[i].layout.ndim, ndim); + } + return ndim; + } + IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout( const TensorLayoutArray& layouts) const { megdnn_assert(layouts.size() >= 3); @@ -65,7 +74,8 @@ struct OprProxy : public OprProxyIndexingMultiAxisVecHelpe void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const { WorkspaceWrapper W( opr->handle(), opr->get_workspace_in_bytes( - tensors[1].layout, axes, tensors.size() - 2)); + tensors[1].layout, axes, tensors.size() - 2, + get_index_ndim(tensors))); opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); } @@ -81,7 +91,8 @@ struct OprProxy : public OprProxyIndexingMultiAxisVecH void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const { WorkspaceWrapper W( opr->handle(), opr->get_workspace_in_bytes( - tensors[1].layout, axes, tensors.size() - 2)); + tensors[1].layout, axes, tensors.size() - 2, + get_index_ndim(tensors))); opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); } @@ -95,7 +106,8 @@ struct OprProxy : public OprProxyIndexingMultiAxisVecHe void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const { WorkspaceWrapper W( opr->handle(), opr->get_workspace_in_bytes( - tensors[1].layout, axes, tensors.size() - 2)); + tensors[1].layout, axes, tensors.size() - 2, + get_index_ndim(tensors))); opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); } diff --git a/dnn/test/common/mesh_indexing.h b/dnn/test/common/mesh_indexing.h index f17b4a7ce4eb4a8738678927c8f02fac2daae47e..e58e322793769a1e111ef9893ad13f9332e8d7b4 100644 --- a/dnn/test/common/mesh_indexing.h +++ b/dnn/test/common/mesh_indexing.h @@ -27,7 +27,7 @@ namespace test { WorkspaceWrapper W( \ opr->handle(), \ opr->get_workspace_in_bytes( \ - tensors[1].layout, axes, tensors.size() - 2)); \ + tensors[1].layout, axes, tensors.size() - 2, 1)); \ opr->exec( \ tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \ } \ @@ -46,7 +46,7 @@ namespace test { WorkspaceWrapper W( \ opr->handle(), \ opr->get_workspace_in_bytes( \ - tensors[1].layout, axes, tensors.size() - 2)); \ + tensors[1].layout, axes, tensors.size() - 2, 1)); \ opr->exec( \ tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \ } \ diff --git a/dnn/test/cuda/indexing_multi_axis_vec.cpp b/dnn/test/cuda/indexing_multi_axis_vec.cpp index 3c9bc4358e9aafe5ee549a411f59e8a58d746660..09aeb3ad5855b2807d3fa67ad705cd469b00d910 100644 --- a/dnn/test/cuda/indexing_multi_axis_vec.cpp +++ b/dnn/test/cuda/indexing_multi_axis_vec.cpp @@ -132,6 +132,25 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) { TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}}); } +TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC_ND_INDEX) { + run_check(handle_cuda()); + Checker checker(handle_cuda()); + OrderedRNG rng; + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Int32()) + .set_dtype(4, dtype::Int32()) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_rng(3, &rng) + .set_rng(4, &rng); + + checker.set_proxy({{1, 2, 3}}) + .execs({{5, 5, 6, 7, 3}, {5, 2, 3, 4, 3}, {3, 1}, {2, 1, 1}, {1, 4}}); +} + TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { run_check(handle_cuda()); Checker checker(handle_cuda()); diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index 2d62f30f3e7d48f846b982176a1a131a7c8e30d8..a2c127a1ce3bf260682a00a8591d0b02e79f8ef7 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -708,3 +708,19 @@ def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic): run_test((10, 10, 0), test4) run_test((10, 10, 10), test3) run_test((10, 10, 10), test4) + + +@pytest.mark.parametrize("symbolic", [True, False, None]) +def test_nd_int_indexing(symbolic): + inp = np.arange(11) + idx = np.random.randint(11, size=(5, 7)) + + def run_test(args, fn): + npy_out = fn(*args) + if symbolic: + fn = jit.trace(symbolic=symbolic)(fn) + for _ in range(3): + out = fn(*[Tensor(arg) for arg in args]) + np.testing.assert_equal(out.numpy(), npy_out) + + run_test([inp, idx], lambda inp, idx: inp[idx]) diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index 8fdf47e6b44e9ff79bb0c6c96c5453e2e7f12817..99f5d0e756e194648e3d9be311014075d0360919 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -197,9 +197,15 @@ Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder::megdnn_opr( template void mixin::IndexingMultiAxisVecMegDNNOprHolder::register_workspace_infer( const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr, - VarNode* data, VarNode* value) { + VarNode* data, VarNode* value, VarNodeArray idx_arr) { using namespace cg::static_infer; - auto infer_shape = [this, &index_desc, &opr](TensorShape& dest, const InpVal& inp) { + DepVal deps = {{data, DepType::SHAPE}, {value, DepType::SHAPE}}; + + for (auto&& idx : idx_arr) { + deps.push_back({idx, DepType::SHAPE}); + } + auto infer_shape = [this, &index_desc, &opr, nr_idx = idx_arr.size()]( + TensorShape& dest, const InpVal& inp) { size_t axes[TensorShape::MAX_NDIM], nr_axes = 0; auto ndim = inp.val[0].shape().ndim; for (auto&& i : reverse_adaptor(index_desc)) { @@ -207,18 +213,22 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder::register_workspace_infer( axes[nr_axes++] = i.axis.get(ndim); } } + mgb_assert(nr_axes == nr_idx); if (!nr_axes) { dest = {0}; } else { + size_t idx_ndim = 0; + for (size_t i = 0; i < nr_idx; ++i) { + idx_ndim = std::max(idx_ndim, inp.val[2 + i].shape().ndim); + } + mgb_assert(idx_ndim > 0); dest = {megdnn_opr(opr).get_workspace_in_bytes( - inp.val[1].shape(), axes, nr_axes)}; + inp.val[1].shape(), axes, nr_axes, idx_ndim)}; } return true; }; opr.owner_graph()->static_infer_manager().register_shape_infer( - opr.output(1), {SourceType::DEP, - {{data, DepType::SHAPE}, {value, DepType::SHAPE}}, - infer_shape}); + opr.output(1), {SourceType::DEP, deps, infer_shape}); } template @@ -342,8 +352,13 @@ void IndexingMultiAxisVecBase::init_output_static_infer_desc() { }; owner_graph()->static_infer_manager().register_shape_infer( output(0), {SourceType::DEP, deps, infer_shape}); - - this->register_workspace_infer(index_desc(), *this, input(0), output(0)); + VarNodeArray idx_arr; + for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) { + if (m_input2idxonly_axis_indexer[i]) { + idx_arr.push_back(input(i)); + } + } + this->register_workspace_infer(index_desc(), *this, input(0), output(0), idx_arr); } template @@ -401,7 +416,13 @@ void intl::IndexingModifyMultiAxisVecHelper::init_output_static_infer_desc( this->owner_graph()->static_infer_manager().register_shape_infer( this->output(0), ShapeInferDesc::make_identity(this->input(0))); - this->register_workspace_infer(index_desc(), *this, input(0), input(1)); + VarNodeArray idx_arr; + for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) { + if (m_input2idxonly_axis_indexer[i]) { + idx_arr.push_back(input(i)); + } + } + this->register_workspace_infer(index_desc(), *this, input(0), input(1), idx_arr); } template diff --git a/src/opr/include/megbrain/opr/indexing.h b/src/opr/include/megbrain/opr/indexing.h index c2f072c5df74771cf5bd07d19ab7096dcc586b0e..a8993830d1e7a8fcb0962a1d798fd8c07b4e256a 100644 --- a/src/opr/include/megbrain/opr/indexing.h +++ b/src/opr/include/megbrain/opr/indexing.h @@ -96,7 +96,7 @@ protected: void register_workspace_infer( const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr, - VarNode* data, VarNode* value); + VarNode* data, VarNode* value, VarNodeArray idx_arr); void record_megdnn_opr(mgb::cg::GraphExecutable::ExecDependencyArray& deps); };