Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
894a2407
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
894a2407
编写于
6月 04, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add relayout format kernel for nchw <-> nhwc
GitOrigin-RevId: e11f3e54085929ab9919fe2070fcc9e633755b69
上级
43c59204
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
660 addition
and
63 deletion
+660
-63
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+3
-1
dnn/src/common/relayout_format.cpp
dnn/src/common/relayout_format.cpp
+28
-0
dnn/src/cuda/pooling/pooling2d_qint.cu
dnn/src/cuda/pooling/pooling2d_qint.cu
+2
-2
dnn/src/cuda/relayout_format/opr_impl.cpp
dnn/src/cuda/relayout_format/opr_impl.cpp
+6
-2
dnn/src/cuda/relayout_format/relayout_format.cpp
dnn/src/cuda/relayout_format/relayout_format.cpp
+16
-0
dnn/src/cuda/relayout_format/relayout_format_kern.cuh
dnn/src/cuda/relayout_format/relayout_format_kern.cuh
+228
-8
dnn/src/cuda/relayout_format/relayout_format_nchw_nhwc.cu
dnn/src/cuda/relayout_format/relayout_format_nchw_nhwc.cu
+211
-0
dnn/src/cuda/relayout_format/translayout.cuh
dnn/src/cuda/relayout_format/translayout.cuh
+40
-35
dnn/src/cuda/warp_perspective/forward.cu
dnn/src/cuda/warp_perspective/forward.cu
+15
-15
dnn/test/cuda/relayout_format.cpp
dnn/test/cuda/relayout_format.cpp
+111
-0
未找到文件。
dnn/scripts/opr_param_defs.py
浏览文件 @
894a2407
...
...
@@ -1001,7 +1001,9 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'NCHW_NCHW4_WEIGHT'
,
'NCHW_NCHW64'
,
'NCHW64_NCHW'
,
)
'NCHW_NHWC'
,
'NHWC_NCHW'
,
)
)
(
pdef
(
'RelayoutFormat'
,
'Change the tensor layout format'
,
version
=
1
).
...
...
dnn/src/common/relayout_format.cpp
浏览文件 @
894a2407
...
...
@@ -268,6 +268,22 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
dst
[
2
]
=
src
[
2
];
dst
[
3
]
=
src
[
3
];
break
;
case
Param
::
Mode
::
NCHW_NHWC
:
megdnn_assert
(
src
.
ndim
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
src
[
0
];
dst
[
1
]
=
src
[
2
];
dst
[
2
]
=
src
[
3
];
dst
[
3
]
=
src
[
1
];
break
;
case
Param
::
Mode
::
NHWC_NCHW
:
megdnn_assert
(
src
.
ndim
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
src
[
0
];
dst
[
1
]
=
src
[
3
];
dst
[
2
]
=
src
[
1
];
dst
[
3
]
=
src
[
2
];
break
;
default:
megdnn_assert
(
0
,
"Invalid RelayoutFormat Mode"
);
break
;
...
...
@@ -375,6 +391,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case
Param
::
Mode
::
NCHW64_NCHW
:
dst
=
src
;
break
;
case
Param
::
Mode
::
NCHW_NHWC
:
case
Param
::
Mode
::
NHWC_NCHW
:
dst
=
src
;
break
;
default:
megdnn_throw
(
"Invalid relayout format mode"
);
break
;
...
...
@@ -666,6 +686,14 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
exec_src
=
src
.
dimshuffle
({
0
,
1
,
4
,
2
,
3
});
exec_dst
=
dst
;
break
;
case
Param
::
Mode
::
NCHW_NHWC
:
exec_src
=
src
.
dimshuffle
({
0
,
2
,
3
,
1
});
exec_dst
=
dst
;
break
;
case
Param
::
Mode
::
NHWC_NCHW
:
exec_src
=
src
.
dimshuffle
({
0
,
3
,
1
,
2
});
exec_dst
=
dst
;
break
;
default:
megdnn_assert
(
0
,
"Invalid RelayoutFormat Mode"
);
}
...
...
dnn/src/cuda/pooling/pooling2d_qint.cu
浏览文件 @
894a2407
...
...
@@ -505,7 +505,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src,
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int8_ncdiv4hw4
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
,
bool
uint_case
,
int
zero_point
)
{
cudaStream_t
stream
,
uint32_t
mode
,
bool
/* uint_case */
,
int
zero_point
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
,
int
zero_point
);
...
...
@@ -545,7 +545,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int8_ncdiv32hw32
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
,
bool
uint_case
,
int
zero_point
)
{
cudaStream_t
stream
,
uint32_t
mode
,
bool
/* uint_case */
,
int
zero_point
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
,
int
zero_point
);
...
...
dnn/src/cuda/relayout_format/opr_impl.cpp
浏览文件 @
894a2407
...
...
@@ -33,7 +33,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
Param
::
Mode
::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT
||
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW64
||
param
().
mode
==
Param
::
Mode
::
NCHW64_NCHW
,
param
().
mode
==
Param
::
Mode
::
NCHW64_NCHW
||
param
().
mode
==
Param
::
Mode
::
NCHW_NHWC
||
param
().
mode
==
Param
::
Mode
::
NHWC_NCHW
,
"relayout format of cuda only support NCHW4->CHWN4 or "
"CHWN4->NCHW4 or NCHW->NCHW4"
);
if
((
param
().
mode
==
param
::
RelayoutFormat
::
Mode
::
NCHW4_CHWN4
||
...
...
@@ -82,7 +84,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
{
src
.
raw_ptr
,
exec_src_layout
},
{
dst
.
raw_ptr
,
exec_dst_layout
});
}
bool
is_trans_4bits
=
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW64
||
param
().
mode
==
Param
::
Mode
::
NCHW64_NCHW
)
&&
param
().
mode
==
Param
::
Mode
::
NCHW64_NCHW
||
param
().
mode
==
Param
::
Mode
::
NCHW_NHWC
||
param
().
mode
==
Param
::
Mode
::
NHWC_NCHW
)
&&
(
src_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src_dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
);
bool
is_nchw_nchw4
=
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4
||
...
...
dnn/src/cuda/relayout_format/relayout_format.cpp
浏览文件 @
894a2407
...
...
@@ -66,6 +66,22 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
return
relayout_format_cuda_nchwx_nchw
(
src
,
dst
,
stream
,
src_scale
,
dst_scale
,
src_zero_point
,
dst_zero_point
);
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW_NHWC
)
{
#define CHECK(dt) \
megdnn_assert(dt.enumv() == DTypeEnum::Quantized4Asymm || \
dt.enumv() == DTypeEnum::QuantizedS4)
CHECK
(
src
.
layout
.
dtype
);
CHECK
(
dst
.
layout
.
dtype
);
return
relayout_format_cuda_nchw_nhwc
(
src
,
dst
,
stream
,
src_scale
,
dst_scale
,
src_zero_point
,
dst_zero_point
);
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NHWC_NCHW
)
{
CHECK
(
src
.
layout
.
dtype
);
CHECK
(
dst
.
layout
.
dtype
);
return
relayout_format_cuda_nhwc_nchw
(
src
,
dst
,
stream
,
src_scale
,
dst_scale
,
src_zero_point
,
dst_zero_point
);
#undef CHECK
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4_WEIGHT
)
{
return
relayout_format_cuda_nchw_nchw4_weight
(
src
,
dst
,
stream
);
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW4_NCHW
)
{
...
...
dnn/src/cuda/relayout_format/relayout_format_kern.cuh
浏览文件 @
894a2407
...
...
@@ -20,8 +20,17 @@ namespace relayout_format {
namespace
internal
{
using
namespace
memory
;
struct
LayoutType
{
static
constexpr
uint32_t
NCHWx
=
0
;
static
constexpr
uint32_t
NHWC
=
1
;
};
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
int
size_nbits_
,
uint32_t
layout_type_
=
LayoutType
::
NCHWx
>
class
TensorIteratorOverChannel
;
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
,
uint32_t
layout_type_
>
class
TensorIteratorOverChannel
{
public:
using
Type
=
Type_
;
...
...
@@ -116,6 +125,98 @@ private:
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
class
TensorIteratorOverChannel
<
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
,
LayoutType
::
NHWC
>
{
public:
using
Type
=
Type_
;
static
constexpr
int
pack_size
=
pack_size_
;
static
constexpr
int
chan_blk
=
chan_blk_
;
static
constexpr
int
width
=
width_
;
static
constexpr
int
size_nbits
=
size_nbits_
;
static
constexpr
int
elements_in_type
=
chan_blk
*
width
*
size_nbits
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
pack_size_in_type
=
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
pack_size_in_byte
=
pack_size_in_type
*
sizeof
(
Type
);
using
AccessType
=
array_wrapper
<
Type
,
pack_size_in_type
>
;
using
Fragment
=
array_wrapper
<
Type
,
elements_in_type
>
;
MEGDNN_HOST
TensorIteratorOverChannel
()
:
pointer
{
nullptr
},
hw_stride_in_elements
{
0
},
channel
{
0
}
{}
MEGDNN_HOST
TensorIteratorOverChannel
(
Type
*
pointer_
,
int
hw_stride_in_elements_
,
int
channel_
,
int
,
int
)
:
pointer
{
pointer_
},
hw_stride_in_elements
{
hw_stride_in_elements_
},
channel
{
channel_
}
{}
MEGDNN_DEVICE
__forceinline__
void
initialize
(
int
c_idx
,
int
hw_idx
)
{
pointer
+=
c_idx
*
size_nbits
/
(
8
*
sizeof
(
Type
))
+
hw_idx
*
hw_stride_in_elements
;
channel
-=
c_idx
;
}
MEGDNN_DEVICE
__forceinline__
void
add_pointer_offset
(
size_t
offset_in_type
)
{
pointer
+=
offset_in_type
;
}
MEGDNN_DEVICE
__forceinline__
void
load
(
Fragment
&
frag
,
int
zero_point
)
{
AccessType
*
frag_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
frag
);
Type
*
pointer_
=
pointer
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
chan_blk
;
j
+=
pack_size
)
{
int
frag_idx
=
i
*
(
chan_blk
/
pack_size
)
+
(
j
/
pack_size
);
bool
guard
=
j
<
channel
;
global_load
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
j
*
size_nbits
/
(
8
*
sizeof
(
Type
))),
guard
,
zero_point
);
}
pointer_
+=
hw_stride_in_elements
;
}
}
MEGDNN_DEVICE
__forceinline__
void
store
(
const
Fragment
&
frag
)
{
const
AccessType
*
frag_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
&
frag
);
Type
*
pointer_
=
pointer
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
chan_blk
;
j
+=
pack_size
)
{
int
frag_idx
=
i
*
(
chan_blk
/
pack_size
)
+
(
j
/
pack_size
);
bool
guard
=
j
<
channel
;
global_store
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
j
*
size_nbits
/
(
8
*
sizeof
(
Type
))),
guard
);
}
pointer_
+=
hw_stride_in_elements
;
}
}
MEGDNN_DEVICE
__forceinline__
void
advance
()
{
pointer
+=
chan_blk
*
size_nbits
/
(
8
*
sizeof
(
Type
));
channel
-=
chan_blk
;
}
private:
Type
*
pointer
;
int
hw_stride_in_elements
;
int
channel
;
};
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
,
uint32_t
layout_type_
=
LayoutType
::
NCHWx
>
class
MaskedTensorIteratorOverChannel
;
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
,
uint32_t
layout_type_
>
class
MaskedTensorIteratorOverChannel
{
public:
using
Type
=
Type_
;
...
...
@@ -243,24 +344,143 @@ private:
size_t
stride
[
lane_size_in_type
/
pack_size_in_type
];
};
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
class
MaskedTensorIteratorOverChannel
<
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
,
LayoutType
::
NHWC
>
{
public:
using
Type
=
Type_
;
static
constexpr
int
pack_size
=
pack_size_
;
static
constexpr
int
chan_blk
=
chan_blk_
;
static
constexpr
int
width
=
width_
;
static
constexpr
int
size_nbits
=
size_nbits_
;
static
constexpr
int
elements_in_type
=
chan_blk
*
width
*
size_nbits
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
lane_size_in_type
=
(
width
*
pack_size
*
size_nbits
)
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
pack_size_in_type
=
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
pack_size_in_byte
=
pack_size_in_type
*
sizeof
(
Type
);
static
constexpr
int
accesses
=
elements_in_type
/
pack_size_in_type
;
static
constexpr
int
mask_size
=
(
accesses
+
32
-
1
)
/
32
;
using
AccessType
=
array_wrapper
<
Type
,
pack_size_in_type
>
;
using
Fragment
=
array_wrapper
<
Type
,
elements_in_type
>
;
MEGDNN_HOST
MaskedTensorIteratorOverChannel
()
:
pointer
{
nullptr
},
hw_stride_in_elements
{
0
},
channel
{
0
}
{}
MEGDNN_HOST
MaskedTensorIteratorOverChannel
(
Type
*
pointer_
,
int
hw_stride_in_elements_
,
int
channel_
,
int
bound_
,
int
div_
)
:
pointer
{
pointer_
},
hw_stride_in_elements
{
hw_stride_in_elements_
},
channel
{
channel_
},
bound
{
bound_
},
div
{
uint32_t
(
div_
)}
{}
MEGDNN_DEVICE
__forceinline__
void
initialize
(
int
c_idx
,
int
hw_idx
)
{
pointer
+=
c_idx
*
size_nbits
/
(
8
*
sizeof
(
Type
));
channel
-=
c_idx
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mask_size
;
++
i
)
{
mask
[
i
]
=
0
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
int
offset
=
hw_idx
+
i
;
int
h
=
(
int
)((
uint32_t
)(
offset
)
/
div
);
int
w
=
(
int
)((
uint32_t
)(
offset
)
%
div
);
stride
[
i
]
=
(
h
*
bound
+
w
)
*
hw_stride_in_elements
;
#pragma unroll
for
(
int
j
=
0
;
j
<
chan_blk
;
j
+=
pack_size
)
{
bool
guard
=
(
j
<
channel
)
&&
(
w
<
bound
);
int
index
=
i
*
(
chan_blk
/
pack_size
)
+
(
j
/
pack_size
);
int
mask_index
=
(
index
>>
5
);
int
mask_shift
=
(
index
&
0x1f
);
mask
[
mask_index
]
|=
(
guard
<<
mask_shift
);
}
}
}
MEGDNN_DEVICE
__forceinline__
void
add_pointer_offset
(
size_t
offset_in_type
)
{
pointer
+=
offset_in_type
;
}
MEGDNN_DEVICE
__forceinline__
void
load
(
Fragment
&
frag
,
int
zero_point
)
{
AccessType
*
frag_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
frag
);
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
Type
*
pointer_
=
pointer
+
stride
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
chan_blk
;
j
+=
pack_size
)
{
int
frag_idx
=
i
*
(
chan_blk
/
pack_size
)
+
(
j
/
pack_size
);
int
mask_index
=
(
frag_idx
>>
5
);
int
mask_shift
=
(
frag_idx
&
0x1f
);
bool
guard
=
(
mask
[
mask_index
]
&
(
1
<<
mask_shift
));
global_load
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
j
*
size_nbits
/
(
8
*
sizeof
(
Type
))),
guard
,
zero_point
);
}
}
}
MEGDNN_DEVICE
__forceinline__
void
store
(
const
Fragment
&
frag
)
{
const
AccessType
*
frag_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
&
frag
);
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
Type
*
pointer_
=
pointer
+
stride
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
chan_blk
;
j
+=
pack_size
)
{
int
frag_idx
=
i
*
(
chan_blk
/
pack_size
)
+
(
j
/
pack_size
);
int
mask_index
=
(
frag_idx
>>
5
);
int
mask_shift
=
(
frag_idx
&
0x1f
);
bool
guard
=
(
mask
[
mask_index
]
&
(
1
<<
mask_shift
));
global_store
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
j
*
size_nbits
/
(
8
*
sizeof
(
Type
))),
guard
);
}
}
}
MEGDNN_DEVICE
__forceinline__
void
advance
()
{
pointer
+=
chan_blk
*
size_nbits
/
(
8
*
sizeof
(
Type
));
channel
-=
chan_blk
;
}
private:
Type
*
pointer
;
int
hw_stride_in_elements
;
int
channel
;
int
bound
;
Uint32Fastdiv
div
;
uint32_t
mask
[
mask_size
];
size_t
stride
[
width
];
};
template
<
bool
padding_
,
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
int
width_
,
int
size_nbits_
,
uint32_t
layout_type_
=
LayoutType
::
NCHWx
>
struct
TensorIteratorPolicy
;
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
int
size_nbits_
,
uint32_t
layout_type_
>
struct
TensorIteratorPolicy
<
true
,
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
{
size_nbits_
,
layout_type_
>
{
using
TensorIterator
=
MaskedTensorIteratorOverChannel
<
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
;
width_
,
size_nbits_
,
layout_type_
>
;
};
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
int
size_nbits_
,
uint32_t
layout_type_
>
struct
TensorIteratorPolicy
<
false
,
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
{
size_nbits_
,
layout_type_
>
{
using
TensorIterator
=
TensorIteratorOverChannel
<
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
;
size_nbits_
,
layout_type_
>
;
};
template
<
typename
SrcIterator_
,
typename
DstIterator_
,
typename
Transpose_
,
...
...
dnn/src/cuda/relayout_format/relayout_format_nchw_nhwc.cu
0 → 100644
浏览文件 @
894a2407
/**
* \file dnn/src/cuda/relayout_format/relayout_format_nchw_nhwc.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/relayout_format/relayout_format_kern.cuh"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
relayout_format
;
using
namespace
internal
;
namespace
{
template
<
int
pack_w
>
struct
rwtype_helper
;
template
<
>
struct
rwtype_helper
<
2
>
{
using
InnerDtype
=
char
;
};
template
<
>
struct
rwtype_helper
<
8
>
{
using
InnerDtype
=
unsigned
;
};
}
// namespace
void
relayout_format
::
relayout_format_cuda_nchw_nhwc
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
,
const
float
dst_scale
,
const
uint8_t
src_zero_point
,
const
uint8_t
dst_zero_point
)
{
auto
&&
stype
=
src
.
layout
.
dtype
;
auto
&&
dtype
=
dst
.
layout
.
dtype
;
auto
&
src_layout
=
src
.
layout
;
auto
&
dst_layout
=
dst
.
layout
;
int
n
=
src
.
layout
[
0
];
int
ic
=
src
.
layout
[
1
];
int
h
=
src
.
layout
[
2
];
int
w
=
src
.
layout
[
3
];
int
w_pad
=
DIVUP
(
w
,
2
)
*
2
;
int
hw
=
h
*
w_pad
;
int
n_stride_src
=
src_layout
.
stride
[
0
];
int
ic_stride
=
src_layout
.
stride
[
1
];
int
n_stride_dst
=
dst_layout
.
stride
[
0
];
int
hw_stride
=
dst_layout
.
stride
[
2
];
static
constexpr
int
chan_blk
=
8
;
static
constexpr
int
pack_oc
=
8
;
int
problem_size
=
n
*
DIVUP
(
ic
,
chan_blk
)
*
hw
;
int
oc
=
dst
.
layout
[
3
];
bool
same_scale
=
src_scale
==
dst_scale
;
bool
padding
=
w
%
2
!=
0
;
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _src_type, _dst_type, \
_src_c_type, _dst_c_type, _size_nbits) \
if (padding == _padding && same_scale == _same_scale && \
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \
using InnerDtype_ = typename rwtype_helper<_pack_w>::InnerDtype; \
using SrcIterator_ = \
TensorIteratorOverChannel<InnerDtype_, 1, chan_blk, _pack_w, \
_size_nbits>; \
using DstIterator_ = typename TensorIteratorPolicy< \
_padding, _dst_c_type, pack_oc, chan_blk, _pack_w, \
_size_nbits, LayoutType::NHWC>::TensorIterator; \
using CudaPostProcess_ = \
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \
_same_scale>; \
using Transpose_ = \
Translayout<_pack_w, chan_blk, InnerDtype_, dtype::_src_type, \
dtype::_dst_type, _same_scale>; \
using RelayoutProblem_ = \
RelayoutProblem<SrcIterator_, DstIterator_, Transpose_, \
CudaPostProcess_>; \
n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(InnerDtype_)); \
ic_stride = ic_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \
hw_stride = hw_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \
typename RelayoutProblem_::Param param{ \
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \
w_pad}, \
DstIterator_{(_dst_c_type*)dst.raw_ptr, hw_stride, oc, w, \
w_pad}, \
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \
dst_zero_point}, \
n_stride_src, \
n_stride_dst, \
n, \
ic, \
hw, \
src_zero_point}; \
auto kernel = relayout_kern<RelayoutProblem_>; \
int nr_threads = query_blocksize_for_kernel(kernel); \
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \
const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \
const dim3 thread_dim(nr_threads); \
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \
}
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, true, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, true, 2, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 2, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 2, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 2, _src_type, _dst_type, char, char, 4);
DISPATCH_4BITS
(
QuantizedS4
,
QuantizedS4
);
DISPATCH_4BITS
(
Quantized4Asymm
,
Quantized4Asymm
);
#undef DISPATCH_4BITS
#undef DISPATCH_RAW
megdnn_assert
(
false
,
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d)."
,
stype
.
name
(),
dtype
.
name
(),
h
,
w
);
}
void
relayout_format
::
relayout_format_cuda_nhwc_nchw
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
,
const
float
dst_scale
,
const
uint8_t
src_zero_point
,
const
uint8_t
dst_zero_point
)
{
auto
&&
stype
=
src
.
layout
.
dtype
;
auto
&&
dtype
=
dst
.
layout
.
dtype
;
auto
&
src_layout
=
src
.
layout
;
auto
&
dst_layout
=
dst
.
layout
;
int
n
=
src
.
layout
[
0
];
int
h
=
src
.
layout
[
1
];
int
w
=
src
.
layout
[
2
];
int
ic
=
src
.
layout
[
3
];
int
w_pad
=
DIVUP
(
w
,
2
)
*
2
;
int
hw
=
h
*
w_pad
;
int
n_stride_src
=
src_layout
.
stride
[
0
];
int
hw_stride
=
src_layout
.
stride
[
2
];
int
n_stride_dst
=
dst_layout
.
stride
[
0
];
int
oc_stride
=
dst_layout
.
stride
[
1
];
static
constexpr
int
chan_blk
=
8
;
static
constexpr
int
pack_oc
=
8
;
int
problem_size
=
n
*
DIVUP
(
ic
,
chan_blk
)
*
hw
;
int
oc
=
dst
.
layout
[
1
];
bool
same_scale
=
src_scale
==
dst_scale
;
bool
padding
=
w
%
2
!=
0
;
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _src_type, _dst_type, \
_src_c_type, _dst_c_type, _size_nbits) \
if (padding == _padding && same_scale == _same_scale && \
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \
using SrcIterator_ = typename TensorIteratorPolicy< \
_padding, _src_c_type, pack_oc, chan_blk, _pack_w, \
_size_nbits, LayoutType::NHWC>::TensorIterator; \
using InnerDtype_ = typename rwtype_helper<_pack_w>::InnerDtype; \
using DstIterator_ = \
TensorIteratorOverChannel<InnerDtype_, 1, chan_blk, _pack_w, \
_size_nbits>; \
using CudaPostProcess_ = \
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \
_same_scale>; \
using Transpose_ = \
Translayout<chan_blk, _pack_w, _src_c_type, dtype::_src_type, \
dtype::_dst_type, _same_scale>; \
using RelayoutProblem_ = \
RelayoutProblem<SrcIterator_, DstIterator_, Transpose_, \
CudaPostProcess_>; \
n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(_src_c_type)); \
hw_stride = hw_stride * _size_nbits / (8 * sizeof(_src_c_type)); \
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \
oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \
typename RelayoutProblem_::Param param{ \
SrcIterator_{(_src_c_type*)src.raw_ptr, hw_stride, ic, w, \
w_pad}, \
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \
w_pad}, \
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \
dst_zero_point}, \
n_stride_src, \
n_stride_dst, \
n, \
ic, \
hw, \
src_zero_point}; \
auto kernel = relayout_kern<RelayoutProblem_>; \
int nr_threads = query_blocksize_for_kernel(kernel); \
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \
const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \
const dim3 thread_dim(nr_threads); \
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \
}
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, true, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, true, 2, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 2, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 8, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 2, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 2, _src_type, _dst_type, char, char, 4);
DISPATCH_4BITS
(
QuantizedS4
,
QuantizedS4
);
DISPATCH_4BITS
(
Quantized4Asymm
,
Quantized4Asymm
);
#undef DISPATCH_4BITS
#undef DISPATCH_RAW
megdnn_assert
(
false
,
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d)."
,
stype
.
name
(),
dtype
.
name
(),
h
,
w
);
}
dnn/src/cuda/relayout_format/translayout.cuh
浏览文件 @
894a2407
...
...
@@ -42,8 +42,9 @@ struct enable_qtype_b4 {
static
constexpr
bool
val_dst
=
std
::
is_same
<
dt_dst
,
dtype
::
QuantizedS4
>::
value
||
std
::
is_same
<
dt_dst
,
dtype
::
Quantized4Asymm
>::
value
;
using
type
=
typename
std
::
enable_if
<
std
::
is_same
<
dt_src
,
dt_dst
>::
value
&&
val_src
&&
val_dst
>::
type
;
static
constexpr
bool
value
=
std
::
is_same
<
dt_src
,
dt_dst
>::
value
&&
val_src
&&
val_dst
;
using
type
=
typename
std
::
enable_if
<
value
>::
type
;
};
// The input fragment is stored in RowMajor order. The translayout operator
...
...
@@ -393,26 +394,32 @@ struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale,
using
Fragment
=
array_wrapper
<
SrcType
,
elements_in_type
>
;
static
inline
__device__
void
trans
(
Fragment
&
dst
,
const
Fragment
&
src
,
CudaPostProcess
<
DnnSrcType
,
DnnDstType
,
same_scale
>&
post_process
,
const
char
zero_point
)
{
CudaPostProcess
<
DnnSrcType
,
DnnDstType
,
same_scale
>&
post_process
)
{
int
intermediate
[
8
][
2
];
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
0
],
reinterpret_cast
<
uint8_t
&>
(
src
[
0
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
1
],
reinterpret_cast
<
uint8_t
&>
(
src
[
1
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
2
],
reinterpret_cast
<
uint8_t
&>
(
src
[
2
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
3
],
reinterpret_cast
<
uint8_t
&>
(
src
[
3
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
4
],
reinterpret_cast
<
uint8_t
&>
(
src
[
4
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
5
],
reinterpret_cast
<
uint8_t
&>
(
src
[
5
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
6
],
reinterpret_cast
<
uint8_t
&>
(
src
[
6
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
7
],
reinterpret_cast
<
uint8_t
&>
(
src
[
7
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
0
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
0
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
1
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
1
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
2
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
2
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
3
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
3
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
4
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
4
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
5
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
5
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
6
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
6
*
col_in_type
]));
transform_b4x2_to_int8
<
signedness
>
(
intermediate
[
7
],
reinterpret_cast
<
const
uint8_t
&>
(
src
[
7
*
col_in_type
]));
int
*
dst_frag
=
reinterpret_cast
<
int
*>
(
&
dst
);
auto
pack
=
[
&
](
int
idx
)
->
int
{
return
transform_int8_to_b4x8
<
signedness
>
(
...
...
@@ -445,25 +452,24 @@ struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale,
using
Fragment
=
array_wrapper
<
SrcType
,
elements_in_type
>
;
static
inline
__device__
void
trans
(
Fragment
&
dst
,
const
Fragment
&
src
,
CudaPostProcess
<
DnnSrcType
,
DnnDstType
,
same_scale
>&
post_process
,
const
char
zero_point
)
{
CudaPostProcess
<
DnnSrcType
,
DnnDstType
,
same_scale
>&
post_process
)
{
int
intermediate
[
8
][
8
];
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
0
],
reinterpret_cast
<
const
int
&>
(
src
[
0
]));
intermediate
[
0
],
reinterpret_cast
<
const
int
&>
(
src
[
0
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
1
],
reinterpret_cast
<
const
int
&>
(
src
[
1
]));
intermediate
[
1
],
reinterpret_cast
<
const
int
&>
(
src
[
1
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
2
],
reinterpret_cast
<
const
int
&>
(
src
[
2
]));
intermediate
[
2
],
reinterpret_cast
<
const
int
&>
(
src
[
2
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
3
],
reinterpret_cast
<
const
int
&>
(
src
[
3
]));
intermediate
[
3
],
reinterpret_cast
<
const
int
&>
(
src
[
3
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
4
],
reinterpret_cast
<
const
int
&>
(
src
[
4
]));
intermediate
[
4
],
reinterpret_cast
<
const
int
&>
(
src
[
4
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
5
],
reinterpret_cast
<
const
int
&>
(
src
[
5
]));
intermediate
[
5
],
reinterpret_cast
<
const
int
&>
(
src
[
5
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
6
],
reinterpret_cast
<
const
int
&>
(
src
[
6
]));
intermediate
[
6
],
reinterpret_cast
<
const
int
&>
(
src
[
6
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
7
],
reinterpret_cast
<
const
int
&>
(
src
[
7
]));
intermediate
[
7
],
reinterpret_cast
<
const
int
&>
(
src
[
7
*
col_in_type
]));
int
*
dst_frag
=
reinterpret_cast
<
int
*>
(
&
dst
);
auto
pack
=
[
&
](
int
idx
)
{
return
transform_int8_to_b4x8
<
signedness
>
(
...
...
@@ -502,13 +508,12 @@ struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale,
using
Fragment
=
array_wrapper
<
SrcType
,
elements_in_type
>
;
static
inline
__device__
void
trans
(
Fragment
&
dst
,
const
Fragment
&
src
,
CudaPostProcess
<
DnnSrcType
,
DnnDstType
,
same_scale
>&
post_process
,
const
char
zero_point
)
{
CudaPostProcess
<
DnnSrcType
,
DnnDstType
,
same_scale
>&
post_process
)
{
int
intermediate
[
2
][
8
];
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
0
],
reinterpret_cast
<
const
int
&>
(
src
[
0
]));
intermediate
[
0
],
reinterpret_cast
<
const
int
&>
(
src
[
0
*
col_in_type
]));
transform_b4x8_to_int8
<
signedness
>
(
intermediate
[
1
],
reinterpret_cast
<
const
int
&>
(
src
[
1
]));
intermediate
[
1
],
reinterpret_cast
<
const
int
&>
(
src
[
1
*
col_in_type
]));
int
*
dst_frag
=
reinterpret_cast
<
int
*>
(
&
dst
);
dst_frag
[
0
]
=
transform_int8_to_b4x8
<
signedness
>
(
post_process
(
intermediate
[
0
][
0
]),
...
...
dnn/src/cuda/warp_perspective/forward.cu
浏览文件 @
894a2407
...
...
@@ -508,7 +508,7 @@ struct KernCoreNHWC<ctype, OutputConverter, 8> {
"assert qu4 or q4"
);
constexpr
bool
signedness
=
std
::
is_same
<
ctype
,
dt_qint4
>::
value
;
int8_t
bval_4
=
bval
.
as_storage
()
&
0xF
;
const
int
bval_int
=
transform_int8_to_b
it
4x8
<
signedness
>
(
const
int
bval_int
=
transform_int8_to_b4x8
<
signedness
>
(
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
);
int
src_ori
[
4
];
src_ori
[
0
]
=
src0_ok
?
*
(
int
*
)(
src_ptr0
+
offset
)
:
bval_int
;
...
...
@@ -516,10 +516,10 @@ struct KernCoreNHWC<ctype, OutputConverter, 8> {
src_ori
[
2
]
=
src2_ok
?
*
(
int
*
)(
src_ptr2
+
offset
)
:
bval_int
;
src_ori
[
3
]
=
src3_ok
?
*
(
int
*
)(
src_ptr3
+
offset
)
:
bval_int
;
int
src
[
4
][
8
];
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
0
],
src_ori
[
0
]);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
1
],
src_ori
[
1
]);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
2
],
src_ori
[
2
]);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
3
],
src_ori
[
3
]);
transform_b4x8_to_int8
<
signedness
>
(
src
[
0
],
src_ori
[
0
]);
transform_b4x8_to_int8
<
signedness
>
(
src
[
1
],
src_ori
[
1
]);
transform_b4x8_to_int8
<
signedness
>
(
src
[
2
],
src_ori
[
2
]);
transform_b4x8_to_int8
<
signedness
>
(
src
[
3
],
src_ori
[
3
]);
int
res
=
pack_output_func
<
signedness
>
(
output_converter
,
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
w00
,
w01
,
w10
,
w11
);
...
...
@@ -542,7 +542,7 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> {
"assert qu4 or q4"
);
constexpr
bool
signedness
=
std
::
is_same
<
ctype
,
dt_qint4
>::
value
;
int8_t
bval_4
=
bval
.
as_storage
()
&
0xF
;
const
int
bval_int_temp
=
transform_int8_to_b
it
4x8
<
signedness
>
(
const
int
bval_int_temp
=
transform_int8_to_b4x8
<
signedness
>
(
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
);
const
int2
bval_int
{
bval_int_temp
,
bval_int_temp
};
...
...
@@ -552,15 +552,15 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> {
src_ori
[
2
]
=
src2_ok
?
*
(
int2
*
)(
src_ptr2
+
offset
)
:
bval_int
;
src_ori
[
3
]
=
src3_ok
?
*
(
int2
*
)(
src_ptr3
+
offset
)
:
bval_int
;
int
src
[
8
][
8
];
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
0
],
src_ori
[
0
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
1
],
src_ori
[
1
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
2
],
src_ori
[
2
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
3
],
src_ori
[
3
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
4
],
src_ori
[
0
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
5
],
src_ori
[
1
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
6
],
src_ori
[
2
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
src
[
7
],
src_ori
[
3
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
0
],
src_ori
[
0
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
1
],
src_ori
[
1
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
2
],
src_ori
[
2
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
3
],
src_ori
[
3
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
4
],
src_ori
[
0
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
5
],
src_ori
[
1
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
6
],
src_ori
[
2
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
src
[
7
],
src_ori
[
3
].
y
);
int2
res
;
res
.
x
=
pack_output_func
<
signedness
>
(
output_converter
,
src
[
0
],
src
[
1
],
...
...
dnn/test/cuda/relayout_format.cpp
浏览文件 @
894a2407
...
...
@@ -325,6 +325,91 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) {
}
}
TEST_F
(
CUDA
,
RELAYOUT_FORMAT_NCHW_NHWC
)
{
Checker
<
RelayoutFormat
>
checker
(
handle_cuda
());
UniformIntRNG
s4
{
-
8
,
7
};
UniformIntRNG
u4
{
0
,
15
};
param
::
RelayoutFormat
param
;
param
.
mode
=
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWC
;
for
(
size_t
n
:
{
1
,
3
})
{
for
(
size_t
c
:
{
8
,
16
})
{
for
(
size_t
h
:
{
7
,
14
,
16
,
28
})
{
for
(
size_t
w
:
{
2
,
3
,
7
,
8
,
16
,
31
})
{
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
{
2.
f
})
.
set_dtype
(
1
,
dtype
::
QuantizedS4
{
2.
f
})
.
set_rng
(
0
,
&
s4
)
.
set_param
(
param
)
.
execs
({{
n
,
c
,
h
,
w
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
{
1.2
f
,
8
})
.
set_dtype
(
1
,
dtype
::
Quantized4Asymm
{
1.2
f
,
4
})
.
set_rng
(
0
,
&
u4
)
.
set_param
(
param
)
.
execs
({{
n
,
c
,
h
,
w
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
{
1.19990307
f
})
.
set_dtype
(
1
,
dtype
::
QuantizedS4
{
1.
f
})
.
set_rng
(
0
,
&
s4
)
.
set_param
(
param
)
.
execs
({{
n
,
c
,
h
,
w
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
{
1.19990307
f
,
8
})
.
set_dtype
(
1
,
dtype
::
Quantized4Asymm
{
1.
f
,
4
})
.
set_rng
(
0
,
&
u4
)
.
set_param
(
param
)
.
set_epsilon
(
1e-3
)
.
execs
({{
n
,
c
,
h
,
w
},
{}});
}
}
}
}
checker
.
execs
({{
1
,
256
,
384
,
640
},
{}});
}
TEST_F
(
CUDA
,
RELAYOUT_FORMAT_NHWC_NCHW
)
{
Checker
<
RelayoutFormat
>
checker
(
handle_cuda
());
UniformIntRNG
s4
{
-
8
,
7
};
UniformIntRNG
u4
{
0
,
15
};
param
::
RelayoutFormat
param
;
param
.
mode
=
param
::
RelayoutFormat
::
Mode
::
NHWC_NCHW
;
for
(
size_t
n
:
{
1
,
3
})
{
for
(
size_t
c
:
{
8
,
16
})
{
for
(
size_t
h
:
{
7
,
14
,
16
,
28
})
{
for
(
size_t
w
:
{
2
,
3
,
4
,
7
,
14
,
16
,
17
})
{
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
{
2.
f
})
.
set_dtype
(
1
,
dtype
::
QuantizedS4
{
2.
f
})
.
set_rng
(
0
,
&
s4
)
.
set_param
(
param
)
.
set_epsilon
(
1e-3
)
.
execs
({{
n
,
h
,
w
,
c
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
{
1.2
f
,
4
})
.
set_dtype
(
1
,
dtype
::
Quantized4Asymm
{
1.2
f
,
8
})
.
set_rng
(
0
,
&
u4
)
.
set_param
(
param
)
.
set_epsilon
(
1e-3
)
.
execs
({{
n
,
h
,
w
,
c
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
{
1.19990307
f
})
.
set_dtype
(
1
,
dtype
::
QuantizedS4
{
1.
f
})
.
set_rng
(
0
,
&
s4
)
.
set_param
(
param
)
.
set_epsilon
(
1e-3
)
.
execs
({{
n
,
h
,
w
,
c
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
{
1.20211209
f
,
8
})
.
set_dtype
(
1
,
dtype
::
Quantized4Asymm
{
1.
f
,
4
})
.
set_rng
(
0
,
&
u4
)
.
set_param
(
param
)
.
set_epsilon
(
1e-3
)
.
execs
({{
n
,
h
,
w
,
c
},
{}});
}
}
}
}
checker
.
execs
({{
1
,
384
,
640
,
256
},
{}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
CUDA
,
BENCHMARK_RELAYOUT_FORMAT
)
{
using
Param
=
RelayoutFormat
::
Param
;
...
...
@@ -393,6 +478,7 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
}
};
printf
(
"nchw -> nchw64
\n
"
);
{
TensorShapeArray
shapes
=
{
{
1
,
64
,
56
,
56
},
{
16
,
64
,
56
,
56
},
{
64
,
64
,
56
,
56
},
...
...
@@ -403,6 +489,18 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
param
.
mode
=
param
::
RelayoutFormat
::
Mode
::
NCHW_NCHW64
;
run
(
shapes
,
param
);
}
printf
(
"nchw -> nhwc
\n
"
);
{
TensorShapeArray
shapes
=
{
{
1
,
64
,
56
,
56
},
{
16
,
64
,
56
,
56
},
{
64
,
64
,
56
,
56
},
{
1
,
64
,
56
,
55
},
{
16
,
64
,
56
,
55
},
{
64
,
64
,
56
,
55
},
{
1
,
256
,
384
,
640
},
{
16
,
16
,
384
,
640
},
};
Param
param
;
param
.
mode
=
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWC
;
run
(
shapes
,
param
);
}
printf
(
"nchw64 -> nchw
\n
"
);
{
TensorShapeArray
shapes
=
{
{
64
,
1
,
56
,
56
,
64
},
...
...
@@ -415,6 +513,19 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
param
.
mode
=
param
::
RelayoutFormat
::
Mode
::
NCHW64_NCHW
;
run
(
shapes
,
param
);
}
printf
(
"nhwc -> nchw
\n
"
);
{
TensorShapeArray
shapes
=
{
{
64
,
56
,
56
,
64
},
{
1
,
7
,
7
,
64
*
32
},
{
16
,
7
,
7
,
64
*
32
},
{
64
,
7
,
7
,
64
*
32
},
{
1
,
384
,
640
,
64
*
4
},
};
Param
param
;
param
.
mode
=
param
::
RelayoutFormat
::
Mode
::
NHWC_NCHW
;
run
(
shapes
,
param
);
}
}
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录