提交 f4784f4a 编写于 作者: M Megvii Engine Team

feat(dnn/rocm): add argsort opr

GitOrigin-RevId: b4c3eb4707bb4def8739d8ece8c26d4aa5af4147
上级 6082c353
/**
* \file dnn/src/rocm/argsort/argsort.cpp.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/utils.h.hip"
#include "./argsort.h.hip"
#include "./bitonic_sort.h.hip"
#include "megdnn/basic_types.h"
#include "hipcub/device/device_radix_sort.hpp"
#include "hipcub/device/device_segmented_radix_sort.hpp"
using namespace megdnn;
using namespace rocm;
namespace {
struct StridedOffsetIterator {
int bias, stride;
StridedOffsetIterator(int bias_, int stride_)
: bias(bias_), stride(stride_) {}
__device__ __forceinline__ int operator[](int i) const {
return stride * i + bias;
}
};
bool use_bitonic(uint32_t /*M*/, uint32_t N) {
// bitonic sort is preferred when N is small (alwyas faster than radix sort)
return N <= BITONIC_SORT_MAX_LENGTH;
}
bool use_segmented(uint32_t M, uint32_t /*N*/) {
// an empirical value:
// sort(1, 1e6): 0.574ms
// segsort({1,2,8,16}, 1e6): 7-8ms
// sort(1, 1e7): 3.425ms
// segsort({1,2,8,16}, 1e7): 71-84ms
//
// segsort is about 7x-10x slower than sort on small batches, so we can
// expect it to be faster than sort when batch is large enough.
return M >= 8;
}
__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) {
uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < n) {
dst[i] = i % mod;
}
}
template <typename ctype>
size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) {
if (use_bitonic(M, N)) {
return 0;
}
return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL,
M, N, 0, sizeof(float)*8, NULL);
}
} // anonymous namespace
template <typename KeyType, typename ValueType>
MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(
bool is_ascending, void* workspace, size_t workspace_size,
const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in,
ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){
hipError_t err;
if (use_segmented(M, N)) {
if (is_ascending) {
err = hipcub::DeviceSegmentedRadixSort::SortPairs(
workspace, workspace_size, keys_in, keys_out, values_in,
values_out, N * M, M, StridedOffsetIterator(0, N),
StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
hip_check(err);
} else {
err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending(
workspace, workspace_size, keys_in, keys_out, values_in,
values_out, N * M, M, StridedOffsetIterator(0, N),
StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
hip_check(err);
}
} else {
if (is_ascending) {
for (size_t i = 0; i < M; ++i) {
err = hipcub::DeviceRadixSort::SortPairs(
workspace, workspace_size, keys_in + N * i,
keys_out + N * i, values_in + N * i, values_out + N * i,
N, begin_bit, end_bit, stream);
hip_check(err);
if (!keys_in) {
return workspace_size;
}
}
} else {
for (size_t i = 0; i < M; ++i) {
err = hipcub::DeviceRadixSort::SortPairsDescending(
workspace, workspace_size, keys_in + N * i,
keys_out + N * i, values_in + N * i, values_out + N * i,
N, begin_bit, end_bit, stream);
hip_check(err);
if (!keys_in) {
return workspace_size;
}
}
}
}
return workspace_size;
}
size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
bool is_ascending,
bool iptr_src_given) {
size_t size = 0;
switch (dtype.enumv().ev) {
#define cb(ctype) \
case DTypeTrait<ctype>::enumv: \
size = get_sort_workspace<ctype>(M, N, is_ascending); \
break;
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default:
megdnn_throw("argsort only supports float, int32 and float16");
}
if (!iptr_src_given) {
size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);
}
return size;
}
template <typename dtype>
void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr,
void* workspace, uint32_t M, uint32_t N,
bool is_ascending, hipStream_t stream,
const int* iptr_src) {
size_t wk_size = get_sort_workspace<dtype>(M, N, is_ascending);
if (!iptr_src) {
int* ptr = reinterpret_cast<int*>(static_cast<uint8_t*>(workspace) +
DIVUP(wk_size, sizeof(float)) *
sizeof(float));
kern_arange<<<DIVUP(N * M, 512), 512, 0, stream>>>(ptr, M * N, N);
iptr_src = ptr;
}
if (use_bitonic(M, N)) {
hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending,
stream));
} else {
cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src,
iptr, M, N, 0, sizeof(float)*8, stream);
}
}
namespace megdnn {
namespace rocm {
#define INST_CUB_SORT(dtype) \
template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool, \
void*, size_t, const dtype*, dtype*, \
const dtype*, dtype*, uint32_t, uint32_t,\
int, int, hipStream_t);
#define INST_FORWARD(dtype) \
template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \
uint32_t, uint32_t, bool, hipStream_t, \
const int*);
ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t)
// INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT
#undef INST_FORWARD
}
} // namespace megdnn
// vim: ft=rocm syntax=rocm.doxygen
/**
* \file dnn/src/rocm/argsort/argsort.h.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "hcc_detail/hcc_defs_prologue.h"
#include "hip_header.h"
#include <stddef.h>
#include <stdint.h>
#include "megdnn/dtype.h"
namespace megdnn {
namespace rocm {
namespace argsort {
size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
bool is_ascending,
bool iptr_src_given = false);
template <typename KeyType, typename ValueType>
size_t cub_sort_pairs(
bool is_ascending, void* workspace, size_t workspace_size,
const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in,
ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream);
/*!
* \param iptr_src pointer to indices; a range would be generated if it is null
*/
template <typename dtype>
void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
uint32_t M, uint32_t N, bool is_ascending, hipStream_t stream,
const int* iptr_src = NULL);
//! iterate over all supported data types
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16))
} // namespace argsort
} // namespace rocm
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/rocm/argsort/backward.cpp.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/utils.h.hip"
#include "./argsort.h.hip"
#include "./backward.h.hip"
// #include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
using namespace argsort;
namespace {
template <typename T>
__global__ void backward_kernel(uint32_t dst_w, uint32_t src_w,
uint32_t src_size, T* dst, const T* src_data,
const int* src_idx) {
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < src_size) {
uint32_t r = idx / src_w;
dst[r * dst_w + src_idx[idx]] = src_data[idx];
}
}
} // namespace
template <typename T>
void argsort::backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w,
T* dst, const T* src_data, const int* src_idx,
hipStream_t stream) {
if (dst_w != src_w) {
hipMemsetAsync(dst, 0, dst_h * dst_w * sizeof(T), stream);
}
uint32_t src_size = dst_h * src_w;
backward_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>(
dst_w, src_w, src_size, dst, src_data, src_idx);
after_kernel_launch();
}
namespace megdnn {
namespace rocm {
namespace argsort {
#define INST(T) \
template void backward_proxy(uint32_t dst_h, uint32_t dst_w, \
uint32_t src_w, T* dst, const T* src_data, \
const int* src_idx, hipStream_t stream);
ARGSORT_FOREACH_CTYPE(INST)
#undef INST
} // namespace argsort
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/argsort/backward.h.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "hip_header.h"
#include <stdint.h>
namespace megdnn {
namespace rocm {
namespace argsort {
template <typename T>
void backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst,
const T* src_data, const int* src_idx, hipStream_t stream);
} // namespace argsort
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/argsort/bitonic_sort.cpp.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "./bitonic_sort.h.hip"
// #include "src/cuda/query_blocksize.cuh"
// #include "megdnn/dtype.h"
// #if __CUDACC_VER_MAJOR__ < 9
// #pragma message "warp sync disabled due to insufficient cuda version"
#define __syncwarp __syncthreads
// #endif
#include <algorithm>
#include <cmath>
using namespace megdnn;
using namespace rocm;
namespace bitonic_sort_impl {
//! load keys and init idx
template <class CompareLess, typename T>
__device__ __forceinline__ void safe_load0(T* dst, uint16_t* idx, const T* src,
uint32_t id, uint32_t size) {
dst[id] = id < size ? src[id] : CompareLess::template max<T>();
idx[id] = id;
}
//! load values
template <typename T>
__device__ __forceinline__ void safe_load1(T* dst, const T* src, uint32_t id,
uint32_t size) {
// broadcast last value to avoid out-of-bound values (for example, when
// input contains NaN)
dst[id] = src[min(id, size - 1)];
}
//! write keys
template <typename T>
__device__ __forceinline__ void safe_write0(T* dst, const T* src, uint32_t id,
uint32_t size) {
if (id < size) {
dst[id] = src[id];
}
}
//! write values
template <typename T>
__device__ __forceinline__ void safe_write1(T* dst, const T* src,
const uint16_t* remap, uint32_t id,
uint32_t size) {
if (id < size) {
dst[id] = src[remap[id]];
}
}
struct SyncWarp {
static __device__ __forceinline__ void s() { __syncwarp(); }
};
struct SyncBlock {
static __device__ __forceinline__ void s() { __syncthreads(); }
};
template <typename T>
struct NumTrait;
template <>
struct NumTrait<float> {
static __device__ __forceinline__ float max() { return INFINITY; }
static __device__ __forceinline__ float min() { return -INFINITY; }
};
template <>
struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t max() { return INT_MAX; }
static __device__ __forceinline__ int32_t min() { return INT_MIN; }
};
// #if !MEGDNN_DISABLE_FLOAT16
// template <>
// struct NumTrait<dt_float16> {
// static __device__ __forceinline__ dt_float16 max() {
// return std::numeric_limits<dt_float16>::max();
// }
// static __device__ __forceinline__ dt_float16 min() {
// return std::numeric_limits<dt_float16>::lowest();
// }
// };
// #endif
struct LessThan {
template <typename Key, typename Value>
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1,
Value v1) {
return (k0 < k1) | ((k0 == k1) & (v0 < v1));
}
template <typename T>
static __device__ __forceinline__ T max() {
return NumTrait<T>::max();
}
};
struct GreaterThan {
template <typename Key, typename Value>
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1,
Value v1) {
return (k0 > k1) | ((k0 == k1) & (v0 < v1));
}
template <typename T>
static __device__ __forceinline__ T max() {
return NumTrait<T>::min();
}
};
template <typename Key, typename Value>
union KVUnion {
Key key;
Value value;
};
template <typename Key, typename Value>
static int get_shmem(int block_size, void* = NULL) {
return (sizeof(KVUnion<Key, Value>) + sizeof(uint16_t)) * block_size * 4;
}
/*!
* \brief batched bitonic sort (M, N) for small N
*
* launch configuration:
* grid(X)
* block(N/4, Y)
*
* where N / 4 == 1 << nr_th_log2
*/
template <class Sync, typename Key, typename Value, class CompareLess,
uint32_t nr_th_log2>
static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp,
const Value* value_inp, Key* key_out,
Value* value_out) {
const uint32_t nr_th = 1 << nr_th_log2;
// 24KiB shared memory for 4-byte keys for 1024 threads
extern __shared__ uint8_t smem_storage[];
uint16_t* idx_storage = reinterpret_cast<uint16_t*>(smem_storage);
KVUnion<Key, Value>* keys_storage = reinterpret_cast<KVUnion<Key, Value>*>(
idx_storage + blockDim.y * (nr_th * 4));
uint32_t cur_batch = blockIdx.x * blockDim.y + threadIdx.y,
off = cur_batch * length;
key_inp += off;
key_out += off;
value_inp += off;
value_out += off;
uint32_t storage_offset = threadIdx.y * (nr_th * 4);
uint16_t* values = idx_storage + storage_offset;
Key* keys = reinterpret_cast<Key*>(keys_storage + storage_offset);
uint32_t tid0 = threadIdx.x, tid1 = tid0 + nr_th,
cur_length = cur_batch < batch ? length : 0;
safe_load0<CompareLess>(keys, values, key_inp, tid0, cur_length);
safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th, cur_length);
safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th * 2,
cur_length);
safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th * 3,
cur_length);
Sync::s();
#define WORK(_idx, _asc) \
do { \
uint32_t _id0 = (_idx), _id1 = _id0 + step; \
Key _k0 = keys[_id0], _k1 = keys[_id1]; \
uint16_t _v0 = values[_id0], _v1 = values[_id1]; \
if (CompareLess::cmp(_k0, _v0, _k1, _v1) != _asc) { \
keys[_id0] = _k1; \
keys[_id1] = _k0; \
values[_id0] = _v1; \
values[_id1] = _v0; \
} \
} while (0)
#pragma unroll
for (uint32_t slen_log = 0; slen_log <= (nr_th_log2 + 1); ++slen_log) {
// log2 of half of current bitonic sequence (i.e. length of its
// monotonic part)
uint32_t asc0 = !((tid0 >> slen_log) & 1),
asc1 = !((tid1 >> slen_log) & 1);
#pragma unroll
for (uint32_t j = 0; j <= slen_log; ++j) {
uint32_t step = 1 << (slen_log - j), xmask = step - 1,
ymask = ~xmask;
WORK((tid0 & xmask) + ((tid0 & ymask) << 1), asc0);
WORK((tid1 & xmask) + ((tid1 & ymask) << 1), asc1);
Sync::s();
}
}
#undef WORK
if (cur_batch < batch) {
safe_write0(key_out, keys, tid0, length);
safe_write0(key_out, keys, tid0 + nr_th, length);
safe_write0(key_out, keys, tid0 + nr_th * 2, length);
safe_write0(key_out, keys, tid0 + nr_th * 3, length);
// permute values according to sorted indices
Value* copied_values = reinterpret_cast<Value*>(keys);
safe_load1(copied_values, value_inp, tid0, cur_length);
safe_load1(copied_values, value_inp, tid0 + nr_th, cur_length);
safe_load1(copied_values, value_inp, tid0 + nr_th * 2, cur_length);
safe_load1(copied_values, value_inp, tid0 + nr_th * 3, cur_length);
Sync::s();
safe_write1(value_out, copied_values, values, tid0, length);
safe_write1(value_out, copied_values, values, tid0 + nr_th, length);
safe_write1(value_out, copied_values, values, tid0 + nr_th * 2, length);
safe_write1(value_out, copied_values, values, tid0 + nr_th * 3, length);
}
}
} // namespace bitonic_sort_impl
template <typename Key, typename Value>
hipError_t rocm::bitonic_sort(uint32_t batch, uint32_t length,
const Key* key_inp, const Value* value_inp,
Key* key_out, Value* value_out, bool ascending,
hipStream_t stream) {
using namespace bitonic_sort_impl;
if (length == 1) {
if (key_inp != key_out) {
hipMemcpyAsync(key_out, key_inp, sizeof(Key) * batch,
hipMemcpyDeviceToDevice, stream);
}
if (value_inp != value_out) {
hipMemcpyAsync(value_out, value_inp, sizeof(Value) * batch,
hipMemcpyDeviceToDevice, stream);
}
return hipGetLastError();
}
void (*kptr)(uint32_t, uint32_t, const Key*, const Value*, Key*, Value*) =
NULL;
uint32_t l4 = (length + 3) / 4;
dim3 block;
#define chk(s) \
do { \
if (!kptr && l4 <= (1 << s)) { \
block.x = 1 << s; \
if ((1 << s) <= 32) { \
if (ascending) { \
kptr = kern<SyncWarp, Key, Value, LessThan, s>; \
} else { \
kptr = kern<SyncWarp, Key, Value, GreaterThan, s>; \
} \
} else { \
if (ascending) { \
kptr = kern<SyncBlock, Key, Value, LessThan, s>; \
} else { \
kptr = kern<SyncBlock, Key, Value, GreaterThan, s>; \
} \
} \
} \
} while (0)
chk(0);
chk(1);
chk(2);
chk(3);
chk(4);
chk(5);
chk(6);
chk(7);
chk(8);
chk(9);
if (!kptr) {
return hipErrorInvalidConfiguration;
}
// TODO: this is randomly choosed
int suggested_block_size = 128;
// query_launch_config_for_kernel(reinterpret_cast<void*>(kptr),
// get_shmem<Key, Value>)
// .block_size;
block.y = std::max<int>(suggested_block_size / block.x, 1);
int shmem = get_shmem<Key, Value>(block.y * block.x);
kptr<<<(batch - 1) / block.y + 1, block, shmem, stream>>>(
batch, length, key_inp, value_inp, key_out, value_out);
return hipGetLastError();
}
namespace megdnn {
namespace rocm {
#define INST(k, v) \
template hipError_t bitonic_sort<k, v>(uint32_t, uint32_t, const k*, \
const v*, k*, v*, bool, \
hipStream_t)
INST(float, int);
INST(int32_t, int);
// DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST
} // namespace megdnn
} // namespace megdnn
// vim: ft=rocm syntax=rocm.doxygen
/**
* \file dnn/src/rocm/argsort/bitonic_sort.h.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "hip_header.h"
#include <stdint.h>
namespace megdnn {
namespace rocm {
const uint32_t BITONIC_SORT_MAX_LENGTH = 1024;
// cub radix sort seems to be faster with lengths > 1024
/*!
* \brief bitonic sort for k/v pairs
*
* Requires \p length no larger than 4 times of cuda thread num. \p key_inp
* and \p key_out can be identical, and so are \p value_inp and \p value_out.
*/
template <typename Key, typename Value>
hipError_t bitonic_sort(uint32_t batch, uint32_t length, const Key* key_inp,
const Value* value_inp, Key* key_out, Value* value_out,
bool ascending, hipStream_t stream);
} // namespace rocm
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/rocm/argsort/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "./argsort.h.hip"
#include "./backward.h.hip"
#include "src/common/utils.h"
#include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, indices.layout, workspace.size);
auto M = src.layout.shape[0], N = src.layout.shape[1];
auto iptr = indices.ptr<dt_int32>();
auto wptr = static_cast<void*>(workspace.raw_ptr);
bool is_ascending = (param().order == Order::ASCENDING);
auto stream = hip_stream(handle());
switch (src.layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
argsort::forward(src.ptr<t>(), dst.ptr<t>(), iptr, wptr, M, N, \
is_ascending, stream); \
break;
ARGSORT_FOREACH_CTYPE(cb);
#undef cb
default:
megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s",
src.layout.dtype.name()));
}
}
size_t ArgsortForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout&,
const TensorLayout&) {
megdnn_assert(src.ndim == 2, "invalid src layout: %s",
src.to_string().c_str());
auto M = src.shape[0], N = src.shape[1];
auto&& dtype = src.dtype;
megdnn_assert(std::max(M, N) <=
static_cast<size_t>(std::numeric_limits<int>::max()));
return argsort::get_fwd_workspace_in_bytes(
M, N, dtype, param().order == Param::Order::ASCENDING);
}
void ArgsortBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in indices,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, indices.layout, grad.layout, workspace.size);
auto stream = hip_stream(handle());
switch (diff.layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
argsort::backward_proxy(grad.layout[0], grad.layout[1], \
diff.layout[1], grad.ptr<t>(), diff.ptr<t>(), \
indices.ptr<int>(), stream); \
break;
ARGSORT_FOREACH_CTYPE(cb);
#undef cb
default:
megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s",
diff.layout.dtype.name()));
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/argsort/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace rocm {
class ArgsortForwardImpl final: public ArgsortForward {
public:
using ArgsortForward::ArgsortForward;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &dst,
const TensorLayout &indices) override;
};
class ArgsortBackwardImpl final: public ArgsortBackward {
public:
using ArgsortBackward::ArgsortBackward;
void exec(_megdnn_tensor_in diff,
_megdnn_tensor_in indices,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &,
const TensorLayout &,
const TensorLayout &) override {
return 0;
}
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -33,6 +33,7 @@
#include "src/rocm/powc/opr_impl.h"
#include "src/rocm/indexing_multi_axis_vec/opr_impl.h"
#include "src/rocm/linspace/opr_impl.h"
#include "src/rocm/argsort/opr_impl.h"
#include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h"
......@@ -148,6 +149,8 @@ bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
}
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter);
......
/**
* \file dnn/test/rocm/argsort.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/rocm/fixture.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "../src/rocm/argsort/opr_impl.h"
using namespace megdnn;
using namespace test;
namespace {
class ArgsortRNG final : public RNG {
bool m_rev_order = false;
DType m_dtype;
template <typename T>
void fill(T* ptr, int n) {
if (m_rev_order) {
for (int i = 0; i < n; ++i)
ptr[i] = static_cast<T>(n / 2 - i);
} else {
for (int i = 0; i < n; ++i)
ptr[i] = static_cast<T>(i - n / 2);
COMPAT_RANDOM(ptr, ptr + n);
}
}
void gen(const TensorND& tensor) override {
auto n = tensor.layout.total_nr_elems();
if (m_dtype == dtype::Float32{}) {
fill(tensor.ptr<dt_float32>(), n);
} else {
megdnn_assert(m_dtype == dtype::Int32{});
fill(tensor.ptr<dt_int32>(), n);
}
}
public:
ArgsortRNG(DType dt) : m_dtype{dt} {}
void set_rev_order(bool flag) { m_rev_order = flag; }
};
void run_forward_test(Handle* handle, DType dtype) {
Checker<ArgsortForward> checker(handle);
using Param = Argsort::Param;
using Order = Param::Order;
ArgsortRNG rng{dtype};
checker.set_dtype(2, dtype::Int32());
checker.set_dtype(0, dtype).set_rng(0, &rng);
for (size_t i = 3; i < 10240; i *= 2) {
Param param;
param.order = Order::ASCENDING;
checker.set_param(param).execs({{3, i + 1}, {}, {}});
param.order = Order::DESCENDING;
checker.set_param(param).execs({{3, i - 1}, {}, {}});
checker.set_param(param).execs({{13, i + 3}, {}, {}});
}
{
// reverse sort large array
constexpr size_t N = 200003;
rng.set_rev_order(true);
Param param;
param.order = Order::ASCENDING;
checker.set_param(param).execs({{1, N}, {}, {}});
}
}
void run_backward_test(Handle* handle, DType dtype) {
class IdxRng final : public RNG {
void gen(const TensorND& tensor) override {
auto ptr = tensor.ptr<dt_int32>();
auto m = tensor.layout[0], n = tensor.layout[1];
for (size_t i = 0; i < m; ++i) {
for (size_t j = 0; j < n; ++j) {
ptr[j] = j;
}
COMPAT_RANDOM(ptr, ptr + n);
ptr += n;
}
}
} rng;
Checker<ArgsortBackward> checker(handle);
checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng);
checker.set_dtype(0, dtype);
checker.set_dtype(2, dtype);
for (size_t i = 16; i < 4096; i *= 2) {
checker.execs({{3, i}, {3, i}, {3, i}});
checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}});
checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}});
}
}
} // anonymous namespace
TEST_F(ROCM, ARGSORT_FORWARD_F32) {
run_forward_test(handle_rocm(), dtype::Float32{});
}
TEST_F(ROCM, ARGSORT_FORWARD_I32) {
run_forward_test(handle_rocm(), dtype::Int32{});
}
TEST_F(ROCM, ARGSORT_BACKWARD_F32) {
run_backward_test(handle_rocm(), dtype::Float32{});
}
TEST_F(ROCM, ARGSORT_BACKWARD_I32) {
run_backward_test(handle_rocm(), dtype::Int32{});
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册