Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0d7ace15
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0d7ace15
编写于
6月 20, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/dnn): suport fp16 for resize nhwc
GitOrigin-RevId: bb04d2a801b5cbe9c8704ce922842231c4158a4c
上级
cfed86f9
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
342 addition
and
80 deletion
+342
-80
dnn/src/common/resize.cpp
dnn/src/common/resize.cpp
+6
-2
dnn/src/cuda/resize/backward.cpp
dnn/src/cuda/resize/backward.cpp
+45
-15
dnn/src/cuda/resize/backward.cu
dnn/src/cuda/resize/backward.cu
+107
-32
dnn/src/cuda/resize/common.h
dnn/src/cuda/resize/common.h
+3
-2
dnn/src/cuda/resize/forward.cpp
dnn/src/cuda/resize/forward.cpp
+5
-0
dnn/src/cuda/resize/forward.cu
dnn/src/cuda/resize/forward.cu
+1
-0
dnn/src/naive/resize/opr_impl.cpp
dnn/src/naive/resize/opr_impl.cpp
+96
-28
dnn/src/naive/resize/opr_impl.h
dnn/src/naive/resize/opr_impl.h
+6
-0
dnn/test/cuda/resize.cpp
dnn/test/cuda/resize.cpp
+73
-1
未找到文件。
dnn/src/common/resize.cpp
浏览文件 @
0d7ace15
...
@@ -67,8 +67,12 @@ void ResizeBackward::check_exec(
...
@@ -67,8 +67,12 @@ void ResizeBackward::check_exec(
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
diff
,
grad
);
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
diff
,
grad
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
megdnn_assert
(
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW
&&
grad
.
dtype
==
dtype
::
Float32
(),
(
param
().
format
==
Param
::
Format
::
NCHW
||
"Backward resize only supports Float32 and NCHW."
);
param
().
format
==
Param
::
Format
::
NHWC
)
&&
(
grad
.
dtype
==
dtype
::
Float32
()
DNN_INC_FLOAT16
(
||
grad
.
dtype
==
dtype
::
Float16
())),
"Backward resize only supports NCHW and NHWC, the dtype only supports "
"Float32 and Float16."
);
}
}
std
::
pair
<
float
,
int
>
ResizeBase
::
get_cubic_coord
(
float
scale
,
int
idx
)
{
std
::
pair
<
float
,
int
>
ResizeBase
::
get_cubic_coord
(
float
scale
,
int
idx
)
{
...
...
dnn/src/cuda/resize/backward.cpp
浏览文件 @
0d7ace15
...
@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec(
...
@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec(
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
check_exec
(
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
check_exec
(
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
auto
stream
=
cuda_stream
(
this
->
handle
());
auto
stream
=
cuda_stream
(
this
->
handle
());
auto
N
=
grad
.
layout
.
shape
[
0
],
C
=
grad
.
layout
.
shape
[
1
],
IH
=
grad
.
layout
.
shape
[
2
],
bool
is_nhwc
=
param
().
format
==
param
::
Resize
::
Format
::
NHWC
;
IW
=
grad
.
layout
.
shape
[
3
],
OH
=
diff
.
layout
.
shape
[
2
],
size_t
N
,
C
,
IH
,
IW
,
OH
,
OW
;
OW
=
diff
.
layout
.
shape
[
3
];
if
(
is_nhwc
)
{
if
(
param
().
imode
!=
Param
::
InterpolationMode
::
LINEAR
&&
is_nhwc_contig_wc
(
grad
.
layout
))
{
megdnn_assert
(
0
,
"unsupport mode in resizeBackward, only support param().imode = "
"LINEAR"
);
}
N
=
grad
.
layout
.
shape
[
0
];
C
=
grad
.
layout
.
shape
[
3
];
IH
=
grad
.
layout
.
shape
[
1
];
IW
=
grad
.
layout
.
shape
[
2
];
OH
=
diff
.
layout
.
shape
[
1
];
OW
=
diff
.
layout
.
shape
[
2
];
}
else
{
N
=
grad
.
layout
.
shape
[
0
],
C
=
grad
.
layout
.
shape
[
1
],
IH
=
grad
.
layout
.
shape
[
2
],
IW
=
grad
.
layout
.
shape
[
3
],
OH
=
diff
.
layout
.
shape
[
2
],
OW
=
diff
.
layout
.
shape
[
3
];
}
size_t
max_batch_x_channel
=
max_batch_x_channel_size
();
size_t
max_batch_x_channel
=
max_batch_x_channel_size
();
dt_float32
*
diff_ptr
=
diff
.
ptr
<
dt_float32
>
();
dt_float32
*
grad_ptr
=
grad
.
ptr
<
dt_float32
>
();
size_t
max_batch_size
=
max_batch_x_channel
/
C
;
size_t
max_batch_size
=
max_batch_x_channel
/
C
;
while
(
N
>
0
)
{
while
(
N
>
0
)
{
size_t
curr_batch_size
=
N
>
max_batch_size
?
max_batch_size
:
N
;
size_t
curr_batch_size
=
N
>
max_batch_size
?
max_batch_size
:
N
;
resize
::
backward_data_proxy
(
switch
(
grad
.
layout
.
dtype
.
enumv
())
{
resize
::
get_imode
(
param
().
imode
),
diff_ptr
,
grad_ptr
,
curr_batch_size
,
#define cb(_t) \
C
,
IH
,
IW
,
OH
,
OW
,
stream
);
case DTypeTrait<_t>::enumv: { \
typedef DTypeTrait<_t>::ctype ct; \
if
(
N
<=
max_batch_size
)
{
ct* diff_ptr = diff.ptr<ct>(); \
break
;
ct* grad_ptr = grad.ptr<ct>(); \
}
else
{
resize::backward_data_proxy( \
N
-=
max_batch_size
;
is_nhwc, resize::get_imode(param().imode), diff_ptr, grad_ptr, \
diff_ptr
+=
curr_batch_size
*
diff
.
layout
.
stride
[
0
];
curr_batch_size, C, IH, IW, OH, OW, stream); \
grad_ptr
+=
curr_batch_size
*
grad
.
layout
.
stride
[
0
];
if (N <= max_batch_size) { \
return; \
} else { \
N -= max_batch_size; \
diff_ptr += curr_batch_size * diff.layout.stride[0]; \
grad_ptr += curr_batch_size * grad.layout.stride[0]; \
} \
break; \
}
cb
(
megdnn
::
dtype
::
Float32
);
DNN_INC_FLOAT16
(
cb
(
megdnn
::
dtype
::
Float16
));
default:
megdnn_throw
(
ssprintf
(
"unsupported dtype: %s in resize backward"
,
grad
.
layout
.
dtype
.
name
()));
}
}
#undef cb
}
}
}
}
...
...
dnn/src/cuda/resize/backward.cu
浏览文件 @
0d7ace15
#include "src/common/rounding_converter.cuh"
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"
#include "src/cuda/resize/common.h"
...
@@ -11,9 +12,52 @@ namespace megdnn {
...
@@ -11,9 +12,52 @@ namespace megdnn {
namespace
cuda
{
namespace
cuda
{
namespace
resize
{
namespace
resize
{
template
<
typename
ctype
,
typename
OutputConverter
>
__global__
void
resize_bwd_nhwc_kernel
(
const
ctype
*
hidden
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
float
scale_w
)
{
OutputConverter
output_converter
;
int
n
=
blockIdx
.
z
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
hidden
+=
n
*
C
*
OH
*
OW
;
dst
+=
n
*
C
*
IH
*
IW
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
float
alphah
,
alphaw
;
int
ih0
,
iw0
;
get_origin_coord
(
scale_h
,
IH
,
oh
,
alphah
,
ih0
);
get_origin_coord
(
scale_w
,
IW
,
ow
,
alphaw
,
iw0
);
int
ih1
=
ih0
+
1
;
int
iw1
=
iw0
+
1
;
float
nalphaw
=
1.0
f
-
alphaw
;
float
nalphah
=
1.0
f
-
alphah
;
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
atomic_add
(
dst
+
(
ih0
*
IW
+
iw0
)
*
C
+
c
,
output_converter
(
hidden
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
nalphaw
*
nalphah
));
atomic_add
(
dst
+
(
ih0
*
IW
+
iw1
)
*
C
+
c
,
output_converter
(
hidden
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
alphaw
*
nalphah
));
atomic_add
(
dst
+
(
ih1
*
IW
+
iw0
)
*
C
+
c
,
output_converter
(
hidden
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
nalphaw
*
alphah
));
atomic_add
(
dst
+
(
ih1
*
IW
+
iw1
)
*
C
+
c
,
output_converter
(
hidden
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
alphaw
*
alphah
));
}
}
}
template
<
typename
ctype
,
typename
OutputConverter
>
__global__
void
resize_bwd_linear_kernel
(
__global__
void
resize_bwd_linear_kernel
(
const
float
*
hidden
,
float
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
const
ctype
*
hidden
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
float
scale_w
)
{
float
scale_h
,
float
scale_w
)
{
OutputConverter
output_converter
;
int
n
=
blockIdx
.
z
;
int
n
=
blockIdx
.
z
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel(
...
@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel(
float
nalphaw
=
1.0
f
-
alphaw
;
float
nalphaw
=
1.0
f
-
alphaw
;
float
nalphah
=
1.0
f
-
alphah
;
float
nalphah
=
1.0
f
-
alphah
;
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
atomicAdd
(
dst
+
ih0
*
IW
+
iw0
,
hidden
[
oh
*
OW
+
ow
]
*
nalphaw
*
nalphah
);
atomic_add
(
atomicAdd
(
dst
+
ih0
*
IW
+
iw1
,
hidden
[
oh
*
OW
+
ow
]
*
alphaw
*
nalphah
);
dst
+
ih0
*
IW
+
iw0
,
atomicAdd
(
dst
+
ih1
*
IW
+
iw0
,
hidden
[
oh
*
OW
+
ow
]
*
nalphaw
*
alphah
);
output_converter
(
hidden
[
oh
*
OW
+
ow
]
*
nalphaw
*
nalphah
));
atomicAdd
(
dst
+
ih1
*
IW
+
iw1
,
hidden
[
oh
*
OW
+
ow
]
*
alphaw
*
alphah
);
atomic_add
(
dst
+
ih0
*
IW
+
iw1
,
output_converter
(
hidden
[
oh
*
OW
+
ow
]
*
alphaw
*
nalphah
));
atomic_add
(
dst
+
ih1
*
IW
+
iw0
,
output_converter
(
hidden
[
oh
*
OW
+
ow
]
*
nalphaw
*
alphah
));
atomic_add
(
dst
+
ih1
*
IW
+
iw1
,
output_converter
(
hidden
[
oh
*
OW
+
ow
]
*
alphaw
*
alphah
));
hidden
+=
OH
*
OW
;
hidden
+=
OH
*
OW
;
dst
+=
IH
*
IW
;
dst
+=
IH
*
IW
;
}
}
}
}
}
}
template
<
typename
ctype
,
typename
OutputConverter
>
__global__
void
resize_bwd_nearest_kernel
(
__global__
void
resize_bwd_nearest_kernel
(
const
float
*
hidden
,
float
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
const
ctype
*
hidden
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
float
scale_w
)
{
float
scale_h
,
float
scale_w
)
{
OutputConverter
output_converter
;
int
n
=
blockIdx
.
z
;
int
n
=
blockIdx
.
z
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel(
...
@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel(
int
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
int
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
atomic
Add
(
dst
+
ih
*
IW
+
iw
,
hidden
[
oh
*
OW
+
ow
]
);
atomic
_add
(
dst
+
ih
*
IW
+
iw
,
output_converter
(
hidden
[
oh
*
OW
+
ow
])
);
hidden
+=
OH
*
OW
;
hidden
+=
OH
*
OW
;
dst
+=
IH
*
IW
;
dst
+=
IH
*
IW
;
}
}
}
}
}
}
template
<
typename
ctype
,
typename
OutputConverter
>
__global__
void
resize_bwd_cubic_kernel
(
__global__
void
resize_bwd_cubic_kernel
(
const
float
*
hidden
,
float
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
const
ctype
*
hidden
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
scale_h
,
float
scale_w
)
{
float
scale_h
,
float
scale_w
)
{
OutputConverter
output_converter
;
int
n
=
blockIdx
.
z
;
int
n
=
blockIdx
.
z
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel(
...
@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel(
int
ih
=
saturate
(
ih0
+
kh
,
0
,
IH
-
1
);
int
ih
=
saturate
(
ih0
+
kh
,
0
,
IH
-
1
);
for
(
int
kw
=
0
;
kw
<
ksize
;
kw
++
)
{
for
(
int
kw
=
0
;
kw
<
ksize
;
kw
++
)
{
int
iw
=
saturate
(
iw0
+
kw
,
0
,
IW
-
1
);
int
iw
=
saturate
(
iw0
+
kw
,
0
,
IW
-
1
);
atomic
A
dd
(
atomic
_a
dd
(
dst
+
ih
*
IW
+
iw
,
dst
+
ih
*
IW
+
iw
,
hidden
[
oh
*
OW
+
ow
]
*
h_coeff
[
kh
]
*
w_coeff
[
kw
]);
output_converter
(
hidden
[
oh
*
OW
+
ow
]
*
h_coeff
[
kh
]
*
w_coeff
[
kw
]));
}
}
}
}
...
@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel(
...
@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel(
}
}
}
}
template
<
typename
ctype
>
void
backward_data_proxy
(
void
backward_data_proxy
(
InterpolationMode
imode
,
const
float
*
diff
,
float
*
grad
,
int
N
,
int
C
,
int
IH
,
bool
is_nhwc
,
InterpolationMode
imode
,
const
ctype
*
diff
,
ctype
*
grad
,
int
N
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
)
{
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
)
{
const
int
BY
=
16
,
BX
=
32
;
const
int
BY
=
16
,
BX
=
32
;
{
{
dim3
threads
(
BX
,
BY
);
dim3
threads
(
BX
,
BY
);
dim3
blocks
((
OW
+
BX
-
1
)
/
BX
,
(
OH
+
BY
-
1
)
/
BY
,
N
);
dim3
blocks
((
OW
+
BX
-
1
)
/
BX
,
(
OH
+
BY
-
1
)
/
BY
,
N
);
cuda_check
(
cudaMemsetAsync
(
grad
,
0
,
sizeof
(
float
)
*
N
*
C
*
IH
*
IW
,
stream
));
cuda_check
(
cudaMemsetAsync
(
grad
,
0
,
sizeof
(
ctype
)
*
N
*
C
*
IH
*
IW
,
stream
));
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
switch
(
imode
)
{
if
(
is_nhwc
)
{
case
InterpolationMode
::
INTER_LINEAR
:
{
resize_bwd_nhwc_kernel
<
ctype
,
rounding
::
RoundingConverter
<
ctype
>>
resize_bwd_linear_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
break
;
}
else
{
}
switch
(
imode
)
{
case
InterpolationMode
::
INTER_NEAREST
:
{
case
InterpolationMode
::
INTER_LINEAR
:
{
resize_bwd_nearest_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
resize_bwd_linear_kernel
<
ctype
,
rounding
::
RoundingConverter
<
ctype
>>
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
<<<
blocks
,
threads
,
0
,
stream
>>>
(
break
;
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
}
break
;
case
InterpolationMode
::
INTER_CUBIC
:
{
}
resize_bwd_cubic_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
case
InterpolationMode
::
INTER_NEAREST
:
{
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
resize_bwd_nearest_kernel
<
ctype
,
rounding
::
RoundingConverter
<
ctype
>>
break
;
<<<
blocks
,
threads
,
0
,
stream
>>>
(
}
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
default:
{
break
;
megdnn_throw
(
"unsupported interpolation mode"
);
}
break
;
case
InterpolationMode
::
INTER_CUBIC
:
{
resize_bwd_cubic_kernel
<
ctype
,
rounding
::
RoundingConverter
<
ctype
>>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
diff
,
grad
,
N
,
C
,
IH
,
IW
,
OH
,
OW
,
scale_h
,
scale_w
);
break
;
}
default:
{
megdnn_throw
(
"unsupported interpolation mode"
);
break
;
}
}
}
}
}
}
}
after_kernel_launch
();
after_kernel_launch
();
}
}
#define INST(ctype) \
template void backward_data_proxy( \
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
int, cudaStream_t);
INST
(
dt_float32
);
DNN_INC_FLOAT16
(
INST
(
dt_float16
));
#undef INST
}
// namespace resize
}
// namespace resize
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/cuda/resize/common.h
浏览文件 @
0d7ace15
...
@@ -20,9 +20,10 @@ void forward_proxy_nchw4(
...
@@ -20,9 +20,10 @@ void forward_proxy_nchw4(
const
ctype
*
src
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
const
ctype
*
src
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
cudaStream_t
stream
);
template
<
typename
ctype
>
void
backward_data_proxy
(
void
backward_data_proxy
(
InterpolationMode
imode
,
const
float
*
diff
,
float
*
grad
,
int
N
,
int
C
,
int
IH
,
bool
is_nhwc
,
InterpolationMode
imode
,
const
ctype
*
diff
,
ctype
*
grad
,
int
N
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
}
// namespace resize
}
// namespace resize
}
// namespace cuda
}
// namespace cuda
...
...
dnn/src/cuda/resize/forward.cpp
浏览文件 @
0d7ace15
...
@@ -148,6 +148,11 @@ void ResizeImpl::exec(
...
@@ -148,6 +148,11 @@ void ResizeImpl::exec(
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_float32
>
(),
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_float32
>
(),
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
stream
);
S_IH
,
S_IW
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Float16
{})
{
resize
::
forward_proxy
(
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_float16
>
(),
dst
.
ptr
<
dt_float16
>
(),
src
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
S_IN
,
S_IC
,
S_IH
,
S_IW
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Uint8
())
{
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Uint8
())
{
resize
::
forward_proxy
(
resize
::
forward_proxy
(
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_uint8
>
(),
is_nhwc
,
resize
::
get_imode
((
param
().
imode
)),
src
.
ptr
<
dt_uint8
>
(),
...
...
dnn/src/cuda/resize/forward.cu
浏览文件 @
0d7ace15
...
@@ -298,6 +298,7 @@ void forward_proxy_nchw4(
...
@@ -298,6 +298,7 @@ void forward_proxy_nchw4(
INST
(
float
)
INST
(
float
)
INST
(
uint8_t
)
INST
(
uint8_t
)
INST
(
int8_t
)
INST
(
int8_t
)
DNN_INC_FLOAT16
(
INST
(
dt_float16
))
#undef INST
#undef INST
#define INST(ctype) \
#define INST(ctype) \
...
...
dnn/src/naive/resize/opr_impl.cpp
浏览文件 @
0d7ace15
...
@@ -387,40 +387,53 @@ void ResizeImpl::exec(
...
@@ -387,40 +387,53 @@ void ResizeImpl::exec(
}
}
}
}
void
ResizeBackwardImpl
::
exec
(
// ***************************Backward*************************** //
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
template
<
typename
ctype
>
check_exec
(
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
void
ResizeBackwardImpl
::
kern_naive
(
megdnn_assert
(
bool
is_nhwc
,
InterpolationMode
imode
,
const
ctype
*
diff
,
ctype
*
grad
,
int
N
,
param
().
format
==
param
::
Resize
::
Format
::
NCHW
,
"invalid resize format"
);
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
)
{
const
int
N
=
grad
.
layout
.
shape
[
0
],
C
=
grad
.
layout
.
shape
[
1
],
IH
=
grad
.
layout
.
shape
[
2
],
IW
=
grad
.
layout
.
shape
[
3
];
const
int
OH
=
diff
.
layout
.
shape
[
2
],
OW
=
diff
.
layout
.
shape
[
3
];
const
float
*
hptr_
=
diff
.
ptr
<
dt_float32
>
();
float
*
sptr_
=
grad
.
ptr
<
dt_float32
>
();
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
rounding
::
RoundingConverter
<
ctype
>
output_converter
;
auto
kern
=
[
=
]()
{
auto
kern
=
[
=
]()
{
auto
hptr
=
hptr_
;
auto
hptr
=
diff
;
auto
sptr
=
sptr_
;
auto
sptr
=
grad
;
std
::
memset
(
sptr
,
0
,
sizeof
(
float
)
*
N
*
C
*
IH
*
IW
);
std
::
memset
(
sptr
,
0
,
sizeof
(
ctype
)
*
N
*
C
*
IH
*
IW
);
rep
(
n
,
N
)
{
rep
(
n
,
N
)
{
rep
(
oh
,
OH
)
rep
(
ow
,
OW
)
{
rep
(
oh
,
OH
)
rep
(
ow
,
OW
)
{
switch
(
param
().
imode
)
{
switch
(
imode
)
{
case
InterpolationMode
::
INTER_LINEAR
:
{
case
InterpolationMode
::
INTER_LINEAR
:
{
int
ih0
,
ih1
,
iw0
,
iw1
;
int
ih0
,
ih1
,
iw0
,
iw1
;
float
ah0
,
ah1
,
aw0
,
aw1
;
float
ah0
,
ah1
,
aw0
,
aw1
;
std
::
tie
(
ah0
,
ih0
,
ah1
,
ih1
)
=
get_nearest_linear_coord
(
std
::
tie
(
ah0
,
ih0
,
ah1
,
ih1
)
=
param
().
imode
,
scale_h
,
IH
,
oh
);
get_nearest_linear_coord
(
imode
,
scale_h
,
IH
,
oh
);
std
::
tie
(
aw0
,
iw0
,
aw1
,
iw1
)
=
get_nearest_linear_coord
(
std
::
tie
(
aw0
,
iw0
,
aw1
,
iw1
)
=
param
().
imode
,
scale_w
,
IW
,
ow
);
get_nearest_linear_coord
(
imode
,
scale_w
,
IW
,
ow
);
rep
(
c
,
C
)
{
if
(
is_nhwc
)
{
float
hidden
=
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
];
rep
(
c
,
C
)
{
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw0
]
+=
ah0
*
aw0
*
hidden
;
sptr
[(
ih0
*
IW
+
iw0
)
*
C
+
c
]
+=
output_converter
(
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw0
]
+=
ah1
*
aw0
*
hidden
;
hptr
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
ah0
*
aw0
);
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw1
]
+=
ah0
*
aw1
*
hidden
;
sptr
[(
ih0
*
IW
+
iw1
)
*
C
+
c
]
+=
output_converter
(
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw1
]
+=
ah1
*
aw1
*
hidden
;
hptr
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
ah0
*
aw1
);
sptr
[(
ih1
*
IW
+
iw0
)
*
C
+
c
]
+=
output_converter
(
hptr
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
ah1
*
aw0
);
sptr
[(
ih1
*
IW
+
iw1
)
*
C
+
c
]
+=
output_converter
(
hptr
[(
oh
*
OW
+
ow
)
*
C
+
c
]
*
ah1
*
aw1
);
}
}
else
{
rep
(
c
,
C
)
{
float
hidden
=
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
];
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw0
]
+=
output_converter
(
ah0
*
aw0
*
hidden
);
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw0
]
+=
output_converter
(
ah1
*
aw0
*
hidden
);
sptr
[
c
*
IH
*
IW
+
ih0
*
IW
+
iw1
]
+=
output_converter
(
ah0
*
aw1
*
hidden
);
sptr
[
c
*
IH
*
IW
+
ih1
*
IW
+
iw1
]
+=
output_converter
(
ah1
*
aw1
*
hidden
);
}
}
}
break
;
break
;
}
}
...
@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec(
...
@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec(
auto
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
auto
iw
=
get_nearest_src
(
scale_w
,
IW
,
ow
);
rep
(
c
,
static_cast
<
int
>
(
C
))
{
rep
(
c
,
static_cast
<
int
>
(
C
))
{
sptr
[
c
*
IH
*
IW
+
ih
*
IW
+
iw
]
+=
sptr
[
c
*
IH
*
IW
+
ih
*
IW
+
iw
]
+=
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
]
;
output_converter
(
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
])
;
}
}
break
;
break
;
}
}
...
@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec(
...
@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec(
int
h
=
saturate
<
int
,
int
>
(
ih0
+
kh
,
0
,
IH
-
1
);
int
h
=
saturate
<
int
,
int
>
(
ih0
+
kh
,
0
,
IH
-
1
);
rep
(
kw
,
ksize
)
{
rep
(
kw
,
ksize
)
{
int
w
=
saturate
<
int
,
int
>
(
iw0
+
kw
,
0
,
IW
-
1
);
int
w
=
saturate
<
int
,
int
>
(
iw0
+
kw
,
0
,
IW
-
1
);
sptr
[
c
*
IH
*
IW
+
h
*
IW
+
w
]
+=
sptr
[
c
*
IH
*
IW
+
h
*
IW
+
w
]
+=
output_converter
(
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
]
*
hptr
[
c
*
OH
*
OW
+
oh
*
OW
+
ow
]
*
h_coeff
[
kh
]
*
w_coeff
[
kw
];
h_coeff
[
kh
]
*
w_coeff
[
kw
]
)
;
}
}
}
}
}
}
...
@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec(
...
@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec(
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern
());
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern
());
}
}
#define INST(ctype) \
template void ResizeBackwardImpl::kern_naive( \
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
int);
INST
(
dt_float32
);
DNN_INC_FLOAT16
(
INST
(
dt_float16
));
#undef INST
void
ResizeBackwardImpl
::
exec
(
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
check_exec
(
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
megdnn_assert
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW
||
param
().
format
==
param
::
Resize
::
Format
::
NHWC
,
"invalid resize format"
);
size_t
N
,
C
,
IH
,
IW
,
OH
,
OW
;
bool
is_nhwc
=
param
().
format
==
param
::
Resize
::
Format
::
NHWC
;
if
(
is_nhwc
)
{
if
(
param
().
imode
!=
Param
::
InterpolationMode
::
LINEAR
&&
is_nhwc_contig_wc
(
grad
.
layout
))
{
megdnn_assert
(
0
,
"unsupport mode in resizeBackward, only support param().imode = "
"LINEAR"
);
}
N
=
grad
.
layout
.
shape
[
0
];
C
=
grad
.
layout
.
shape
[
3
];
IH
=
grad
.
layout
.
shape
[
1
];
IW
=
grad
.
layout
.
shape
[
2
];
OH
=
diff
.
layout
.
shape
[
1
];
OW
=
diff
.
layout
.
shape
[
2
];
}
else
{
N
=
grad
.
layout
.
shape
[
0
],
C
=
grad
.
layout
.
shape
[
1
],
IH
=
grad
.
layout
.
shape
[
2
],
IW
=
grad
.
layout
.
shape
[
3
];
OH
=
diff
.
layout
.
shape
[
2
],
OW
=
diff
.
layout
.
shape
[
3
];
}
switch
(
grad
.
layout
.
dtype
.
enumv
())
{
#define cb(_t) \
case DTypeTrait<_t>::enumv: { \
typedef DTypeTrait<_t>::ctype ct; \
ct* diff_ptr = diff.ptr<ct>(); \
ct* grad_ptr = grad.ptr<ct>(); \
ResizeBackwardImpl::kern_naive( \
is_nhwc, param().imode, diff_ptr, grad_ptr, N, C, IH, IW, OH, OW); \
break; \
}
cb
(
megdnn
::
dtype
::
Float32
);
DNN_INC_FLOAT16
(
cb
(
megdnn
::
dtype
::
Float16
));
default:
megdnn_throw
(
ssprintf
(
"unsupported dtype: %s in resize backward"
,
grad
.
layout
.
dtype
.
name
()));
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/naive/resize/opr_impl.h
浏览文件 @
0d7ace15
...
@@ -75,6 +75,12 @@ public:
...
@@ -75,6 +75,12 @@ public:
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
return
0
;
}
}
private:
template
<
typename
ctype
>
void
kern_naive
(
bool
is_nhwc
,
InterpolationMode
imode
,
const
ctype
*
diff
,
ctype
*
grad
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
);
};
};
}
// namespace naive
}
// namespace naive
...
...
dnn/test/cuda/resize.cpp
浏览文件 @
0d7ace15
...
@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) {
...
@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) {
.
set_epsilon
(
1
)
.
set_epsilon
(
1
)
.
execs
({
arg
.
src
,
arg
.
dst
});
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
for
(
auto
&&
arg
:
args
)
{
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
}
TEST_F
(
CUDA
,
RESIZE_NHWC
)
{
using
namespace
resize
;
std
::
vector
<
TestArg
>
args
;
param
::
Resize
param
;
param
.
format
=
param
::
Resize
::
Format
::
NHWC
;
param
.
imode
=
param
::
Resize
::
InterpolationMode
::
LINEAR
;
args
.
emplace_back
(
param
,
TensorShape
{
1
,
1
,
4
,
5
},
TensorShape
{
1
,
1
,
8
,
5
});
args
.
emplace_back
(
param
,
TensorShape
{
2
,
6
,
4
,
5
},
TensorShape
{
2
,
3
,
8
,
5
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
2
,
2
,
2
},
TensorShape
{
1
,
4
,
3
,
2
});
Checker
<
ResizeBackward
>
checkerBackward
(
handle_cuda
());
for
(
auto
&&
arg
:
args
)
{
checkerBackward
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
for
(
auto
&&
arg
:
args
)
{
checkerBackward
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
Checker
<
ResizeForward
>
checkerForward
(
handle_cuda
());
for
(
auto
&&
arg
:
args
)
{
checkerForward
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
for
(
auto
&&
arg
:
args
)
{
checkerForward
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_epsilon
(
1e-3
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
}
}
TEST_F
(
CUDA
,
RESIZE_NCHW4
)
{
TEST_F
(
CUDA
,
RESIZE_NCHW4
)
{
using
namespace
resize
;
using
namespace
resize
;
Checker
<
Resize
>
checker
(
handle_cuda
());
Checker
<
Resize
>
checker
(
handle_cuda
());
auto
args
=
get_nchw4_args
();
auto
args
=
get_nchw4_args
();
for
(
auto
&&
arg
:
args
)
{
for
(
auto
&&
arg
:
args
)
{
checker
.
set_param
(
arg
.
param
)
checker
.
set_param
(
arg
.
param
)
...
@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
...
@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
param
.
imode
=
imode
;
param
.
imode
=
imode
;
checker
.
set_param
(
param
);
checker
.
set_param
(
param
);
checker
.
set_dtype
(
0
,
dtype
::
Float16
());
checker
.
set_dtype
(
1
,
dtype
::
Float16
());
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
2
,
3
,
4
,
5
},
{
2
,
3
,
8
,
9
}});
checker
.
execs
({{
2
,
5
,
8
,
9
},
{
2
,
5
,
4
,
5
}});
checker
.
execs
({{
2
,
5
,
8
,
5
},
{
2
,
5
,
4
,
9
}});
checker
.
execs
({{
2
,
5
,
4
,
9
},
{
2
,
5
,
8
,
5
}});
}
for
(
auto
imode
:
modes
)
{
Checker
<
ResizeBackward
>
checker
(
handle_cuda
());
param
::
Resize
param
;
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
param
.
imode
=
imode
;
checker
.
set_param
(
param
);
checker
.
set_dtype
(
0
,
dtype
::
Float32
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
execs
({{
2
,
3
,
4
,
5
},
{
2
,
3
,
8
,
9
}});
checker
.
execs
({{
2
,
3
,
4
,
5
},
{
2
,
3
,
8
,
9
}});
checker
.
execs
({{
2
,
5
,
8
,
9
},
{
2
,
5
,
4
,
5
}});
checker
.
execs
({{
2
,
5
,
8
,
9
},
{
2
,
5
,
4
,
5
}});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录