Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0558b212
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0558b212
编写于
8月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/opr): add interpolate nearest mode
GitOrigin-RevId: d384b87f504c7dd2731bb3c618f35f8b70d00ed2
上级
171d6915
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
315 addition
and
151 deletion
+315
-151
dnn/include/megdnn/oprs/cv.h
dnn/include/megdnn/oprs/cv.h
+3
-0
dnn/src/common/resize.cpp
dnn/src/common/resize.cpp
+9
-3
dnn/src/cuda/resize/backward.cpp
dnn/src/cuda/resize/backward.cpp
+3
-2
dnn/src/cuda/resize/backward.cu
dnn/src/cuda/resize/backward.cu
+35
-7
dnn/src/cuda/resize/common.cuh
dnn/src/cuda/resize/common.cuh
+4
-0
dnn/src/cuda/resize/common.h
dnn/src/cuda/resize/common.h
+6
-5
dnn/src/cuda/resize/forward.cpp
dnn/src/cuda/resize/forward.cpp
+11
-6
dnn/src/cuda/resize/forward.cu
dnn/src/cuda/resize/forward.cu
+53
-16
dnn/src/fallback/resize/opr_impl.cpp
dnn/src/fallback/resize/opr_impl.cpp
+3
-1
dnn/src/naive/resize/opr_impl.cpp
dnn/src/naive/resize/opr_impl.cpp
+89
-23
dnn/src/naive/resize/opr_impl.h
dnn/src/naive/resize/opr_impl.h
+3
-0
dnn/test/common/resize.h
dnn/test/common/resize.h
+10
-7
dnn/test/cuda/resize.cpp
dnn/test/cuda/resize.cpp
+66
-56
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+20
-25
未找到文件。
dnn/include/megdnn/oprs/cv.h
浏览文件 @
0558b212
...
@@ -198,6 +198,9 @@ public:
...
@@ -198,6 +198,9 @@ public:
protected:
protected:
//! get origin coord
//! get origin coord
std
::
pair
<
float
,
int
>
get_origin_coord
(
float
scale
,
int
size
,
int
idx
);
std
::
pair
<
float
,
int
>
get_origin_coord
(
float
scale
,
int
size
,
int
idx
);
//! get nearest index in src
int
get_nearest_src
(
float
scale
,
int
size
,
int
idx
);
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
);
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
);
};
};
...
...
dnn/src/common/resize.cpp
浏览文件 @
0558b212
...
@@ -6,9 +6,11 @@
...
@@ -6,9 +6,11 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include "src/common/utils.h"
...
@@ -26,8 +28,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
...
@@ -26,8 +28,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
errmsg
().
c_str
());
errmsg
().
c_str
());
if
(
param
().
format
==
Param
::
Format
::
NCHW
)
{
if
(
param
().
format
==
Param
::
Format
::
NCHW
)
{
megdnn_assert
(
dst
.
shape
[
1
]
==
src
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
shape
[
1
]
==
src
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
auto
imode
=
param
().
imode
;
param
::
Resize
::
InterpolationMode
::
INTER_LINEAR
);
megdnn_assert
(
imode
==
param
::
Resize
::
InterpolationMode
::
INTER_LINEAR
||
imode
==
param
::
Resize
::
InterpolationMode
::
NEAREST
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC
)
{
megdnn_assert
(
dst
.
shape
[
3
]
==
src
.
shape
[
3
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
shape
[
3
]
==
src
.
shape
[
3
],
"%s"
,
errmsg
().
c_str
());
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
)
{
...
@@ -79,6 +82,9 @@ std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
...
@@ -79,6 +82,9 @@ std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
return
{
alpha
,
origin_idx
};
return
{
alpha
,
origin_idx
};
}
}
int
ResizeBase
::
get_nearest_src
(
float
scale
,
int
size
,
int
idx
)
{
return
std
::
min
(
static_cast
<
int
>
(
idx
/
scale
),
size
-
1
);
}
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/cuda/resize/backward.cpp
浏览文件 @
0558b212
...
@@ -30,8 +30,9 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
...
@@ -30,8 +30,9 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
size_t
max_batch_size
=
max_batch_x_channel
/
C
;
size_t
max_batch_size
=
max_batch_x_channel
/
C
;
while
(
N
>
0
)
{
while
(
N
>
0
)
{
size_t
curr_batch_size
=
N
>
max_batch_size
?
max_batch_size
:
N
;
size_t
curr_batch_size
=
N
>
max_batch_size
?
max_batch_size
:
N
;
resize
::
backward_data_proxy
(
diff_ptr
,
grad_ptr
,
curr_batch_size
,
C
,
IH
,
resize
::
backward_data_proxy
(
resize
::
get_imode
(
param
().
imode
),
diff_ptr
,
IW
,
OH
,
OW
,
stream
);
grad_ptr
,
curr_batch_size
,
C
,
IH
,
IW
,
OH
,
OW
,
stream
);
if
(
N
<=
max_batch_size
)
{
if
(
N
<=
max_batch_size
)
{
break
;
break
;
...
...
dnn/src/cuda/resize/backward.cu
浏览文件 @
0558b212
...
@@ -17,9 +17,9 @@ namespace megdnn {
...
@@ -17,9 +17,9 @@ namespace megdnn {
namespace
cuda
{
namespace
cuda
{
namespace
resize
{
namespace
resize
{
__global__
void
resize_bwd_
kernel
(
const
float
*
hidden
,
float
*
dst
,
int
N
,
int
C
,
__global__
void
resize_bwd_
linear_kernel
(
const
float
*
hidden
,
float
*
dst
,
int
N
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_w
)
{
float
scale_h
,
float
scale_w
)
{
int
n
=
blockIdx
.
z
;
int
n
=
blockIdx
.
z
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
@@ -51,8 +51,30 @@ __global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C,
...
@@ -51,8 +51,30 @@ __global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C,
}
}
}
}
void
backward_data_proxy
(
const
float
*
diff
,
float
*
grad
,
int
N
,
int
C
,
int
IH
,
__global__
void
resize_bwd_nearest_kernel
(
const
float
*
hidden
,
float
*
dst
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
)
{
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
float
scale_w
)
{
int
n
=
blockIdx
.
z
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
hidden
+=
n
*
C
*
OH
*
OW
;
dst
+=
n
*
C
*
IH
*
IW
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
int
ih
=
get_nearest_src
(
scale_h
,
IH
,
oh
);
int
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
atomicAdd
(
dst
+
ih
*
IW
+
iw
,
hidden
[
oh
*
OW
+
ow
]);
hidden
+=
OH
*
OW
;
dst
+=
IH
*
IW
;
}
}
}
void
backward_data_proxy
(
InterpolationMode
imode
,
const
float
*
diff
,
float
*
grad
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
)
{
const
int
BY
=
16
,
BX
=
32
;
const
int
BY
=
16
,
BX
=
32
;
{
{
dim3
threads
(
BX
,
BY
);
dim3
threads
(
BX
,
BY
);
...
@@ -61,8 +83,14 @@ void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH,
...
@@ -61,8 +83,14 @@ void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH,
stream
));
stream
));
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
resize_bwd_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
if
(
imode
==
InterpolationMode
::
INTER_LINEAR
)
{
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
resize_bwd_linear_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
}
else
if
(
imode
==
InterpolationMode
::
INTER_NEAREST
)
{
resize_bwd_nearest_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
}
}
}
after_kernel_launch
();
after_kernel_launch
();
}
}
...
...
dnn/src/cuda/resize/common.cuh
浏览文件 @
0558b212
...
@@ -28,6 +28,10 @@ __device__ inline void get_origin_coord(float scale, int size, int idx,
...
@@ -28,6 +28,10 @@ __device__ inline void get_origin_coord(float scale, int size, int idx,
}
}
}
}
__device__
inline
int
get_nearest_src
(
float
scale
,
int
size
,
int
idx
)
{
return
min
(
static_cast
<
int
>
(
idx
/
scale
),
size
-
1
);
}
}
// namespace resize
}
// namespace resize
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/cuda/resize/common.h
浏览文件 @
0558b212
...
@@ -20,16 +20,17 @@ namespace resize {
...
@@ -20,16 +20,17 @@ namespace resize {
// all these kernels use bilinear interpolation
// all these kernels use bilinear interpolation
template
<
typename
ctype
>
template
<
typename
ctype
>
void
forward_proxy
(
bool
is_nhwc
,
const
ctype
*
src
,
ctype
*
dst
,
int
N
,
int
C
,
void
forward_proxy
(
bool
is_nhwc
,
InterpolationMode
imode
,
const
ctype
*
src
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IN
,
int
S_IC
,
int
S_IH
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IW
,
cudaStream_t
stream
);
int
S_I
N
,
int
S_IC
,
int
S_IH
,
int
S_I
W
,
cudaStream_t
stream
);
template
<
typename
ctype
>
template
<
typename
ctype
>
void
forward_proxy_nchw4
(
const
ctype
*
src
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
void
forward_proxy_nchw4
(
const
ctype
*
src
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
void
backward_data_proxy
(
const
float
*
diff
,
float
*
grad
,
int
N
,
int
C
,
int
IH
,
void
backward_data_proxy
(
InterpolationMode
imode
,
const
float
*
diff
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
float
*
grad
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
}
// namespace resize
}
// namespace resize
}
// namespace cuda
}
// namespace cuda
...
...
dnn/src/cuda/resize/forward.cpp
浏览文件 @
0558b212
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "src/common/cv/common.h"
#include "src/common/cv/common.h"
#include "src/common/cv/enums.h"
#include "src/cuda/handle.h"
#include "src/cuda/handle.h"
#include "src/cuda/resize/common.h"
#include "src/cuda/resize/common.h"
#include "src/cuda/resize/helper.h"
#include "src/cuda/resize/helper.h"
...
@@ -146,19 +147,23 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
...
@@ -146,19 +147,23 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
C
,
IH
,
IW
,
OH
,
OW
,
stream
);
C
,
IH
,
IW
,
OH
,
OW
,
stream
);
return
;
return
;
}
}
megdnn_assert
(
param
().
imode
==
Param
::
InterpolationMode
::
LINEAR
,
megdnn_assert
(
param
().
imode
==
Param
::
InterpolationMode
::
LINEAR
||
param
().
imode
==
Param
::
InterpolationMode
::
NEAREST
,
"unsupported interpolation mode for NCHW format"
);
"unsupported interpolation mode for NCHW format"
);
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
resize
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_float32
>
(),
resize
::
forward_proxy
(
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
C
,
IH
,
IW
,
src
.
ptr
<
dt_float32
>
(),
dst
.
ptr
<
dt_float32
>
(),
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
stream
);
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Uint8
())
{
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Uint8
())
{
resize
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_uint8
>
(),
dst
.
ptr
<
dt_uint8
>
(),
resize
::
forward_proxy
(
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_uint8
>
(),
dst
.
ptr
<
dt_uint8
>
(),
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
stream
);
S_IH
,
S_IW
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Int8
())
{
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Int8
())
{
resize
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_int8
>
(),
dst
.
ptr
<
dt_int8
>
(),
resize
::
forward_proxy
(
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_int8
>
(),
dst
.
ptr
<
dt_int8
>
(),
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
stream
);
S_IH
,
S_IW
,
stream
);
}
else
{
}
else
{
...
...
dnn/src/cuda/resize/forward.cu
浏览文件 @
0558b212
...
@@ -32,9 +32,10 @@ struct DirectSrcVisitor {
...
@@ -32,9 +32,10 @@ struct DirectSrcVisitor {
};
};
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_general
(
SrcVisitor
src
,
ctype
*
__restrict
dst
,
int
C
,
__global__
void
kern_general_linear
(
SrcVisitor
src
,
ctype
*
__restrict
dst
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IN
,
int
S_IC
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IH
,
int
S_IW
,
float
scale_h
,
float
scale_w
)
{
int
S_IN
,
int
S_IC
,
int
S_IH
,
int
S_IW
,
float
scale_h
,
float
scale_w
)
{
OutputConverter
output_converter
;
OutputConverter
output_converter
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
@@ -64,6 +65,31 @@ __global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C,
...
@@ -64,6 +65,31 @@ __global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C,
}
}
}
}
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_general_nearest
(
SrcVisitor
src
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IN
,
int
S_IC
,
int
S_IH
,
int
S_IW
,
float
scale_h
,
float
scale_w
)
{
OutputConverter
output_converter
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
ctype
*
__restrict
sptr
=
src
.
get
(
blockIdx
.
z
,
S_IN
);
dst
+=
blockIdx
.
z
*
C
*
OH
*
OW
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
int
ih
=
get_nearest_src
(
scale_h
,
IH
,
oh
);
int
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
dst
[
oh
*
OW
+
ow
]
=
output_converter
(
sptr
[
ih
*
S_IH
+
iw
*
S_IW
]);
sptr
+=
S_IC
;
dst
+=
OH
*
OW
;
}
}
}
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_general_nhwc
(
SrcVisitor
src
,
ctype
*
__restrict
dst
,
int
C
,
__global__
void
kern_general_nhwc
(
SrcVisitor
src
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
...
@@ -94,9 +120,10 @@ __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
...
@@ -94,9 +120,10 @@ __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
}
}
template
<
typename
ctype
,
typename
SrcVisitor
>
template
<
typename
ctype
,
typename
SrcVisitor
>
void
dispatch_with_visitor
(
bool
is_nhwc
,
SrcVisitor
src
,
ctype
*
dst
,
int
N
,
void
dispatch_with_visitor
(
bool
is_nhwc
,
InterpolationMode
imode
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IN
,
SrcVisitor
src
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
S_IC
,
int
S_IH
,
int
S_IW
,
cudaStream_t
stream
)
{
int
IW
,
int
OH
,
int
OW
,
int
S_IN
,
int
S_IC
,
int
S_IH
,
int
S_IW
,
cudaStream_t
stream
)
{
const
int
BY
=
16
,
BX
=
32
;
const
int
BY
=
16
,
BX
=
32
;
const
int
max_batch_size
=
65535
;
const
int
max_batch_size
=
65535
;
...
@@ -113,10 +140,19 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N,
...
@@ -113,10 +140,19 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N,
<<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
dst
,
C
,
IH
,
IW
,
OH
,
<<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
dst
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
OW
,
scale_h
,
scale_w
);
}
else
{
}
else
{
kern_general
<
ctype
,
SrcVisitor
,
rounding
::
RoundingConverter
<
ctype
>>
if
(
imode
==
InterpolationMode
::
INTER_LINEAR
)
{
<<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
dst
,
C
,
IH
,
IW
,
OH
,
kern_general_linear
<
ctype
,
SrcVisitor
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
rounding
::
RoundingConverter
<
ctype
>>
scale_h
,
scale_w
);
<<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
dst
,
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
scale_h
,
scale_w
);
}
else
if
(
imode
==
InterpolationMode
::
INTER_NEAREST
)
{
kern_general_nearest
<
ctype
,
SrcVisitor
,
rounding
::
RoundingConverter
<
ctype
>>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
src
,
dst
,
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
scale_h
,
scale_w
);
}
}
}
N
-=
curr_batch_size
;
N
-=
curr_batch_size
;
src
.
move_batch
(
curr_batch_size
,
C
*
IH
*
IW
);
src
.
move_batch
(
curr_batch_size
,
C
*
IH
*
IW
);
...
@@ -194,13 +230,14 @@ namespace cuda {
...
@@ -194,13 +230,14 @@ namespace cuda {
namespace
resize
{
namespace
resize
{
template
<
typename
ctype
>
template
<
typename
ctype
>
void
forward_proxy
(
bool
is_nhwc
,
const
ctype
*
src
,
ctype
*
dst
,
int
N
,
int
C
,
void
forward_proxy
(
bool
is_nhwc
,
InterpolationMode
imode
,
const
ctype
*
src
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IN
,
int
S_IC
,
int
S_IH
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
int
S_IW
,
cudaStream_t
stream
)
{
int
S_IN
,
int
S_IC
,
int
S_IH
,
int
S_IW
,
cudaStream_t
stream
)
{
DirectSrcVisitor
<
ctype
>
visitor
;
DirectSrcVisitor
<
ctype
>
visitor
;
visitor
.
ptr
=
src
;
visitor
.
ptr
=
src
;
dispatch_with_visitor
(
is_nhwc
,
visitor
,
dst
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
dispatch_with_visitor
(
is_nhwc
,
imode
,
visitor
,
dst
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
S_IC
,
S_IH
,
S_IW
,
stream
);
S_I
N
,
S_I
C
,
S_IH
,
S_IW
,
stream
);
after_kernel_launch
();
after_kernel_launch
();
}
}
...
@@ -214,7 +251,7 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
...
@@ -214,7 +251,7 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
}
}
#define INST(ctype) \
#define INST(ctype) \
template void forward_proxy(bool, const ctype*, ctype*, int, int, int, \
template void forward_proxy(bool,
InterpolationMode,
const ctype*, ctype*, int, int, int, \
int, int, int, int, int, int, int, \
int, int, int, int, int, int, int, \
cudaStream_t);
cudaStream_t);
INST
(
float
)
INST
(
float
)
...
...
dnn/src/fallback/resize/opr_impl.cpp
浏览文件 @
0558b212
...
@@ -116,7 +116,9 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) {
...
@@ -116,7 +116,9 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) {
void
ResizeImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
void
ResizeImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
if
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW4
)
{
if
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW4
||
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW
&&
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
NEAREST
))
{
naive
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
naive
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
return
;
return
;
}
}
...
...
dnn/src/naive/resize/opr_impl.cpp
浏览文件 @
0558b212
...
@@ -10,12 +10,14 @@
...
@@ -10,12 +10,14 @@
*/
*/
#include "src/common/rounding_converter.cuh"
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
#include "src/naive/handle.h"
#include "src/naive/handle.h"
#include "src/naive/resize/opr_impl.h"
#include "src/naive/resize/opr_impl.h"
#include "src/naive/resize/resize_cv.h"
#include "src/naive/resize/resize_cv.h"
#include "midout.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_naive_resize_layout
)
MIDOUT_DECL
(
megdnn_naive_resize_layout
)
MIDOUT_DECL
(
megdnn_naive_resize_layout_nearest
)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
naive
;
using
namespace
naive
;
...
@@ -86,6 +88,28 @@ INST(dt_qint8);
...
@@ -86,6 +88,28 @@ INST(dt_qint8);
INST
(
dt_quint8
);
INST
(
dt_quint8
);
#undef INST
#undef INST
template
<
typename
ctype
>
void
ResizeImpl
::
kern_nchw_nearest
(
const
KernParam
<
ctype
>&
kern_param
)
{
megdnn_assert
(
kern_param
.
format
==
Format
::
NCHW
);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE
(
kern_param
);
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
rep
(
n
,
N
)
{
rep
(
oh
,
OH
)
rep
(
ow
,
OW
)
{
auto
ih
=
get_nearest_src
(
scale_h
,
IH
,
oh
);
auto
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
rep
(
c
,
static_cast
<
int
>
(
C
))
{
dptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
]
=
sptr
[
c
*
S_IC
+
ih
*
S_IH
+
iw
*
S_IW
];
}
}
sptr
+=
S_IN
;
dptr
+=
C
*
OH
*
OW
;
}
}
template
<
typename
ctype
>
template
<
typename
ctype
>
void
ResizeImpl
::
kern_naive
(
const
KernParam
<
ctype
>&
kern_param
)
{
void
ResizeImpl
::
kern_naive
(
const
KernParam
<
ctype
>&
kern_param
)
{
if
(
kern_param
.
format
==
Format
::
NHWC
)
{
if
(
kern_param
.
format
==
Format
::
NHWC
)
{
...
@@ -266,6 +290,39 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
...
@@ -266,6 +290,39 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
void
ResizeImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
void
ResizeImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
if
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW
&&
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
NEAREST
)
{
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \
midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \
} \
MIDOUT_END(); \
return; \
}
switch
(
src
.
layout
.
dtype
.
enumv
())
{
cb
(
dtype
::
Float32
,
float
,
0
);
DNN_INC_FLOAT16
(
cb
(
dtype
::
Float16
,
dt_float16
,
1
));
cb
(
dtype
::
Int8
,
int8_t
,
2
);
cb
(
dtype
::
QuantizedS8
,
int8_t
,
3
);
cb
(
dtype
::
Uint8
,
uint8_t
,
4
);
cb
(
dtype
::
Quantized8Asymm
,
uint8_t
,
5
);
default:
megdnn_throw
(
ssprintf
(
"Unsupported input DType in Resize "
"NEAREST mode: %s"
,
src
.
layout
.
dtype
.
name
())
.
c_str
());
return
;
}
#undef cb
#undef cb
}
if
((
param
().
format
==
param
::
Resize
::
Format
::
NCHW
||
if
((
param
().
format
==
param
::
Resize
::
Format
::
NCHW
||
(
src
.
layout
[
3
]
!=
1
&&
src
.
layout
[
3
]
!=
3
)
||
(
src
.
layout
[
3
]
!=
1
&&
src
.
layout
[
3
]
!=
3
)
||
!
is_nhwc_contig_wc
(
src
.
layout
))
||
!
is_nhwc_contig_wc
(
src
.
layout
))
||
...
@@ -306,8 +363,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
...
@@ -306,8 +363,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
void
ResizeBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
void
ResizeBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
_megdnn_workspace
workspace
)
{
check_exec
(
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
check_exec
(
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
megdnn_assert
(
param
().
format
==
param
::
WarpPerspectiv
e
::
Format
::
NCHW
,
megdnn_assert
(
param
().
format
==
param
::
Resiz
e
::
Format
::
NCHW
,
"invalid
warp_perspectiv
e format"
);
"invalid
resiz
e format"
);
const
int
N
=
grad
.
layout
.
shape
[
0
],
C
=
grad
.
layout
.
shape
[
1
],
const
int
N
=
grad
.
layout
.
shape
[
0
],
C
=
grad
.
layout
.
shape
[
1
],
IH
=
grad
.
layout
.
shape
[
2
],
IW
=
grad
.
layout
.
shape
[
3
];
IH
=
grad
.
layout
.
shape
[
2
],
IW
=
grad
.
layout
.
shape
[
3
];
const
int
OH
=
diff
.
layout
.
shape
[
2
],
OW
=
diff
.
layout
.
shape
[
3
];
const
int
OH
=
diff
.
layout
.
shape
[
2
],
OW
=
diff
.
layout
.
shape
[
3
];
...
@@ -321,28 +378,37 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
...
@@ -321,28 +378,37 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
std
::
memset
(
sptr
,
0
,
sizeof
(
float
)
*
N
*
C
*
IH
*
IW
);
std
::
memset
(
sptr
,
0
,
sizeof
(
float
)
*
N
*
C
*
IH
*
IW
);
rep
(
n
,
N
)
{
rep
(
n
,
N
)
{
rep
(
oh
,
OH
)
rep
(
ow
,
OW
)
{
rep
(
oh
,
OH
)
rep
(
ow
,
OW
)
{
auto
coord_h
=
get_origin_coord
(
scale_h
,
IH
,
oh
);
if
(
param
().
imode
==
InterpolationMode
::
INTER_LINEAR
)
{
auto
coord_w
=
get_origin_coord
(
scale_w
,
IW
,
ow
);
auto
coord_h
=
get_origin_coord
(
scale_h
,
IH
,
oh
);
auto
coord_w
=
get_origin_coord
(
scale_w
,
IW
,
ow
);
float
alphah
=
coord_h
.
first
;
float
alphaw
=
coord_w
.
first
;
float
alphah
=
coord_h
.
first
;
float
alphaw
=
coord_w
.
first
;
int
ih0
=
coord_h
.
second
;
int
ih1
=
ih0
+
1
;
int
ih0
=
coord_h
.
second
;
int
iw0
=
coord_w
.
second
;
int
ih1
=
ih0
+
1
;
int
iw1
=
iw0
+
1
;
int
iw0
=
coord_w
.
second
;
int
iw1
=
iw0
+
1
;
rep
(
c
,
C
)
{
float
hidden
=
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
];
rep
(
c
,
C
)
{
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw0
]
+=
float
hidden
=
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
];
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
*
hidden
;
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw0
]
+=
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw0
]
+=
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
*
hidden
;
(
1.0
f
-
alphaw
)
*
alphah
*
hidden
;
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw0
]
+=
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw1
]
+=
(
1.0
f
-
alphaw
)
*
alphah
*
hidden
;
alphaw
*
(
1.0
f
-
alphah
)
*
hidden
;
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw1
]
+=
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw1
]
+=
alphaw
*
(
1.0
f
-
alphah
)
*
hidden
;
alphaw
*
alphah
*
hidden
;
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw1
]
+=
alphaw
*
alphah
*
hidden
;
}
}
else
if
(
param
().
imode
==
InterpolationMode
::
NEAREST
)
{
auto
ih
=
get_nearest_src
(
scale_h
,
IH
,
oh
);
auto
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
rep
(
c
,
static_cast
<
int
>
(
C
))
{
sptr
[
c
*
IH
*
IW
+
ih
*
IW
+
iw
]
+=
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
];
}
}
}
else
megdnn_throw
(
"unsupported mode in ResizeBackwardImpl"
);
}
}
sptr
+=
C
*
IH
*
IW
;
sptr
+=
C
*
IH
*
IW
;
hptr
+=
C
*
OH
*
OW
;
hptr
+=
C
*
OH
*
OW
;
...
...
dnn/src/naive/resize/opr_impl.h
浏览文件 @
0558b212
...
@@ -46,6 +46,9 @@ private:
...
@@ -46,6 +46,9 @@ private:
template
<
typename
ctype
>
template
<
typename
ctype
>
void
kern_naive
(
const
KernParam
<
ctype
>&
kern_param
);
void
kern_naive
(
const
KernParam
<
ctype
>&
kern_param
);
template
<
typename
ctype
>
void
kern_nchw_nearest
(
const
KernParam
<
ctype
>&
kern_param
);
template
<
typename
ctype
>
template
<
typename
ctype
>
void
kern_naive_nhwc
(
const
KernParam
<
ctype
>&
kern_param
);
void
kern_naive_nhwc
(
const
KernParam
<
ctype
>&
kern_param
);
...
...
dnn/test/common/resize.h
浏览文件 @
0558b212
...
@@ -18,6 +18,8 @@ namespace megdnn {
...
@@ -18,6 +18,8 @@ namespace megdnn {
namespace
test
{
namespace
test
{
namespace
resize
{
namespace
resize
{
using
IMode
=
param
::
Resize
::
InterpolationMode
;
struct
TestArg
{
struct
TestArg
{
param
::
Resize
param
;
param
::
Resize
param
;
TensorShape
src
;
TensorShape
src
;
...
@@ -62,17 +64,18 @@ static void set_nchw_args(std::vector<TestArg>& args) {
...
@@ -62,17 +64,18 @@ static void set_nchw_args(std::vector<TestArg>& args) {
args
.
emplace_back
(
param
,
TensorShape
{
1
,
2
,
6
,
8
},
TensorShape
{
1
,
2
,
3
,
4
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
2
,
6
,
8
},
TensorShape
{
1
,
2
,
3
,
4
});
}
}
static
inline
std
::
vector
<
TestArg
>
get_args
()
{
static
inline
std
::
vector
<
TestArg
>
get_args
(
IMode
imode
=
IMode
::
INTER_LINEAR
)
{
std
::
vector
<
TestArg
>
args
;
std
::
vector
<
TestArg
>
args
;
set_nchw_args
(
args
);
set_nchw_args
(
args
);
if
(
imode
==
IMode
::
INTER_LINEAR
)
{
//! test NHWC with ch != 1 or ch != 3
//! test NHWC with ch != 1 or ch != 3
param
::
Resize
param
;
param
::
Resize
param
;
param
.
format
=
param
::
Resize
::
Format
::
NHWC
;
param
.
format
=
param
::
Resize
::
Format
::
NHWC
;
param
.
imode
=
param
::
Resize
::
InterpolationMode
::
LINEAR
;
param
.
imode
=
imode
;
args
.
emplace_back
(
param
,
TensorShape
{
2
,
2
,
3
,
4
},
TensorShape
{
2
,
4
,
6
,
4
});
args
.
emplace_back
(
param
,
TensorShape
{
2
,
2
,
3
,
4
},
TensorShape
{
2
,
4
,
6
,
4
});
args
.
emplace_back
(
param
,
TensorShape
{
2
,
4
,
6
,
4
},
TensorShape
{
2
,
2
,
3
,
4
});
args
.
emplace_back
(
param
,
TensorShape
{
2
,
4
,
6
,
4
},
TensorShape
{
2
,
2
,
3
,
4
});
}
return
args
;
return
args
;
}
}
...
...
dnn/test/cuda/resize.cpp
浏览文件 @
0558b212
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "test/common/resize.h"
#include "test/common/resize.h"
#include "src/common/cv/enums.h"
#include "test/common/benchmarker.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
#include "test/cuda/fixture.h"
...
@@ -42,30 +43,33 @@ TEST_F(CUDA, RESIZE_CV) {
...
@@ -42,30 +43,33 @@ TEST_F(CUDA, RESIZE_CV) {
TEST_F
(
CUDA
,
RESIZE_FORWARD
)
{
TEST_F
(
CUDA
,
RESIZE_FORWARD
)
{
using
namespace
resize
;
using
namespace
resize
;
std
::
vector
<
TestArg
>
args
=
get_args
();
IMode
modes
[
2
]
=
{
IMode
::
INTER_LINEAR
,
IMode
::
NEAREST
};
Checker
<
Resize
>
checker
(
handle_cuda
());
for
(
auto
imode
:
modes
)
{
std
::
vector
<
TestArg
>
args
=
get_args
(
imode
);
for
(
auto
&&
arg
:
args
)
{
Checker
<
Resize
>
checker
(
handle_cuda
());
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Uint8
())
for
(
auto
&&
arg
:
args
)
{
.
set_dtype
(
1
,
dtype
::
Uint8
())
checker
.
set_param
(
arg
.
param
)
.
execs
({
arg
.
src
,
arg
.
dst
});
.
set_dtype
(
0
,
dtype
::
Uint8
())
}
.
set_dtype
(
1
,
dtype
::
Uint8
())
.
execs
({
arg
.
src
,
arg
.
dst
});
for
(
auto
&&
arg
:
args
)
{
}
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
for
(
auto
&&
arg
:
args
)
{
.
set_dtype
(
1
,
dtype
::
Float32
())
checker
.
set_param
(
arg
.
param
)
.
set_epsilon
(
1e-3
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
execs
({
arg
.
src
,
arg
.
dst
});
.
set_dtype
(
1
,
dtype
::
Float32
())
}
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
for
(
auto
&&
arg
:
args
)
{
}
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Int8
())
for
(
auto
&&
arg
:
args
)
{
.
set_dtype
(
1
,
dtype
::
Int8
())
checker
.
set_param
(
arg
.
param
)
.
set_epsilon
(
1e-3
)
.
set_dtype
(
0
,
dtype
::
Int8
())
.
execs
({
arg
.
src
,
arg
.
dst
});
.
set_dtype
(
1
,
dtype
::
Int8
())
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
}
}
}
...
@@ -84,42 +88,48 @@ TEST_F(CUDA, RESIZE_NCHW4) {
...
@@ -84,42 +88,48 @@ TEST_F(CUDA, RESIZE_NCHW4) {
}
}
TEST_F
(
CUDA
,
RESIZE_NCHW_WITH_STRIDE
)
{
TEST_F
(
CUDA
,
RESIZE_NCHW_WITH_STRIDE
)
{
param
::
Resize
param
;
IMode
modes
[
2
]
=
{
IMode
::
INTER_LINEAR
,
IMode
::
NEAREST
};
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
for
(
auto
imode
:
modes
)
{
param
.
imode
=
param
::
Resize
::
InterpolationMode
::
LINEAR
;
param
::
Resize
param
;
Checker
<
Resize
>
checker
(
handle_cuda
());
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
checker
.
set_epsilon
(
1
+
1e-3
)
param
.
imode
=
imode
;
.
set_param
(
param
);
Checker
<
Resize
>
checker
(
handle_cuda
());
checker
.
set_epsilon
(
1
+
1e-3
)
auto
run
=
[
&
](
TensorShape
src_shape
,
std
::
vector
<
ptrdiff_t
>
src_layout
,
.
set_param
(
param
);
TensorShape
dst_shape
,
DType
dtype
)
{
checker
.
set_dtype
(
0
,
dtype
)
auto
run
=
[
&
](
TensorShape
src_shape
,
std
::
vector
<
ptrdiff_t
>
src_layout
,
.
set_dtype
(
1
,
dtype
)
TensorShape
dst_shape
,
DType
dtype
)
{
.
execl
({{
src_shape
,
src_layout
,
dtype
},
{
dst_shape
,
dtype
}});
checker
.
set_dtype
(
0
,
dtype
)
};
.
set_dtype
(
1
,
dtype
)
.
execl
({{
src_shape
,
src_layout
,
dtype
},
{
dst_shape
,
dtype
}});
for
(
DType
&
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Uint8
(),
};
dtype
::
Int8
()})
{
run
({
2
,
3
,
4
,
4
},
{
256
,
32
,
8
,
1
},
{
2
,
3
,
3
,
3
},
dtype
);
for
(
DType
&
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Uint8
(),
run
({
1
,
3
,
4
,
3
},
{
105
,
35
,
7
,
2
},
{
1
,
3
,
5
,
5
},
dtype
);
dtype
::
Int8
()})
{
run
({
1
,
3
,
40
,
40
},
{
25600
,
3200
,
80
,
1
},
{
1
,
3
,
30
,
30
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
256
,
32
,
8
,
1
},
{
2
,
3
,
3
,
3
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
-
256
,
32
,
-
8
,
1
},
{
2
,
3
,
3
,
3
},
dtype
);
run
({
1
,
3
,
4
,
3
},
{
105
,
35
,
7
,
2
},
{
1
,
3
,
5
,
5
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
256
,
-
32
,
8
,
-
1
},
{
2
,
3
,
3
,
3
},
dtype
);
run
({
1
,
3
,
40
,
40
},
{
25600
,
3200
,
80
,
1
},
{
1
,
3
,
30
,
30
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
-
256
,
-
32
,
-
8
,
-
1
},
{
2
,
3
,
3
,
3
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
-
256
,
32
,
-
8
,
1
},
{
2
,
3
,
3
,
3
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
256
,
-
32
,
8
,
-
1
},
{
2
,
3
,
3
,
3
},
dtype
);
run
({
2
,
3
,
4
,
4
},
{
-
256
,
-
32
,
-
8
,
-
1
},
{
2
,
3
,
3
,
3
},
dtype
);
}
}
}
}
}
TEST_F
(
CUDA
,
RESIZE_BACKWARD
)
{
TEST_F
(
CUDA
,
RESIZE_BACKWARD
)
{
Checker
<
ResizeBackward
>
checker
(
handle_cuda
());
IMode
modes
[
2
]
=
{
IMode
::
INTER_LINEAR
,
IMode
::
NEAREST
};
param
::
Resize
param
;
for
(
auto
imode
:
modes
)
{
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
Checker
<
ResizeBackward
>
checker
(
handle_cuda
());
param
.
imode
=
param
::
Resize
::
InterpolationMode
::
LINEAR
;
param
::
Resize
param
;
checker
.
set_param
(
param
);
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
param
.
imode
=
imode
;
checker
.
execs
({{
2
,
3
,
4
,
5
},
{
2
,
3
,
8
,
9
}});
checker
.
set_param
(
param
);
checker
.
execs
({{
2
,
5
,
8
,
9
},
{
2
,
5
,
4
,
5
}});
checker
.
execs
({{
2
,
5
,
8
,
5
},
{
2
,
5
,
4
,
9
}});
checker
.
execs
({{
2
,
3
,
4
,
5
},
{
2
,
3
,
8
,
9
}});
checker
.
execs
({{
2
,
5
,
4
,
9
},
{
2
,
5
,
8
,
5
}});
checker
.
execs
({{
2
,
5
,
8
,
9
},
{
2
,
5
,
4
,
5
}});
checker
.
execs
({{
2
,
5
,
8
,
5
},
{
2
,
5
,
4
,
9
}});
checker
.
execs
({{
2
,
5
,
4
,
9
},
{
2
,
5
,
8
,
5
}});
}
}
}
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
...
...
imperative/python/megengine/functional/vision.py
浏览文件 @
0558b212
...
@@ -522,29 +522,13 @@ def interpolate(
...
@@ -522,29 +522,13 @@ def interpolate(
if
align_corners
is
None
:
if
align_corners
is
None
:
align_corners
=
False
align_corners
=
False
if
(
size
is
not
None
and
scale_factor
is
None
and
not
align_corners
and
mode
==
"bilinear"
and
inp
.
ndim
in
[
4
,
5
]
):
# fastpath for interpolate
op
=
builtin
.
Resize
(
imode
=
"linear"
,
format
=
"NCHW"
)
shape
=
astensor1d
(
size
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
op
,
inp
,
shape
)
return
result
if
mode
==
"linear"
:
if
mode
==
"linear"
:
inp
=
expand_dims
(
inp
,
3
)
inp
=
expand_dims
(
inp
,
3
)
if
inp
.
ndim
!=
4
:
if
inp
.
ndim
!=
4
:
raise
ValueError
(
"shape of input tensor must correspond to the operartion mode"
)
raise
ValueError
(
"shape of input tensor must correspond to the operartion mode"
)
if
size
is
None
:
def
get_dsize
(
scale_factor
):
if
scale_factor
is
None
:
raise
ValueError
(
"scale_factor must not be None when size is None"
)
if
isinstance
(
scale_factor
,
(
float
,
int
)):
if
isinstance
(
scale_factor
,
(
float
,
int
)):
scale_factor
=
float
(
scale_factor
)
scale_factor
=
float
(
scale_factor
)
if
mode
==
"linear"
:
if
mode
==
"linear"
:
...
@@ -572,6 +556,13 @@ def interpolate(
...
@@ -572,6 +556,13 @@ def interpolate(
for
i
in
range
(
2
)
for
i
in
range
(
2
)
)
)
dsize
=
concat
([
dsize
[
0
],
dsize
[
1
]],
axis
=
0
)
dsize
=
concat
([
dsize
[
0
],
dsize
[
1
]],
axis
=
0
)
return
dsize
if
size
is
None
:
if
scale_factor
is
None
:
raise
ValueError
(
"scale_factor must not be None when size is None"
)
dsize
=
get_dsize
(
scale_factor
)
else
:
else
:
if
scale_factor
is
not
None
:
if
scale_factor
is
not
None
:
raise
ValueError
(
"scale_factor must be None when size is provided"
)
raise
ValueError
(
"scale_factor must be None when size is provided"
)
...
@@ -583,6 +574,15 @@ def interpolate(
...
@@ -583,6 +574,15 @@ def interpolate(
raise
ValueError
(
"under linear mode, size can only be single value"
)
raise
ValueError
(
"under linear mode, size can only be single value"
)
dsize
=
size
dsize
=
size
if
not
align_corners
and
mode
in
(
"bilinear"
,
"nearest"
)
and
inp
.
ndim
in
[
4
,
5
]:
# fastpath for interpolate
op
=
builtin
.
Resize
(
imode
=
"linear"
if
mode
==
"bilinear"
else
"nearest"
,
format
=
"NCHW"
)
shape
=
astensor1d
(
dsize
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
op
,
inp
,
shape
)
return
result
oh
,
ow
=
dsize
[
0
],
dsize
[
1
]
oh
,
ow
=
dsize
[
0
],
dsize
[
1
]
ih
,
iw
=
inp
.
shape
[
2
],
inp
.
shape
[
3
]
ih
,
iw
=
inp
.
shape
[
2
],
inp
.
shape
[
3
]
...
@@ -630,15 +630,10 @@ def interpolate(
...
@@ -630,15 +630,10 @@ def interpolate(
if
mode
==
"linear"
:
if
mode
==
"linear"
:
ret
=
reshape
(
ret
,
ret
.
shape
[
0
:
3
])
ret
=
reshape
(
ret
,
ret
.
shape
[
0
:
3
])
else
:
else
:
# only NHWC format support "cubic" and "nearest" mode
# only NHWC format support "cubic" mode
assert
mode
==
"bicubic"
inp
=
transpose
(
inp
,
(
0
,
2
,
3
,
1
))
inp
=
transpose
(
inp
,
(
0
,
2
,
3
,
1
))
ret
=
warp_perspective
(
ret
=
warp_perspective
(
inp
,
weight
,
dsize
,
format
=
"NHWC"
,
interp_mode
=
"cubic"
,)
inp
,
weight
,
dsize
,
format
=
"NHWC"
,
interp_mode
=
"cubic"
if
mode
==
"bicubic"
else
mode
,
)
ret
=
transpose
(
ret
,
(
0
,
3
,
1
,
2
))
ret
=
transpose
(
ret
,
(
0
,
3
,
1
,
2
))
return
ret
return
ret
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录