Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7b855dc6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7b855dc6
编写于
9月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/cuda): fix compilation for windows bazel
GitOrigin-RevId: 2023dea19c04dbdd17f559f80cfa2e6b4be27a0e
上级
3abe0b24
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
46 addition
and
30 deletion
+46
-30
dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu
dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu
+7
-5
dnn/src/cuda/transpose_utils.cuh
dnn/src/cuda/transpose_utils.cuh
+39
-25
未找到文件。
dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu
浏览文件 @
7b855dc6
...
...
@@ -30,7 +30,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
const
int32_t
fhfw
=
blockIdx
.
x
*
BLOCKSIZE_Y
+
threadIdx
.
x
;
if
(
fhfw
<
FHFW
&&
icb
<
IC
/
4
)
{
int
src_value
[
4
],
dst_value
[
4
];
array_wrapper
<
int
,
4
>
src_value
;
int
dst_value
[
4
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
src_value
[
i
]
=
*
reinterpret_cast
<
const
int
*>
(
...
...
@@ -38,7 +39,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
}
// transpose 4x4
transpose_int8_interleavedx4
<
4
,
int
>
(
src_value
,
dst_value
);
auto
trans
=
transpose_int8_interleavedx4
<
4
,
int
>
();
trans
(
src_value
,
dst_value
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
...
@@ -60,7 +62,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
const
int32_t
icb
=
fhfw_icb
%
(
IC
/
4
);
if
(
ocb
<
OC
/
interleaved
&&
fhfw
<
FHFW
)
{
int
src_value
[
interleaved
]
;
array_wrapper
<
int
,
interleaved
>
src_value
;
vec_type
dst_value
[
4
];
#pragma unroll
...
...
@@ -70,8 +72,8 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
icb
*
4
);
}
transpose_int8_interleavedx4
<
interleaved
,
vec_type
>
(
src_value
,
dst_value
);
auto
trans
=
transpose_int8_interleavedx4
<
interleaved
,
vec_type
>
();
trans
(
src_value
,
dst_value
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
...
dnn/src/cuda/transpose_utils.cuh
浏览文件 @
7b855dc6
...
...
@@ -30,37 +30,51 @@ MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl(
}
template
<
uint32_t
interleaved
,
typename
vec_type
>
MEGDNN_DEVICE
__forceinline__
void
transpose_int8_interleavedx4
(
const
int
src
[
interleaved
],
vec_type
(
&
dst
)[
4
]);
struct
transpose_int8_interleavedx4
;
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transpose_int8_interleavedx4
<
4
,
int
>
(
const
int
src
[
4
],
int
(
&
dst
)[
4
])
{
transpose_int8_4x4_impl
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
dst
[
0
],
dst
[
1
],
dst
[
2
],
dst
[
3
]);
}
struct
transpose_int8_interleavedx4
<
4
,
int
>
{
static
constexpr
uint32_t
interleaved
=
4
;
using
vec_type
=
int
;
using
Fragment
=
array_wrapper
<
int
,
interleaved
>
;
MEGDNN_DEVICE
__forceinline__
void
operator
()(
const
Fragment
src
,
vec_type
(
&
dst
)[
4
])
{
transpose_int8_4x4_impl
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
dst
[
0
],
dst
[
1
],
dst
[
2
],
dst
[
3
]);
}
};
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transpose_int8_interleavedx4
<
8
,
int2
>
(
const
int
src
[
8
],
int2
(
&
dst
)[
4
])
{
transpose_int8_4x4_impl
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
dst
[
0
].
x
,
dst
[
1
].
x
,
dst
[
2
].
x
,
dst
[
3
].
x
);
transpose_int8_4x4_impl
(
src
[
4
],
src
[
5
],
src
[
6
],
src
[
7
],
dst
[
0
].
y
,
dst
[
1
].
y
,
dst
[
2
].
y
,
dst
[
3
].
y
);
}
struct
transpose_int8_interleavedx4
<
8
,
int2
>
{
static
constexpr
uint32_t
interleaved
=
8
;
using
vec_type
=
int2
;
using
Fragment
=
array_wrapper
<
int
,
interleaved
>
;
MEGDNN_DEVICE
__forceinline__
void
operator
()(
const
Fragment
src
,
vec_type
(
&
dst
)[
4
])
{
transpose_int8_4x4_impl
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
dst
[
0
].
x
,
dst
[
1
].
x
,
dst
[
2
].
x
,
dst
[
3
].
x
);
transpose_int8_4x4_impl
(
src
[
4
],
src
[
5
],
src
[
6
],
src
[
7
],
dst
[
0
].
y
,
dst
[
1
].
y
,
dst
[
2
].
y
,
dst
[
3
].
y
);
}
};
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transpose_int8_interleavedx4
<
16
,
int4
>
(
const
int
src
[
16
],
int4
(
&
dst
)[
4
])
{
transpose_int8_4x4_impl
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
dst
[
0
].
x
,
dst
[
1
].
x
,
dst
[
2
].
x
,
dst
[
3
].
x
);
transpose_int8_4x4_impl
(
src
[
4
],
src
[
5
],
src
[
6
],
src
[
7
],
dst
[
0
].
y
,
dst
[
1
].
y
,
dst
[
2
].
y
,
dst
[
3
].
y
);
transpose_int8_4x4_impl
(
src
[
8
],
src
[
9
],
src
[
10
],
src
[
11
],
dst
[
0
].
z
,
dst
[
1
].
z
,
dst
[
2
].
z
,
dst
[
3
].
z
);
transpose_int8_4x4_impl
(
src
[
12
],
src
[
13
],
src
[
14
],
src
[
15
],
dst
[
0
].
w
,
dst
[
1
].
w
,
dst
[
2
].
w
,
dst
[
3
].
w
);
}
struct
transpose_int8_interleavedx4
<
16
,
int4
>
{
static
constexpr
uint32_t
interleaved
=
16
;
using
vec_type
=
int4
;
using
Fragment
=
array_wrapper
<
int
,
interleaved
>
;
MEGDNN_DEVICE
__forceinline__
void
operator
()(
const
Fragment
src
,
vec_type
(
&
dst
)[
4
])
{
transpose_int8_4x4_impl
(
src
[
0
],
src
[
1
],
src
[
2
],
src
[
3
],
dst
[
0
].
x
,
dst
[
1
].
x
,
dst
[
2
].
x
,
dst
[
3
].
x
);
transpose_int8_4x4_impl
(
src
[
4
],
src
[
5
],
src
[
6
],
src
[
7
],
dst
[
0
].
y
,
dst
[
1
].
y
,
dst
[
2
].
y
,
dst
[
3
].
y
);
transpose_int8_4x4_impl
(
src
[
8
],
src
[
9
],
src
[
10
],
src
[
11
],
dst
[
0
].
z
,
dst
[
1
].
z
,
dst
[
2
].
z
,
dst
[
3
].
z
);
transpose_int8_4x4_impl
(
src
[
12
],
src
[
13
],
src
[
14
],
src
[
15
],
dst
[
0
].
w
,
dst
[
1
].
w
,
dst
[
2
].
w
,
dst
[
3
].
w
);
}
};
}
// namespace cuda
}
// namespace megdnn
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录