Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
96d90be1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
96d90be1
编写于
6月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): fallback support int4 relayout
GitOrigin-RevId: 3625f5847055940e646358654f296922f05afa93
上级
eef0308b
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
409 addition
and
76 deletion
+409
-76
dnn/include/megdnn/dtype.h
dnn/include/megdnn/dtype.h
+13
-11
dnn/src/fallback/relayout/opr_impl.cpp
dnn/src/fallback/relayout/opr_impl.cpp
+341
-64
dnn/test/common/checker.cpp
dnn/test/common/checker.cpp
+1
-1
dnn/test/fallback/relayout.cpp
dnn/test/fallback/relayout.cpp
+54
-0
未找到文件。
dnn/include/megdnn/dtype.h
浏览文件 @
96d90be1
...
...
@@ -625,6 +625,7 @@ struct log<1> {
::megdnn::dtype::log<sizeof(ctype)>::value; \
static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name; \
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \
static MEGDNN_CONSTEXPR uint16_t bits = _bits == 0 ? sizeof(_ctype) * 8 : _bits; \
static MEGDNN_CONSTEXPR bool has_param = _has_param
#else
#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \
...
...
@@ -632,7 +633,8 @@ struct log<1> {
typedef ::megdnn::dtype::_name dtype; \
static const uint16_t size_log = ::megdnn::dtype::log<sizeof(ctype)>::value; \
static MEGDNN_CONSTEXPR int enumv = DTypeEnum::_name; \
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \
static MEGDNN_CONSTEXPR uint16_t bits = _bits == 0 ? sizeof(_ctype) * 8 : _bits;
#endif // MEGDNN_CC_HOST
#define MEGDNN_DEF_DT(_name, _ctype, _cat, _sign, _minval, _maxval) \
...
...
dnn/src/fallback/relayout/opr_impl.cpp
浏览文件 @
96d90be1
...
...
@@ -8,12 +8,129 @@
using
namespace
megdnn
;
using
namespace
fallback
;
namespace
megdnn
{
namespace
relayout
{
namespace
transpose_fallback
{
template
<
>
struct
transpose_traits
<
dt_qint4
>
{
static
constexpr
size_t
block_size
=
BLOCK_LINE_SIZE_BYTES
;
};
template
<
>
void
transpose_block_fallback
<
dt_qint4
>
(
const
dt_qint4
*
src
,
dt_qint4
*
dst
,
const
size_t
src_stride
,
const
size_t
dst_stride
,
size_t
block_h
,
size_t
block_w
)
{
constexpr
size_t
block_size
=
transpose_traits
<
dt_qint4
>::
block_size
;
uint8_t
block
[
block_size
][
block_size
];
uint8_t
*
src_ptr
=
(
uint8_t
*
)
src
;
uint8_t
*
dst_ptr
=
(
uint8_t
*
)
dst
;
for
(
size_t
i
=
0
;
i
<
block_h
;
++
i
)
{
size_t
src_offset_base
=
i
*
src_stride
;
for
(
size_t
j
=
0
;
j
<
block_w
;
++
j
)
{
size_t
src_offset
=
src_offset_base
+
j
;
size_t
src_byte_offset
=
src_offset
>>
1
;
if
(
src_offset
%
2
==
0
)
{
block
[
j
][
i
]
=
src_ptr
[
src_byte_offset
]
&
0xf
;
}
else
{
block
[
j
][
i
]
=
((
src_ptr
[
src_byte_offset
]
&
0xf0
)
>>
4
)
&
0xf
;
}
}
}
for
(
size_t
i
=
0
;
i
<
block_w
;
++
i
)
{
size_t
dst_offset_base
=
i
*
dst_stride
;
for
(
size_t
j
=
0
;
j
<
block_h
;
++
j
)
{
size_t
dst_offset
=
dst_offset_base
+
j
;
size_t
dst_byte_offset
=
dst_offset
>>
1
;
uint8_t
dst_temp
=
dst_ptr
[
dst_byte_offset
];
uint8_t
src_temp
=
block
[
i
][
j
];
if
(
dst_offset
%
2
==
0
)
{
dst_temp
=
(
dst_temp
&
0xf0
)
|
src_temp
;
}
else
{
dst_temp
=
(
dst_temp
&
0xf
)
|
(
src_temp
<<
4
);
}
dst_ptr
[
dst_byte_offset
]
=
dst_temp
;
}
}
}
template
<
>
void
transpose
<
dt_qint4
>
(
size_t
batch
,
size_t
m
,
size_t
n
,
dt_qint4
*
src
,
dt_qint4
*
dst
,
size_t
stride_m
)
{
if
(
stride_m
==
0
)
{
stride_m
=
n
;
}
uint8_t
*
batch_src
=
(
uint8_t
*
)(
src
);
uint8_t
*
batch_dst
=
(
uint8_t
*
)(
dst
);
constexpr
size_t
B
=
transpose_traits
<
dt_qint4
>::
block_size
;
auto
work_block
=
[
m
,
stride_m
,
&
batch_src
,
&
batch_dst
](
const
size_t
i
,
const
size_t
j
,
const
size_t
h
,
const
size_t
w
)
{
size_t
src_offset
=
i
*
stride_m
+
j
;
size_t
dst_offset
=
j
*
m
+
i
;
megdnn_assert
(
src_offset
%
2
==
0
&&
dst_offset
%
2
==
0
);
auto
src
=
batch_src
+
(
src_offset
>>
1
);
auto
dst
=
batch_dst
+
(
dst_offset
>>
1
);
MIDOUT_BEGIN
(
transpose_fallback
,
midout_iv
(
0
))
{
if
(
h
==
B
&&
w
==
B
)
{
transpose_block
((
dt_qint4
*
)
src
,
(
dt_qint4
*
)
dst
,
stride_m
,
m
);
}
else
{
transpose_block
((
dt_qint4
*
)
src
,
(
dt_qint4
*
)
dst
,
stride_m
,
m
,
h
,
w
);
}
}
MIDOUT_END
();
};
auto
work_row
=
[
&
work_block
,
n
](
size_t
i
,
size_t
h
)
{
size_t
j
=
0
;
for
(;
j
+
B
<=
n
;
j
+=
B
)
{
work_block
(
i
,
j
,
h
,
B
);
}
if
(
j
<
n
)
{
work_block
(
i
,
j
,
h
,
n
-
j
);
}
};
for
(
size_t
b
=
0
;
b
<
batch
;
++
b
)
{
size_t
i
=
0
;
for
(;
i
+
B
<=
m
;
i
+=
B
)
{
work_row
(
i
,
B
);
}
if
(
i
<
m
)
{
work_row
(
i
,
m
-
i
);
}
size_t
src_offset
=
m
*
stride_m
;
size_t
dst_offset
=
m
*
n
;
megdnn_assert
(
src_offset
%
2
==
0
&&
dst_offset
%
2
==
0
);
batch_src
+=
(
src_offset
>>
1
);
batch_dst
+=
(
dst_offset
>>
1
);
}
}
}
// namespace transpose_fallback
}
// namespace relayout
}
// namespace megdnn
namespace
{
bool
is_lastdim_contig
(
const
TensorLayout
&
layout
)
{
return
layout
.
ndim
<=
3
&&
layout
.
stride
[
layout
.
ndim
-
1
]
==
1
;
}
bool
is_int4
(
const
TensorLayout
&
layout
)
{
return
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
;
}
inline
bool
check_dtype_support_transparam
(
bool
trans
,
bool
is_bit4
,
const
relayout
::
TransposeParam
&
param
)
{
if
(
trans
&&
is_bit4
)
{
auto
c
=
param
.
c
;
return
c
==
1
||
c
==
2
||
c
==
4
||
c
==
8
;
}
return
trans
;
}
template
<
size_t
sz
,
typename
T0
=
char
>
struct
equiv_ctype_storage
{
T0
_
[
sz
];
...
...
@@ -26,16 +143,111 @@ struct equiv_ctype {
alignof
(
typename
DTypeTrait
<
dtype
>::
ctype
)
>
;
};
typedef
void
(
*
memcpy_policy_t
)(
void
*
cont
,
void
*
non_cont
,
size_t
);
typedef
void
(
*
memcpy_policy_t
)(
void
*
cont
,
void
*
non_cont
,
size_t
src_offset
,
size_t
dst_offset
,
size_t
size
);
void
memcpy_cont2noncont
(
void
*
cont
,
void
*
non_cont
,
size_t
size
)
{
void
memcpy_cont2noncont
(
void
*
cont
,
void
*
non_cont
,
size_t
,
size_t
,
size_t
size
)
{
memcpy
(
non_cont
,
cont
,
size
);
}
void
memcpy_noncont2cont
(
void
*
cont
,
void
*
non_cont
,
size_t
size
)
{
void
memcpy_noncont2cont
(
void
*
cont
,
void
*
non_cont
,
size_t
,
size_t
,
size_t
size
)
{
memcpy
(
cont
,
non_cont
,
size
);
}
void
memcpy_4bit
(
void
*
cont
,
void
*
nocont
,
size_t
cont_offset
,
size_t
nocont_offset
,
size_t
size
)
{
if
(
size
==
0
)
return
;
uint8_t
*
cont_u8
=
(
uint8_t
*
)
cont
;
uint8_t
*
nocont_u8
=
(
uint8_t
*
)
nocont
;
size_t
cont_bytes
=
cont_offset
>>
1
;
size_t
nocont_bytes
=
nocont_offset
>>
1
;
size_t
size_byte
=
size
>>
1
;
void
*
cont_ptr
=
cont_u8
+
cont_bytes
;
void
*
nocont_ptr
=
nocont_u8
+
nocont_bytes
;
bool
size_align
=
size
%
2
==
0
;
bool
cont_align
=
cont_offset
%
2
==
0
;
bool
nocont_align
=
nocont_offset
%
2
==
0
;
if
(
cont_align
&&
nocont_align
)
{
memcpy
(
cont_ptr
,
nocont_ptr
,
size_byte
);
if
(
!
size_align
)
{
uint8_t
*
dst_ptr
=
(
uint8_t
*
)
cont_ptr
+
size_byte
;
uint8_t
*
src_ptr
=
(
uint8_t
*
)
nocont_ptr
+
size_byte
;
*
dst_ptr
=
(
*
src_ptr
)
&
0xf
;
}
}
else
if
(
!
cont_align
&&
nocont_align
)
{
uint8_t
*
dst_ptr
=
(
uint8_t
*
)
cont_ptr
;
uint8_t
*
src_ptr
=
(
uint8_t
*
)
nocont_ptr
;
for
(
size_t
i
=
0
;
i
<
size_byte
;
++
i
)
{
uint8_t
dst_low
=
*
dst_ptr
;
uint8_t
src_all
=
*
src_ptr
;
uint8_t
last
=
(
dst_low
&
0xf
)
|
(
src_all
&
0xf
)
<<
4
;
uint8_t
now
=
((
src_all
&
0xf0
)
>>
4
)
&
0xf
;
*
dst_ptr
=
last
;
++
dst_ptr
;
*
dst_ptr
=
now
;
++
src_ptr
;
}
if
(
!
size_align
)
{
uint8_t
dst_low
=
*
dst_ptr
;
uint8_t
src_all
=
*
src_ptr
;
uint8_t
last
=
(
dst_low
&
0xf
)
|
(
src_all
&
0xf
)
<<
4
;
*
dst_ptr
=
last
;
}
}
else
if
(
cont_align
&&
!
nocont_align
)
{
uint8_t
*
dst_ptr
=
(
uint8_t
*
)
cont_ptr
;
uint8_t
*
src_ptr
=
(
uint8_t
*
)
nocont_ptr
;
for
(
size_t
i
=
0
;
i
<
size_byte
;
++
i
)
{
uint8_t
src_last_high
=
*
src_ptr
;
++
src_ptr
;
uint8_t
src_low
=
*
src_ptr
;
uint8_t
rst
=
(
src_low
&
0xf
)
<<
4
|
((
src_last_high
>>
4
)
&
0xf
);
*
dst_ptr
=
rst
;
++
dst_ptr
;
}
if
(
!
size_align
)
{
uint8_t
src_last_high
=
*
src_ptr
;
*
dst_ptr
=
((
src_last_high
>>
4
)
&
0xf
);
}
}
else
{
uint8_t
*
dst_ptr
=
(
uint8_t
*
)
cont_ptr
;
uint8_t
*
src_ptr
=
(
uint8_t
*
)
nocont_ptr
;
{
uint8_t
src_last_high
=
*
src_ptr
;
uint8_t
dst_last_low
=
*
dst_ptr
;
uint8_t
rst
=
(
dst_last_low
&
0xf
)
|
(
src_last_high
&
0xf0
);
*
dst_ptr
=
rst
;
++
dst_ptr
;
++
src_ptr
;
}
if
(
!
size_align
)
{
memcpy
(
dst_ptr
,
src_ptr
,
size_byte
);
}
else
{
if
(
size_byte
>
1
)
{
size_t
align_size
=
size_byte
-
1
;
memcpy
(
dst_ptr
,
src_ptr
,
align_size
);
dst_ptr
+=
align_size
;
src_ptr
+=
align_size
;
}
uint8_t
src_last_low
=
*
src_ptr
;
*
dst_ptr
=
src_last_low
&
0xf
;
}
}
}
void
memcpy_cont2noncont_4bit
(
void
*
cont
,
void
*
non_cont
,
size_t
cont_offset
,
size_t
nocont_offset
,
size_t
size
)
{
memcpy_4bit
(
non_cont
,
cont
,
nocont_offset
,
cont_offset
,
size
);
}
void
memcpy_noncont2cont_4bit
(
void
*
cont
,
void
*
non_cont
,
size_t
cont_offset
,
size_t
nocont_offset
,
size_t
size
)
{
memcpy_4bit
(
cont
,
non_cont
,
cont_offset
,
nocont_offset
,
size
);
}
template
<
typename
T
>
void
call_transpose
(
size_t
batch
,
size_t
m
,
size_t
n
,
size_t
ch
,
void
*
src
,
void
*
dst
,
...
...
@@ -46,7 +258,7 @@ void call_transpose(
}
//! one operand contiguous, and the other non-contiguous
template
<
typename
ctype
>
template
<
int
bits
>
void
dispatch_on_dtype_cont
(
Handle
*
handle
,
const
TensorND
&
cont
,
const
TensorND
&
nonc
,
memcpy_policy_t
mcp_pol
)
{
...
...
@@ -54,13 +266,13 @@ void dispatch_on_dtype_cont(
switch
(
nonc
.
layout
.
ndim
)
{
case
2
:
{
auto
shp0
=
nonc
.
layout
.
shape
[
0
],
shp1
=
nonc
.
layout
.
shape
[
1
];
auto
strd0_n
=
nonc
.
layout
.
stride
[
0
]
*
sizeof
(
ctype
)
;
auto
strd0_c
=
shp1
*
sizeof
(
ctype
)
;
auto
strd0_n
=
nonc
.
layout
.
stride
[
0
]
*
bits
/
8
;
auto
strd0_c
=
shp1
*
bits
/
8
;
kern
=
[
=
]()
{
auto
cur_ctptr
=
static_cast
<
uint8_t
*>
(
cont
.
raw_ptr
());
auto
cur_ncptr
=
static_cast
<
uint8_t
*>
(
nonc
.
raw_ptr
());
for
(
size_t
i
=
0
;
i
<
shp0
;
++
i
)
{
mcp_pol
(
cur_ctptr
,
cur_ncptr
,
strd0_c
);
mcp_pol
(
cur_ctptr
,
cur_ncptr
,
0
,
0
,
strd0_c
);
cur_ctptr
+=
strd0_c
;
cur_ncptr
+=
strd0_n
;
}
...
...
@@ -70,16 +282,16 @@ void dispatch_on_dtype_cont(
case
3
:
{
auto
shp0
=
nonc
.
layout
.
shape
[
0
],
shp1
=
nonc
.
layout
.
shape
[
1
],
shp2
=
nonc
.
layout
.
shape
[
2
];
auto
strd0_n
=
nonc
.
layout
.
stride
[
0
]
*
sizeof
(
ctype
)
,
strd1_n
=
nonc
.
layout
.
stride
[
1
]
*
sizeof
(
ctype
)
;
auto
strd1_c
=
shp2
*
sizeof
(
ctype
)
;
auto
strd0_n
=
nonc
.
layout
.
stride
[
0
]
*
bits
/
8
,
strd1_n
=
nonc
.
layout
.
stride
[
1
]
*
bits
/
8
;
auto
strd1_c
=
shp2
*
bits
/
8
;
kern
=
[
=
]()
{
auto
cur_ctptr
=
static_cast
<
uint8_t
*>
(
cont
.
raw_ptr
());
auto
ncptr_row
=
static_cast
<
uint8_t
*>
(
nonc
.
raw_ptr
());
for
(
size_t
i
=
0
;
i
<
shp0
;
++
i
)
{
auto
cur_ncptr
=
ncptr_row
;
for
(
size_t
j
=
0
;
j
<
shp1
;
++
j
)
{
mcp_pol
(
cur_ctptr
,
cur_ncptr
,
strd1_c
);
mcp_pol
(
cur_ctptr
,
cur_ncptr
,
0
,
0
,
strd1_c
);
cur_ctptr
+=
strd1_c
;
cur_ncptr
+=
strd1_n
;
}
...
...
@@ -95,13 +307,64 @@ void dispatch_on_dtype_cont(
static_cast
<
naive
::
HandleImpl
*>
(
handle
)
->
dispatch_kern
(
std
::
move
(
kern
));
}
template
<
>
void
dispatch_on_dtype_cont
<
4
>
(
Handle
*
handle
,
const
TensorND
&
cont
,
const
TensorND
&
nonc
,
memcpy_policy_t
mcp_pol
)
{
thin_function
<
void
()
>
kern
;
switch
(
nonc
.
layout
.
ndim
)
{
case
2
:
{
auto
shp0
=
nonc
.
layout
.
shape
[
0
],
shp1
=
nonc
.
layout
.
shape
[
1
];
auto
strd0_n
=
nonc
.
layout
.
stride
[
0
];
auto
strd0_c
=
shp1
;
kern
=
[
=
]()
{
auto
cur_ctptr
=
static_cast
<
uint8_t
*>
(
cont
.
raw_ptr
());
auto
cur_ncptr
=
static_cast
<
uint8_t
*>
(
nonc
.
raw_ptr
());
size_t
c_cnt
=
0
;
size_t
n_cnt
=
0
;
for
(
size_t
i
=
0
;
i
<
shp0
;
++
i
)
{
mcp_pol
(
cur_ctptr
,
cur_ncptr
,
c_cnt
,
n_cnt
,
strd0_c
);
c_cnt
+=
strd0_c
;
n_cnt
+=
strd0_n
;
}
};
break
;
}
case
3
:
{
auto
shp0
=
nonc
.
layout
.
shape
[
0
],
shp1
=
nonc
.
layout
.
shape
[
1
],
shp2
=
nonc
.
layout
.
shape
[
2
];
auto
strd0_n
=
nonc
.
layout
.
stride
[
0
],
strd1_n
=
nonc
.
layout
.
stride
[
1
];
auto
strd1_c
=
shp2
;
kern
=
[
=
]()
{
auto
cur_ctptr
=
static_cast
<
uint8_t
*>
(
cont
.
raw_ptr
());
auto
ncptr_row
=
static_cast
<
uint8_t
*>
(
nonc
.
raw_ptr
());
size_t
c_cnt
=
0
;
size_t
n_cnt
=
0
;
for
(
size_t
i
=
0
;
i
<
shp0
;
++
i
)
{
n_cnt
=
i
*
strd0_n
;
for
(
size_t
j
=
0
;
j
<
shp1
;
++
j
)
{
mcp_pol
(
cur_ctptr
,
ncptr_row
,
c_cnt
,
n_cnt
,
strd1_c
);
c_cnt
+=
strd1_c
;
n_cnt
+=
strd1_n
;
}
}
};
break
;
}
default:
megdnn_assert
(
0
);
}
static_cast
<
naive
::
HandleImpl
*>
(
handle
)
->
dispatch_kern
(
std
::
move
(
kern
));
}
void
dispatch_cont
(
Handle
*
handle
,
const
TensorND
&
cont
,
const
TensorND
&
nonc
,
memcpy_policy_t
mcp_pol
)
{
switch
(
cont
.
layout
.
dtype
.
enumv
())
{
#define cb(_dt) \
case DTypeTrait<dtype::_dt>::enumv: \
return dispatch_on_dtype_cont<
equiv_ctype<dtype::_dt>::type
>( \
return dispatch_on_dtype_cont<
DTypeTrait<dtype::_dt>::bits
>( \
handle, cont, nonc, mcp_pol);
MEGDNN_FOREACH_DTYPE_NAME
(
cb
)
MEGDNN_FOREACH_PARAMETERIZED_DTYPE
(
cb
)
...
...
@@ -110,8 +373,8 @@ void dispatch_cont(
}
}
const
size_t
BLOCK_SIZE
=
16
,
TRANSPOSE_CV_MAX_C
=
relayout
::
transpose_fallback
::
BLOCK_LINE_SIZE_BYTES
;
const
size_t
BLOCK_SIZE
=
16
;
const
size_t
TRANSPOSE_CV_MAX_C
=
relayout
::
transpose_fallback
::
BLOCK_LINE_SIZE_BYTES
;
/*!
* \tparam ctype The type of the data
...
...
@@ -221,28 +484,34 @@ void RelayoutForwardImpl::exec(
return
;
}
// FIXME: optimize for lowbit cases
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
NaiveRelayoutForwardImpl
::
do_exec
(
src
,
dst
);
return
;
}
bool
is_bit4
=
is_int4
(
src
.
layout
);
bool
allow_nocontig
=
!
is_bit4
;
relayout
::
TransposeParam
trans_param
;
bool
trans
=
relayout
::
is_transpose
(
src
.
layout
,
dst
.
layout
,
trans_param
,
true
);
bool
trans
=
relayout
::
is_transpose
(
src
.
layout
,
dst
.
layout
,
trans_param
,
allow_nocontig
);
trans
=
check_dtype_support_transparam
(
trans
,
is_bit4
,
trans_param
);
exec_after_preprocess
(
src
,
dst
,
trans
?
&
trans_param
:
nullptr
);
}
void
RelayoutForwardImpl
::
exec_after_preprocess
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
relayout
::
TransposeParam
*
transpose
)
{
if
(
transpose
)
{
auto
kernel
=
[
tparam
=
*
transpose
,
src
,
dst
]()
{
bool
is_bit4
=
is_int4
(
src
.
layout
);
auto
kernel
=
[
tparam
=
*
transpose
,
src
,
dst
,
is_bit4
]()
{
auto
t
=
tparam
;
auto
dsize
=
src
.
layout
.
dtype
.
size
()
*
t
.
c
;
void
(
*
kptr
)(
size_t
,
size_t
,
size_t
,
size_t
,
void
*
,
void
*
,
size_t
)
=
nullptr
;
auto
src_addr
=
reinterpret_cast
<
uintptr_t
>
(
src
.
raw_ptr
()),
dst_addr
=
reinterpret_cast
<
uintptr_t
>
(
dst
.
raw_ptr
());
size_t
dsize
=
0
;
if
(
is_bit4
)
{
dsize
=
t
.
c
>>
1
;
}
else
{
dsize
=
src
.
layout
.
dtype
.
size
()
*
t
.
c
;
}
if
(
is_bit4
&&
dsize
==
0
)
{
kptr
=
call_transpose
<
dt_qint4
>
;
}
else
{
if
(
dsize
==
1
)
{
megdnn_assert
(
t
.
c
==
1
);
kptr
=
call_transpose
<
uint8_t
>
;
...
...
@@ -285,6 +554,7 @@ void RelayoutForwardImpl::exec_after_preprocess(
}
megdnn_assert
(
kptr
);
}
}
if
(
kptr
)
{
auto
sptr
=
src
.
raw_ptr
();
...
...
@@ -305,13 +575,20 @@ void RelayoutForwardImpl::exec_after_preprocess(
MEGDNN_DISPATCH_CPU_KERN_OPR
(
memcpy
(
dst
.
raw_ptr
(),
src
.
raw_ptr
(),
sz
));
return
;
}
memcpy_policy_t
cpy_noncont2cont
=
memcpy_noncont2cont
;
memcpy_policy_t
cpy_cont2noncont
=
memcpy_cont2noncont
;
bool
is_bit4
=
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
;
if
(
is_bit4
)
{
cpy_noncont2cont
=
memcpy_noncont2cont_4bit
;
cpy_cont2noncont
=
memcpy_cont2noncont_4bit
;
}
if
(
is_contig
(
dst
.
layout
)
&&
is_lastdim_contig
(
src
.
layout
))
{
return
dispatch_cont
(
handle
(),
dst
,
src
,
mem
cpy_noncont2cont
);
return
dispatch_cont
(
handle
(),
dst
,
src
,
cpy_noncont2cont
);
}
if
(
is_contig
(
src
.
layout
)
&&
is_lastdim_contig
(
dst
.
layout
))
{
return
dispatch_cont
(
handle
(),
src
,
dst
,
mem
cpy_cont2noncont
);
return
dispatch_cont
(
handle
(),
src
,
dst
,
cpy_cont2noncont
);
}
NaiveRelayoutForwardImpl
::
do_exec
(
src
,
dst
);
}
...
...
dnn/test/common/checker.cpp
浏览文件 @
96d90be1
...
...
@@ -98,7 +98,7 @@ template <typename Impl>
void
copy_tensors
(
const
CheckerHelper
::
TensorValueArray
&
dest
,
const
CheckerHelper
::
TensorValueArray
&
src
,
const
Impl
&
copy_impl
)
{
megdnn_assert
(
dest
.
size
()
==
src
.
size
());
megdnn_assert
(
dest
.
size
()
==
src
.
size
()
,
"%zu != %zu"
,
dest
.
size
(),
src
.
size
()
);
for
(
size_t
i
=
0
;
i
<
src
.
size
();
i
++
)
{
auto
&&
tensor
=
src
[
i
];
if
(
tensor
.
layout
.
ndim
==
0
)
...
...
dnn/test/fallback/relayout.cpp
浏览文件 @
96d90be1
...
...
@@ -34,6 +34,60 @@ TEST_F(FALLBACK, RELAYOUT_RECORD) {
checker
.
exec
({{
2
,
2
,
2
},
{
2
,
2
,
2
}});
}
TEST_F
(
FALLBACK
,
RELAYOUT_Q4
)
{
Checker
<
Relayout
>
checker
(
handle
());
UniformIntRNG
rng_int4
{
-
7
,
7
};
checker
.
set_rng
(
0
,
&
rng_int4
)
.
set_rng
(
1
,
&
rng_int4
)
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
1.
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS4
(
1.
f
))
.
execs
({{
2
,
2
,
1
,
1
},
{
1
,
1
,
2
,
2
}})
.
execs
({{
1
,
64
,
15
,
15
},
{
1
,
15
,
15
,
64
}})
.
execs
({{
1
,
5
,
9
,
32
},
{
1
,
5
,
32
,
9
}})
.
execl
(
TensorLayoutArray
{
{{
6400
},
{
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
20
,
320
},
{
1024
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
156
},
{
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
13
,
3
,
4
},
{
16
,
1
,
4
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
48
},
{
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
3
,
4
,
4
},
{
16
,
1
,
4
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
84
},
{
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
3
,
4
,
7
},
{
28
,
1
,
4
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
336
},
{
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
3
,
4
,
7
,
4
},
{
112
,
4
,
16
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
54
},
{
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
6
,
3
,
3
},
{
16
,
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
1200
,
3
},
{
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
20
,
60
,
3
},
{
256
,
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
20
,
20
,
3
,
3
},
{
256
,
12
,
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
1200
,
3
},
{
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
5
,
16
,
7
,
7
,
4
},
{
3136
,
196
,
28
,
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
5
,
16
,
7
,
7
,
4
},
{
3136
,
4
,
448
,
64
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
5
,
7
,
7
,
16
,
4
},
{
3136
,
448
,
64
,
4
,
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
5
,
7
,
7
,
16
,
4
},
{
3136
,
28
,
4
,
196
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
5
,
2
,
7
,
7
,
32
},
{
3136
,
1568
,
224
,
32
,
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
5
,
2
,
7
,
7
,
32
},
{
3136
,
32
,
448
,
64
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}})
.
execl
(
TensorLayoutArray
{
{{
5
,
7
,
7
,
2
,
32
},
{
3136
,
448
,
64
,
32
,
1
},
dtype
::
QuantizedS4
{
1.
f
}},
{{
5
,
7
,
7
,
2
,
32
},
{
3136
,
224
,
32
,
1568
,
1
},
dtype
::
QuantizedS4
{
1.
f
}}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
FALLBACK
,
BENCHMARK_RELAYOUT_CV
)
{
relayout
::
run_cv_benchmark
(
handle
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录