提交 323a4642 编写于 作者: M Megvii Engine Team

feat(dnn/rocm): add topk opr

GitOrigin-RevId: 5ecb07985491359bb8063427cc142fbcec3da943
上级 f4784f4a
......@@ -24,6 +24,7 @@
#include "src/rocm/pooling/opr_impl.h"
#include "src/rocm/reduce/opr_impl.h"
#include "src/rocm/type_cvt/opr_impl.h"
#include "src/rocm/topk/opr_impl.h"
#include "src/rocm/add_update/opr_impl.h"
#include "src/rocm/matrix_mul/opr_impl.h"
#include "src/rocm/batched_matrix_mul/opr_impl.h"
......@@ -161,6 +162,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward);
......
/**
* \file dnn/src/rocm/topk/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "./topk_radix.h.hip"
#include "src/common/utils.h"
#include "src/rocm/argsort/argsort.h.hip"
#include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
template <typename ctype>
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
const ctype* data, ctype* values,
int* indices, void* workspace) {
auto _handle = concrete_handle(handle());
auto stream = _handle->stream();
size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1];
switch (param().mode) {
case Param::Mode::KTH_ONLY:
hip_check(topk::find_kth_radix<ctype>(data, values, workspace, m,
n, lda, k, grid_dim_y_limit,
stream));
return;
case Param::Mode::VALUE_IDX_NOSORT: {
WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}};
auto thresh = static_cast<ctype*>(wk_bundle.get(0));
auto real_wk = wk_bundle.get(1);
hip_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
lda, k, grid_dim_y_limit,
stream));
hip_check(topk::topk_select<ctype>(data, thresh, values, indices,
real_wk, m, n, lda, k,
grid_dim_y_limit, stream));
return;
}
case Param::Mode::VALUE_IDX_SORTED: {
WorkspaceBundle wk_bundle{
workspace,
{m * sizeof(ctype), m * std::abs(k) * sizeof(ctype),
m * std::abs(k) * sizeof(int32_t), 1}};
auto thresh = static_cast<ctype*>(wk_bundle.get(0)),
nosort_values = static_cast<ctype*>(wk_bundle.get(1));
auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2));
auto real_wk = wk_bundle.get(3);
hip_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
lda, k, grid_dim_y_limit,
stream));
hip_check(topk::topk_select<ctype>(data, thresh, nosort_values,
nosort_idx, real_wk, m, n, lda,
k, grid_dim_y_limit, stream));
argsort::forward(nosort_values, values, indices, real_wk, m,
std::abs(k), k > 0, stream, nosort_idx);
return;
}
}
megdnn_throw("bad topk mode");
}
void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
int32_t* indices, _megdnn_workspace workspace) {
switch (data.layout.dtype.enumv()) {
case DTypeEnum::Float32:
dispatch_with_ctype<float>(k, data.layout[0], data.layout[1],
data.layout.stride[0], data.ptr<float>(),
values.ptr<float>(), indices,
workspace.raw_ptr);
return;
case DTypeEnum::Int32:
dispatch_with_ctype<int32_t>(k, data.layout[0], data.layout[1],
data.layout.stride[0], data.ptr<int32_t>(),
values.ptr<int32_t>(), indices,
workspace.raw_ptr);
return;
// #if !MEGDNN_DISABLE_FLOAT16
// case DTypeEnum::Float16:
// dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1],
// data.layout.stride[0], data.ptr<dt_float16>(),
// values.ptr<dt_float16>(), indices,
// workspace.raw_ptr);
// return;
// #endif
default:
megdnn_throw(
ssprintf("only float32, int32 and float16 supported for "
"cuda topk, got: %s",
data.layout.dtype.name()));
}
}
size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data,
const TensorLayout& values,
const TensorLayout& indices) {
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(indices);
size_t m = data[0], n = data[1];
size_t kabs = std::abs(k);
size_t grid_dim_y_limit =
concrete_handle(handle())->device_prop().maxGridSize[1];
megdnn_assert(std::max(m, n) <=
static_cast<size_t>(std::numeric_limits<int>::max()));
size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit),
sel = topk::topk_select_workspace(m, n);
auto ctsize = data.dtype.size();
switch (param().mode) {
case Param::Mode::KTH_ONLY:
return kth;
case Param::Mode::VALUE_IDX_NOSORT:
return WorkspaceBundle{nullptr, {m * ctsize, std::max(kth, sel)}}
.total_size_in_bytes();
case Param::Mode::VALUE_IDX_SORTED:
return WorkspaceBundle{
nullptr,
{m * ctsize, m * kabs * ctsize, m * kabs * sizeof(int32_t),
std::max(std::max(kth, sel),
argsort::get_fwd_workspace_in_bytes(
m, kabs, data.dtype, k > 0, true))}}
.total_size_in_bytes();
}
megdnn_throw("bad topk mode");
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/topk/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs/general.h"
namespace megdnn {
namespace rocm {
class TopKImpl : public TopK {
protected:
template <typename ctype>
void dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
const ctype* data, ctype* values, int* indices,
void* workspace);
void do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
int32_t* indices, _megdnn_workspace workspace) override;
public:
using TopK::TopK;
size_t get_workspace_in_bytes(int k, const TensorLayout& data,
const TensorLayout& values,
const TensorLayout& indices) override;
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
/**
* \file dnn/src/rocm/topk/topk_radix.h.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/rocm/utils.h.hip"
#include <stdint.h>
namespace megdnn {
namespace rocm {
namespace topk {
namespace internal {
template <typename ctype>
struct RadixConverter;
template <>
struct RadixConverter<float> {
union FIunion {
float fv;
uint32_t iv;
};
static __forceinline__ __device__ __host__ uint32_t to_radix(float val) {
FIunion fi;
fi.fv = val;
return fi.iv ^ (((!(fi.iv >> 31u)) - 1u) | 0x80000000u);
}
static __forceinline__ __device__ __host__ float from_radix(uint32_t val) {
FIunion fi;
// do not write as to_radix() to work around a compiler bug in cuda-9.0
uint32_t m = 0x80000000u;
fi.iv = val ^ (m | (m - !(val >> 31u)));
return fi.fv;
}
};
template <>
struct RadixConverter<int32_t> {
union SUUnion {
int32_t sv;
uint32_t uv;
};
static __forceinline__ __device__ __host__ uint32_t to_radix(int32_t val) {
SUUnion su;
su.sv = val;
return su.uv ^ (1u << 31u);
}
static __forceinline__ __device__ __host__ int32_t
from_radix(uint32_t val) {
SUUnion su;
su.uv = val;
return su.sv ^ (1u << 31u);
}
};
// #if !MEGDNN_DISABLE_FLOAT16
// template <>
// struct RadixConverter<dt_float16> {
// union FIunion {
// FIunion() {}
// dt_float16 fv;
// uint16_t iv;
// };
// static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) {
// FIunion fi;
// fi.fv = val;
// return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u);
// }
// static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) {
// FIunion fi;
// // do not write as to_radix() to work around a compiler bug in cuda-9.0
// uint16_t m = 0x8000u;
// fi.iv = val ^ (m | (m - !(val >> 15u)));
// return fi.fv;
// }
// };
// #endif
} // namespace internal
/*!
* \brief find the k'th values of a (batch, length) matrix along the length
* dimension
*
* \param input input matrix, shape [batch, length], contiguous
* \param lda distance of contiguous rows in \p input, measured in num of
* elements in \p ctype
* \param k if positive, return the smallest top-k; otherwise return the
* largest top-k
* \param output top-k values of each batch, shape [batch]
*/
template <typename ctype>
hipError_t find_kth_radix(const ctype* input, ctype* output, void* workspace,
uint32_t batch, uint32_t length, int32_t lda,
int32_t k, uint32_t grid_dim_y_limit,
hipStream_t stream);
//! get workspace in bytes
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length,
uint32_t grid_dim_y_limit);
/*!
* \brief select values from rows of input that compare to thresh as specified
* \param k if k > 0, select values <= thresh; otherwise select values >=
* thresh. Its absolute value specifies output width.
*/
template <typename ctype>
hipError_t topk_select(const ctype* input, const ctype* thresh,
ctype* output_value, int32_t* output_idx,
void* workspace, uint32_t batch, uint32_t length,
int32_t lda, int32_t k, uint32_t batch_upper_limit,
hipStream_t stream);
uint32_t topk_select_workspace(uint32_t batch, uint32_t length);
} // namespace topk
} // namespace rocm
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/test/rocm/topk.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "test/common/topk.h"
#include "test/rocm/fixture.h"
using namespace megdnn;
using namespace test;
/*
* !!!!!!!!!!!!!!!! IMPORTANT NOTE !!!!!!!!!!!!!!!!
* The kernels are indepedently developed and tested in the
* MegDNN/expr/cuda_topk directory. Here we only check some common cases.
*/
TEST_F(ROCM, TOP_K) {
run_topk_test<dtype::Float32>(handle_rocm());
}
TEST_F(ROCM, TOP_K_I32) {
run_topk_test<dtype::Int32>(handle_rocm());
}
// #if !MEGDNN_DISABLE_FLOAT16
// TEST_F(ROCM, TOP_K_F16) {
// run_topk_test<dtype::Float16>(handle_rocm());
// }
// #endif
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册