Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8c415f4e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
11 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8c415f4e
编写于
3月 15, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): cuda nhwc nearest resize support not 1 or 3 channel
GitOrigin-RevId: 764504c34162221af49b34a3eac7d20caa58f215
上级
04475744
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
27 addition
and
13 deletion
+27
-13
dnn/src/cuda/resize/resize_cv.cu
dnn/src/cuda/resize/resize_cv.cu
+20
-13
dnn/test/common/resize.h
dnn/test/common/resize.h
+7
-0
未找到文件。
dnn/src/cuda/resize/resize_cv.cu
浏览文件 @
8c415f4e
...
...
@@ -150,11 +150,11 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) {
}
}
template
<
typename
T
,
size_t
CH
>
template
<
typename
T
>
__global__
void
resize_nearest_vector_kernel
(
const
T
*
src
,
T
*
dst
,
const
size_t
dst_rows
,
const
size_t
dst_cols
,
const
size_t
src_step
,
const
size_t
dst_step
,
const
float
row_scale
,
const
float
col_scale
)
{
const
float
col_scale
,
size_t
CH
)
{
size_t
dc
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
size_t
dr
=
blockIdx
.
y
*
blockDim
.
y
*
ELEMENTS_PER_THREADS
+
threadIdx
.
y
;
...
...
@@ -178,11 +178,11 @@ __global__ void resize_nearest_vector_kernel(
}
}
template
<
typename
T
,
size_t
CH
>
template
<
typename
T
>
__global__
void
resize_nearest_kernel
(
const
T
*
__restrict__
src
,
T
*
dst
,
const
size_t
dst_rows
,
const
size_t
dst_cols
,
const
size_t
src_step
,
const
size_t
dst_step
,
const
float
row_scale
,
const
float
col_scale
)
{
const
float
col_scale
,
size_t
CH
)
{
size_t
dc
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
size_t
dr
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
dr
<
dst_rows
&&
dc
<
dst_cols
)
{
...
...
@@ -196,23 +196,24 @@ __global__ void resize_nearest_kernel(
}
}
template
<
typename
T
,
size_t
CH
>
template
<
typename
T
>
void
resize_nearest_proxy
(
const
T
*
src
,
T
*
dst
,
const
size_t
src_rows
,
const
size_t
src_cols
,
const
size_t
dst_rows
,
const
size_t
dst_cols
,
const
size_t
src_step
,
const
size_t
dst_step
,
void
*
workspace
,
cudaStream_t
stream
)
{
const
size_t
dst_step
,
void
*
workspace
,
cudaStream_t
stream
,
size_t
CH
)
{
MEGDNN_MARK_USED_VAR
(
workspace
);
float
row_scale
=
(
float
)
src_rows
/
dst_rows
;
float
col_scale
=
(
float
)
src_cols
/
dst_cols
;
if
(
CH
==
3
&&
sizeof
(
T
)
==
4
&&
(
dst_cols
*
dst_rows
<=
src_cols
*
src_rows
))
{
if
(
CH
>
1
&&
sizeof
(
T
)
==
4
&&
(
dst_cols
*
dst_rows
<=
src_cols
*
src_rows
))
{
dim3
THREADS
(
32
,
8
,
1
);
dim3
BLOCKS
(
DIVUP
(
dst_cols
,
THREADS
.
x
),
DIVUP
(
dst_rows
,
THREADS
.
y
));
cudaDeviceSetCacheConfig
(
cudaFuncCachePreferL1
);
resize_nearest_kernel
<
T
,
CH
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
src
,
dst
,
dst_rows
,
dst_cols
,
src_step
,
dst_step
,
row_scale
,
col_scale
);
resize_nearest_kernel
<
T
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
src
,
dst
,
dst_rows
,
dst_cols
,
src_step
,
dst_step
,
row_scale
,
col_scale
,
CH
);
}
else
{
dim3
THREADS
(
32
,
8
,
1
);
...
...
@@ -220,11 +221,12 @@ void resize_nearest_proxy(
DIVUP
(
dst_cols
,
THREADS
.
x
),
DIVUP
(
dst_rows
,
THREADS
.
y
*
ELEMENTS_PER_THREADS
));
if
(
CH
==
3
&&
sizeof
(
T
)
==
1
)
if
(
CH
>
1
&&
sizeof
(
T
)
==
1
)
cudaDeviceSetCacheConfig
(
cudaFuncCachePreferL1
);
resize_nearest_vector_kernel
<
T
,
CH
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
src
,
dst
,
dst_rows
,
dst_cols
,
src_step
,
dst_step
,
row_scale
,
col_scale
);
resize_nearest_vector_kernel
<
T
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
src
,
dst
,
dst_rows
,
dst_cols
,
src_step
,
dst_step
,
row_scale
,
col_scale
,
CH
);
}
}
...
...
@@ -1594,6 +1596,12 @@ void megdnn::cuda::resize::resize_cv(
const
size_t
dst_rows
,
const
size_t
dst_cols
,
const
size_t
src_step
,
const
size_t
dst_step
,
size_t
ch
,
InterpolationMode
imode
,
void
*
workspace
,
cudaStream_t
stream
)
{
if
(
imode
==
INTER_NEAREST
)
{
resize_nearest_proxy
<
T
>
(
src
,
dst
,
src_rows
,
src_cols
,
dst_rows
,
dst_cols
,
src_step
,
dst_step
,
workspace
,
stream
,
ch
);
return
;
}
megdnn_assert
(
ch
==
1
||
ch
==
3
);
#define cb(_mode, _MODE) \
case INTER_##_MODE: { \
...
...
@@ -1610,7 +1618,6 @@ void megdnn::cuda::resize::resize_cv(
}
switch
(
imode
)
{
cb
(
nearest
,
NEAREST
);
cb
(
linear
,
LINEAR
);
cb
(
cubic
,
CUBIC
);
cb
(
lanczos4
,
LANCZOS4
);
...
...
dnn/test/common/resize.h
浏览文件 @
8c415f4e
...
...
@@ -178,6 +178,10 @@ static inline std::vector<TestArg> get_cv_args() {
cur_param
,
TensorShape
{
1
,
i
,
i
,
3
},
TensorShape
{
1
,
i
/
2
,
i
/
2
,
3
});
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
i
,
i
,
1
},
TensorShape
{
1
,
8
,
8
,
1
});
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
i
,
i
,
6
},
TensorShape
{
1
,
i
/
2
,
i
/
2
,
6
});
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
i
,
i
,
6
},
TensorShape
{
1
,
8
,
8
,
6
});
cur_param
.
imode
=
param
::
Resize
::
InterpolationMode
::
INTER_AREA
;
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
i
,
i
,
3
},
TensorShape
{
1
,
8
,
8
,
3
});
...
...
@@ -193,6 +197,9 @@ static inline std::vector<TestArg> get_cv_args() {
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
3
,
3
,
1
},
TensorShape
{
1
,
500
,
600
,
1
});
cur_param
.
imode
=
param
::
Resize
::
InterpolationMode
::
INTER_LANCZOS4
;
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
3
,
3
,
1
},
TensorShape
{
1
,
500
,
600
,
1
});
cur_param
.
imode
=
param
::
Resize
::
InterpolationMode
::
INTER_NEAREST
;
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
3
,
3
,
1
},
TensorShape
{
1
,
500
,
600
,
1
});
args
.
emplace_back
(
cur_param
,
TensorShape
{
1
,
3
,
3
,
4
},
TensorShape
{
1
,
500
,
600
,
4
});
return
args
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录