提交 8c415f4e 编写于 作者: M Megvii Engine Team

feat(dnn): cuda nhwc nearest resize support not 1 or 3 channel

GitOrigin-RevId: 764504c34162221af49b34a3eac7d20caa58f215
上级 04475744
......@@ -150,11 +150,11 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) {
}
}
template <typename T, size_t CH>
template <typename T>
__global__ void resize_nearest_vector_kernel(
const T* src, T* dst, const size_t dst_rows, const size_t dst_cols,
const size_t src_step, const size_t dst_step, const float row_scale,
const float col_scale) {
const float col_scale, size_t CH) {
size_t dc = blockIdx.x * blockDim.x + threadIdx.x;
size_t dr = blockIdx.y * blockDim.y * ELEMENTS_PER_THREADS + threadIdx.y;
......@@ -178,11 +178,11 @@ __global__ void resize_nearest_vector_kernel(
}
}
template <typename T, size_t CH>
template <typename T>
__global__ void resize_nearest_kernel(
const T* __restrict__ src, T* dst, const size_t dst_rows, const size_t dst_cols,
const size_t src_step, const size_t dst_step, const float row_scale,
const float col_scale) {
const float col_scale, size_t CH) {
size_t dc = blockIdx.x * blockDim.x + threadIdx.x;
size_t dr = blockIdx.y * blockDim.y + threadIdx.y;
if (dr < dst_rows && dc < dst_cols) {
......@@ -196,23 +196,24 @@ __global__ void resize_nearest_kernel(
}
}
template <typename T, size_t CH>
template <typename T>
void resize_nearest_proxy(
const T* src, T* dst, const size_t src_rows, const size_t src_cols,
const size_t dst_rows, const size_t dst_cols, const size_t src_step,
const size_t dst_step, void* workspace, cudaStream_t stream) {
const size_t dst_step, void* workspace, cudaStream_t stream, size_t CH) {
MEGDNN_MARK_USED_VAR(workspace);
float row_scale = (float)src_rows / dst_rows;
float col_scale = (float)src_cols / dst_cols;
if (CH == 3 && sizeof(T) == 4 && (dst_cols * dst_rows <= src_cols * src_rows)) {
if (CH > 1 && sizeof(T) == 4 && (dst_cols * dst_rows <= src_cols * src_rows)) {
dim3 THREADS(32, 8, 1);
dim3 BLOCKS(DIVUP(dst_cols, THREADS.x), DIVUP(dst_rows, THREADS.y));
cudaDeviceSetCacheConfig(cudaFuncCachePreferL1);
resize_nearest_kernel<T, CH><<<BLOCKS, THREADS, 0, stream>>>(
src, dst, dst_rows, dst_cols, src_step, dst_step, row_scale, col_scale);
resize_nearest_kernel<T><<<BLOCKS, THREADS, 0, stream>>>(
src, dst, dst_rows, dst_cols, src_step, dst_step, row_scale, col_scale,
CH);
} else {
dim3 THREADS(32, 8, 1);
......@@ -220,11 +221,12 @@ void resize_nearest_proxy(
DIVUP(dst_cols, THREADS.x),
DIVUP(dst_rows, THREADS.y * ELEMENTS_PER_THREADS));
if (CH == 3 && sizeof(T) == 1)
if (CH > 1 && sizeof(T) == 1)
cudaDeviceSetCacheConfig(cudaFuncCachePreferL1);
resize_nearest_vector_kernel<T, CH><<<BLOCKS, THREADS, 0, stream>>>(
src, dst, dst_rows, dst_cols, src_step, dst_step, row_scale, col_scale);
resize_nearest_vector_kernel<T><<<BLOCKS, THREADS, 0, stream>>>(
src, dst, dst_rows, dst_cols, src_step, dst_step, row_scale, col_scale,
CH);
}
}
......@@ -1594,6 +1596,12 @@ void megdnn::cuda::resize::resize_cv(
const size_t dst_rows, const size_t dst_cols, const size_t src_step,
const size_t dst_step, size_t ch, InterpolationMode imode, void* workspace,
cudaStream_t stream) {
if (imode == INTER_NEAREST) {
resize_nearest_proxy<T>(
src, dst, src_rows, src_cols, dst_rows, dst_cols, src_step, dst_step,
workspace, stream, ch);
return;
}
megdnn_assert(ch == 1 || ch == 3);
#define cb(_mode, _MODE) \
case INTER_##_MODE: { \
......@@ -1610,7 +1618,6 @@ void megdnn::cuda::resize::resize_cv(
}
switch (imode) {
cb(nearest, NEAREST);
cb(linear, LINEAR);
cb(cubic, CUBIC);
cb(lanczos4, LANCZOS4);
......
......@@ -178,6 +178,10 @@ static inline std::vector<TestArg> get_cv_args() {
cur_param, TensorShape{1, i, i, 3}, TensorShape{1, i / 2, i / 2, 3});
args.emplace_back(cur_param, TensorShape{1, i, i, 1}, TensorShape{1, 8, 8, 1});
args.emplace_back(
cur_param, TensorShape{1, i, i, 6}, TensorShape{1, i / 2, i / 2, 6});
args.emplace_back(cur_param, TensorShape{1, i, i, 6}, TensorShape{1, 8, 8, 6});
cur_param.imode = param::Resize::InterpolationMode::INTER_AREA;
args.emplace_back(cur_param, TensorShape{1, i, i, 3}, TensorShape{1, 8, 8, 3});
......@@ -193,6 +197,9 @@ static inline std::vector<TestArg> get_cv_args() {
args.emplace_back(cur_param, TensorShape{1, 3, 3, 1}, TensorShape{1, 500, 600, 1});
cur_param.imode = param::Resize::InterpolationMode::INTER_LANCZOS4;
args.emplace_back(cur_param, TensorShape{1, 3, 3, 1}, TensorShape{1, 500, 600, 1});
cur_param.imode = param::Resize::InterpolationMode::INTER_NEAREST;
args.emplace_back(cur_param, TensorShape{1, 3, 3, 1}, TensorShape{1, 500, 600, 1});
args.emplace_back(cur_param, TensorShape{1, 3, 3, 4}, TensorShape{1, 500, 600, 4});
return args;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册