Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d6b098a0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d6b098a0
编写于
4月 24, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn/cuda): remove function name weiming
GitOrigin-RevId: f36495a46a36ab0976f3c70254af95e07cd92a80
上级
ca24c4cd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
23 addition
and
60 deletion
+23
-60
dnn/src/cuda/local/forward.cpp
dnn/src/cuda/local/forward.cpp
+9
-44
dnn/src/cuda/local/forward.cu
dnn/src/cuda/local/forward.cu
+7
-2
dnn/src/cuda/local/local.cuh
dnn/src/cuda/local/local.cuh
+7
-14
未找到文件。
dnn/src/cuda/local/forward.cpp
浏览文件 @
d6b098a0
...
@@ -14,43 +14,7 @@
...
@@ -14,43 +14,7 @@
#include "src/cuda/utils.h"
#include "src/cuda/utils.h"
#include "src/cuda/handle.h"
#include "src/cuda/handle.h"
namespace
megdnn
{
#include "src/common/utils.cuh"
namespace
cuda
{
namespace
local
{
void
check_input
(
size_t
N
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
FH
,
size_t
FW
,
size_t
INs
,
size_t
ONs
,
size_t
PH
,
size_t
PW
,
size_t
SH
,
size_t
SW
,
bool
is_xcorr
)
{
megdnn_ignore
(
N
);
megdnn_ignore
(
IC
);
megdnn_ignore
(
IH
);
megdnn_ignore
(
IW
);
megdnn_ignore
(
OC
);
megdnn_ignore
(
OH
);
megdnn_ignore
(
OW
);
megdnn_ignore
(
FH
);
megdnn_ignore
(
FW
);
megdnn_ignore
(
INs
);
megdnn_ignore
(
ONs
);
megdnn_ignore
(
PH
);
megdnn_ignore
(
PW
);
megdnn_ignore
(
SH
);
megdnn_ignore
(
SW
);
megdnn_ignore
(
is_xcorr
);
// shared memory constraint
megdnn_assert
(
IH
*
IW
<=
768
,
"spatial size should not be larger than 768."
);
// megdnn_assert(4 * 4 * 4 * IH * IW <= 49152);
}
}
// namespace local
}
// namespace cuda
}
// namespace megdnn
namespace
megdnn
{
namespace
megdnn
{
namespace
cuda
{
namespace
cuda
{
...
@@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
...
@@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
param
().
stride_h
,
param
().
stride_w
,
param
().
stride_h
,
param
().
stride_w
,
cublas
,
stream
,
cublas
,
stream
,
one
,
zero
);
one
,
zero
);
}
else
{
}
else
if
(
local
::
forward_proxy_default_share_mem_in_bytes
(
IH
,
IW
)
<=
local
::
check_input
(
N
,
IC
,
IH
,
IW
,
OC
,
OH
,
OW
,
FH
,
FW
,
handle
->
device_prop
().
sharedMemPerBlock
)
{
IC
*
IH
*
IW
,
OC
*
OH
*
OW
,
local
::
forward_proxy_default
(
src
.
ptr
<
dt_float32
>
(),
param
().
pad_h
,
param
().
pad_w
,
param
().
stride_h
,
param
().
stride_w
,
is_xcorr
);
local
::
forward_proxy_weiming
(
src
.
ptr
<
dt_float32
>
(),
filter
.
ptr
<
dt_float32
>
(),
filter
.
ptr
<
dt_float32
>
(),
dst
.
ptr
<
dt_float32
>
(),
dst
.
ptr
<
dt_float32
>
(),
N
,
N
,
...
@@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
...
@@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
param
().
stride_h
,
param
().
stride_w
,
param
().
stride_h
,
param
().
stride_w
,
is_xcorr
,
is_xcorr
,
stream
);
stream
);
}
else
{
megdnn_throw
(
ssprintf
(
"No usable kernel for local conv, src: %s filter: %s
\n
"
,
src
.
layout
.
to_string
().
c_str
(),
filter
.
layout
.
to_string
().
c_str
()));
}
}
}
}
...
...
dnn/src/cuda/local/forward.cu
浏览文件 @
d6b098a0
...
@@ -18,6 +18,12 @@ namespace megdnn {
...
@@ -18,6 +18,12 @@ namespace megdnn {
namespace
cuda
{
namespace
cuda
{
namespace
local
{
namespace
local
{
constexpr
size_t
Ns
=
4
,
ICs
=
4
;
size_t
forward_proxy_default_share_mem_in_bytes
(
size_t
IH
,
size_t
IW
)
{
return
Ns
*
ICs
*
sizeof
(
float
)
*
IH
*
IW
;
}
// blockIdx.y is OC*OH*OW/1024
// blockIdx.y is OC*OH*OW/1024
// blockIdx.x is N/4
// blockIdx.x is N/4
// threadIdx.x is [0, 1024)
// threadIdx.x is [0, 1024)
...
@@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src,
...
@@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src,
}
}
}
}
void
forward_proxy_
weiming
(
const
float
*
src
,
const
float
*
filter
,
float
*
dst
,
void
forward_proxy_
default
(
const
float
*
src
,
const
float
*
filter
,
float
*
dst
,
size_t
N
,
size_t
N
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
...
@@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
...
@@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
cudaStream_t
stream
)
cudaStream_t
stream
)
{
{
size_t
threads
=
256
;
size_t
threads
=
256
;
const
size_t
Ns
=
4
,
ICs
=
4
;
dim3
blocks
=
dim3
(
DIVUP
(
N
,
Ns
),
DIVUP
(
OC
*
OH
*
OW
,
threads
));
dim3
blocks
=
dim3
(
DIVUP
(
N
,
Ns
),
DIVUP
(
OC
*
OH
*
OW
,
threads
));
if
(
is_xcorr
)
{
if
(
is_xcorr
)
{
forward_kernel
<
Ns
,
ICs
,
true
><<<
blocks
,
threads
,
forward_kernel
<
Ns
,
ICs
,
true
><<<
blocks
,
threads
,
...
...
dnn/src/cuda/local/local.cuh
浏览文件 @
d6b098a0
...
@@ -17,17 +17,10 @@ namespace megdnn {
...
@@ -17,17 +17,10 @@ namespace megdnn {
namespace
cuda
{
namespace
cuda
{
namespace
local
{
namespace
local
{
void
check_input
(
size_t
N
,
size_t
forward_proxy_default_share_mem_in_bytes
(
size_t
IH
,
size_t
IW
);
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
FH
,
size_t
FW
,
size_t
INs
,
size_t
ONs
,
size_t
PH
,
size_t
PW
,
size_t
SH
,
size_t
SW
,
bool
is_xcorr
);
void
forward_proxy_
weiming
(
const
float
*
src
,
const
float
*
filter
,
float
*
dst
,
void
forward_proxy_
default
(
const
float
*
src
,
const
float
*
filter
,
float
*
dst
,
size_t
N
,
size_t
N
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
FH
,
size_t
FW
,
size_t
FH
,
size_t
FW
,
...
@@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
...
@@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
/// forward
/// forward
bool
can_forward_proxy_convnet
(
size_t
N
,
bool
can_forward_proxy_convnet
(
size_t
N
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
FH
,
size_t
FW
,
size_t
FH
,
size_t
FW
,
...
@@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N,
...
@@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N,
/// bwd data
/// bwd data
bool
can_backward_data_proxy_convnet
(
size_t
N
,
bool
can_backward_data_proxy_convnet
(
size_t
N
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
FH
,
size_t
FW
,
size_t
FH
,
size_t
FW
,
...
@@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N,
...
@@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N,
size_t
PH
,
size_t
PW
,
size_t
PH
,
size_t
PW
,
size_t
SH
,
size_t
SW
);
size_t
SH
,
size_t
SW
);
void
backward_data_proxy_convnet
(
const
float
*
filter
,
void
backward_data_proxy_convnet
(
const
float
*
filter
,
const
float
*
diff
,
const
float
*
diff
,
float
*
grad
,
float
*
grad
,
float
*
workspace
,
float
*
workspace
,
...
@@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N,
...
@@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N,
/// bwd filter
/// bwd filter
bool
can_backward_filter_proxy_convnet
(
size_t
N
,
bool
can_backward_filter_proxy_convnet
(
size_t
N
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
IH
,
size_t
IW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
FH
,
size_t
FW
,
size_t
FH
,
size_t
FW
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录