Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
71c2f612
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
71c2f612
编写于
3月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add relayout format to support layout transform between NCHW and NCHW64
GitOrigin-RevId: 1445ecfabe106eee57494b6bee4df14b9b81556b
上级
df009e89
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
1144 addition
and
221 deletion
+1144
-221
dnn/include/megdnn/tensor_format.h
dnn/include/megdnn/tensor_format.h
+26
-0
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+3
-0
dnn/src/common/relayout_format.cpp
dnn/src/common/relayout_format.cpp
+36
-1
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
+24
-22
dnn/src/cuda/relayout_format/opr_impl.cpp
dnn/src/cuda/relayout_format/opr_impl.cpp
+21
-17
dnn/src/cuda/relayout_format/relayout_format.cpp
dnn/src/cuda/relayout_format/relayout_format.cpp
+21
-14
dnn/src/cuda/relayout_format/relayout_format.cu
dnn/src/cuda/relayout_format/relayout_format.cu
+872
-157
dnn/src/cuda/relayout_format/relayout_format.cuh
dnn/src/cuda/relayout_format/relayout_format.cuh
+12
-8
dnn/src/cuda/utils.cuh
dnn/src/cuda/utils.cuh
+129
-2
未找到文件。
dnn/include/megdnn/tensor_format.h
浏览文件 @
71c2f612
...
...
@@ -196,6 +196,32 @@ public:
const
TensorLayout
&
layout
)
const
override
;
};
using
Image2DPack4TensorFormatBase
=
Image2DPackedTensorFormatBase
<
4
>
;
///*!
// * \brief used for tensors with lowbit data type
// *
// * \p SIZE_NBITS is the size in bits of element of the tensor.
// *
// */
//template <size_t SIZE_NBITS_>
//class LowbitTensorFormat : public TensorFormat::ImplBase {
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
// size_t m_align_size_in_bits;
//
//protected: //?
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits);
//
//public:
// size_t align_size_in_bits() const {
// return m_align_size_in_bits;
// }
//
// std::string to_string() const override;
//
// void serialize_append(
//
//
//};
}
// namespace detail
/*!
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
71c2f612
...
...
@@ -895,6 +895,7 @@ Relayout mode.
* ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
* ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
* ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
* ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
**Float weight transformation definitions**
...
...
@@ -969,6 +970,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'NCHW_NCHW4'
,
'NCHW4_NCHW'
,
'NCHW_NCHW4_WEIGHT'
,
'NCHW_NCHW64'
,
'NCHW64_NCHW'
,
)
)
...
...
dnn/src/common/relayout_format.cpp
浏览文件 @
71c2f612
...
...
@@ -251,6 +251,23 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
dst
[
3
]
=
src
[
3
];
megdnn_assert
(
dst
[
1
]
%
param
().
group
==
0
);
break
;
case
Param
::
Mode
::
NCHW_NCHW64
:
megdnn_assert
(
src
.
ndim
==
4
&&
(
src
[
1
]
%
64
)
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
src
[
0
];
dst
[
1
]
=
src
[
1
]
/
64
;
dst
[
2
]
=
src
[
2
];
dst
[
3
]
=
src
[
3
];
dst
[
4
]
=
64
;
break
;
case
Param
::
Mode
::
NCHW64_NCHW
:
megdnn_assert
(
src
.
ndim
==
5
);
dst
.
ndim
=
4
;
dst
[
0
]
=
src
[
0
];
dst
[
1
]
=
src
[
1
]
*
64
;
dst
[
2
]
=
src
[
2
];
dst
[
3
]
=
src
[
3
];
break
;
default:
megdnn_assert
(
0
,
"Invalid RelayoutFormat Mode"
);
break
;
...
...
@@ -352,7 +369,12 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
CHECK_SRC
(
DefaultTensorFormat
::
make
());
dst
=
src
;
break
;
case
Param
::
Mode
::
NCHW_NCHW64
:
dst
=
src
;
break
;
case
Param
::
Mode
::
NCHW64_NCHW
:
dst
=
src
;
break
;
default:
megdnn_throw
(
"Invalid relayout format mode"
);
break
;
...
...
@@ -633,6 +655,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
exec_src
=
src
.
dimshuffle
({
3
,
0
,
1
,
2
,
4
});
exec_dst
=
dst
;
break
;
case
Param
::
Mode
::
NCHW_NCHW64
:
// src is {N, C, H, W}
// dst is {N, C/64, H, W, 64}
exec_src
=
src
.
reshape
({
src
[
0
],
src
[
1
]
/
64
,
64
,
src
[
2
],
src
[
3
]})
.
dimshuffle
({
0
,
1
,
3
,
4
,
2
});
exec_dst
=
dst
;
break
;
case
Param
::
Mode
::
NCHW64_NCHW
:
// src is {N, C/64, H, W, 64}
// dst is {N, C, H, W}
exec_src
=
src
.
dimshuffle
({
0
,
1
,
4
,
2
,
3
});
exec_dst
=
dst
;
break
;
default:
megdnn_assert
(
0
,
"Invalid RelayoutFormat Mode"
);
}
...
...
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
浏览文件 @
71c2f612
...
...
@@ -69,12 +69,9 @@ size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_in_bytes(
void
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
exec
(
const
ExecArgs
&
args
)
const
{
using
Format
=
Param
::
Format
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
auto
layouts
=
make_underlying_tensor_layout
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
bias_layout
),
*
(
args
.
z
_layout
),
*
(
args
.
dst_layout
));
*
(
args
.
src_layout
),
*
(
args
.
filter_layout
),
*
(
args
.
bias
_layout
),
*
(
args
.
z_layout
),
*
(
args
.
dst_layout
));
auto
ws
=
get_workspace_bundle
(
args
.
workspace
.
raw_ptr
,
args
);
auto
ws_src
=
ws
.
get
(
0
);
auto
ws_filter
=
ws
.
get
(
1
);
...
...
@@ -82,20 +79,27 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
void
*
ws_z
=
nullptr
;
if
(
args
.
z_layout
->
ndim
>
0
)
ws_z
=
ws
.
get
(
4
);
auto
&&
stream
=
cuda_stream
(
args
.
opr
->
handle
());
auto
nchw2nchw64
=
[
](
const
TensorND
&
src
,
void
*
raw_dptr
)
{
if
(
raw_d
ptr
==
nullptr
)
//
auto&& stream = cuda_stream(args.opr->handle());
auto
nchw2nchw64
=
[
&
args
](
const
TensorND
&
src
,
TensorND
&&
dst
)
{
if
(
dst
.
raw_
ptr
==
nullptr
)
return
;
auto
relayout
=
args
.
handle
->
create_operator
<
RelayoutFormat
>
();
relayout
->
param
()
=
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW64
;
Workspace
dummy
;
relayout
->
exec
(
src
,
dst
,
dummy
);
};
auto
nchw642nchw
=
[](
const
TensorND
&
src
,
void
*
raw_dptr
)
{
auto
nchw642nchw
=
[
&
args
](
const
TensorND
&
src
,
TensorND
&&
dst
)
{
auto
relayout
=
args
.
handle
->
create_operator
<
RelayoutFormat
>
();
relayout
->
param
()
=
RelayoutFormat
::
Param
::
Mode
::
NCHW64_NCHW
;
Workspace
dummy
;
relayout
->
exec
(
src
,
dst
,
dummy
);
};
// reformat src
nchw2nchw64
(
*
(
args
.
src_tensor
),
ws_src
);
nchw2nchw64
(
*
(
args
.
src_tensor
),
{
ws_src
,
layouts
[
0
]}
);
// reformat filter
nchw2nchw64
(
*
(
args
.
filter_tensor
),
ws_filter
);
nchw2nchw64
(
*
(
args
.
filter_tensor
),
{
ws_filter
,
layouts
[
1
]}
);
// reformat z
nchw2nchw64
(
*
(
args
.
z_tensor
),
ws_z
);
nchw2nchw64
(
*
(
args
.
z_tensor
),
{
ws_z
,
layouts
[
3
]}
);
TensorND
src_
{
ws_src
,
layouts
[
0
]},
filter_
{
ws_filter
,
layouts
[
1
]},
bias_
{
args
.
bias_tensor
->
raw_ptr
,
layouts
[
2
]},
z_
{
ws_z
,
layouts
[
3
]},
dst_
{
ws_dst
,
layouts
[
4
]};
...
...
@@ -109,22 +113,22 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
args
.
preprocessed_filter
};
m_underlying_algo
.
exec
(
args
);
// reformat dst
nchw642nchw
(
dst_
,
args
.
dst_tensor
->
raw_ptr
);
nchw642nchw
(
dst_
,
{
args
.
dst_tensor
->
raw_ptr
,
args
.
dst_tensor
->
layout
}
);
}
SmallVector
<
TensorLayout
>
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
make_underlying_tensor_layout
(
const
TensorLayout
&
src
,
const
CanonizedFilterMeta
&
filter_meta
,
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
const
{
size_t
n
=
src
[
0
],
ci
=
src
[
1
],
hi
=
src
[
2
],
wi
=
src
[
3
];
size_t
co
=
dst
[
1
],
ho
=
dst
[
2
],
wo
=
dst
[
3
];
size_t
fh
=
filter
_meta
.
spatial
[
0
],
fw
=
filter_meta
.
spatial
[
1
];
size_t
fh
=
filter
[
2
],
fw
=
filter
[
3
];
SmallVector
<
TensorLayout
>
rst
;
rst
.
emplace_back
(
TensorLayout
{{
n
,
ci
/
64
,
hi
,
wi
,
64
},
src
.
dtype
});
rst
.
emplace_back
(
TensorLayout
{{
co
,
ci
/
64
,
fh
,
fw
,
64
},
filter
.
dtype
});
rst
.
emplace_back
(
TensorLayout
{{
1
,
co
/
64
,
1
,
1
,
64
},
bias
.
dtype
});
if
(
z
.
layout
.
ndim
>
0
)
{
if
(
z
.
ndim
>
0
)
{
rst
.
emplace_back
(
TensorLayout
{{
n
,
co
/
64
,
ho
,
wo
,
64
},
z
.
dtype
});
}
else
{
rst
.
emplace_back
(
TensorLayout
{});
...
...
@@ -134,15 +138,13 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout(
}
WorkspaceBundle
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
get_workspace_bundle
(
void
*
ptr
,
const
SizeArgs
&
args
)
const
{
void
*
raw_
ptr
,
const
SizeArgs
&
args
)
const
{
size_t
ws_size_src
=
args
.
src_layout
->
span
().
dist_byte
();
size_t
ws_size_filter
=
args
.
filter_layout
->
span
().
dist_byte
();
size_t
ws_size_dst
=
args
.
dst_layout
->
span
().
dist_byte
();
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
auto
layouts
=
make_underlying_tensor_layout
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
bias_layout
),
*
(
args
.
z
_layout
),
*
(
args
.
dst_layout
));
*
(
args
.
src_layout
),
*
(
args
.
filter_layout
),
*
(
args
.
bias
_layout
),
*
(
args
.
z_layout
),
*
(
args
.
dst_layout
));
SizeArgs
args_
{
args
.
opr
,
layouts
[
0
],
layouts
[
1
],
...
...
dnn/src/cuda/relayout_format/opr_impl.cpp
浏览文件 @
71c2f612
...
...
@@ -78,29 +78,33 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
return
handle
()
->
create_operator
<
RelayoutForward
>
()
->
exec
(
{
src
.
raw_ptr
,
exec_src_layout
},
{
dst
.
raw_ptr
,
exec_dst_layout
});
}
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4
||
param
().
mode
==
Param
::
Mode
::
NCHW4_NCHW
||
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4_WEIGHT
)
{
bool
is_trans_4bits
=
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW64
||
param
().
mode
==
Param
::
Mode
::
NCHW64_NCHW
)
&&
(
src_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src_dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
);
bool
is_nchw_nchw4
=
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4
||
param
().
mode
==
Param
::
Mode
::
NCHW4_NCHW
||
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4_WEIGHT
;
if
(
is_trans_4bits
||
is_nchw_nchw4
)
{
bool
is_usable
=
relayout_format
::
RelayoutFormatFast
::
usable
(
src
.
layout
,
dst
.
layout
);
megdnn_assert
(
is_usable
,
"RelayoutFormat
NCHW_NCHW4 kernel not usable for %s(%s)
"
"t
o %s(%s)
"
,
"RelayoutFormat
Fast kernel is not usable for
"
"t
ransforming %s(%s) to %s(%s).
"
,
src
.
layout
.
to_string
().
c_str
(),
src
.
layout
.
dtype
.
name
(),
dst
.
layout
.
to_string
().
c_str
(),
dst
.
layout
.
dtype
.
name
());
relayout_format
::
RelayoutFormatFast
::
exec
(
src
,
dst
,
cuda_stream
(
this
->
handle
()),
param
().
mode
,
param
().
group
);
}
else
{
TensorLayout
exec_src
,
exec_dst
,
exec_workspace
;
deduce_exec_layout
(
src
.
layout
,
dst
.
layout
,
exec_workspace
,
exec_src
,
exec_dst
);
TensorND
exec_src_nd
{
src
.
raw_ptr
,
exec_src
};
TensorND
exec_dst_nd
{
dst
.
raw_ptr
,
exec_dst
};
handle
()
->
create_operator
<
RelayoutForward
>
()
->
exec
(
exec_src_nd
,
exec_dst_nd
);
return
relayout_format
::
RelayoutFormatFast
::
exec
(
src
,
dst
,
cuda_stream
(
this
->
handle
()),
param
().
mode
,
param
().
group
);
}
// fallback impls
TensorLayout
exec_src
,
exec_dst
,
exec_workspace
;
deduce_exec_layout
(
src
.
layout
,
dst
.
layout
,
exec_workspace
,
exec_src
,
exec_dst
);
TensorND
exec_src_nd
{
src
.
raw_ptr
,
exec_src
};
TensorND
exec_dst_nd
{
dst
.
raw_ptr
,
exec_dst
};
handle
()
->
create_operator
<
RelayoutForward
>
()
->
exec
(
exec_src_nd
,
exec_dst_nd
);
}
size_t
RelayoutFormatImpl
::
get_workspace_in_bytes
(
...
...
dnn/src/cuda/relayout_format/relayout_format.cpp
浏览文件 @
71c2f612
...
...
@@ -24,6 +24,8 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale,
scale
=
tensor_dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
}
else
if
(
tensor_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
scale
=
tensor_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
}
else
if
(
tensor_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
scale
=
tensor_dtype
.
param
<
dtype
::
QuantizedS4
>
().
scale
;
}
}
...
...
@@ -39,9 +41,8 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
cudaStream_t
stream
,
RelayoutFormat
::
Param
::
Mode
mode
,
int
group
)
{
size_t
ih
=
src
.
layout
[
2
];
size_t
iw
=
src
.
layout
[
3
];
size_t
hw
=
ih
*
iw
;
auto
&&
stype
=
src
.
layout
.
dtype
;
auto
&&
dtype
=
dst
.
layout
.
dtype
;
float
src_scale
=
1.
f
;
float
dst_scale
=
1.
f
;
uint8_t
src_zero_point
=
0
;
...
...
@@ -51,22 +52,28 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
)
{
src_zero_point
=
128
;
}
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4
)
{
if
(
hw
%
4
==
0
)
{
relayout_format_cuda_nchw_nchw4
<
4
>
(
src
,
dst
,
stream
,
src_scale
,
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4
||
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW64
)
{
return
relayout_format_cuda_nchw_nchwx
(
src
,
dst
,
stream
,
src_scale
,
dst_scale
,
src_zero_point
,
dst_zero_point
,
group
);
}
else
{
relayout_format_cuda_nchw_nchw4
<
1
>
(
src
,
dst
,
stream
,
src_scale
,
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW64_NCHW
)
{
megdnn_assert
(
group
==
1
,
"RelayoutFormat kernel only support transforming NCHW64 "
"to NCHW with group = 1(group:%d)"
,
group
);
return
relayout_format_cuda_nchwx_nchw
(
src
,
dst
,
stream
,
src_scale
,
dst_scale
,
src_zero_point
,
dst_zero_point
,
group
);
}
dst_zero_point
);
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4_WEIGHT
)
{
relayout_format_cuda_nchw_nchw4_weight
(
src
,
dst
,
stream
);
re
turn
re
layout_format_cuda_nchw_nchw4_weight
(
src
,
dst
,
stream
);
}
else
if
(
mode
==
RelayoutFormat
::
Param
::
Mode
::
NCHW4_NCHW
)
{
relayout_format_cuda_nchw4_nchw
(
src
,
dst
,
stream
,
group
);
re
turn
re
layout_format_cuda_nchw4_nchw
(
src
,
dst
,
stream
,
group
);
}
else
{
megdnn_throw
(
"only support nchw_nchw4 nchw4_nchw layout_format"
);
megdnn_throw
(
"only support nchw_nchw64/nchw64_nchw/nchw_nchw4/nchw4_nchw "
"layout_format"
);
}
}
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
dnn/src/cuda/relayout_format/relayout_format.cu
浏览文件 @
71c2f612
此差异已折叠。
点击以展开。
dnn/src/cuda/relayout_format/relayout_format.cuh
浏览文件 @
71c2f612
...
...
@@ -19,14 +19,11 @@ namespace megdnn {
namespace
cuda
{
namespace
relayout_format
{
template
<
int
pack_w
=
1
>
void
relayout_format_cuda_nchw_nchw4
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
=
1.
f
,
const
float
dst_scale
=
1.
f
,
const
uint8_t
src_zero_point
=
0
,
const
uint8_t
dst_zero_point
=
0
,
const
int
group
=
1
);
void
relayout_format_cuda_nchw_nchwx
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
=
1.
f
,
const
float
dst_scale
=
1.
f
,
const
uint8_t
src_zero_point
=
0
,
const
uint8_t
dst_zero_point
=
0
,
const
int
group
=
1
);
bool
relayout_format_cuda_usable
(
const
TensorLayout
&
src_layout
,
const
TensorLayout
&
dst_layout
);
...
...
@@ -35,6 +32,13 @@ void relayout_format_cuda_nchw4_nchw(const TensorND& src, const TensorND& dst,
const
cudaStream_t
&
stream
,
const
int
group
);
void
relayout_format_cuda_nchwx_nchw
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
=
1.
f
,
const
float
dst_scale
=
1.
f
,
const
uint8_t
src_zero_point
=
0
,
const
uint8_t
dst_zero_point
=
0
);
void
relayout_format_cuda_nchw_nchw4_weight
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
);
...
...
dnn/src/cuda/utils.cuh
浏览文件 @
71c2f612
...
...
@@ -110,6 +110,12 @@ MEGDNN_NORETURN void report_error(const char* msg);
template
<
typename
T
,
size_t
N
>
struct
array_wrapper
{
T
data
[
N
];
MEGDNN_DEVICE
__forceinline__
T
&
operator
[](
size_t
pos
)
{
return
reinterpret_cast
<
T
&>
(
data
[
pos
]);
}
MEGDNN_DEVICE
__forceinline__
T
const
&
operator
[](
size_t
pos
)
const
{
return
reinterpret_cast
<
T
const
&>
(
data
[
pos
]);
}
};
/*!
...
...
@@ -207,12 +213,29 @@ struct CudaDTypeParamImpl<dt_quint4> : DTypeParamImpl<dt_quint4> {
CudaDTypeParamImpl
(
const
DTypeParamImpl
<
dt_quint4
>&
param
)
:
CudaDTypeParamImpl
(
param
.
scale
,
param
.
zero_point
)
{}
__device__
uint8_t
quantize
(
float
in
)
const
{
__device__
dt_quint4
quantize
(
float
in
)
const
{
float
v
=
in
*
inv_scale
;
v
=
roundf
(
v
);
v
=
v
+
zero_point
;
v
=
fmin
(
fmax
(
0.
f
,
v
),
15.
f
);
return
static_cast
<
uint8_t
>
(
v
);
return
static_cast
<
dt_quint4
>
(
v
);
}
};
template
<
>
struct
CudaDTypeParamImpl
<
dt_qint4
>
:
DTypeParamImpl
<
dt_qint4
>
{
float
inv_scale
;
CudaDTypeParamImpl
()
=
default
;
CudaDTypeParamImpl
(
float
scale
)
:
DTypeParamImpl
<
dt_qint4
>
(
scale
),
inv_scale
(
1.0
f
/
scale
)
{}
CudaDTypeParamImpl
(
const
DTypeParamImpl
<
dt_qint4
>&
param
)
:
CudaDTypeParamImpl
(
param
.
scale
)
{}
__device__
dt_qint4
quantize
(
float
in
)
const
{
float
v
=
in
*
inv_scale
;
v
=
roundf
(
v
);
v
=
fmin
(
fmax
(
-
8.
f
,
v
),
7.
f
);
return
static_cast
<
dt_qint4
>
(
v
);
}
};
...
...
@@ -351,6 +374,110 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval,
return
make_float4
(
lval
.
x
+
rval
.
x
,
lval
.
y
+
rval
.
y
,
lval
.
z
+
rval
.
z
,
lval
.
w
+
rval
.
w
);
}
MEGDNN_DEVICE
__forceinline__
static
int
transform_int8_to_int4x8
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
unsigned
out
;
#if __CUDA_ARCH__ >= 750
asm
volatile
(
"{ .reg .u32 r4;"
"cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;"
"}"
:
"=r"
(
out
)
:
"r"
(
s0
),
"r"
(
s1
),
"r"
(
s2
),
"r"
(
s3
),
"r"
(
s4
),
"r"
(
s5
),
"r"
(
s6
),
"r"
(
s7
));
#else
#define CVT_SAT_S4_S32(r, bits) \
r = r <= -8 ? -8 : r; \
r = r > 7 ? 7 : r; \
r = (((unsigned)r & 0xf) << bits);
CVT_SAT_S4_S32
(
s0
,
0
)
CVT_SAT_S4_S32
(
s1
,
4
)
CVT_SAT_S4_S32
(
s2
,
8
)
CVT_SAT_S4_S32
(
s3
,
12
)
CVT_SAT_S4_S32
(
s4
,
16
)
CVT_SAT_S4_S32
(
s5
,
20
)
CVT_SAT_S4_S32
(
s6
,
24
)
CVT_SAT_S4_S32
(
s7
,
28
)
out
=
s0
+
s1
+
s2
+
s3
+
s4
+
s5
+
s6
+
s7
;
#undef CVT_SAT_S4_S32
#endif
return
reinterpret_cast
<
int
const
&>
(
out
);
}
MEGDNN_DEVICE
__forceinline__
static
int
transform_int8_to_uint4x8
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
unsigned
out
;
#if __CUDA_ARCH__ >= 750
asm
volatile
(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
:
"=r"
(
out
)
:
"r"
(
s0
),
"r"
(
s1
),
"r"
(
s2
),
"r"
(
s3
),
"r"
(
s4
),
"r"
(
s5
),
"r"
(
s6
),
"r"
(
s7
));
#else
#define CVT_SAT_U4_S32(r, bits) \
r = r <= 0 ? 0 : r; \
r = r > 15 ? 15 : r; \
r = (((unsigned)r & 0xf) << bits);
CVT_SAT_U4_S32
(
s0
,
0
)
CVT_SAT_U4_S32
(
s1
,
4
)
CVT_SAT_U4_S32
(
s2
,
8
)
CVT_SAT_U4_S32
(
s3
,
12
)
CVT_SAT_U4_S32
(
s4
,
16
)
CVT_SAT_U4_S32
(
s5
,
20
)
CVT_SAT_U4_S32
(
s6
,
24
)
CVT_SAT_U4_S32
(
s7
,
28
)
out
=
s0
+
s1
+
s2
+
s3
+
s4
+
s5
+
s6
+
s7
;
#undef CVT_SAT_U4_S32
#endif
return
reinterpret_cast
<
int
const
&>
(
out
);
}
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
static
int
unpack_integer_4bits
(
unsigned
storage
,
unsigned
bits
);
template
<
>
MEGDNN_DEVICE
__forceinline__
int
unpack_integer_4bits
<
true
>
(
unsigned
storage
,
unsigned
bits
)
{
uint8_t
result
=
(
uint8_t
)((
unsigned
)(
storage
>>
bits
)
&
0xf
);
static
constexpr
uint8_t
mask
=
(
uint8_t
)((
1
<<
4
)
-
1
);
return
(
result
&
uint8_t
(
1
<<
3
))
?
((
int
)(
result
)
|
~
(
int
)(
mask
))
:
(
int
)(
result
);
}
template
<
>
MEGDNN_DEVICE
__forceinline__
int
unpack_integer_4bits
<
false
>
(
unsigned
storage
,
unsigned
bits
)
{
uint8_t
result
=
(
uint8_t
)((
unsigned
)(
storage
>>
bits
)
&
0xf
);
return
(
int
)(
result
);
}
MEGDNN_DEVICE
__forceinline__
static
void
transform_int4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
unpack_integer_4bits
<
true
>
(
reinterpret_cast
<
unsigned
const
&>
(
source
),
(
i
<<
2
));
}
}
MEGDNN_DEVICE
__forceinline__
static
void
transform_uint4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
unpack_integer_4bits
<
false
>
(
reinterpret_cast
<
unsigned
const
&>
(
source
),
(
i
<<
2
));
}
}
#endif
}
// namespace cuda
}
// namespace megdnn
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录