Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
43c59204
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
43c59204
编写于
5月 31, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn/cuda): refactor relayout format kernels
GitOrigin-RevId: ab86e6653342ae9f74dd069cfb85aabca6dc637c
上级
f41a8086
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
1255 addition
and
1451 deletion
+1255
-1451
dnn/src/cuda/integer_subbyte_utils.cuh
dnn/src/cuda/integer_subbyte_utils.cuh
+16
-18
dnn/src/cuda/relayout_format/cuda_post_process.cuh
dnn/src/cuda/relayout_format/cuda_post_process.cuh
+171
-0
dnn/src/cuda/relayout_format/helper.cuh
dnn/src/cuda/relayout_format/helper.cuh
+0
-252
dnn/src/cuda/relayout_format/relayout_format.cu
dnn/src/cuda/relayout_format/relayout_format.cu
+4
-1096
dnn/src/cuda/relayout_format/relayout_format.cuh
dnn/src/cuda/relayout_format/relayout_format.cuh
+14
-0
dnn/src/cuda/relayout_format/relayout_format_kern.cuh
dnn/src/cuda/relayout_format/relayout_format_kern.cuh
+346
-0
dnn/src/cuda/relayout_format/relayout_format_utils.cuh
dnn/src/cuda/relayout_format/relayout_format_utils.cuh
+128
-0
dnn/src/cuda/relayout_format/translayout.cuh
dnn/src/cuda/relayout_format/translayout.cuh
+537
-0
dnn/src/cuda/warp_perspective/forward.cu
dnn/src/cuda/warp_perspective/forward.cu
+39
-85
未找到文件。
dnn/src/cuda/integer_subbyte_utils.cuh
浏览文件 @
43c59204
...
...
@@ -110,35 +110,33 @@ MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage,
return
(
result
<<
(
shift
-
bits
))
>>
shift
;
}
MEGDNN_DEVICE
__forceinline__
static
void
transform_int4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
unpack_integer_4bits
<
true
>
(
reinterpret_cast
<
unsigned
const
&>
(
source
),
(
i
<<
2
));
}
}
MEGDNN_DEVICE
__forceinline__
static
void
transform_uint4x8_to_int8
(
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
static
void
transform_b4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
unpack_integer_4bits
<
false
>
(
result
[
i
]
=
unpack_integer_4bits
<
signedness
>
(
reinterpret_cast
<
unsigned
const
&>
(
source
),
(
i
<<
2
));
}
}
MEGDNN_DEVICE
__forceinline__
static
void
transform_int4x2_to_int8
(
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
static
void
transform_b4x2_to_int8
(
int
(
&
result
)[
2
],
const
uint8_t
&
source
)
{
result
[
0
]
=
unpack_integer_4bits
<
true
>
(
source
,
0
);
result
[
1
]
=
unpack_integer_4bits
<
true
>
(
source
,
4
);
result
[
0
]
=
unpack_integer_4bits
<
signedness
>
(
source
,
0
);
result
[
1
]
=
unpack_integer_4bits
<
signedness
>
(
source
,
4
);
}
MEGDNN_DEVICE
__forceinline__
static
void
transform_uint4x2_to_int8
(
int
(
&
result
)[
2
],
const
uint8_t
&
source
)
{
result
[
0
]
=
unpack_integer_4bits
<
false
>
(
source
,
0
);
result
[
1
]
=
unpack_integer_4bits
<
false
>
(
source
,
4
);
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
static
int
transform_int8_to_b4x8
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
if
(
signedness
)
{
return
transform_int8_to_int4x8
(
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
);
}
else
{
return
transform_int8_to_uint4x8
(
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
);
}
}
}
// namespace integer_subbyte
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/relayout_format/cuda_post_process.cuh
0 → 100644
浏览文件 @
43c59204
/**
* \file dnn/src/cuda/relayout_format/cuda_post_process.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/relayout_format/relayout_format.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
relayout_format
{
namespace
internal
{
template
<
typename
SrcType
,
typename
DstType
,
bool
same_scale
>
struct
CudaPostProcess
;
template
<
>
struct
CudaPostProcess
<
dtype
::
Uint8
,
dtype
::
QuantizedS8
,
true
>
{
CudaPostProcess
(
float
,
uint8_t
,
float
,
uint8_t
){};
inline
__device__
int8_t
operator
()(
uint8_t
val
)
{
return
val
-
128
;
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
Uint8
,
dtype
::
QuantizedS8
,
false
>
{
CudaDTypeParamImpl
<
dt_qint8
>
m_dst_type_cvt
;
CudaPostProcess
(
float
,
uint8_t
,
float
dst_scale
,
uint8_t
)
{
m_dst_type_cvt
=
CudaDTypeParamImpl
<
dt_qint8
>
(
dst_scale
);
};
inline
__device__
int8_t
operator
()(
uint8_t
val
)
{
return
m_dst_type_cvt
.
quantize
((
float
)
val
-
128.
f
).
as_int8
();
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS8
,
false
>
{
CudaDTypeParamImpl
<
dt_qint8
>
m_dst_type_cvt
;
CudaDTypeParamImpl
<
dt_quint8
>
m_src_type_cvt
;
CudaPostProcess
(
float
src_scale
,
uint8_t
src_zero_point
,
float
dst_scale
,
uint8_t
)
{
m_dst_type_cvt
=
CudaDTypeParamImpl
<
dt_qint8
>
(
dst_scale
);
m_src_type_cvt
=
CudaDTypeParamImpl
<
dt_quint8
>
(
src_scale
,
src_zero_point
);
};
inline
__device__
int8_t
operator
()(
uint8_t
val
)
{
float
med_var
=
m_src_type_cvt
.
dequantize
(
dt_quint8
(
val
));
return
m_dst_type_cvt
.
quantize
(
med_var
).
as_int8
();
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS8
,
true
>
{
uint8_t
m_src_zero_point
=
0
;
CudaPostProcess
(
float
,
uint8_t
src_zero_point
,
float
,
uint8_t
)
{
m_src_zero_point
=
src_zero_point
;
};
inline
__device__
int8_t
operator
()(
uint8_t
val
)
{
return
val
-
m_src_zero_point
;
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
QuantizedS8
,
dtype
::
QuantizedS8
,
false
>
{
CudaDTypeParamImpl
<
dt_qint8
>
m_dst_type_cvt
;
CudaDTypeParamImpl
<
dt_qint8
>
m_src_type_cvt
;
CudaPostProcess
(
float
src_scale
,
uint8_t
,
float
dst_scale
,
uint8_t
)
{
m_dst_type_cvt
=
CudaDTypeParamImpl
<
dt_qint8
>
(
dst_scale
);
m_src_type_cvt
=
CudaDTypeParamImpl
<
dt_qint8
>
(
src_scale
);
};
inline
__device__
int8_t
operator
()(
int8_t
val
)
{
float
med_var
=
m_src_type_cvt
.
dequantize
(
dt_qint8
(
val
));
return
m_dst_type_cvt
.
quantize
(
med_var
).
as_int8
();
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
QuantizedS8
,
dtype
::
QuantizedS8
,
true
>
{
CudaPostProcess
(){};
CudaPostProcess
(
float
,
uint8_t
,
float
,
uint8_t
){};
inline
__device__
int8_t
operator
()(
int8_t
val
)
{
return
val
;
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
false
>
{
CudaDTypeParamImpl
<
dt_qint32
>
m_dst_type_cvt
;
CudaDTypeParamImpl
<
dt_qint32
>
m_src_type_cvt
;
CudaPostProcess
(
float
src_scale
,
int
,
float
dst_scale
,
int
)
{
m_dst_type_cvt
=
CudaDTypeParamImpl
<
dt_qint32
>
(
dst_scale
);
m_src_type_cvt
=
CudaDTypeParamImpl
<
dt_qint32
>
(
src_scale
);
};
inline
__device__
int
operator
()(
int
val
)
{
float
med_var
=
m_src_type_cvt
.
dequantize
(
dt_qint32
(
val
));
return
m_dst_type_cvt
.
quantize
(
med_var
).
as_int32
();
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
true
>
{
CudaPostProcess
(
float
,
int
,
float
,
int
){};
inline
__device__
int
operator
()(
int
val
)
{
return
val
;
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
QuantizedS4
,
dtype
::
QuantizedS4
,
false
>
{
using
SrcType
=
dtype
::
QuantizedS4
;
using
DstType
=
dtype
::
QuantizedS4
;
CudaDTypeParamImpl
<
dt_qint4
>
m_dst_type_cvt
;
CudaDTypeParamImpl
<
dt_qint4
>
m_src_type_cvt
;
CudaPostProcess
(
float
src_scale
,
uint8_t
,
float
dst_scale
,
uint8_t
)
{
m_dst_type_cvt
=
CudaDTypeParamImpl
<
dt_qint4
>
(
dst_scale
);
m_src_type_cvt
=
CudaDTypeParamImpl
<
dt_qint4
>
(
src_scale
);
}
inline
__device__
int8_t
operator
()(
int8_t
val
)
{
float
intermediate
=
m_src_type_cvt
.
dequantize
(
dt_qint4
(
val
));
return
m_dst_type_cvt
.
quantize
(
intermediate
).
as_int8
();
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
QuantizedS4
,
dtype
::
QuantizedS4
,
true
>
{
using
SrcType
=
dtype
::
QuantizedS4
;
using
DstType
=
dtype
::
QuantizedS4
;
CudaPostProcess
(
float
,
uint8_t
,
float
,
uint8_t
){};
inline
__device__
int8_t
operator
()(
int8_t
val
)
{
return
val
;
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
Quantized4Asymm
,
dtype
::
Quantized4Asymm
,
false
>
{
using
SrcType
=
dtype
::
Quantized4Asymm
;
using
DstType
=
dtype
::
Quantized4Asymm
;
CudaDTypeParamImpl
<
dt_quint4
>
m_dst_type_cvt
;
CudaDTypeParamImpl
<
dt_quint4
>
m_src_type_cvt
;
CudaPostProcess
(
float
src_scale
,
uint8_t
src_zero_point
,
float
dst_scale
,
uint8_t
dst_zero_point
)
{
m_dst_type_cvt
=
CudaDTypeParamImpl
<
dt_quint4
>
(
dst_scale
,
dst_zero_point
);
m_src_type_cvt
=
CudaDTypeParamImpl
<
dt_quint4
>
(
src_scale
,
src_zero_point
);
};
inline
__device__
uint8_t
operator
()(
uint8_t
val
)
{
float
intermediate
=
m_src_type_cvt
.
dequantize
(
dt_quint4
(
val
));
return
m_dst_type_cvt
.
quantize
(
intermediate
).
as_uint8
();
}
};
template
<
>
struct
CudaPostProcess
<
dtype
::
Quantized4Asymm
,
dtype
::
Quantized4Asymm
,
true
>
{
using
SrcType
=
dtype
::
Quantized4Asymm
;
using
DstType
=
dtype
::
Quantized4Asymm
;
uint8_t
m_src_zero_point
=
0
;
uint8_t
m_dst_zero_point
=
0
;
CudaPostProcess
(
float
,
uint8_t
src_zero_point
,
float
,
uint8_t
dst_zero_point
)
{
m_src_zero_point
=
src_zero_point
;
m_dst_zero_point
=
dst_zero_point
;
};
inline
__device__
uint8_t
operator
()(
uint8_t
val
)
{
int
result
=
val
-
m_src_zero_point
+
m_dst_zero_point
;
result
=
result
>=
0
?
result
:
0
;
result
=
result
<
16
?
result
:
15
;
return
static_cast
<
uint8_t
>
(
result
);
}
};
}
// namespace internal
}
// namespace relayout_format
}
// namespace cuda
}
// namespace megdnn
dnn/src/cuda/relayout_format/helper.cuh
已删除
100644 → 0
浏览文件 @
f41a8086
/**
* \file dnn/src/cuda/relayout_format/helper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
namespace
megdnn
{
namespace
cuda
{
namespace
relayout_format
{
#define devfunc __forceinline__ __device__
template
<
int
size_nbits
>
devfunc
int
make_zero
(
int
zero_point
);
template
<
>
devfunc
int
make_zero
<
4
>
(
int
zero_point
)
{
return
transform_int8_to_uint4x8
(
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
);
}
template
<
typename
AccessType
,
int
LoadBytes
>
struct
global_load_with_zero_point
;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// The redundant mov PTX instruction is used to enforce the compiler to
// initialize data to zero before ld.global
template
<
typename
AccessType
>
struct
global_load_with_zero_point
<
AccessType
,
32
>
{
devfunc
global_load_with_zero_point
(
AccessType
&
D
,
void
const
*
ptr
,
bool
pred_guard
,
int
zero_point
)
{
uint4
*
data
=
reinterpret_cast
<
uint4
*>
(
&
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %9, 0;
\n
"
" mov.b32 %0, %10;
\n
"
" mov.b32 %1, %10;
\n
"
" mov.b32 %2, %10;
\n
"
" mov.b32 %3, %10;
\n
"
" mov.b32 %4, %10;
\n
"
" mov.b32 %5, %10;
\n
"
" mov.b32 %6, %10;
\n
"
" mov.b32 %7, %10;
\n
"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];
\n
"
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];
\n
"
"}
\n
"
:
"=r"
(
data
[
0
].
x
),
"=r"
(
data
[
0
].
y
),
"=r"
(
data
[
0
].
z
),
"=r"
(
data
[
0
].
w
),
"=r"
(
data
[
1
].
x
),
"=r"
(
data
[
1
].
y
),
"=r"
(
data
[
1
].
z
),
"=r"
(
data
[
1
].
w
)
:
"l"
(
ptr
),
"r"
((
int
)
pred_guard
),
"r"
(
reinterpret_cast
<
unsigned
&>
(
zero_point
)),
"l"
(((
uint8_t
*
)
ptr
)
+
16
));
}
};
template
<
typename
AccessType
>
struct
global_load_with_zero_point
<
AccessType
,
16
>
{
devfunc
global_load_with_zero_point
(
AccessType
&
D
,
void
const
*
ptr
,
bool
pred_guard
,
int
zero_point
)
{
uint4
&
data
=
reinterpret_cast
<
uint4
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %5, 0;
\n
"
" mov.b32 %0, %6;
\n
"
" mov.b32 %1, %6;
\n
"
" mov.b32 %2, %6;
\n
"
" mov.b32 %3, %6;
\n
"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];
\n
"
"}
\n
"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
ptr
),
"r"
((
int
)
pred_guard
),
"r"
(
reinterpret_cast
<
unsigned
&>
(
zero_point
)));
}
};
template
<
typename
AccessType
>
struct
global_load_with_zero_point
<
AccessType
,
8
>
{
devfunc
global_load_with_zero_point
(
AccessType
&
D
,
void
const
*
ptr
,
bool
pred_guard
,
int
zero_point
)
{
uint2
&
data
=
reinterpret_cast
<
uint2
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %3, 0;
\n
"
" mov.b32 %0, %4;
\n
"
" mov.b32 %1, %4;
\n
"
" @p ld.global.v2.u32 {%0, %1}, [%2];
\n
"
"}
\n
"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
ptr
),
"r"
((
int
)
pred_guard
),
"r"
(
reinterpret_cast
<
unsigned
&>
(
zero_point
)));
}
};
template
<
typename
AccessType
>
struct
global_load_with_zero_point
<
AccessType
,
4
>
{
devfunc
global_load_with_zero_point
(
AccessType
&
D
,
void
const
*
ptr
,
bool
pred_guard
,
int
zero_point
)
{
unsigned
&
data
=
reinterpret_cast
<
unsigned
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %2, 0;
\n
"
" mov.b32 %0, %3;
\n
"
" @p ld.global.u32 %0, [%1];
\n
"
"}
\n
"
:
"=r"
(
data
)
:
"l"
(
ptr
),
"r"
((
int
)
pred_guard
),
"r"
(
reinterpret_cast
<
unsigned
&>
(
zero_point
)));
}
};
template
<
typename
AccessType
>
struct
global_load_with_zero_point
<
AccessType
,
1
>
{
devfunc
global_load_with_zero_point
(
AccessType
&
D
,
void
const
*
ptr
,
bool
pred_guard
,
int
zero_point
)
{
if
(
pred_guard
)
D
=
*
(
reinterpret_cast
<
AccessType
const
*>
(
ptr
));
else
{
unsigned
uv
=
reinterpret_cast
<
unsigned
&>
(
zero_point
);
uint8_t
&
data
=
reinterpret_cast
<
uint8_t
&>
(
D
);
data
=
uv
&
0xff
;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
/// Fragment type to store loaded data
typename
AccessType
,
/// The bytes of loading
int
LoadBytes
>
struct
global_store
;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
AccessType
>
struct
global_store
<
AccessType
,
32
>
{
devfunc
global_store
(
AccessType
const
&
D
,
void
*
ptr
,
bool
pred_guard
)
{
uint4
const
*
data
=
reinterpret_cast
<
uint4
const
*>
(
&
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %5, 0;
\n
"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};
\n
"
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};
\n
"
"}
\n
"
:
:
"l"
(
ptr
),
"r"
(
data
[
0
].
x
),
"r"
(
data
[
0
].
y
),
"r"
(
data
[
0
].
z
),
"r"
(
data
[
0
].
w
),
"r"
((
int
)
pred_guard
),
"l"
(((
uint8_t
*
)
ptr
)
+
16
),
"r"
(
data
[
1
].
x
),
"r"
(
data
[
1
].
y
),
"r"
(
data
[
1
].
z
),
"r"
(
data
[
1
].
w
));
}
};
template
<
typename
AccessType
>
struct
global_store
<
AccessType
,
16
>
{
devfunc
global_store
(
AccessType
const
&
D
,
void
*
ptr
,
bool
pred_guard
)
{
uint4
const
&
data
=
reinterpret_cast
<
uint4
const
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %5, 0;
\n
"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};
\n
"
"}
\n
"
:
:
"l"
(
ptr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
),
"r"
((
int
)
pred_guard
));
}
};
template
<
typename
AccessType
>
struct
global_store
<
AccessType
,
8
>
{
devfunc
global_store
(
AccessType
const
&
D
,
void
*
ptr
,
bool
pred_guard
)
{
uint2
const
&
data
=
reinterpret_cast
<
uint2
const
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %3, 0;
\n
"
" @p st.global.v2.u32 [%0], {%1, %2};
\n
"
"}
\n
"
:
:
"l"
(
ptr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
((
int
)
pred_guard
));
}
};
template
<
typename
AccessType
>
struct
global_store
<
AccessType
,
4
>
{
devfunc
global_store
(
AccessType
const
&
D
,
void
*
ptr
,
bool
pred_guard
)
{
uint32_t
const
&
data
=
reinterpret_cast
<
uint32_t
const
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %2, 0;
\n
"
" @p st.global.u32 [%0], %1;
\n
"
"}
\n
"
:
:
"l"
(
ptr
),
"r"
(
data
),
"r"
((
int
)
pred_guard
));
}
};
template
<
typename
AccessType
>
struct
global_store
<
AccessType
,
2
>
{
devfunc
global_store
(
AccessType
const
&
D
,
void
*
ptr
,
bool
pred_guard
)
{
uint16_t
const
&
data
=
reinterpret_cast
<
uint16_t
const
&>
(
D
);
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %2, 0;
\n
"
" @p st.global.u16 [%0], %1;
\n
"
"}
\n
"
:
:
"l"
(
ptr
),
"h"
(
data
),
"r"
((
int
)
pred_guard
));
}
};
template
<
typename
AccessType
>
struct
global_store
<
AccessType
,
1
>
{
devfunc
global_store
(
AccessType
const
&
D
,
void
*
ptr
,
bool
pred_guard
)
{
if
(
pred_guard
)
*
(
reinterpret_cast
<
AccessType
*>
(
ptr
))
=
D
;
}
};
#undef devfunc
}
// namespace relayout_format
}
// namespace cuda
}
// namespace megdnn
dnn/src/cuda/relayout_format/relayout_format.cu
浏览文件 @
43c59204
此差异已折叠。
点击以展开。
dnn/src/cuda/relayout_format/relayout_format.cuh
浏览文件 @
43c59204
...
...
@@ -39,6 +39,20 @@ void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst,
const
uint8_t
src_zero_point
=
0
,
const
uint8_t
dst_zero_point
=
0
);
void
relayout_format_cuda_nchw_nhwc
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
=
1.
f
,
const
float
dst_scale
=
1.
f
,
const
uint8_t
src_zero_point
=
0
,
const
uint8_t
dst_zero_point
=
0
);
void
relayout_format_cuda_nhwc_nchw
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
,
const
float
src_scale
=
1.
f
,
const
float
dst_scale
=
1.
f
,
const
uint8_t
src_zero_point
=
0
,
const
uint8_t
dst_zero_point
=
0
);
void
relayout_format_cuda_nchw_nchw4_weight
(
const
TensorND
&
src
,
const
TensorND
&
dst
,
const
cudaStream_t
&
stream
);
...
...
dnn/src/cuda/relayout_format/relayout_format_kern.cuh
0 → 100644
浏览文件 @
43c59204
/**
* \file dnn/src/cuda/relayout_format/relayout_format_kern.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/int_fastdiv.cuh"
#include "src/cuda/memory_utils.cuh"
#include "src/cuda/relayout_format/translayout.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
relayout_format
{
namespace
internal
{
using
namespace
memory
;
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
class
TensorIteratorOverChannel
{
public:
using
Type
=
Type_
;
static
constexpr
int
pack_size
=
pack_size_
;
static
constexpr
int
chan_blk
=
chan_blk_
;
static
constexpr
int
width
=
width_
;
static
constexpr
int
size_nbits
=
size_nbits_
;
static
constexpr
int
elements_in_type
=
chan_blk
*
width
*
size_nbits
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
lane_size_in_type
=
(
width
*
pack_size
*
size_nbits
)
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
pack_size_in_type
=
(
pack_size
*
size_nbits
)
>=
(
8
*
sizeof
(
Type
))
?
(
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
)))
:
(
width
*
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
)));
static
constexpr
int
pack_size_in_byte
=
pack_size_in_type
*
sizeof
(
Type
);
using
AccessType
=
array_wrapper
<
Type
,
pack_size_in_type
>
;
using
Fragment
=
array_wrapper
<
Type
,
elements_in_type
>
;
MEGDNN_HOST
TensorIteratorOverChannel
()
:
pointer
{
nullptr
},
chan_stride_in_elements
{
0
},
channel
{
0
}
{}
MEGDNN_HOST
TensorIteratorOverChannel
(
Type
*
pointer_
,
int
chan_stride_in_elements_
,
int
channel_
,
int
,
int
)
:
pointer
{
pointer_
},
chan_stride_in_elements
{
chan_stride_in_elements_
},
channel
{
channel_
}
{}
MEGDNN_DEVICE
__forceinline__
void
initialize
(
int
c_idx
,
int
hw_idx
)
{
pointer
+=
(
c_idx
/
pack_size
)
*
chan_stride_in_elements
+
hw_idx
*
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
));
channel
-=
c_idx
;
}
MEGDNN_DEVICE
__forceinline__
void
add_pointer_offset
(
size_t
offset_in_type
)
{
pointer
+=
offset_in_type
;
}
MEGDNN_DEVICE
__forceinline__
void
load
(
Fragment
&
frag
,
int
zero_point
)
{
AccessType
*
frag_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
frag
);
Type
*
pointer_
=
pointer
;
#pragma unroll
for
(
int
i
=
0
;
i
<
chan_blk
;
i
+=
pack_size
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
lane_size_in_type
/
pack_size_in_type
;
j
++
)
{
int
frag_idx
=
i
/
pack_size
*
(
lane_size_in_type
/
pack_size_in_type
)
+
j
;
bool
guard
=
i
<
channel
;
global_load
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
j
*
pack_size_in_type
),
guard
,
zero_point
);
}
pointer_
+=
chan_stride_in_elements
;
}
}
MEGDNN_DEVICE
__forceinline__
void
store
(
const
Fragment
&
frag
)
{
const
AccessType
*
frag_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
&
frag
);
Type
*
pointer_
=
pointer
;
#pragma unroll
for
(
int
i
=
0
;
i
<
chan_blk
;
i
+=
pack_size
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
lane_size_in_type
/
pack_size_in_type
;
j
++
)
{
int
frag_idx
=
i
/
pack_size
*
(
lane_size_in_type
/
pack_size_in_type
)
+
j
;
bool
guard
=
i
<
channel
;
global_store
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
j
*
pack_size_in_type
),
guard
);
}
pointer_
+=
chan_stride_in_elements
;
}
}
MEGDNN_DEVICE
__forceinline__
void
advance
()
{
pointer
+=
(
chan_blk
/
pack_size
)
*
chan_stride_in_elements
;
channel
-=
chan_blk
;
}
private:
Type
*
pointer
;
int
chan_stride_in_elements
;
int
channel
;
};
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
class
MaskedTensorIteratorOverChannel
{
public:
using
Type
=
Type_
;
static
constexpr
int
pack_size
=
pack_size_
;
static
constexpr
int
chan_blk
=
chan_blk_
;
static
constexpr
int
width
=
width_
;
static
constexpr
int
size_nbits
=
size_nbits_
;
static
constexpr
int
elements_in_type
=
chan_blk
*
width
*
size_nbits
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
lane_size_in_type
=
(
width
*
pack_size
*
size_nbits
)
/
(
8
*
sizeof
(
Type
));
static
constexpr
int
pack_size_in_type
=
(
pack_size
*
size_nbits
)
>=
(
8
*
sizeof
(
Type
))
?
(
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
)))
:
(
width
*
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
)));
static
constexpr
int
pack_size_in_byte
=
pack_size_in_type
*
sizeof
(
Type
);
static
constexpr
int
accesses
=
elements_in_type
/
pack_size_in_type
;
static
constexpr
int
mask_size
=
(
accesses
+
32
-
1
)
/
32
;
using
AccessType
=
array_wrapper
<
Type
,
pack_size_in_type
>
;
using
Fragment
=
array_wrapper
<
Type
,
elements_in_type
>
;
MEGDNN_HOST
MaskedTensorIteratorOverChannel
()
:
pointer
{
nullptr
},
chan_stride_in_elements
{
0
},
channel
{
0
}
{}
MEGDNN_HOST
MaskedTensorIteratorOverChannel
(
Type
*
pointer_
,
int
chan_stride_in_elements_
,
int
channel_
,
int
bound_
,
int
div_
)
:
pointer
{
pointer_
},
chan_stride_in_elements
{
chan_stride_in_elements_
},
channel
{
channel_
},
bound
{
bound_
},
div
{
uint32_t
(
div_
)}
{}
MEGDNN_DEVICE
__forceinline__
void
initialize
(
int
c_idx
,
int
hw_idx
)
{
pointer
+=
(
c_idx
/
pack_size
)
*
chan_stride_in_elements
;
channel
-=
c_idx
;
int
w
[
lane_size_in_type
/
pack_size_in_type
];
#pragma unroll
for
(
int
i
=
0
;
i
<
mask_size
;
++
i
)
{
mask
[
i
]
=
0
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
lane_size_in_type
/
pack_size_in_type
;
j
++
)
{
int
offset
=
hw_idx
+
j
;
int
h
=
(
int
)((
uint32_t
)(
offset
)
/
div
);
w
[
j
]
=
(
int
)((
uint32_t
)(
offset
)
%
div
);
stride
[
j
]
=
(
h
*
bound
+
w
[
j
])
*
pack_size
*
size_nbits
/
(
8
*
sizeof
(
Type
));
}
#pragma unroll
for
(
int
i
=
0
;
i
<
chan_blk
;
i
+=
pack_size
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
lane_size_in_type
/
pack_size_in_type
;
j
++
)
{
bool
guard
=
(
i
<
channel
)
&&
(
w
[
j
]
<
bound
);
int
index
=
(
i
/
pack_size
)
*
(
lane_size_in_type
/
pack_size_in_type
)
+
j
;
int
mask_index
=
(
index
>>
5
);
int
mask_shift
=
(
index
&
0x1f
);
mask
[
mask_index
]
|=
(
guard
<<
mask_shift
);
}
}
}
MEGDNN_DEVICE
__forceinline__
void
add_pointer_offset
(
size_t
offset_in_type
)
{
pointer
+=
offset_in_type
;
}
MEGDNN_DEVICE
__forceinline__
void
load
(
Fragment
&
frag
,
int
zero_point
)
{
AccessType
*
frag_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
frag
);
Type
*
pointer_
=
pointer
;
#pragma unroll
for
(
int
i
=
0
;
i
<
chan_blk
;
i
+=
pack_size
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
lane_size_in_type
/
pack_size_in_type
;
j
++
)
{
int
frag_idx
=
i
/
pack_size
*
(
lane_size_in_type
/
pack_size_in_type
)
+
j
;
int
mask_index
=
(
frag_idx
>>
5
);
int
mask_shift
=
(
frag_idx
&
0x1f
);
bool
guard
=
(
mask
[
mask_index
]
&
(
1
<<
mask_shift
));
global_load
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
stride
[
j
]),
guard
,
zero_point
);
}
pointer_
+=
chan_stride_in_elements
;
}
}
MEGDNN_DEVICE
__forceinline__
void
store
(
const
Fragment
&
frag
)
{
const
AccessType
*
frag_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
&
frag
);
Type
*
pointer_
=
pointer
;
#pragma unroll
for
(
int
i
=
0
;
i
<
chan_blk
;
i
+=
pack_size
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
lane_size_in_type
/
pack_size_in_type
;
j
++
)
{
int
frag_idx
=
i
/
pack_size
*
(
lane_size_in_type
/
pack_size_in_type
)
+
j
;
int
mask_index
=
(
frag_idx
>>
5
);
int
mask_shift
=
(
frag_idx
&
0x1f
);
bool
guard
=
(
mask
[
mask_index
]
&
(
1
<<
mask_shift
));
global_store
<
AccessType
,
pack_size_in_byte
>
(
frag_ptr
[
frag_idx
],
reinterpret_cast
<
void
*>
(
pointer_
+
stride
[
j
]),
guard
);
}
pointer_
+=
chan_stride_in_elements
;
}
}
MEGDNN_DEVICE
__forceinline__
void
advance
()
{
pointer
+=
(
chan_blk
/
pack_size
)
*
chan_stride_in_elements
;
channel
-=
chan_blk
;
}
private:
Type
*
pointer
;
int
chan_stride_in_elements
;
int
channel
;
int
bound
;
Uint32Fastdiv
div
;
uint32_t
mask
[
mask_size
];
size_t
stride
[
lane_size_in_type
/
pack_size_in_type
];
};
template
<
bool
padding_
,
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
struct
TensorIteratorPolicy
;
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
struct
TensorIteratorPolicy
<
true
,
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
{
using
TensorIterator
=
MaskedTensorIteratorOverChannel
<
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
;
};
template
<
typename
Type_
,
int
pack_size_
,
int
chan_blk_
,
int
width_
,
int
size_nbits_
>
struct
TensorIteratorPolicy
<
false
,
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
{
using
TensorIterator
=
TensorIteratorOverChannel
<
Type_
,
pack_size_
,
chan_blk_
,
width_
,
size_nbits_
>
;
};
template
<
typename
SrcIterator_
,
typename
DstIterator_
,
typename
Transpose_
,
typename
CudaPostProcess_
>
struct
RelayoutProblem
{
using
SrcIterator
=
SrcIterator_
;
using
DstIterator
=
DstIterator_
;
using
Transpose
=
Transpose_
;
using
CudaPostProcess
=
CudaPostProcess_
;
MEGDNN_STATIC_ASSERT
(
SrcIterator
::
chan_blk
==
DstIterator
::
chan_blk
,
"channel block mismatch"
);
MEGDNN_STATIC_ASSERT
(
SrcIterator
::
width
==
DstIterator
::
width
,
"width block mismatch"
);
MEGDNN_STATIC_ASSERT
(
SrcIterator
::
size_nbits
==
DstIterator
::
size_nbits
,
"size in bits of elements mismatch"
);
static
constexpr
int
pack_chan
=
SrcIterator
::
chan_blk
;
static
constexpr
int
pack_width
=
SrcIterator
::
width
;
using
DnnSrcType
=
typename
CudaPostProcess
::
SrcType
;
using
DnnDstType
=
typename
CudaPostProcess
::
DstType
;
struct
Param
{
SrcIterator
src_iterator
;
DstIterator
dst_iterator
;
CudaPostProcess
post_process
;
int
n_stride_src
;
int
n_stride_dst
;
int
batch_size
;
int
channels
;
int
hw
;
int
zero_point
;
MEGDNN_HOST
MEGDNN_DEVICE
Param
(
SrcIterator
src_iterator_
,
DstIterator
dst_iterator_
,
CudaPostProcess
post_process_
,
int
n_stride_src_
,
int
n_stride_dst_
,
int
batch_size_
,
int
channels_
,
int
hw_
,
int
zero_point_
)
:
src_iterator
{
src_iterator_
},
dst_iterator
{
dst_iterator_
},
post_process
{
post_process_
},
n_stride_src
{
n_stride_src_
},
n_stride_dst
{
n_stride_dst_
},
batch_size
{
batch_size_
},
channels
{
channels_
},
hw
{
hw_
},
zero_point
{
zero_point_
}
{}
};
};
template
<
typename
RelayoutProblem_
>
__global__
void
relayout_kern
(
typename
RelayoutProblem_
::
Param
param
)
{
using
SrcIterator
=
typename
RelayoutProblem_
::
SrcIterator
;
using
DstIterator
=
typename
RelayoutProblem_
::
DstIterator
;
static
constexpr
int
pack_chan
=
RelayoutProblem_
::
pack_chan
;
static
constexpr
int
pack_width
=
RelayoutProblem_
::
pack_width
;
const
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
thread_offset
=
thread_idx
*
pack_width
;
const
int
hw_idx
=
(
thread_offset
%
param
.
hw
);
const
int
nc_blks
=
thread_offset
/
param
.
hw
;
const
int
c_blks
=
(
param
.
channels
+
pack_chan
-
1
)
/
pack_chan
;
const
int
n_idx
=
nc_blks
/
c_blks
;
const
int
c_blk_idx
=
nc_blks
%
c_blks
;
const
int
c_idx
=
c_blk_idx
*
pack_chan
;
if
(
n_idx
<
param
.
batch_size
)
{
const
int
src_offset
=
n_idx
*
param
.
n_stride_src
;
const
int
dst_offset
=
n_idx
*
param
.
n_stride_dst
;
param
.
src_iterator
.
add_pointer_offset
(
src_offset
);
param
.
dst_iterator
.
add_pointer_offset
(
dst_offset
);
param
.
src_iterator
.
initialize
(
c_idx
,
hw_idx
);
param
.
dst_iterator
.
initialize
(
c_idx
,
hw_idx
);
typename
SrcIterator
::
Fragment
src_frag
;
typename
DstIterator
::
Fragment
dst_frag
;
int
zp
=
make_zero
<
SrcIterator
::
size_nbits
>
(
param
.
zero_point
);
param
.
src_iterator
.
load
(
src_frag
,
zp
);
RelayoutProblem_
::
Transpose
::
trans
(
reinterpret_cast
<
typename
SrcIterator
::
Fragment
&>
(
dst_frag
),
src_frag
,
param
.
post_process
);
param
.
dst_iterator
.
store
(
dst_frag
);
}
}
}
// namespace internal
}
// namespace relayout_format
}
// namespace cuda
}
// namespace megdnn
dnn/src/cuda/relayout_format/relayout_format_utils.cuh
0 → 100644
浏览文件 @
43c59204
/**
* \file dnn/src/cuda/relayout_format/relayout_format_utils.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/relayout_format/relayout_format.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
relayout_format
{
namespace
internal
{
template
<
typename
cype
,
int
pack_w
,
typename
enable
=
void
>
struct
DTypeRWHelper
;
template
<
typename
ctype
>
struct
DTypeRWHelper
<
ctype
,
1
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype
,
dt_qint8
>::
value
||
std
::
is_same
<
ctype
,
dt_quint8
>::
value
||
std
::
is_same
<
ctype
,
dt_uint8
>::
value
>::
type
>
{
using
InnerDtype
=
char
;
using
DstDtype
=
char4
;
};
template
<
typename
ctype
>
struct
DTypeRWHelper
<
ctype
,
4
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype
,
dt_qint8
>::
value
||
std
::
is_same
<
ctype
,
dt_quint8
>::
value
||
std
::
is_same
<
ctype
,
dt_uint8
>::
value
>::
type
>
{
using
InnerDtype
=
char4
;
using
DstDtype
=
char4
;
};
template
<
>
struct
DTypeRWHelper
<
dt_qint32
,
1
>
{
using
InnerDtype
=
int
;
using
DstDtype
=
int4
;
};
template
<
>
struct
DTypeRWHelper
<
dt_qint32
,
4
>
{
using
InnerDtype
=
int4
;
using
DstDtype
=
int4
;
};
template
<
typename
ctype
>
struct
DTypeRWHelper
<
ctype
,
2
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype
,
dt_qint4
>::
value
||
std
::
is_same
<
ctype
,
dt_quint4
>::
value
>::
type
>
{
using
InnerDtype
=
char
;
using
DstDtype
=
array_wrapper
<
uint8_t
,
32
>
;
};
template
<
typename
ctype
>
struct
DTypeRWHelper
<
ctype
,
8
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype
,
dt_qint4
>::
value
||
std
::
is_same
<
ctype
,
dt_quint4
>::
value
>::
type
>
{
using
InnerDtype
=
unsigned
;
using
DstDtype
=
array_wrapper
<
uint8_t
,
32
>
;
};
template
<
typename
DstType
>
inline
__device__
DstType
make_zero_pad
(
const
uint8_t
zero_point
)
{
return
zero_point
;
}
template
<
>
inline
__device__
char4
make_zero_pad
<
char4
>
(
const
uint8_t
zero_point
)
{
char
izp
=
reinterpret_cast
<
const
char
&>
(
zero_point
);
return
{
izp
,
izp
,
izp
,
izp
};
}
template
<
>
inline
__device__
int4
make_zero_pad
<
int4
>
(
const
uint8_t
zero_point
)
{
return
{
zero_point
,
zero_point
,
zero_point
,
zero_point
};
}
template
<
int
size_nbits
>
inline
__device__
int
make_zero
(
int
zero_point
);
template
<
>
inline
__device__
int
make_zero
<
4
>
(
int
zero_point
)
{
return
integer_subbyte
::
transform_int8_to_uint4x8
(
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
,
zero_point
);
}
template
<
typename
DstDtype
>
inline
__device__
void
write_helper
(
DstDtype
*
ptr
,
DstDtype
val
)
{
*
ptr
=
val
;
}
template
<
>
inline
__device__
void
write_helper
<
char4
>
(
char4
*
ptr
,
char4
val
)
{
int32_t
*
rel_ptr
=
(
int32_t
*
)
ptr
;
*
rel_ptr
=
*
(
int32_t
*
)(
&
val
);
}
template
<
>
inline
__device__
void
write_helper
<
array_wrapper
<
uint8_t
,
32
>>
(
array_wrapper
<
uint8_t
,
32
>*
ptr
,
array_wrapper
<
uint8_t
,
32
>
val
)
{
uint4
const
*
data
=
reinterpret_cast
<
uint4
const
*>
(
&
val
);
void
*
ptr_
=
reinterpret_cast
<
void
*>
(
ptr
);
asm
volatile
(
"{
\n
"
" st.global.v4.u32 [%0], {%1, %2, %3, %4};
\n
"
" st.global.v4.u32 [%5], {%6, %7, %8, %9};
\n
"
"}
\n
"
:
:
"l"
(
ptr_
),
"r"
(
data
[
0
].
x
),
"r"
(
data
[
0
].
y
),
"r"
(
data
[
0
].
z
),
"r"
(
data
[
0
].
w
),
"l"
(((
uint8_t
*
)
ptr_
)
+
16
),
"r"
(
data
[
1
].
x
),
"r"
(
data
[
1
].
y
),
"r"
(
data
[
1
].
z
),
"r"
(
data
[
1
].
w
));
}
}
// namespace internal
}
// namespace relayout_format
}
// namespace cuda
}
// namespace megdnn
dnn/src/cuda/relayout_format/translayout.cuh
0 → 100644
浏览文件 @
43c59204
此差异已折叠。
点击以展开。
dnn/src/cuda/warp_perspective/forward.cu
浏览文件 @
43c59204
...
...
@@ -176,60 +176,22 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat,
}
}
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
int
transform_int8_to_bit4x8
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
);
template
<
>
MEGDNN_DEVICE
__forceinline__
int
transform_int8_to_bit4x8
<
true
>
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
return
transform_int8_to_int4x8
(
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
);
}
template
<
>
MEGDNN_DEVICE
__forceinline__
int
transform_int8_to_bit4x8
<
false
>
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
return
transform_int8_to_uint4x8
(
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
);
}
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
void
transform_bit4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
);
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transform_bit4x8_to_int8
<
true
>
(
int
(
&
result
)[
8
],
const
int
&
source
){
transform_int4x8_to_int8
(
result
,
source
);
}
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transform_bit4x8_to_int8
<
false
>
(
int
(
&
result
)[
8
],
const
int
&
source
){
transform_uint4x8_to_int8
(
result
,
source
);
}
template
<
bool
signedness
,
typename
OutputConverter
>
MEGDNN_DEVICE
__forceinline__
int
pack_output_func
(
OutputConverter
&
output_converter
,
int
(
&
s00
)[
8
],
int
(
&
s01
)[
8
],
int
(
&
s10
)[
8
],
int
(
&
s11
)[
8
],
float
w00
,
float
w01
,
float
w10
,
float
w11
)
{
#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * w00 + \
s01[idx] * w01 + \
s10[idx] * w10 + \
s11[idx] * w11) \
#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * w00 + s01[idx] * w01 + \
s10[idx] * w10 + s11[idx] * w11) \
.as_storage())
return
transform_int8_to_b
it
4x8
<
signedness
>
(
return
transform_int8_to_b4x8
<
signedness
>
(
warp_perspective_transform
(
0
),
warp_perspective_transform
(
1
),
warp_perspective_transform
(
2
),
warp_perspective_transform
(
3
),
warp_perspective_transform
(
4
),
warp_perspective_transform
(
5
),
warp_perspective_transform
(
6
),
warp_perspective_transform
(
7
));
#undef warp_perspective_transform
#undef warp_perspective_transform
}
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
...
...
@@ -278,31 +240,31 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat,
s
[
2
]
=
__ldg
(
sptr_int4
+
i_coor_10
+
c1
);
s
[
3
]
=
__ldg
(
sptr_int4
+
i_coor_11
+
c1
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
x
);
d
.
x
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
y
);
d
.
y
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
z
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
z
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
z
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
z
);
d
.
z
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
w
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
w
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
w
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
w
);
d
.
w
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
...
...
@@ -403,15 +365,7 @@ __global__ void kern_const_border_nchw4(SrcVisitor src,
}
}
}
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
static
void
transform_bit4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
unpack_integer_4bits
<
signedness
>
(
reinterpret_cast
<
unsigned
const
&>
(
source
),
(
i
<<
2
));
}
}
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_const_border_nchw64
(
SrcVisitor
src
,
...
...
@@ -457,7 +411,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
bool
flag00
=
okh0
&&
okw0
,
flag01
=
okh0
&&
okw1
,
flag10
=
okh1
&&
okw0
,
flag11
=
okh1
&&
okw1
;
int8_t
bval_4
=
bval
.
as_storage
()
&
0xF
;
int
bval_8
=
transform_int8_to_b
it
4x8
<
signedness
>
(
int
bval_8
=
transform_int8_to_b4x8
<
signedness
>
(
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
);
int4
bval_int4
;
bval_int4
.
x
=
bval_8
;
...
...
@@ -488,31 +442,31 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
s
[
3
]
=
bval_int4
;
}
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
x
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
x
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
x
);
d
.
x
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
y
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
y
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
y
);
d
.
y
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
z
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
z
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
z
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
z
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
z
);
d
.
z
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
w
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
w
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
w
);
transform_b
it
4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
w
);
transform_b4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
w
);
d
.
w
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
w00
,
w01
,
w10
,
w11
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录