Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bde5cf35
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bde5cf35
编写于
9月 19, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add resize linear for arm
GitOrigin-RevId: 14ac5bda3f60f530ca9d42e94b1f4e401d0a1309
上级
b6142bee
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
976 addition
and
196 deletion
+976
-196
dnn/src/arm_common/resize/direct_nchwxx.cpp
dnn/src/arm_common/resize/direct_nchwxx.cpp
+105
-0
dnn/src/arm_common/resize/direct_nchwxx.h
dnn/src/arm_common/resize/direct_nchwxx.h
+36
-0
dnn/src/arm_common/resize/helper.h
dnn/src/arm_common/resize/helper.h
+134
-0
dnn/src/arm_common/resize/opr_impl.cpp
dnn/src/arm_common/resize/opr_impl.cpp
+149
-180
dnn/src/arm_common/resize/opr_impl.h
dnn/src/arm_common/resize/opr_impl.h
+0
-10
dnn/src/arm_common/resize/upsample2_nchw.cpp
dnn/src/arm_common/resize/upsample2_nchw.cpp
+228
-0
dnn/src/arm_common/resize/upsample2_nchw.h
dnn/src/arm_common/resize/upsample2_nchw.h
+36
-0
dnn/src/arm_common/resize/upsample2_nchwxx.cpp
dnn/src/arm_common/resize/upsample2_nchwxx.cpp
+197
-0
dnn/src/arm_common/resize/upsample2_nchwxx.h
dnn/src/arm_common/resize/upsample2_nchwxx.h
+36
-0
dnn/test/arm_common/resize.cpp
dnn/test/arm_common/resize.cpp
+53
-5
dnn/test/cuda/resize.cpp
dnn/test/cuda/resize.cpp
+2
-1
未找到文件。
dnn/src/arm_common/resize/direct_nchwxx.cpp
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/direct_nchwxx.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/resize/direct_nchwxx.h"
#include "src/arm_common/resize/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
namespace
resize
;
namespace
{
template
<
typename
ctype
,
InterpolationMode
imode
>
void
resize_direct_nchwxx
(
const
ctype
*
sptr
,
ctype
*
dptr
,
size_t
N
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
)
{
using
simd_helper
=
SIMDHelper
<
ctype
>
;
constexpr
size_t
PC
=
simd_helper
::
simd_width
;
using
simd_type
=
typename
simd_helper
::
simd_type
;
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
{
for
(
size_t
oh
=
0
;
oh
<
OH
;
++
oh
)
{
for
(
size_t
ow
=
0
;
ow
<
OW
;
++
ow
)
{
int
ih0
,
ih1
,
iw0
,
iw1
;
float
ah0
,
ah1
,
aw0
,
aw1
;
std
::
tie
(
ah0
,
ih0
,
ah1
,
ih1
)
=
get_nearest_linear_coord
(
imode
,
scale_h
,
IH
,
oh
);
std
::
tie
(
aw0
,
iw0
,
aw1
,
iw1
)
=
get_nearest_linear_coord
(
imode
,
scale_w
,
IW
,
ow
);
simd_type
r0
=
simd_helper
::
load
(
sptr
+
(
ih0
*
IW
+
iw0
)
*
PC
);
simd_type
r1
=
simd_helper
::
load
(
sptr
+
(
ih0
*
IW
+
iw1
)
*
PC
);
simd_type
r2
=
simd_helper
::
load
(
sptr
+
(
ih1
*
IW
+
iw0
)
*
PC
);
simd_type
r3
=
simd_helper
::
load
(
sptr
+
(
ih1
*
IW
+
iw1
)
*
PC
);
// FIXME: weight fp16 may cause precision problem
ctype
a0
=
ah0
*
aw0
;
ctype
a1
=
ah0
*
aw1
;
ctype
a2
=
ah1
*
aw0
;
ctype
a3
=
ah1
*
aw1
;
simd_type
c
=
simd_helper
::
dup
(
0
);
c
=
simd_helper
::
fma
(
c
,
r0
,
a0
);
c
=
simd_helper
::
fma
(
c
,
r1
,
a1
);
c
=
simd_helper
::
fma
(
c
,
r2
,
a2
);
c
=
simd_helper
::
fma
(
c
,
r3
,
a3
);
simd_helper
::
store
(
dptr
+
(
oh
*
OW
+
ow
)
*
PC
,
c
);
}
}
sptr
+=
IH
*
IW
*
PC
;
dptr
+=
OH
*
OW
*
PC
;
}
}
}
void
megdnn
::
arm_common
::
resize_direct_nearest_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
)
{
resize_direct_nchwxx
<
float
,
InterpolationMode
::
INTER_NEAREST
>
(
kern_param
.
sptr
,
kern_param
.
dptr
,
kern_param
.
n
*
kern_param
.
c
/
4
,
kern_param
.
ih
,
kern_param
.
iw
,
kern_param
.
oh
,
kern_param
.
ow
);
}
void
megdnn
::
arm_common
::
resize_direct_linear_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
)
{
resize_direct_nchwxx
<
float
,
InterpolationMode
::
INTER_LINEAR
>
(
kern_param
.
sptr
,
kern_param
.
dptr
,
kern_param
.
n
*
kern_param
.
c
/
4
,
kern_param
.
ih
,
kern_param
.
iw
,
kern_param
.
oh
,
kern_param
.
ow
);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void
megdnn
::
arm_common
::
resize_direct_nearest_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
)
{
auto
sptr
=
reinterpret_cast
<
const
__fp16
*>
(
kern_param
.
sptr
);
auto
dptr
=
reinterpret_cast
<
__fp16
*>
(
kern_param
.
dptr
);
resize_direct_nchwxx
<
__fp16
,
InterpolationMode
::
INTER_NEAREST
>
(
sptr
,
dptr
,
kern_param
.
n
*
kern_param
.
c
/
8
,
kern_param
.
ih
,
kern_param
.
iw
,
kern_param
.
oh
,
kern_param
.
ow
);
}
void
megdnn
::
arm_common
::
resize_direct_linear_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
)
{
auto
sptr
=
reinterpret_cast
<
const
__fp16
*>
(
kern_param
.
sptr
);
auto
dptr
=
reinterpret_cast
<
__fp16
*>
(
kern_param
.
dptr
);
resize_direct_nchwxx
<
__fp16
,
InterpolationMode
::
INTER_LINEAR
>
(
sptr
,
dptr
,
kern_param
.
n
*
kern_param
.
c
/
8
,
kern_param
.
ih
,
kern_param
.
iw
,
kern_param
.
oh
,
kern_param
.
ow
);
}
#endif
dnn/src/arm_common/resize/direct_nchwxx.h
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/direct_nchwxx.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/resize/opr_impl.h"
namespace
megdnn
{
namespace
arm_common
{
void
resize_direct_linear_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
);
void
resize_direct_nearest_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void
resize_direct_linear_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
);
void
resize_direct_nearest_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
);
#endif
}
// namespace arm_common
}
// namespace megdnn
dnn/src/arm_common/resize/helper.h
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/simd_macro/marm_neon.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
resize
{
using
InterpolationMode
=
Resize
::
InterpolationMode
;
template
<
typename
ctype
>
struct
SIMDHelper
{};
template
<
>
struct
SIMDHelper
<
float
>
{
using
simd_type
=
float32x4_t
;
using
simd_type_x2
=
float32x4x2_t
;
using
ctype
=
float
;
static
constexpr
size_t
simd_width
=
4
;
static
inline
simd_type
load
(
const
ctype
*
src_ptr
)
{
return
vld1q_f32
(
src_ptr
);
}
static
inline
void
store
(
ctype
*
dst_ptr
,
const
simd_type
&
rdst
)
{
vst1q_f32
(
dst_ptr
,
rdst
);
}
static
inline
void
store2_interleave
(
ctype
*
dst_ptr
,
const
simd_type
&
rdst1
,
const
simd_type
&
rdst2
)
{
simd_type_x2
rdst
;
rdst
.
val
[
0
]
=
rdst1
;
rdst
.
val
[
1
]
=
rdst2
;
vst2q_f32
(
dst_ptr
,
rdst
);
}
static
inline
simd_type
fma
(
const
simd_type
&
a
,
const
simd_type
&
b
,
ctype
n
)
{
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
return
vfmaq_n_f32
(
a
,
b
,
n
);
#else
return
vmlaq_n_f32
(
a
,
b
,
n
);
#endif
}
static
inline
simd_type
fma
(
const
simd_type
&
a
,
const
simd_type
&
b
,
const
simd_type
&
c
)
{
#if defined(__ARM_FEATURE_FMA)
return
vfmaq_f32
(
a
,
b
,
c
);
#else
return
vmlaq_f32
(
a
,
b
,
c
);
#endif
}
static
inline
simd_type
dup
(
float
val
)
{
return
vdupq_n_f32
(
val
);
}
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template
<
>
struct
SIMDHelper
<
__fp16
>
{
using
simd_type
=
float16x8_t
;
using
simd_type_x2
=
float16x8x2_t
;
using
ctype
=
__fp16
;
static
constexpr
size_t
simd_width
=
8
;
static
inline
simd_type
load
(
const
ctype
*
src_ptr
)
{
return
vld1q_f16
(
src_ptr
);
}
static
inline
void
store
(
ctype
*
dst_ptr
,
const
simd_type
&
rdst
)
{
vst1q_f16
(
dst_ptr
,
rdst
);
}
static
inline
void
store2_interleave
(
ctype
*
dst_ptr
,
const
simd_type
&
rdst1
,
const
simd_type
&
rdst2
)
{
simd_type_x2
rdst
;
rdst
.
val
[
0
]
=
rdst1
;
rdst
.
val
[
1
]
=
rdst2
;
vst2q_f16
(
dst_ptr
,
rdst
);
}
static
inline
simd_type
fma
(
const
simd_type
&
a
,
const
simd_type
&
b
,
ctype
n
)
{
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
return
vfmaq_n_f16
(
a
,
b
,
n
);
#else
return
vaddq_f16
(
a
,
vmulq_n_f16
(
b
,
n
));
#endif
}
static
inline
simd_type
fma
(
const
simd_type
&
a
,
const
simd_type
&
b
,
const
simd_type
&
c
)
{
return
vfmaq_f16
(
a
,
b
,
c
);
}
static
inline
simd_type
dup
(
float
val
)
{
return
vdupq_n_f16
(
val
);
}
};
#endif
static
inline
int
get_nearest_src
(
float
scale
,
int
size
,
int
idx
)
{
return
std
::
min
(
static_cast
<
int
>
(
idx
/
scale
),
size
-
1
);
}
static
inline
std
::
tuple
<
float
,
int
,
float
,
int
>
get_nearest_linear_coord
(
InterpolationMode
imode
,
float
scale
,
int
size
,
int
idx
)
{
if
(
size
==
1
)
{
return
std
::
make_tuple
(
1.0
f
,
0
,
0.0
f
,
0
);
}
float
alpha
=
(
idx
+
0.5
f
)
/
scale
-
0.5
f
;
int
origin_idx
=
static_cast
<
int
>
(
floor
(
alpha
));
alpha
-=
origin_idx
;
if
(
imode
==
InterpolationMode
::
INTER_NEAREST
)
{
origin_idx
=
get_nearest_src
(
scale
,
size
,
idx
);
alpha
=
0
;
}
if
(
origin_idx
<
0
)
{
origin_idx
=
0
;
alpha
=
0
;
}
else
if
(
origin_idx
+
1
>=
size
)
{
origin_idx
=
size
-
2
;
alpha
=
1
;
}
return
std
::
make_tuple
(
1
-
alpha
,
origin_idx
,
alpha
,
origin_idx
+
1
);
}
};
};
};
dnn/src/arm_common/resize/opr_impl.cpp
浏览文件 @
bde5cf35
...
...
@@ -12,212 +12,181 @@
#include "src/arm_common/resize/opr_impl.h"
#include "src/arm_common/handle.h"
#include "src/arm_common/resize/direct_nchwxx.h"
#include "src/arm_common/resize/resize_cv.h"
#include "src/arm_common/resize/upsample2_nchw.h"
#include "src/arm_common/resize/upsample2_nchwxx.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
#include "midout.h"
MIDOUT_DECL
(
megdnn_arm_resize
)
namespace
megdnn
{
namespace
arm_common
{
void
ResizeImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
if
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW44
||
param
().
format
==
param
::
Resize
::
Format
::
NCHW88
)
{
bool
is_contiguous
=
src
.
layout
.
is_contiguous
()
&&
dst
.
layout
.
is_contiguous
();
bool
dtype_same
=
src
.
layout
.
dtype
==
dst
.
layout
.
dtype
;
bool
nchw44_enable
=
param
().
format
==
param
::
Resize
::
Format
::
NCHW44
&&
src
.
layout
.
dtype
==
dtype
::
Float32
();
bool
nchw88_enable
=
param
().
format
==
param
::
Resize
::
Format
::
NCHW88
&&
DNN_FLOAT16_SELECT
(
src
.
layout
.
dtype
==
dtype
::
Float16
(),
false
);
bool
interp_supported
=
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
INTER_NEAREST
||
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
INTER_LINEAR
;
bool
is_upsample2
=
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
INTER_NEAREST
&&
src
.
layout
.
shape
[
2
]
*
2
==
dst
.
layout
.
shape
[
2
]
&&
src
.
layout
.
shape
[
3
]
*
2
==
dst
.
layout
.
shape
[
3
];
bool
need_fallback
=
!
is_contiguous
||
!
dtype_same
||
!
interp_supported
||
(
!
nchw44_enable
&&
!
nchw88_enable
);
bool
is_contiguous
=
src
.
layout
.
is_contiguous
()
&&
dst
.
layout
.
is_contiguous
();
bool
is_dtype_same
=
src
.
layout
.
dtype
==
dst
.
layout
.
dtype
;
bool
is_dtype_fp32
=
src
.
layout
.
dtype
==
dtype
::
Float32
();
bool
is_dtype_fp16
=
DNN_FLOAT16_SELECT
(
src
.
layout
.
dtype
==
dtype
::
Float16
(),
false
);
bool
is_dtype_supported
=
is_dtype_same
&&
(
is_dtype_fp32
||
is_dtype_fp16
);
bool
is_nchw
=
param
().
format
==
param
::
Resize
::
Format
::
NCHW
&&
(
is_dtype_fp32
||
is_dtype_fp16
);
bool
is_nchw44_fp32
=
param
().
format
==
param
::
Resize
::
Format
::
NCHW44
&&
is_dtype_fp32
;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
bool
is_nchw88_fp16
=
param
().
format
==
param
::
Resize
::
Format
::
NCHW88
&&
is_dtype_fp16
;
#endif
if
(
need_fallback
)
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
else
if
(
nchw44_enable
)
{
auto
kern_param
=
KernParam
<
float
>::
from_tensors
(
param
().
format
,
param
().
imode
,
src
,
dst
,
workspace
);
bool
is_imode_nearest
=
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
INTER_NEAREST
;
bool
is_imode_linear
=
param
().
imode
==
param
::
Resize
::
InterpolationMode
::
INTER_LINEAR
;
bool
is_imode_supported
=
is_imode_nearest
||
is_imode_linear
;
bool
is_upsample2
=
src
.
layout
.
shape
[
2
]
*
2
==
dst
.
layout
.
shape
[
2
]
&&
src
.
layout
.
shape
[
3
]
*
2
==
dst
.
layout
.
shape
[
3
];
bool
usable
=
is_contiguous
&&
is_dtype_supported
&&
is_imode_supported
;
if
(
param
().
format
==
param
::
Resize
::
Format
::
NHWC
&&
(
src
.
layout
[
3
]
==
1
||
src
.
layout
[
3
]
==
3
)
&&
is_nhwc_contig_wc
(
src
.
layout
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_cv_exec
(
src
,
dst
,
param
().
imode
));
}
else
if
(
!
usable
)
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
else
if
(
is_dtype_fp32
)
{
auto
kern_param
=
KernParam
<
float
>::
from_tensors
(
param
().
format
,
param
().
imode
,
src
,
dst
,
workspace
);
if
(
is_nchw44_fp32
)
{
if
(
is_upsample2
)
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern_nearest_upsample2_pack_simd_width
(
src
,
dst
));
if
(
is_imode_nearest
)
{
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
0
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_nearest_upsample2_nchw44_fp32
(
kern_param
));
}
MIDOUT_END
();
}
else
{
megdnn_assert
(
is_imode_linear
,
"invalid imode"
);
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
1
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_linear_upsample2_nchw44_fp32
(
kern_param
));
}
MIDOUT_END
();
}
}
else
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern_nchw44_fp32
(
kern_param
));
if
(
is_imode_nearest
)
{
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
2
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_direct_nearest_nchw44_fp32
(
kern_param
));
}
MIDOUT_END
();
}
else
{
megdnn_assert
(
is_imode_linear
,
"invalid imode"
);
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
3
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_direct_linear_nchw44_fp32
(
kern_param
));
}
MIDOUT_END
();
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
else
if
(
nchw88_enable
)
{
auto
kern_param
=
KernParam
<
dt_float16
>::
from_tensors
(
param
().
format
,
param
().
imode
,
src
,
dst
,
workspace
);
}
else
if
(
is_nchw
)
{
if
(
is_upsample2
)
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern_nearest_upsample2_pack_simd_width
(
src
,
dst
));
if
(
is_imode_nearest
)
{
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
4
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_nearest_upsample2_nchw_fp32
(
kern_param
));
}
MIDOUT_END
();
}
else
{
megdnn_assert
(
is_imode_linear
,
"invalid imode"
);
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
5
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_linear_upsample2_nchw_fp32
(
kern_param
));
}
MIDOUT_END
();
}
}
else
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
kern_nchw88_fp16
(
kern_param
)
);
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
#endif
}
else
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
}
else
if
(
param
().
format
==
param
::
Resize
::
Format
::
NCHW
||
(
src
.
layout
[
3
]
!=
1
&&
src
.
layout
[
3
]
!=
3
)
||
!
is_nhwc_contig_wc
(
src
.
layout
))
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
else
{
megdnn_assert
(
param
().
format
==
param
::
Resize
::
Format
::
NHWC
,
"invalid resize format"
);
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_cv_exec
(
src
,
dst
,
param
().
imode
));
}
}
template
<
typename
ctype
>
void
ResizeImpl
::
kern_nchw44_fp32
(
const
KernParam
<
ctype
>&
kern_param
)
{
UNPACK_RESIZE_FWD_KERN_PARAM
(
kern_param
);
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
{
for
(
size_t
c
=
0
;
c
<
C
/
4
;
++
c
)
{
for
(
size_t
oh
=
0
;
oh
<
OH
;
++
oh
)
{
for
(
size_t
ow
=
0
;
ow
<
OW
;
++
ow
)
{
int
ih0
,
ih1
,
iw0
,
iw1
;
float
ah0
,
ah1
,
aw0
,
aw1
;
std
::
tie
(
ah0
,
ih0
,
ah1
,
ih1
)
=
get_nearest_linear_coord
(
kern_param
.
imode
,
scale_h
,
IH
,
oh
);
std
::
tie
(
aw0
,
iw0
,
aw1
,
iw1
)
=
get_nearest_linear_coord
(
kern_param
.
imode
,
scale_w
,
IW
,
ow
);
#define SRC_ADDRESS(ih, iw) \
(sptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 4)
#define DST_ADDRESS(oh, ow) \
(dptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 4)
float32x4_t
r0
=
vld1q_f32
(
SRC_ADDRESS
(
ih0
,
iw0
));
float32_t
a0
=
ah0
*
aw0
;
float32x4_t
r1
=
vld1q_f32
(
SRC_ADDRESS
(
ih0
,
iw1
));
float32_t
a1
=
ah0
*
aw1
;
float32x4_t
r2
=
vld1q_f32
(
SRC_ADDRESS
(
ih1
,
iw0
));
float32_t
a2
=
ah1
*
aw0
;
float32x4_t
r3
=
vld1q_f32
(
SRC_ADDRESS
(
ih1
,
iw1
));
float32_t
a3
=
ah1
*
aw1
;
r0
=
vmulq_n_f32
(
r0
,
a0
);
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
r0
=
vfmaq_n_f32
(
r0
,
r1
,
a1
);
r0
=
vfmaq_n_f32
(
r0
,
r2
,
a2
);
r0
=
vfmaq_n_f32
(
r0
,
r3
,
a3
);
#else
r0
=
vaddq_f32
(
r0
,
vmulq_n_f32
(
r1
,
a1
));
r0
=
vaddq_f32
(
r0
,
vmulq_n_f32
(
r2
,
a2
));
r0
=
vaddq_f32
(
r0
,
vmulq_n_f32
(
r3
,
a3
));
#endif
vst1q_f32
(
DST_ADDRESS
(
oh
,
ow
),
r0
);
#undef SRC_ADDRESS
#undef DST_ADDRESS
}
}
}
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template
<
typename
ctype
>
void
ResizeImpl
::
kern_nchw88_fp16
(
const
KernParam
<
ctype
>&
kern_param
)
{
UNPACK_RESIZE_FWD_KERN_PARAM
(
kern_param
);
float
scale_h
=
static_cast
<
float
>
(
OH
)
/
IH
;
float
scale_w
=
static_cast
<
float
>
(
OW
)
/
IW
;
const
float16_t
*
src_ptr
=
reinterpret_cast
<
float16_t
*>
(
sptr
);
float16_t
*
dst_ptr
=
reinterpret_cast
<
float16_t
*>
(
dptr
);
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
{
for
(
size_t
c
=
0
;
c
<
C
/
8
;
++
c
)
{
for
(
size_t
oh
=
0
;
oh
<
OH
;
++
oh
)
{
for
(
size_t
ow
=
0
;
ow
<
OW
;
++
ow
)
{
int
ih0
,
ih1
,
iw0
,
iw1
;
float
ah0
,
ah1
,
aw0
,
aw1
;
std
::
tie
(
ah0
,
ih0
,
ah1
,
ih1
)
=
get_nearest_linear_coord
(
kern_param
.
imode
,
scale_h
,
IH
,
oh
);
std
::
tie
(
aw0
,
iw0
,
aw1
,
iw1
)
=
get_nearest_linear_coord
(
kern_param
.
imode
,
scale_w
,
IW
,
ow
);
#define SRC_ADDRESS(ih, iw) \
(src_ptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 8)
#define DST_ADDRESS(oh, ow) \
(dst_ptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 8)
float16x8_t
r0
=
vld1q_f16
(
SRC_ADDRESS
(
ih0
,
iw0
));
float32_t
a0
=
ah0
*
aw0
;
float16x8_t
r1
=
vld1q_f16
(
SRC_ADDRESS
(
ih0
,
iw1
));
float32_t
a1
=
ah0
*
aw1
;
float16x8_t
r2
=
vld1q_f16
(
SRC_ADDRESS
(
ih1
,
iw0
));
float32_t
a2
=
ah1
*
aw0
;
float16x8_t
r3
=
vld1q_f16
(
SRC_ADDRESS
(
ih1
,
iw1
));
float32_t
a3
=
ah1
*
aw1
;
r0
=
vmulq_n_f16
(
r0
,
a0
);
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
r0
=
vfmaq_n_f16
(
r0
,
r1
,
a1
);
r0
=
vfmaq_n_f16
(
r0
,
r2
,
a2
);
r0
=
vfmaq_n_f16
(
r0
,
r3
,
a3
);
#else
r0
=
vaddq_f16
(
r0
,
vmulq_n_f16
(
r1
,
a1
));
r0
=
vaddq_f16
(
r0
,
vmulq_n_f16
(
r2
,
a2
));
r0
=
vaddq_f16
(
r0
,
vmulq_n_f16
(
r3
,
a3
));
#endif
vst1q_f16
(
DST_ADDRESS
(
oh
,
ow
),
r0
);
#undef SRC_ADDRESS
#undef DST_ADDRESS
}
else
if
(
is_dtype_fp16
)
{
auto
kern_param
=
KernParam
<
dt_float16
>::
from_tensors
(
param
().
format
,
param
().
imode
,
src
,
dst
,
workspace
);
if
(
is_nchw88_fp16
)
{
if
(
is_upsample2
)
{
if
(
is_imode_nearest
)
{
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
6
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_nearest_upsample2_nchw88_fp16
(
kern_param
));
}
MIDOUT_END
();
}
else
{
megdnn_assert
(
is_imode_linear
,
"invalid imode"
);
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
7
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_linear_upsample2_nchw88_fp16
(
kern_param
));
}
MIDOUT_END
();
}
}
else
{
if
(
is_imode_nearest
)
{
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
8
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_direct_nearest_nchw88_fp16
(
kern_param
));
}
MIDOUT_END
();
}
else
{
megdnn_assert
(
is_imode_linear
,
"invalid imode"
);
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
9
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_direct_linear_nchw88_fp16
(
kern_param
));
}
MIDOUT_END
();
}
}
}
}
}
#endif
void
ResizeImpl
::
kern_nearest_upsample2_pack_simd_width
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
)
{
const
uint8_t
*
src_ptr
=
reinterpret_cast
<
uint8_t
*>
(
src
.
raw_ptr
);
uint8_t
*
dst_ptr
=
reinterpret_cast
<
uint8_t
*>
(
dst
.
raw_ptr
);
size_t
S
=
2
;
size_t
N
=
src
.
layout
.
shape
[
0
];
size_t
IC
=
src
.
layout
.
shape
[
1
];
size_t
IH
=
src
.
layout
.
shape
[
2
];
size_t
IW
=
src
.
layout
.
shape
[
3
];
size_t
OH
=
dst
.
layout
.
shape
[
2
];
size_t
OW
=
dst
.
layout
.
shape
[
3
];
for
(
size_t
i
=
0
;
i
<
N
*
IC
;
++
i
)
{
for
(
size_t
ih
=
0
;
ih
<
IH
;
++
ih
)
{
for
(
size_t
iw
=
0
;
iw
<
IW
;
++
iw
)
{
size_t
oh
=
ih
*
S
;
size_t
ow
=
iw
*
S
;
uint8x16_t
r0
=
vld1q_u8
(
src_ptr
+
i
*
IH
*
IW
*
16
+
ih
*
IW
*
16
+
iw
*
16
);
for
(
size_t
fh
=
0
;
fh
<
S
;
++
fh
)
{
for
(
size_t
fw
=
0
;
fw
<
S
;
++
fw
)
{
vst1q_u8
(
dst_ptr
+
i
*
OH
*
OW
*
16
+
(
oh
+
fh
)
*
OW
*
16
+
(
ow
+
fw
)
*
16
,
r0
);
}
else
if
(
is_nchw
)
{
if
(
is_upsample2
)
{
if
(
is_imode_nearest
)
{
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
10
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_nearest_upsample2_nchw_fp16
(
kern_param
));
}
MIDOUT_END
();
}
else
{
megdnn_assert
(
is_imode_linear
,
"invalid imode"
);
MIDOUT_BEGIN
(
megdnn_arm_resize
,
midout_iv
(
11
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
resize_linear_upsample2_nchw_fp16
(
kern_param
));
}
MIDOUT_END
();
}
}
else
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
}
else
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
#endif
}
else
{
fallback
::
ResizeImpl
::
exec
(
src
,
dst
,
workspace
);
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/resize/opr_impl.h
浏览文件 @
bde5cf35
...
...
@@ -26,16 +26,6 @@ public:
const
TensorLayout
&
)
override
{
return
0
;
}
private:
template
<
typename
ctype
>
void
kern_nchw44_fp32
(
const
KernParam
<
ctype
>&
kern_param
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template
<
typename
ctype
>
void
kern_nchw88_fp16
(
const
KernParam
<
ctype
>&
kern_param
);
#endif
void
kern_nearest_upsample2_pack_simd_width
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
);
};
}
// namespace arm_common
...
...
dnn/src/arm_common/resize/upsample2_nchw.cpp
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/upsample2_nchw.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/resize/upsample2_nchw.h"
#include "src/arm_common/resize/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
namespace
resize
;
namespace
{
template
<
typename
ctype
,
size_t
fh
,
size_t
fw
>
static
inline
ctype
compute_linear_element
(
const
ctype
src
[
4
],
const
ctype
alpha
[
2
])
{
return
src
[
0
]
*
alpha
[
0
^
fh
]
*
alpha
[
0
^
fw
]
+
src
[
1
]
*
alpha
[
0
^
fh
]
*
alpha
[
1
^
fw
]
+
src
[
2
]
*
alpha
[
1
^
fh
]
*
alpha
[
0
^
fw
]
+
src
[
3
]
*
alpha
[
1
^
fh
]
*
alpha
[
1
^
fw
];
}
template
<
typename
simd_helper
,
size_t
fh
,
size_t
fw
>
static
inline
typename
simd_helper
::
simd_type
compute_linear_element_simd
(
const
typename
simd_helper
::
simd_type
src
[
4
],
const
typename
simd_helper
::
simd_type
alpha
[
2
][
2
])
{
typename
simd_helper
::
simd_type
c
=
simd_helper
::
dup
(
0
);
c
=
simd_helper
::
fma
(
c
,
src
[
0
],
alpha
[
0
^
fh
][
0
^
fw
]);
c
=
simd_helper
::
fma
(
c
,
src
[
1
],
alpha
[
0
^
fh
][
1
^
fw
]);
c
=
simd_helper
::
fma
(
c
,
src
[
2
],
alpha
[
1
^
fh
][
0
^
fw
]);
c
=
simd_helper
::
fma
(
c
,
src
[
3
],
alpha
[
1
^
fh
][
1
^
fw
]);
return
c
;
}
template
<
typename
ctype
,
bool
has_right
,
bool
has_bottom
>
static
inline
void
compute_linear_2x2_element
(
const
ctype
*
src
,
ctype
*
dst
,
size_t
IW
,
size_t
OW
,
const
ctype
alpha
[
2
])
{
const
ctype
*
src_ptr
[
4
]
=
{
src
,
src
,
src
,
src
};
if
(
has_right
)
{
src_ptr
[
1
]
+=
1
;
src_ptr
[
3
]
+=
1
;
}
if
(
has_bottom
)
{
src_ptr
[
2
]
+=
IW
;
src_ptr
[
3
]
+=
IW
;
}
ctype
rsrc
[
4
];
rsrc
[
0
]
=
*
src_ptr
[
0
];
rsrc
[
1
]
=
*
src_ptr
[
1
];
rsrc
[
2
]
=
*
src_ptr
[
2
];
rsrc
[
3
]
=
*
src_ptr
[
3
];
dst
[
0
]
=
compute_linear_element
<
ctype
,
0
,
0
>
(
rsrc
,
alpha
);
if
(
has_right
)
{
dst
[
1
]
=
compute_linear_element
<
ctype
,
0
,
1
>
(
rsrc
,
alpha
);
}
if
(
has_bottom
)
{
dst
[
OW
]
=
compute_linear_element
<
ctype
,
1
,
0
>
(
rsrc
,
alpha
);
}
if
(
has_right
&&
has_bottom
)
{
dst
[
OW
+
1
]
=
compute_linear_element
<
ctype
,
1
,
1
>
(
rsrc
,
alpha
);
}
}
template
<
typename
simd_helper
>
static
inline
void
compute_linear_2x2_element_simd
(
const
typename
simd_helper
::
ctype
*
src
,
typename
simd_helper
::
ctype
*
dst
,
size_t
IW
,
size_t
OW
,
const
typename
simd_helper
::
simd_type
alpha
[
2
][
2
])
{
using
simd_type
=
typename
simd_helper
::
simd_type
;
simd_type
rsrc
[
4
];
rsrc
[
0
]
=
simd_helper
::
load
(
src
);
rsrc
[
1
]
=
simd_helper
::
load
(
src
+
1
);
rsrc
[
2
]
=
simd_helper
::
load
(
src
+
IW
);
rsrc
[
3
]
=
simd_helper
::
load
(
src
+
IW
+
1
);
simd_type
rdst
[
4
];
rdst
[
0
]
=
compute_linear_element_simd
<
simd_helper
,
0
,
0
>
(
rsrc
,
alpha
);
rdst
[
1
]
=
compute_linear_element_simd
<
simd_helper
,
0
,
1
>
(
rsrc
,
alpha
);
rdst
[
2
]
=
compute_linear_element_simd
<
simd_helper
,
1
,
0
>
(
rsrc
,
alpha
);
rdst
[
3
]
=
compute_linear_element_simd
<
simd_helper
,
1
,
1
>
(
rsrc
,
alpha
);
simd_helper
::
store2_interleave
(
dst
,
rdst
[
0
],
rdst
[
1
]);
simd_helper
::
store2_interleave
(
dst
+
OW
,
rdst
[
2
],
rdst
[
3
]);
}
template
<
typename
ctype
>
void
linear_upsample2_nchw
(
const
ctype
*
src_ptr
,
ctype
*
dst_ptr
,
size_t
N
,
size_t
IH
,
size_t
IW
)
{
using
simd_helper
=
SIMDHelper
<
ctype
>
;
size_t
OW
=
IW
*
2
;
constexpr
size_t
PC
=
simd_helper
::
simd_width
;
ctype
alpha
[
2
]
=
{
0.75
,
0.25
};
typename
simd_helper
::
simd_type
simd_alpha
[
2
][
2
];
simd_alpha
[
0
][
0
]
=
simd_helper
::
dup
(
0.75
*
0.75
);
simd_alpha
[
0
][
1
]
=
simd_helper
::
dup
(
0.75
*
0.25
);
simd_alpha
[
1
][
0
]
=
simd_helper
::
dup
(
0.25
*
0.75
);
simd_alpha
[
1
][
1
]
=
simd_helper
::
dup
(
0.25
*
0.25
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_linear_2x2_element
<
ctype
,
false
,
false
>
(
src_ptr
,
dst_ptr
,
IW
,
OW
,
alpha
);
{
for
(
size_t
iw
=
0
;
iw
+
1
<
IW
;
++
iw
)
{
compute_linear_2x2_element
<
ctype
,
true
,
false
>
(
src_ptr
+
iw
,
dst_ptr
+
(
iw
*
2
+
1
),
IW
,
OW
,
alpha
);
}
}
compute_linear_2x2_element
<
ctype
,
false
,
false
>
(
src_ptr
+
(
IW
-
1
),
dst_ptr
+
(
OW
-
1
),
IW
,
OW
,
alpha
);
dst_ptr
+=
OW
;
for
(
size_t
ih
=
0
;
ih
+
1
<
IH
;
++
ih
)
{
compute_linear_2x2_element
<
ctype
,
false
,
true
>
(
src_ptr
,
dst_ptr
,
IW
,
OW
,
alpha
);
size_t
iw
=
0
;
for
(;
iw
+
PC
<
IW
;
iw
+=
PC
)
{
compute_linear_2x2_element_simd
<
simd_helper
>
(
src_ptr
+
iw
,
dst_ptr
+
(
iw
*
2
+
1
),
IW
,
OW
,
simd_alpha
);
}
for
(;
iw
+
1
<
IW
;
++
iw
)
{
compute_linear_2x2_element
<
ctype
,
true
,
true
>
(
src_ptr
+
iw
,
dst_ptr
+
(
iw
*
2
+
1
),
IW
,
OW
,
alpha
);
}
compute_linear_2x2_element
<
ctype
,
false
,
true
>
(
src_ptr
+
(
IW
-
1
),
dst_ptr
+
(
OW
-
1
),
IW
,
OW
,
alpha
);
src_ptr
+=
IW
;
dst_ptr
+=
2
*
OW
;
}
compute_linear_2x2_element
<
ctype
,
false
,
false
>
(
src_ptr
,
dst_ptr
,
IW
,
OW
,
alpha
);
{
for
(
size_t
iw
=
0
;
iw
+
1
<
IW
;
++
iw
)
{
compute_linear_2x2_element
<
ctype
,
true
,
false
>
(
src_ptr
+
iw
,
dst_ptr
+
(
iw
*
2
+
1
),
IW
,
OW
,
alpha
);
}
}
compute_linear_2x2_element
<
ctype
,
false
,
false
>
(
src_ptr
+
(
IW
-
1
),
dst_ptr
+
(
OW
-
1
),
IW
,
OW
,
alpha
);
src_ptr
+=
IW
;
dst_ptr
+=
OW
;
}
}
template
<
typename
ctype
>
void
nearest_upsample2_nchw
(
const
ctype
*
src_ptr
,
ctype
*
dst_ptr
,
size_t
N
,
size_t
IH
,
size_t
IW
)
{
using
simd_helper
=
SIMDHelper
<
ctype
>
;
size_t
OW
=
IW
*
2
;
constexpr
size_t
PC
=
simd_helper
::
simd_width
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
ih
=
0
;
ih
<
IH
;
++
ih
)
{
size_t
iw
=
0
;
for
(;
iw
+
PC
-
1
<
IW
;
iw
+=
PC
)
{
typename
simd_helper
::
simd_type
r0
=
simd_helper
::
load
(
src_ptr
+
iw
);
simd_helper
::
store2_interleave
(
dst_ptr
+
(
iw
*
2
),
r0
,
r0
);
simd_helper
::
store2_interleave
(
dst_ptr
+
(
OW
+
iw
*
2
),
r0
,
r0
);
}
for
(;
iw
<
IW
;
iw
+=
1
)
{
ctype
v
=
src_ptr
[
iw
];
dst_ptr
[
iw
*
2
]
=
v
;
dst_ptr
[
iw
*
2
+
1
]
=
v
;
dst_ptr
[
OW
+
iw
*
2
]
=
v
;
dst_ptr
[
OW
+
iw
*
2
+
1
]
=
v
;
}
src_ptr
+=
IW
;
dst_ptr
+=
2
*
OW
;
}
}
}
}
// namespace
void
megdnn
::
arm_common
::
resize_linear_upsample2_nchw_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
)
{
linear_upsample2_nchw
(
kern_param
.
sptr
,
kern_param
.
dptr
,
kern_param
.
n
*
kern_param
.
c
,
kern_param
.
ih
,
kern_param
.
iw
);
}
void
megdnn
::
arm_common
::
resize_nearest_upsample2_nchw_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
)
{
nearest_upsample2_nchw
(
kern_param
.
sptr
,
kern_param
.
dptr
,
kern_param
.
n
*
kern_param
.
c
,
kern_param
.
ih
,
kern_param
.
iw
);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void
megdnn
::
arm_common
::
resize_linear_upsample2_nchw_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
)
{
auto
sptr
=
reinterpret_cast
<
const
__fp16
*>
(
kern_param
.
sptr
);
auto
dptr
=
reinterpret_cast
<
__fp16
*>
(
kern_param
.
dptr
);
linear_upsample2_nchw
(
sptr
,
dptr
,
kern_param
.
n
*
kern_param
.
c
,
kern_param
.
ih
,
kern_param
.
iw
);
}
void
megdnn
::
arm_common
::
resize_nearest_upsample2_nchw_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
)
{
auto
sptr
=
reinterpret_cast
<
const
__fp16
*>
(
kern_param
.
sptr
);
auto
dptr
=
reinterpret_cast
<
__fp16
*>
(
kern_param
.
dptr
);
nearest_upsample2_nchw
(
sptr
,
dptr
,
kern_param
.
n
*
kern_param
.
c
,
kern_param
.
ih
,
kern_param
.
iw
);
}
#endif
dnn/src/arm_common/resize/upsample2_nchw.h
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/upsample2_nchw.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/resize/opr_impl.h"
namespace
megdnn
{
namespace
arm_common
{
void
resize_linear_upsample2_nchw_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
);
void
resize_nearest_upsample2_nchw_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void
resize_linear_upsample2_nchw_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
);
void
resize_nearest_upsample2_nchw_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
);
#endif
}
// namespace arm_common
}
// namespace megdnn
dnn/src/arm_common/resize/upsample2_nchwxx.cpp
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/upsample2_nchwxx.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/resize/upsample2_nchwxx.h"
#include "src/arm_common/resize/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
namespace
resize
;
namespace
{
template
<
typename
simd_helper
,
size_t
fh
,
size_t
fw
>
static
inline
typename
simd_helper
::
simd_type
compute_linear_element
(
const
typename
simd_helper
::
simd_type
src
[
4
],
const
typename
simd_helper
::
simd_type
alpha
[
2
][
2
])
{
typename
simd_helper
::
simd_type
c
=
simd_helper
::
dup
(
0
);
c
=
simd_helper
::
fma
(
c
,
src
[
0
],
alpha
[
0
^
fh
][
0
^
fw
]);
c
=
simd_helper
::
fma
(
c
,
src
[
1
],
alpha
[
0
^
fh
][
1
^
fw
]);
c
=
simd_helper
::
fma
(
c
,
src
[
2
],
alpha
[
1
^
fh
][
0
^
fw
]);
c
=
simd_helper
::
fma
(
c
,
src
[
3
],
alpha
[
1
^
fh
][
1
^
fw
]);
return
c
;
}
template
<
typename
simd_helper
,
bool
has_right
,
bool
has_bottom
>
static
inline
void
compute_linear_2x2_element
(
const
typename
simd_helper
::
ctype
*
src
,
typename
simd_helper
::
ctype
*
dst
,
size_t
IW
,
size_t
OW
,
const
typename
simd_helper
::
simd_type
alpha
[
2
][
2
])
{
constexpr
size_t
PC
=
simd_helper
::
simd_width
;
const
typename
simd_helper
::
ctype
*
src_ptr
[
4
]
=
{
src
,
src
,
src
,
src
};
if
(
has_right
)
{
src_ptr
[
1
]
+=
PC
;
src_ptr
[
3
]
+=
PC
;
}
if
(
has_bottom
)
{
src_ptr
[
2
]
+=
IW
*
PC
;
src_ptr
[
3
]
+=
IW
*
PC
;
}
typename
simd_helper
::
simd_type
rsrc
[
4
];
rsrc
[
0
]
=
simd_helper
::
load
(
src_ptr
[
0
]);
rsrc
[
1
]
=
simd_helper
::
load
(
src_ptr
[
1
]);
rsrc
[
2
]
=
simd_helper
::
load
(
src_ptr
[
2
]);
rsrc
[
3
]
=
simd_helper
::
load
(
src_ptr
[
3
]);
typename
simd_helper
::
simd_type
rdst
[
4
];
rdst
[
0
]
=
compute_linear_element
<
simd_helper
,
0
,
0
>
(
rsrc
,
alpha
);
rdst
[
1
]
=
compute_linear_element
<
simd_helper
,
0
,
1
>
(
rsrc
,
alpha
);
rdst
[
2
]
=
compute_linear_element
<
simd_helper
,
1
,
0
>
(
rsrc
,
alpha
);
rdst
[
3
]
=
compute_linear_element
<
simd_helper
,
1
,
1
>
(
rsrc
,
alpha
);
simd_helper
::
store
(
dst
,
rdst
[
0
]);
if
(
has_right
)
{
simd_helper
::
store
(
dst
+
PC
,
rdst
[
1
]);
}
if
(
has_bottom
)
{
simd_helper
::
store
(
dst
+
OW
*
PC
,
rdst
[
2
]);
}
if
(
has_right
&&
has_bottom
)
{
simd_helper
::
store
(
dst
+
(
OW
+
1
)
*
PC
,
rdst
[
3
]);
}
}
template
<
typename
ctype
>
void
linear_upsample2_nchwxx
(
const
ctype
*
src_ptr
,
ctype
*
dst_ptr
,
size_t
N
,
size_t
IH
,
size_t
IW
)
{
using
simd_helper
=
SIMDHelper
<
ctype
>
;
size_t
OW
=
IW
*
2
;
constexpr
size_t
PC
=
simd_helper
::
simd_width
;
typename
simd_helper
::
simd_type
alpha
[
2
][
2
];
alpha
[
0
][
0
]
=
simd_helper
::
dup
(
0.75
*
0.75
);
alpha
[
0
][
1
]
=
simd_helper
::
dup
(
0.75
*
0.25
);
alpha
[
1
][
0
]
=
simd_helper
::
dup
(
0.25
*
0.75
);
alpha
[
1
][
1
]
=
simd_helper
::
dup
(
0.25
*
0.25
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_linear_2x2_element
<
simd_helper
,
false
,
false
>
(
src_ptr
,
dst_ptr
,
IW
,
OW
,
alpha
);
{
for
(
size_t
iw
=
0
;
iw
+
1
<
IW
;
++
iw
)
{
compute_linear_2x2_element
<
simd_helper
,
true
,
false
>
(
src_ptr
+
iw
*
PC
,
dst_ptr
+
(
iw
*
2
+
1
)
*
PC
,
IW
,
OW
,
alpha
);
}
}
compute_linear_2x2_element
<
simd_helper
,
false
,
false
>
(
src_ptr
+
(
IW
-
1
)
*
PC
,
dst_ptr
+
(
OW
-
1
)
*
PC
,
IW
,
OW
,
alpha
);
dst_ptr
+=
OW
*
PC
;
for
(
size_t
ih
=
0
;
ih
+
1
<
IH
;
++
ih
)
{
compute_linear_2x2_element
<
simd_helper
,
false
,
true
>
(
src_ptr
,
dst_ptr
,
IW
,
OW
,
alpha
);
for
(
size_t
iw
=
0
;
iw
+
1
<
IW
;
++
iw
)
{
compute_linear_2x2_element
<
simd_helper
,
true
,
true
>
(
src_ptr
+
iw
*
PC
,
dst_ptr
+
(
iw
*
2
+
1
)
*
PC
,
IW
,
OW
,
alpha
);
}
compute_linear_2x2_element
<
simd_helper
,
false
,
true
>
(
src_ptr
+
(
IW
-
1
)
*
PC
,
dst_ptr
+
(
OW
-
1
)
*
PC
,
IW
,
OW
,
alpha
);
src_ptr
+=
IW
*
PC
;
dst_ptr
+=
2
*
OW
*
PC
;
}
compute_linear_2x2_element
<
simd_helper
,
false
,
false
>
(
src_ptr
,
dst_ptr
,
IW
,
OW
,
alpha
);
{
for
(
size_t
iw
=
0
;
iw
+
1
<
IW
;
++
iw
)
{
compute_linear_2x2_element
<
simd_helper
,
true
,
false
>
(
src_ptr
+
iw
*
PC
,
dst_ptr
+
(
iw
*
2
+
1
)
*
PC
,
IW
,
OW
,
alpha
);
}
}
compute_linear_2x2_element
<
simd_helper
,
false
,
false
>
(
src_ptr
+
(
IW
-
1
)
*
PC
,
dst_ptr
+
(
OW
-
1
)
*
PC
,
IW
,
OW
,
alpha
);
src_ptr
+=
IW
*
PC
;
dst_ptr
+=
OW
*
PC
;
}
}
template
<
typename
ctype
>
void
nearest_upsample2_nchwxx
(
const
ctype
*
src_ptr
,
ctype
*
dst_ptr
,
size_t
N
,
size_t
IH
,
size_t
IW
)
{
using
simd_helper
=
SIMDHelper
<
ctype
>
;
size_t
OW
=
IW
*
2
;
constexpr
size_t
PC
=
simd_helper
::
simd_width
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
ih
=
0
;
ih
<
IH
;
++
ih
)
{
for
(
size_t
iw
=
0
;
iw
<
IW
;
++
iw
)
{
typename
simd_helper
::
simd_type
r0
=
simd_helper
::
load
(
src_ptr
+
iw
*
PC
);
simd_helper
::
store
(
dst_ptr
+
(
iw
*
2
)
*
PC
,
r0
);
simd_helper
::
store
(
dst_ptr
+
(
iw
*
2
+
1
)
*
PC
,
r0
);
simd_helper
::
store
(
dst_ptr
+
(
OW
+
iw
*
2
)
*
PC
,
r0
);
simd_helper
::
store
(
dst_ptr
+
(
OW
+
iw
*
2
+
1
)
*
PC
,
r0
);
}
src_ptr
+=
IW
*
PC
;
dst_ptr
+=
2
*
OW
*
PC
;
}
}
}
}
// namespace
void
megdnn
::
arm_common
::
resize_linear_upsample2_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
)
{
linear_upsample2_nchwxx
(
kern_param
.
sptr
,
kern_param
.
dptr
,
kern_param
.
n
*
kern_param
.
c
/
4
,
kern_param
.
ih
,
kern_param
.
iw
);
}
void
megdnn
::
arm_common
::
resize_nearest_upsample2_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
)
{
nearest_upsample2_nchwxx
(
kern_param
.
sptr
,
kern_param
.
dptr
,
kern_param
.
n
*
kern_param
.
c
/
4
,
kern_param
.
ih
,
kern_param
.
iw
);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void
megdnn
::
arm_common
::
resize_linear_upsample2_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
)
{
auto
sptr
=
reinterpret_cast
<
const
__fp16
*>
(
kern_param
.
sptr
);
auto
dptr
=
reinterpret_cast
<
__fp16
*>
(
kern_param
.
dptr
);
linear_upsample2_nchwxx
(
sptr
,
dptr
,
kern_param
.
n
*
kern_param
.
c
/
8
,
kern_param
.
ih
,
kern_param
.
iw
);
}
void
megdnn
::
arm_common
::
resize_nearest_upsample2_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
)
{
auto
sptr
=
reinterpret_cast
<
const
__fp16
*>
(
kern_param
.
sptr
);
auto
dptr
=
reinterpret_cast
<
__fp16
*>
(
kern_param
.
dptr
);
nearest_upsample2_nchwxx
(
sptr
,
dptr
,
kern_param
.
n
*
kern_param
.
c
/
8
,
kern_param
.
ih
,
kern_param
.
iw
);
}
#endif
dnn/src/arm_common/resize/upsample2_nchwxx.h
0 → 100644
浏览文件 @
bde5cf35
/**
* \file dnn/src/arm_common/resize/upsample2_nchwxx.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/resize/opr_impl.h"
namespace
megdnn
{
namespace
arm_common
{
void
resize_linear_upsample2_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
);
void
resize_nearest_upsample2_nchw44_fp32
(
const
ResizeImpl
::
KernParam
<
float
>&
kern_param
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void
resize_linear_upsample2_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
);
void
resize_nearest_upsample2_nchw88_fp16
(
const
ResizeImpl
::
KernParam
<
dt_float16
>&
kern_param
);
#endif
}
// namespace arm_common
}
// namespace megdnn
dnn/test/arm_common/resize.cpp
浏览文件 @
bde5cf35
...
...
@@ -16,8 +16,25 @@
namespace
megdnn
{
namespace
test
{
using
namespace
resize
;
static
void
set_nchw_args
(
IMode
imode
,
std
::
vector
<
TestArg
>&
args
)
{
param
::
Resize
param
;
param
.
format
=
param
::
Resize
::
Format
::
NCHW
;
param
.
imode
=
imode
;
rep
(
n
,
4ul
)
rep
(
c
,
4ul
)
rep
(
ih
,
4ul
)
rep
(
iw
,
4ul
)
rep
(
oh
,
4ul
)
rep
(
ow
,
4ul
)
args
.
emplace_back
(
param
,
TensorShape
{
n
+
1ul
,
c
+
1ul
,
ih
+
1ul
,
iw
+
1ul
},
TensorShape
{
n
+
1ul
,
c
+
1ul
,
oh
+
1ul
,
ow
+
1ul
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
1
,
10
,
10
},
TensorShape
{
1
,
1
,
20
,
20
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
1
,
10
,
10
},
TensorShape
{
1
,
1
,
7
,
9
});
args
.
emplace_back
(
param
,
TensorShape
{
2
,
2
,
3
,
4
},
TensorShape
{
2
,
2
,
6
,
8
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
2
,
6
,
8
},
TensorShape
{
1
,
2
,
3
,
4
});
}
TEST_F
(
ARM_COMMON
,
RESIZE_CV
)
{
using
namespace
resize
;
std
::
vector
<
TestArg
>
args
=
get_cv_args
();
Checker
<
Resize
>
checker
(
handle
());
...
...
@@ -37,8 +54,38 @@ TEST_F(ARM_COMMON, RESIZE_CV) {
}
}
TEST_F
(
ARM_COMMON
,
RESIZE_NCHW44
)
{
using
namespace
resize
;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON
,
RESIZE_NCHW_FP16
)
{
std
::
vector
<
TestArg
>
args
;
set_nchw_args
(
IMode
::
INTER_LINEAR
,
args
);
set_nchw_args
(
IMode
::
INTER_NEAREST
,
args
);
Checker
<
Resize
>
checker
(
handle
());
for
(
auto
&&
arg
:
args
)
{
checker
.
set_param
(
arg
.
param
)
.
set_epsilon
(
0.01
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
#endif
TEST_F
(
ARM_COMMON
,
RESIZE_NCHW_FP32
)
{
std
::
vector
<
TestArg
>
args
;
set_nchw_args
(
IMode
::
INTER_LINEAR
,
args
);
set_nchw_args
(
IMode
::
INTER_NEAREST
,
args
);
Checker
<
Resize
>
checker
(
handle
());
for
(
auto
&&
arg
:
args
)
{
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
TEST_F
(
ARM_COMMON
,
RESIZE_NCHW44_FP32
)
{
std
::
vector
<
TestArg
>
args
=
get_nchw44_args
();
Checker
<
Resize
>
checker
(
handle
());
...
...
@@ -50,8 +97,8 @@ TEST_F(ARM_COMMON, RESIZE_NCHW44) {
}
}
TEST_F
(
ARM_COMMON
,
RESIZE_NCHW88
)
{
using
namespace
resize
;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON
,
RESIZE_NCHW88_FP16
)
{
std
::
vector
<
TestArg
>
args
=
get_nchw88_args
();
Checker
<
Resize
>
checker
(
handle
());
...
...
@@ -63,6 +110,7 @@ TEST_F(ARM_COMMON, RESIZE_NCHW88) {
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
#endif
}
// namespace test
}
// namespace megdnn
...
...
dnn/test/cuda/resize.cpp
浏览文件 @
bde5cf35
...
...
@@ -52,6 +52,7 @@ TEST_F(CUDA, RESIZE_FORWARD) {
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Uint8
())
.
set_dtype
(
1
,
dtype
::
Uint8
())
.
set_epsilon
(
1
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
...
...
@@ -67,7 +68,7 @@ TEST_F(CUDA, RESIZE_FORWARD) {
checker
.
set_param
(
arg
.
param
)
.
set_dtype
(
0
,
dtype
::
Int8
())
.
set_dtype
(
1
,
dtype
::
Int8
())
.
set_epsilon
(
1
e-3
)
.
set_epsilon
(
1
)
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录