提交 68cdabd2 编写于 作者: M Megvii Engine Team

feat(opr): indexing_multi_axis_vec support nd index

GitOrigin-RevId: 07b1248bdcaa8d12c91220eb482090ece16a0a10
上级 05ee6038
......@@ -1115,7 +1115,7 @@ public:
* access *data*; stride of layout on that axis would be zero, and
* strides on other axes correspond to the strides in *data*
*/
static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout(
static std::tuple<TensorLayout, size_t, TensorShape> get_value_iter_optimized_layout(
const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
size_t idx_axis);
......@@ -1159,7 +1159,8 @@ public:
* \brief get workspace size based on output shape and indexing axes
*/
size_t get_workspace_in_bytes(
const TensorShape& dst, const size_t* axes, size_t nr_axes);
const TensorShape& dst, const size_t* axes, size_t nr_axes,
size_t idx_ndim);
static void deduce_layout(
const TensorLayout& data, const IndexDescLayoutOnly& index,
......@@ -1193,7 +1194,8 @@ public:
* axes
*/
size_t get_workspace_in_bytes(
const TensorShape& value, const size_t* axes, size_t nr_axes);
const TensorShape& value, const size_t* axes, size_t nr_axes,
size_t idx_ndim);
protected:
ExecInfo check_exec(
......@@ -1223,7 +1225,7 @@ public:
using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) {
size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t, size_t) {
return 0;
}
......
......@@ -15,8 +15,10 @@
using namespace megdnn;
namespace {
// we need a workspace to store offset base table, which has same size with index
size_t get_index_size_for_workspace(
const TensorShape& shp, const size_t* axes, size_t nr_axes) {
const TensorShape& shp, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
size_t idx_axis = axes[0];
megdnn_assert(shp.ndim && nr_axes);
for (size_t i = 1; i < nr_axes; ++i) {
......@@ -29,7 +31,11 @@ size_t get_index_size_for_workspace(
megdnn_assert(
shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis,
shp.to_string().c_str());
return shp.shape[idx_axis];
size_t idx_size = 1;
for (size_t i = 0; i < idx_ndim; ++i) {
idx_size *= shp.shape[idx_axis + i];
}
return idx_size;
}
} // anonymous namespace
......@@ -47,23 +53,17 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) {
megdnn_assert(!index.empty());
megdnn_assert(data.ndim >= index.size());
dst.ndim = data.ndim - index.size() + 1;
dst.shape[0] = 1;
dst.ndim = data.ndim - index.size();
dst.dtype = data.dtype;
TensorShapeArray index_shapes;
auto brdcast = [&](const TensorLayout& ly) {
if (ly.ndim != 1)
return false;
if (dst.shape[0] == ly.shape[0])
return true;
if (dst.shape[0] == 1) {
dst.shape[0] = ly.shape[0];
return true;
}
return ly.shape[0] == 1;
megdnn_assert(ly.dtype == dtype::Int32{});
index_shapes.push_back(ly);
};
size_t dst_axis = 1;
size_t dst_axis = 0;
ptrdiff_t prev_axis = -1;
for (size_t axis = 0; axis < index.size(); ++axis) {
auto&& idx = index[axis];
......@@ -73,10 +73,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
megdnn_assert(
idx.axis<data.ndim&& static_cast<ptrdiff_t>(idx.axis)> prev_axis,
"index %zu requests invalid axis %zu", axis, idx.axis);
auto brd_succ = brdcast(idx.layout);
megdnn_assert(
brd_succ, "invalid layout at index %zu: %s", axis,
idx.layout.to_string().c_str());
brdcast(idx.layout);
for (size_t i = prev_axis + 1; i < idx.axis; ++i) {
dst.shape[dst_axis++] = data.shape[i];
......@@ -99,13 +96,16 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
}
}
if (contig_idx) {
auto shp0 = dst.shape[0];
idx_axis = index[0].axis;
for (size_t i = 0; i < idx_axis; ++i) {
dst.shape[i] = dst.shape[i + 1];
}
dst.shape[idx_axis] = shp0;
}
TensorShape index_shape;
Elemwise::deduce_shape(index_shapes, index_shape);
for (size_t i = 0; i < index_shape.ndim; ++i) {
dst.add_axis_inplace(idx_axis + i, 1, 0);
dst.shape[idx_axis + i] = index_shape.shape[i];
}
dst.init_contiguous_stride();
......@@ -145,15 +145,26 @@ IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworksp
return ret;
}
std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
std::tuple<TensorLayout, size_t, TensorShape> IndexingMultiAxisVecBase::
get_value_iter_optimized_layout(
const TensorLayout& data, const TensorLayout& value,
const IndexDesc& index, size_t idx_axis) {
size_t data_axes[TensorLayout::MAX_NDIM],
nr_axes = get_nonindex_axes(data.ndim, index, data_axes);
// broadcast index shapes
TensorLayout index_shape;
{
TensorShapeArray index_shapes;
for (auto& idx : index) {
megdnn_assert(idx.vec.layout.dtype == dtype::Int32{});
index_shapes.push_back(idx.vec.layout);
}
Elemwise::deduce_shape(index_shapes, index_shape);
}
megdnn_assert(
nr_axes == value.ndim - 1 && idx_axis < value.ndim &&
nr_axes == value.ndim - index_shape.ndim && idx_axis < value.ndim &&
nr_axes + index.size() == data.ndim);
TensorLayout ret;
......@@ -165,10 +176,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
}
ret = ret.collapse_contiguous();
}
ret.shape[ret.ndim] = value.shape[idx_axis];
ret.stride[ret.ndim] = 0;
size_t ret_idx_axis = ret.ndim;
for (size_t i = 0; i < index_shape.ndim; ++i) {
ret.shape[ret.ndim] = value.shape[idx_axis + i];
ret.stride[ret.ndim] = 0;
++ret.ndim;
}
if (idx_axis < nr_axes) {
TensorLayout tail;
......@@ -185,12 +199,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase::
}
}
return {ret, ret_idx_axis};
return std::make_tuple(ret, ret_idx_axis, index_shape);
}
size_t IndexingMultiAxisVec::get_workspace_in_bytes(
const TensorShape& dst, const size_t* axes, size_t nr_axes) {
return get_workspace_in_bytes(get_index_size_for_workspace(dst, axes, nr_axes));
const TensorShape& dst, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
return get_workspace_in_bytes(
get_index_size_for_workspace(dst, axes, nr_axes, idx_ndim));
}
IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
......@@ -205,8 +220,9 @@ IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
}
size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes(
const TensorShape& value, const size_t* axes, size_t nr_axes) {
return get_workspace_in_bytes(get_index_size_for_workspace(value, axes, nr_axes));
const TensorShape& value, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
return get_workspace_in_bytes(
get_index_size_for_workspace(value, axes, nr_axes, idx_ndim));
}
IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec(
......
......@@ -21,17 +21,24 @@ namespace cuda {
namespace indexing_multi_axis_vec {
//! AxisIndexer equiv in kernel
template <int idx_ndim>
struct KAxisIndexer {
int stride;
int stride[idx_ndim];
#ifdef WIN32
Uint32Fastdiv shape[idx_ndim];
#else
// original shape[0] not storaged
Uint32Fastdiv shape[idx_ndim - 1];
#endif
const int* ptr;
};
//! param for gen_offset_base
template <int nidx>
template <int nidx, int idx_ndim>
struct GenOffsetBaseParam {
uint32_t size; //!< number of outputs; also size of each index
int* output; //!< output ptr
KAxisIndexer indexer[nidx];
KAxisIndexer<idx_ndim> indexer[nidx];
uint32_t data_shape[nidx];
int data_stride[nidx];
......@@ -59,7 +66,12 @@ struct ApplyOprParam {
const int* offset_base;
ctype *data, *value;
// first idx axis
int idx_axis;
// last idx axis + 1
int idx_axis_end;
// number of elements for idx shape
int idx_nelems;
int value_stride;
......@@ -68,8 +80,9 @@ struct ApplyOprParam {
};
//! generate offset bases for first axis in the output
template <int nidx>
void gen_offset_base(const GenOffsetBaseParam<nidx>& param, cudaStream_t stream);
template <int nidx, int idx_ndim>
void gen_offset_base(
const GenOffsetBaseParam<nidx, idx_ndim>& param, cudaStream_t stream);
struct OprAtomicIncr {
#if MEGDNN_CC_CUDA
......
......@@ -29,11 +29,23 @@ namespace {
uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.tot_size) {
int offset = 0, coidx = oidx;
int all_ax_idx[ndim];
// offset in index
int idx_flat = 0;
// for non-indexed axes get offset
#pragma unroll
for (int i = ndim - 1; i >= 0; -- i) {
int next_coidx, ax_idx;
// [..., indexed_axes... |, ...]
if (i + 1 == param.idx_axis_end) {
idx_flat = coidx;
}
// [... |, indexed_axes..., ...]
if (i + 1 == param.idx_axis) {
idx_flat -= coidx * param.idx_nelems;
}
// shape[i] was storaged at shape[i-1]
if (i) {
// fast divide
next_coidx = coidx / param.value_ly_on_data.shape[i - 1];
ax_idx =
coidx -
......@@ -44,9 +56,9 @@ namespace {
ax_idx = coidx;
}
offset += param.value_ly_on_data.stride[i] * ax_idx;
all_ax_idx[i] = ax_idx;
}
offset += param.offset_base[all_ax_idx[param.idx_axis]];
// offset from index, which was generated before
offset += param.offset_base[idx_flat];
Opr::apply(
param.data[offset],
param.value[oidx * param.value_stride]);
......
......@@ -18,14 +18,29 @@ using namespace cuda;
using namespace indexing_multi_axis_vec;
namespace {
template <int nidx>
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) {
template <int nidx, int idx_ndim>
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx, idx_ndim> param) {
int oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.size) {
int offset = 0;
#pragma unroll
for (int i = 0; i < nidx; ++i) {
int data_idx = param.indexer[i].ptr[param.indexer[i].stride * oidx];
auto& indexer = param.indexer[i];
// index in index
int idx_flat = 0, coidx = oidx;
#pragma unroll
for (int j = idx_ndim - 1; j >= 0; --j) {
int ax_idx;
if (j) {
int next_coidx = coidx / indexer.shape[j - 1];
ax_idx = coidx - (next_coidx * indexer.shape[j - 1].divisor());
coidx = next_coidx;
} else {
ax_idx = coidx;
}
idx_flat += indexer.stride[j] * ax_idx;
}
int data_idx = indexer.ptr[idx_flat];
data_idx += (data_idx < 0 ? param.data_shape[i] : 0);
if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) {
// cast to uint32 to handle both negative and overflow
......@@ -36,17 +51,19 @@ __global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) {
i, data_idx, param.data_shape[i]);
data_idx = 0;
}
// calculate offset from current index
offset += data_idx * param.data_stride[i];
}
// sum offsets and store at offset table
param.output[oidx] = offset;
}
}
} // namespace
template <int nidx>
template <int nidx, int idx_ndim>
void indexing_multi_axis_vec::gen_offset_base(
const GenOffsetBaseParam<nidx>& param, cudaStream_t stream) {
void (*kptr)(GenOffsetBaseParam<nidx>) = kgen_offset_base<nidx>;
const GenOffsetBaseParam<nidx, idx_ndim>& param, cudaStream_t stream) {
void (*kptr)(GenOffsetBaseParam<nidx, idx_ndim>) = kgen_offset_base<nidx, idx_ndim>;
int bsize = query_blocksize_for_kernel(kptr);
(*kptr)<<<DIVUP(param.size, bsize), bsize, 0, stream>>>(param);
}
......@@ -55,9 +72,17 @@ namespace megdnn {
namespace cuda {
namespace indexing_multi_axis_vec {
#define INST(_n) \
template void gen_offset_base(const GenOffsetBaseParam<_n>&, cudaStream_t);
MEGDNN_FOREACH_TENSOR_NDIM(INST)
#define INST(_m, _n) \
template void gen_offset_base(const GenOffsetBaseParam<_m, _n>&, cudaStream_t);
MEGDNN_FOREACH_TENSOR_NDIM(INST, 1)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 2)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 3)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 4)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 5)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 6)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 7)
#undef INST
} // namespace indexing_multi_axis_vec
......
......@@ -21,9 +21,10 @@ using namespace indexing_multi_axis_vec;
namespace {
class ExecImplHelper {
template <int nidx, int idx_ndim>
void dispatch_gen_offset_base_nidx_ndim();
template <int nidx>
void dispatch_gen_offset_base_nidx();
void dispatch_gen_offset_base();
protected:
......@@ -38,6 +39,7 @@ protected:
int* const m_offset_base;
TensorLayout m_value_layout_on_data;
size_t m_idx_axis;
TensorShape m_idx_shape;
int m_value_stride;
public:
......@@ -76,28 +78,30 @@ ExecImplHelper::ExecImplHelper(
m_exec_info{&exec_info},
m_offset_base{workspace.ptr<int>()} {
safe_size_in_kern(data.layout.total_nr_elems());
dispatch_gen_offset_base();
std::tie(m_value_layout_on_data, m_idx_axis) =
std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) =
IndexingMultiAxisVec::get_value_iter_optimized_layout(
data.layout, value.layout, index, exec_info.idx_axis);
dispatch_gen_offset_base();
m_value_stride = exec_info.value_stride;
}
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
GenOffsetBaseParam<nidx> param;
param.size = m_value->layout.shape[m_exec_info->idx_axis];
template <int nidx, int idx_ndim>
void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() {
GenOffsetBaseParam<nidx, idx_ndim> param;
param.size = m_idx_shape.total_nr_elems();
param.output = m_offset_base;
param.error_tracker = m_exec_info->error_tracker;
param.error_info = m_exec_info->error_info;
megdnn_assert(m_idx_shape.ndim == idx_ndim);
for (int i = 0; i < nidx; ++i) {
auto&& dst = param.indexer[i];
auto&& src = m_index->operator[](i);
megdnn_assert(src.vec.layout.ndim == 1);
dst.stride = src.vec.layout.stride[0];
if (src.vec.layout.shape[0] == 1) {
dst.stride = 0;
auto&& src = m_index->at(i);
auto src_layout = src.vec.layout.broadcast(m_idx_shape);
for (size_t i = 0; i < idx_ndim; ++i) {
if (i) {
dst.shape[i - 1] = src_layout.shape[i];
}
dst.stride[i] = src_layout.stride[i];
}
dst.ptr = src.vec.ptr<int>();
param.data_shape[i] = m_data->layout.shape[src.axis];
......@@ -106,6 +110,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
gen_offset_base(param, m_stream);
}
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
switch (m_idx_shape.ndim) {
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
megdnn_throw("bad index ndim");
}
void ExecImplHelper::dispatch_gen_offset_base() {
switch (m_index->size()) {
#define cb(_n) \
......@@ -153,6 +169,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
param.data = m_data->ptr<ctype>();
param.value = m_value->ptr<ctype>();
param.idx_axis = m_idx_axis;
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim;
param.idx_nelems = m_idx_shape.total_nr_elems();
param.value_stride = m_value_stride;
for (int i = 0; i < ndim; ++i) {
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i];
......
......@@ -33,37 +33,46 @@ void do_exec(
auto data_layout = data.layout;
auto data_ptr = data.ptr<data_type>();
std::tuple<size_t, const idx_type*, ptrdiff_t> index_raw[TensorLayout::MAX_NDIM];
std::tuple<size_t, const idx_type*, TensorLayout> index_raw[TensorLayout::MAX_NDIM];
size_t nr_index = index.size();
TensorShape idx_shape;
{
TensorShapeArray idx_shapes;
for (size_t i = 0; i < nr_index; ++i) {
idx_shapes.push_back(index[i].vec.layout);
}
Elemwise::deduce_shape(idx_shapes, idx_shape);
}
for (size_t i = 0; i < nr_index; ++i) {
auto&& s = index[i];
index_raw[i] =
std::make_tuple(s.axis, s.vec.ptr<idx_type>(), s.vec.layout.stride[0]);
if (s.vec.layout.shape[0] == 1)
std::get<2>(index_raw[i]) = 0;
index_raw[i] = std::make_tuple(
s.axis, s.vec.ptr<idx_type>(), s.vec.layout.broadcast(idx_shape));
}
auto value_iter = tensor_iter<data_type>(value).begin();
for (size_t _ = 0, _t = value.layout.total_nr_elems(); _ < _t; ++_) {
ptrdiff_t offset = 0;
auto index_idx = value_iter.idx()[exec_info.idx_axis];
auto* index_idx = value_iter.idx() + exec_info.idx_axis;
for (size_t i = 0; i < nr_index; ++i) {
size_t axis = std::get<0>(index_raw[i]),
data_shape = data_layout.shape[axis];
ptrdiff_t data_stride = data_layout.stride[axis];
idx_type data_idx =
std::get<1>(index_raw[i])[std::get<2>(index_raw[i]) * index_idx];
size_t index_offset = 0;
TensorLayout& index_layout = std::get<2>(index_raw[i]);
for (size_t i = 0; i < index_layout.ndim; ++i) {
index_offset += index_idx[i] * index_layout.stride[i];
}
idx_type data_idx = std::get<1>(index_raw[i])[index_offset];
if (data_idx < 0)
data_idx += data_shape;
megdnn_assert(
data_idx >= 0 && static_cast<size_t>(data_idx) < data_shape,
"bad index value for index %zu at output %zu", i, index_idx);
"bad index value for index %zu at output %zu", i, *index_idx);
offset += data_stride * data_idx;
}
for (size_t i = 0; i < nr_nonidx_axes; ++i) {
auto stride = data_layout.stride[nonidx_axes[i]];
auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis)];
auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis) * idx_shape.ndim];
offset += stride * idx;
}
Opr::apply(data_ptr[offset], *value_iter);
......
......@@ -21,17 +21,23 @@ namespace rocm {
namespace indexing_multi_axis_vec {
//! AxisIndexer equiv in kernel
template <int idx_ndim>
struct KAxisIndexer {
int stride;
int stride[idx_ndim];
#ifdef WIN32
Uint32Fastdiv shape[idx_ndim];
#else
Uint32Fastdiv shape[idx_ndim - 1];
#endif
const int *ptr;
};
//! param for gen_offset_base
template<int nidx>
template<int nidx, int idx_ndim>
struct GenOffsetBaseParam {
uint32_t size; //!< number of outputs; also size of each index
int *output; //!< output ptr
KAxisIndexer indexer[nidx];
KAxisIndexer<idx_ndim> indexer[nidx];
uint32_t data_shape[nidx];
int data_stride[nidx];
......@@ -60,6 +66,8 @@ namespace indexing_multi_axis_vec {
ctype *data, *value;
int idx_axis;
int idx_axis_end;
int idx_nelems;
int value_stride;
......@@ -68,8 +76,8 @@ namespace indexing_multi_axis_vec {
};
//! generate offset bases for first axis in the output
template<int nidx>
void gen_offset_base(const GenOffsetBaseParam<nidx> &param,
template<int nidx, int idx_ndim>
void gen_offset_base(const GenOffsetBaseParam<nidx, idx_ndim> &param,
hipStream_t stream);
struct OprAtomicIncr {
......
......@@ -30,10 +30,17 @@ namespace {
uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.tot_size) {
int offset = 0, coidx = oidx;
int all_ax_idx[ndim];
int idx_flat = 0;
#pragma unroll
for (int i = ndim - 1; i >= 0; -- i) {
int next_coidx, ax_idx;
if (i + 1 == param.idx_axis_end) {
idx_flat = coidx;
}
// may not trigger
if (i + 1 == param.idx_axis) {
idx_flat -= coidx * param.idx_nelems;
}
if (i) {
next_coidx = coidx / param.value_ly_on_data.shape[i - 1];
ax_idx =
......@@ -45,9 +52,8 @@ namespace {
ax_idx = coidx;
}
offset += param.value_ly_on_data.stride[i] * ax_idx;
all_ax_idx[i] = ax_idx;
}
offset += param.offset_base[all_ax_idx[param.idx_axis]];
offset += param.offset_base[idx_flat];
Opr::apply(
param.data[offset],
param.value[oidx * param.value_stride]);
......
......@@ -21,15 +21,28 @@ using namespace rocm;
using namespace indexing_multi_axis_vec;
namespace {
template<int nidx>
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) {
template<int nidx, int idx_ndim>
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx, idx_ndim> param) {
int oidx = threadIdx.x + blockDim.x * blockIdx.x;
if (oidx < param.size) {
int offset = 0;
#pragma unroll
for (int i = 0; i < nidx; ++ i) {
int data_idx = param.indexer[i].ptr[
param.indexer[i].stride * oidx];
auto& indexer = param.indexer[i];
int offset2 = 0, coidx = oidx;
#pragma unroll
for (int j = idx_ndim-1; j >= 0; --j) {
int ax_idx;
if (j) {
int next_coidx = coidx / indexer.shape[j-1];
ax_idx = coidx - (next_coidx * indexer.shape[j-1].divisor());
coidx = next_coidx;
} else {
ax_idx = coidx;
}
offset2 += indexer.stride[j] * ax_idx;
}
int data_idx = indexer.ptr[offset2];
data_idx += (data_idx < 0 ? param.data_shape[i] : 0);
if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) {
// cast to uint32 to handle both negative and overflow
......@@ -50,20 +63,28 @@ namespace megdnn {
namespace rocm {
namespace indexing_multi_axis_vec {
#define INST(_n) \
#define INST(_m, _n) \
template void gen_offset_base( \
const GenOffsetBaseParam<_n> &, hipStream_t);
MEGDNN_FOREACH_TENSOR_NDIM(INST)
const GenOffsetBaseParam<_m, _n> &, hipStream_t);
MEGDNN_FOREACH_TENSOR_NDIM(INST, 1)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 2)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 3)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 4)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 5)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 6)
MEGDNN_FOREACH_TENSOR_NDIM(INST, 7)
#undef INST
} // namespace indexing_multi_axis_vec
} // namespace rocm
} // namespace megdnn
template<int nidx>
template<int nidx, int idx_ndim>
void indexing_multi_axis_vec::gen_offset_base(
const GenOffsetBaseParam<nidx> &param, hipStream_t stream) {
void (*kptr)(GenOffsetBaseParam<nidx>) = kgen_offset_base<nidx>;
const GenOffsetBaseParam<nidx, idx_ndim> &param, hipStream_t stream) {
void (*kptr)(GenOffsetBaseParam<nidx, idx_ndim>) = kgen_offset_base<nidx, idx_ndim>;
int bsize = 256;
hipLaunchKernelGGL(kptr,
DIVUP(param.size, bsize), bsize, 0, stream,
......
......@@ -22,9 +22,10 @@ using namespace indexing_multi_axis_vec;
namespace {
class ExecImplHelper {
template <int nidx, int idx_ndim>
void dispatch_gen_offset_base_nidx_ndim();
template <int nidx>
void dispatch_gen_offset_base_nidx();
void dispatch_gen_offset_base();
protected:
......@@ -39,6 +40,7 @@ protected:
int* const m_offset_base;
TensorLayout m_value_layout_on_data;
size_t m_idx_axis;
TensorShape m_idx_shape;
int m_value_stride;
public:
......@@ -77,18 +79,17 @@ ExecImplHelper::ExecImplHelper(
m_exec_info{&exec_info},
m_offset_base{workspace.ptr<int>()} {
safe_size_in_kern(data.layout.total_nr_elems());
dispatch_gen_offset_base();
std::tie(m_value_layout_on_data, m_idx_axis) =
std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) =
IndexingMultiAxisVec::get_value_iter_optimized_layout(
data.layout, value.layout, index, exec_info.idx_axis);
dispatch_gen_offset_base();
m_value_stride = exec_info.value_stride;
}
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
GenOffsetBaseParam<nidx> param;
param.size = m_value->layout.shape[m_exec_info->idx_axis];
template <int nidx, int idx_ndim>
void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() {
GenOffsetBaseParam<nidx, idx_ndim> param;
param.size = m_idx_shape.total_nr_elems();
param.output = m_offset_base;
param.error_tracker = m_exec_info->error_tracker;
param.error_info = m_exec_info->error_info;
......@@ -96,9 +97,12 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
auto&& dst = param.indexer[i];
auto&& src = m_index->operator[](i);
megdnn_assert(src.vec.layout.ndim == 1);
dst.stride = src.vec.layout.stride[0];
if (src.vec.layout.shape[0] == 1) {
dst.stride = 0;
auto src_layout = src.vec.layout.broadcast(m_idx_shape);
for (size_t i = 0; i < idx_ndim; ++i) {
if (i) {
dst.shape[i - 1] = src_layout.shape[i];
}
dst.stride[i] = src_layout.stride[i];
}
dst.ptr = src.vec.ptr<int>();
param.data_shape[i] = m_data->layout.shape[src.axis];
......@@ -107,6 +111,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() {
gen_offset_base(param, m_stream);
}
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
switch (m_idx_shape.ndim) {
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
megdnn_throw("bad index ndim");
}
void ExecImplHelper::dispatch_gen_offset_base() {
switch (m_index->size()) {
#define cb(_n) \
......@@ -154,6 +170,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
param.data = m_data->ptr<ctype>();
param.value = m_value->ptr<ctype>();
param.idx_axis = m_idx_axis;
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim;
param.idx_nelems = m_idx_shape.total_nr_elems();
param.value_stride = m_value_stride;
for (int i = 0; i < ndim; ++i) {
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i];
......
......@@ -46,6 +46,15 @@ struct OprProxyIndexingMultiAxisVecHelper {
return ret;
}
size_t get_index_ndim(const TensorNDArray& tensors) const {
megdnn_assert(tensors.size() >= 3);
size_t ndim = 0;
for (size_t i = 2; i < tensors.size(); ++i) {
ndim = std::max(tensors[i].layout.ndim, ndim);
}
return ndim;
}
IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout(
const TensorLayoutArray& layouts) const {
megdnn_assert(layouts.size() >= 3);
......@@ -65,7 +74,8 @@ struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelpe
void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const {
WorkspaceWrapper W(
opr->handle(), opr->get_workspace_in_bytes(
tensors[1].layout, axes, tensors.size() - 2));
tensors[1].layout, axes, tensors.size() - 2,
get_index_ndim(tensors)));
opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace());
}
......@@ -81,7 +91,8 @@ struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecH
void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const {
WorkspaceWrapper W(
opr->handle(), opr->get_workspace_in_bytes(
tensors[1].layout, axes, tensors.size() - 2));
tensors[1].layout, axes, tensors.size() - 2,
get_index_ndim(tensors)));
opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
}
......@@ -95,7 +106,8 @@ struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHe
void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const {
WorkspaceWrapper W(
opr->handle(), opr->get_workspace_in_bytes(
tensors[1].layout, axes, tensors.size() - 2));
tensors[1].layout, axes, tensors.size() - 2,
get_index_ndim(tensors)));
opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
}
......
......@@ -27,7 +27,7 @@ namespace test {
WorkspaceWrapper W( \
opr->handle(), \
opr->get_workspace_in_bytes( \
tensors[1].layout, axes, tensors.size() - 2)); \
tensors[1].layout, axes, tensors.size() - 2, 1)); \
opr->exec( \
tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
} \
......@@ -46,7 +46,7 @@ namespace test {
WorkspaceWrapper W( \
opr->handle(), \
opr->get_workspace_in_bytes( \
tensors[1].layout, axes, tensors.size() - 2)); \
tensors[1].layout, axes, tensors.size() - 2, 1)); \
opr->exec( \
tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
} \
......
......@@ -132,6 +132,25 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
}
TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC_ND_INDEX) {
run_check<IndexingMultiAxisVec>(handle_cuda());
Checker<IndexingMultiAxisVec> checker(handle_cuda());
OrderedRNG rng;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Int32())
.set_dtype(3, dtype::Int32())
.set_dtype(4, dtype::Int32())
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_rng(4, &rng);
checker.set_proxy({{1, 2, 3}})
.execs({{5, 5, 6, 7, 3}, {5, 2, 3, 4, 3}, {3, 1}, {2, 1, 1}, {1, 4}});
}
TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
run_check<IndexingIncrMultiAxisVec>(handle_cuda());
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
......
......@@ -708,3 +708,19 @@ def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic):
run_test((10, 10, 0), test4)
run_test((10, 10, 10), test3)
run_test((10, 10, 10), test4)
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_nd_int_indexing(symbolic):
inp = np.arange(11)
idx = np.random.randint(11, size=(5, 7))
def run_test(args, fn):
npy_out = fn(*args)
if symbolic:
fn = jit.trace(symbolic=symbolic)(fn)
for _ in range(3):
out = fn(*[Tensor(arg) for arg in args])
np.testing.assert_equal(out.numpy(), npy_out)
run_test([inp, idx], lambda inp, idx: inp[idx])
......@@ -197,9 +197,15 @@ Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
template <class Opr>
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr,
VarNode* data, VarNode* value) {
VarNode* data, VarNode* value, VarNodeArray idx_arr) {
using namespace cg::static_infer;
auto infer_shape = [this, &index_desc, &opr](TensorShape& dest, const InpVal& inp) {
DepVal deps = {{data, DepType::SHAPE}, {value, DepType::SHAPE}};
for (auto&& idx : idx_arr) {
deps.push_back({idx, DepType::SHAPE});
}
auto infer_shape = [this, &index_desc, &opr, nr_idx = idx_arr.size()](
TensorShape& dest, const InpVal& inp) {
size_t axes[TensorShape::MAX_NDIM], nr_axes = 0;
auto ndim = inp.val[0].shape().ndim;
for (auto&& i : reverse_adaptor(index_desc)) {
......@@ -207,18 +213,22 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
axes[nr_axes++] = i.axis.get(ndim);
}
}
mgb_assert(nr_axes == nr_idx);
if (!nr_axes) {
dest = {0};
} else {
size_t idx_ndim = 0;
for (size_t i = 0; i < nr_idx; ++i) {
idx_ndim = std::max(idx_ndim, inp.val[2 + i].shape().ndim);
}
mgb_assert(idx_ndim > 0);
dest = {megdnn_opr(opr).get_workspace_in_bytes(
inp.val[1].shape(), axes, nr_axes)};
inp.val[1].shape(), axes, nr_axes, idx_ndim)};
}
return true;
};
opr.owner_graph()->static_infer_manager().register_shape_infer(
opr.output(1), {SourceType::DEP,
{{data, DepType::SHAPE}, {value, DepType::SHAPE}},
infer_shape});
opr.output(1), {SourceType::DEP, deps, infer_shape});
}
template <class Opr>
......@@ -342,8 +352,13 @@ void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() {
};
owner_graph()->static_infer_manager().register_shape_infer(
output(0), {SourceType::DEP, deps, infer_shape});
this->register_workspace_infer(index_desc(), *this, input(0), output(0));
VarNodeArray idx_arr;
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (m_input2idxonly_axis_indexer[i]) {
idx_arr.push_back(input(i));
}
}
this->register_workspace_infer(index_desc(), *this, input(0), output(0), idx_arr);
}
template <class Opr>
......@@ -401,7 +416,13 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc(
this->owner_graph()->static_infer_manager().register_shape_infer(
this->output(0), ShapeInferDesc::make_identity(this->input(0)));
this->register_workspace_infer(index_desc(), *this, input(0), input(1));
VarNodeArray idx_arr;
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
if (m_input2idxonly_axis_indexer[i]) {
idx_arr.push_back(input(i));
}
}
this->register_workspace_infer(index_desc(), *this, input(0), input(1), idx_arr);
}
template <class Opr>
......
......@@ -96,7 +96,7 @@ protected:
void register_workspace_infer(
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr,
VarNode* data, VarNode* value);
VarNode* data, VarNode* value, VarNodeArray idx_arr);
void record_megdnn_opr(mgb::cg::GraphExecutable::ExecDependencyArray& deps);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册