Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
14b65e4d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
14b65e4d
编写于
5月 10, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add reduce_filter_and_update_bias
GitOrigin-RevId: 31b6e6b0abe2790029e63c9f91c64290a1801958
上级
2d4e62ef
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
77 addition
and
6 deletion
+77
-6
dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp
dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp
+1
-1
dnn/src/cuda/conv_bias/reduce_filter.cu
dnn/src/cuda/conv_bias/reduce_filter.cu
+62
-3
dnn/src/cuda/conv_bias/reduce_filter.cuh
dnn/src/cuda/conv_bias/reduce_filter.cuh
+14
-2
未找到文件。
dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp
浏览文件 @
14b65e4d
...
...
@@ -15,7 +15,7 @@
#include "./quint4x4x32_wmma/activation_u4.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_data.cuh"
#include "./reduce_
with_scale_
filter.cuh"
#include "./reduce_filter.cuh"
#include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh"
using
namespace
megdnn
;
...
...
dnn/src/cuda/conv_bias/reduce_
with_scale_
filter.cu
→
dnn/src/cuda/conv_bias/reduce_filter.cu
浏览文件 @
14b65e4d
...
...
@@ -25,7 +25,7 @@
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/reduce_
with_scale_
filter.cu
* \file dnn/src/cuda/conv_bias/reduce_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...
...
@@ -36,9 +36,11 @@
* implied.
*/
#include "./reduce_with_scale_filter.cuh"
#include "src/cuda/reduce_helper.cuh"
#include "./reduce_filter.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/reduce_helper.cuh"
using
namespace
megdnn
;
using
namespace
cuda
;
...
...
@@ -76,6 +78,38 @@ struct ReduceWithScaleInt4Op {
#endif
};
template
<
bool
signedness
>
struct
ReduceUpdateBiasInt4Op
{
typedef
int32_t
wtype
;
const
uint8_t
*
filter
;
const
int32_t
*
src_bias
;
int32_t
*
dst_bias
;
int32_t
zero_point
;
static
const
wtype
INIT
=
0
;
#if MEGDNN_CC_CUDA
__host__
__device__
void
write
(
uint32_t
idx
,
wtype
val
)
{
dst_bias
[
idx
]
=
src_bias
[
idx
]
-
val
*
zero_point
;
}
__host__
__device__
static
wtype
apply
(
wtype
a
,
wtype
b
)
{
return
a
+
b
;
}
__device__
wtype
read
(
uint32_t
idx
)
{
constexpr
uint32_t
subbytes_per_pixel
=
8
;
const
uint32_t
*
fptr
=
(
const
uint32_t
*
)(
filter
+
subbytes_per_pixel
*
idx
/
2
);
uint32_t
val
=
*
fptr
;
int32_t
ret
=
0
;
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
ret
+=
integer_subbyte
::
unpack_integer_4bits
<
signedness
>
(
val
,
(
j
<<
2
));
}
return
ret
;
}
#endif
};
}
// namespace
template
<
bool
signedness
>
...
...
@@ -106,6 +140,31 @@ INST(false);
INST
(
true
);
#undef INST
template
<
bool
signedness
>
void
megdnn
::
cuda
::
do_dispatch_reduce_filter_and_update_bias_4bit
(
const
uint8_t
*
filter
,
const
int32_t
*
src_bias
,
uint32_t
rows
,
uint32_t
cols
,
int32_t
*
dst_bias
,
int32_t
*
workspace
,
int32_t
zero_point
,
cudaStream_t
stream
)
{
ReduceUpdateBiasInt4Op
<
signedness
>
op
;
op
.
filter
=
filter
;
op
.
src_bias
=
src_bias
;
op
.
dst_bias
=
dst_bias
;
op
.
zero_point
=
zero_point
;
run_reduce
<
ReduceUpdateBiasInt4Op
<
signedness
>
,
false
>
(
workspace
,
rows
,
cols
,
1
,
stream
,
op
);
}
#define INST(signedness) \
template void \
megdnn::cuda::do_dispatch_reduce_filter_and_update_bias_4bit<signedness>( \
const uint8_t* filter, const int32_t* src_bias, uint32_t rows, \
uint32_t cols, int32_t* dst_bias, int32_t* workspace, \
int32_t zero_point, cudaStream_t stream)
INST
(
false
);
INST
(
true
);
#undef INST
size_t
megdnn
::
cuda
::
do_dispatch_reduce_workspace_in_bytes
(
size_t
A
,
size_t
B
,
size_t
C
)
{
return
get_reduce_workspace_in_bytes
<
ReduceWithScaleInt4Op
<
false
>>
(
A
,
B
,
C
);
...
...
dnn/src/cuda/conv_bias/reduce_
with_scale_
filter.cuh
→
dnn/src/cuda/conv_bias/reduce_filter.cuh
浏览文件 @
14b65e4d
...
...
@@ -25,7 +25,7 @@
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/conv_bias/reduce_
with_scale_
filter.cuh
* \file dnn/src/cuda/conv_bias/reduce_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...
...
@@ -36,16 +36,28 @@
* implied.
*/
#include "src/cuda/utils.cuh"
#include <stddef.h>
#include <stdint.h>
#include <cuda_runtime.h>
namespace
megdnn
{
namespace
cuda
{
template
<
bool
signedness
>
void
do_dispatch_reduce_with_scale_filter_4bit
(
const
uint8_t
*
src
,
int32_t
scale
,
uint32_t
rows
,
uint32_t
cols
,
int32_t
*
dst
,
cudaStream_t
stream
);
template
<
bool
signedness
>
void
do_dispatch_reduce_filter_and_update_bias_4bit
(
const
uint8_t
*
filter
,
const
int32_t
*
src_bias
,
uint32_t
rows
,
uint32_t
cols
,
int32_t
*
dst_bias
,
int32_t
*
workspace
,
int
zero_point
,
cudaStream_t
stream
);
size_t
do_dispatch_reduce_workspace_in_bytes
(
size_t
A
,
size_t
B
,
size_t
C
);
}
// namespace cuda
}
// namespace megdnn
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录