Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dbb65880
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dbb65880
编写于
7月 18, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modity the format
上级
44927bf7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
61 addition
and
64 deletion
+61
-64
paddle/function/DepthwiseConvOp.cpp
paddle/function/DepthwiseConvOp.cpp
+3
-6
paddle/function/DepthwiseConvOpGpu.cu
paddle/function/DepthwiseConvOpGpu.cu
+58
-58
未找到文件。
paddle/function/DepthwiseConvOp.cpp
浏览文件 @
dbb65880
...
@@ -99,8 +99,7 @@ public:
...
@@ -99,8 +99,7 @@ public:
ConvFunctionBase
::
init
(
config
);
ConvFunctionBase
::
init
(
config
);
}
}
virtual
void
check
(
const
BufferArgs
&
inputs
,
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
input
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
output
=
outputs
[
0
].
shape
();
const
TensorShape
&
output
=
outputs
[
0
].
shape
();
...
@@ -162,8 +161,7 @@ public:
...
@@ -162,8 +161,7 @@ public:
ConvFunctionBase
::
init
(
config
);
ConvFunctionBase
::
init
(
config
);
}
}
virtual
void
check
(
const
BufferArgs
&
inputs
,
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
...
@@ -225,8 +223,7 @@ public:
...
@@ -225,8 +223,7 @@ public:
ConvFunctionBase
::
init
(
config
);
ConvFunctionBase
::
init
(
config
);
}
}
virtual
void
check
(
const
BufferArgs
&
inputs
,
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
...
...
paddle/function/DepthwiseConvOpGpu.cu
浏览文件 @
dbb65880
...
@@ -20,58 +20,58 @@ namespace paddle {
...
@@ -20,58 +20,58 @@ namespace paddle {
// CUDA kernel to compute the depthwise convolution forward pass
// CUDA kernel to compute the depthwise convolution forward pass
template
<
class
T
>
template
<
class
T
>
__global__
__global__
void
ConvolutionDepthwiseForward
(
const
int
nthreads
,
void
ConvolutionDepthwiseForward
(
const
int
nthreads
,
const
T
*
const
inputData
,
const
T
*
const
filterData
,
const
T
*
const
inputData
,
const
T
*
const
filterData
,
const
int
batchSize
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
batchSize
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
inputWidth
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
filterWidth
,
const
int
strideH
,
const
int
inputWidth
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
strideW
,
const
int
paddingH
,
const
int
padding
W
,
const
int
filterWidth
,
const
int
strideH
,
const
int
stride
W
,
T
*
const
outputData
)
{
const
int
paddingH
,
const
int
paddingW
,
T
*
const
outputData
)
{
int
index
=
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
if
(
index
<
nthreads
)
{
const
int
batch
=
index
/
outputChannels
/
outputHeight
/
outputWidth
;
const
int
batch
=
index
/
outputChannels
/
outputHeight
/
outputWidth
;
const
int
c_out
=
(
index
/
outputHeight
/
outputWidth
)
%
outputChannels
;
const
int
c_out
=
(
index
/
outputHeight
/
outputWidth
)
%
outputChannels
;
const
int
h_out
=
(
index
/
outputWidth
)
%
outputHeight
;
const
int
h_out
=
(
index
/
outputWidth
)
%
outputHeight
;
const
int
w_out
=
index
%
outputWidth
;
const
int
w_out
=
index
%
outputWidth
;
const
int
c_in
=
c_out
/
filterMultiplier
;
const
int
c_in
=
c_out
/
filterMultiplier
;
const
T
*
weight
=
filterData
+
c_out
*
filterHeight
*
filterWidth
;
const
T
*
weight
=
filterData
+
c_out
*
filterHeight
*
filterWidth
;
T
value
=
0
;
T
value
=
0
;
const
int
h_in_start
=
-
paddingH
+
h_out
*
strideH
;
const
int
h_in_start
=
-
paddingH
+
h_out
*
strideH
;
const
int
w_in_start
=
-
paddingW
+
w_out
*
strideW
;
const
int
w_in_start
=
-
paddingW
+
w_out
*
strideW
;
const
int
h_in_end
=
-
paddingH
+
h_out
*
strideH
+
filterHeight
-
1
;
const
int
h_in_end
=
-
paddingH
+
h_out
*
strideH
+
filterHeight
-
1
;
const
int
w_in_end
=
-
paddingW
+
w_out
*
strideW
+
filterWidth
-
1
;
const
int
w_in_end
=
-
paddingW
+
w_out
*
strideW
+
filterWidth
-
1
;
if
((
h_in_start
>=
0
)
&&
(
h_in_end
<
inputHeight
)
if
((
h_in_start
>=
0
)
&&
(
h_in_end
<
inputHeight
)
&&
(
w_in_start
>=
0
)
&&
(
w_in_end
<
inputWidth
))
{
&&
(
w_in_start
>=
0
)
&&
(
w_in_end
<
inputWidth
))
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
const
int
offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputHeight
+
h_in
)
const
int
offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputWidth
+
w_in
;
*
inputHeight
+
h_in
)
*
inputWidth
+
w_in
;
value
+=
(
*
weight
)
*
inputData
[
offset
];
value
+=
(
*
weight
)
*
inputData
[
offset
];
++
weight
;
++
weight
;
}
}
}
}
}
else
{
}
else
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
h_in
=
-
paddingH
+
h_out
*
strideH
+
kh
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
if
((
h_in
>=
0
)
&&
(
h_in
<
inputHeight
)
if
((
h_in
>=
0
)
&&
(
h_in
<
inputHeight
)
&&
(
w_in
>=
0
)
&&
(
w_in
<
inputWidth
))
{
&&
(
w_in
>=
0
)
&&
(
w_in
<
inputWidth
))
{
const
int
offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputHeight
+
h_in
)
const
int
offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputWidth
+
w_in
;
*
input
Height
+
h_in
)
*
input
Width
+
w_in
;
value
+=
(
*
weight
)
*
inputData
[
offset
];
value
+=
(
*
weight
)
*
inputData
[
offset
];
}
}
++
weight
;
++
weight
;
}
}
}
}
}
}
outputData
[
index
]
=
value
;
outputData
[
index
]
=
value
;
}
}
}
}
...
@@ -82,21 +82,21 @@ __global__
...
@@ -82,21 +82,21 @@ __global__
void
ConvolutionDepthwiseInputBackward
(
const
int
nthreads
,
void
ConvolutionDepthwiseInputBackward
(
const
int
nthreads
,
const
T
*
const
top_diff
,
const
T
*
const
weight_data
,
const
T
*
const
top_diff
,
const
T
*
const
weight_data
,
const
int
num
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
num
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
inputWidth
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
filterWidth
,
const
int
strideH
,
const
int
inputWidth
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
strideW
,
const
int
paddingH
,
const
int
padding
W
,
const
int
filterWidth
,
const
int
strideH
,
const
int
stride
W
,
T
*
const
bottom_diff
)
{
const
int
paddingH
,
const
int
paddingW
,
T
*
const
bottom_diff
)
{
int
index
=
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
if
(
index
<
nthreads
)
{
const
int
batch
=
index
/
inputChannels
/
inputHeight
/
inputWidth
;
const
int
batch
=
index
/
inputChannels
/
inputHeight
/
inputWidth
;
const
int
c_in
=
(
index
/
inputHeight
/
inputWidth
)
%
inputChannels
;
const
int
c_in
=
(
index
/
inputHeight
/
inputWidth
)
%
inputChannels
;
const
int
h_in
=
(
index
/
inputWidth
)
%
inputHeight
;
const
int
h_in
=
(
index
/
inputWidth
)
%
inputHeight
;
const
int
w_in
=
index
%
inputWidth
;
const
int
w_in
=
index
%
inputWidth
;
const
int
c_out_start
=
c_in
*
filterMultiplier
;
const
int
c_out_start
=
c_in
*
filterMultiplier
;
T
value
=
0
;
T
value
=
0
;
for
(
int
c_out
=
c_out_start
;
c_out
<
c_out_start
+
filterMultiplier
;
c_out
++
){
for
(
int
c_out
=
c_out_start
;
//weight bixu c_out
c_out
<
c_out_start
+
filterMultiplier
;
c_out
++
)
{
const
T
*
weight
=
weight_data
+
c_out
*
filterHeight
*
filterWidth
;
const
T
*
weight
=
weight_data
+
c_out
*
filterHeight
*
filterWidth
;
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
filterHeight
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
filterWidth
;
++
kw
)
{
...
@@ -105,11 +105,12 @@ void ConvolutionDepthwiseInputBackward(const int nthreads,
...
@@ -105,11 +105,12 @@ void ConvolutionDepthwiseInputBackward(const int nthreads,
if
(((
h_out_s
%
strideH
)
==
0
)
&&
((
w_out_s
%
strideW
)
==
0
))
{
if
(((
h_out_s
%
strideH
)
==
0
)
&&
((
w_out_s
%
strideW
)
==
0
))
{
const
int
h_out
=
h_out_s
/
strideH
;
const
int
h_out
=
h_out_s
/
strideH
;
const
int
w_out
=
w_out_s
/
strideW
;
const
int
w_out
=
w_out_s
/
strideW
;
// TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
// TODO(zhaolong) : the 'if' affect the effectiveness,
// it needs to optimize
if
((
h_out
>=
0
)
&&
(
h_out
<
outputHeight
)
if
((
h_out
>=
0
)
&&
(
h_out
<
outputHeight
)
&&
(
w_out
>=
0
)
&&
(
w_out
<
outputWidth
))
{
&&
(
w_out
>=
0
)
&&
(
w_out
<
outputWidth
))
{
const
int
offset
=
((
batch
*
outputChannels
+
c_out
)
*
outputHeight
+
h_out
)
const
int
offset
=
((
batch
*
outputChannels
+
c_out
)
*
outputWidth
+
w_out
;
*
outputHeight
+
h_out
)
*
outputWidth
+
w_out
;
value
+=
(
*
weight
)
*
top_diff
[
offset
];
value
+=
(
*
weight
)
*
top_diff
[
offset
];
}
}
}
}
...
@@ -127,10 +128,10 @@ __global__
...
@@ -127,10 +128,10 @@ __global__
void
ConvolutionDepthwiseFilterBackward
(
const
int
num_i
,
const
int
nthreads
,
void
ConvolutionDepthwiseFilterBackward
(
const
int
num_i
,
const
int
nthreads
,
const
T
*
const
top_diff
,
const
T
*
const
inputData
,
const
T
*
const
top_diff
,
const
T
*
const
inputData
,
const
int
num
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
num
,
const
int
outputChannels
,
const
int
outputHeight
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
inputWidth
,
const
int
outputWidth
,
const
int
inputChannels
,
const
int
inputHeight
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
filterWidth
,
const
int
strideH
,
const
int
inputWidth
,
const
int
filterMultiplier
,
const
int
filterHeight
,
const
int
strideW
,
const
int
paddingH
,
const
int
padding
W
,
const
int
filterWidth
,
const
int
strideH
,
const
int
stride
W
,
T
*
const
buffer_data
)
{
const
int
paddingH
,
const
int
paddingW
,
T
*
const
buffer_data
)
{
int
index
=
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
if
(
index
<
nthreads
)
{
...
@@ -143,13 +144,14 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
...
@@ -143,13 +144,14 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
const
int
w_in
=
-
paddingW
+
w_out
*
strideW
+
kw
;
if
((
h_in
>=
0
)
&&
(
h_in
<
inputHeight
)
if
((
h_in
>=
0
)
&&
(
h_in
<
inputHeight
)
&&
(
w_in
>=
0
)
&&
(
w_in
<
inputWidth
))
{
&&
(
w_in
>=
0
)
&&
(
w_in
<
inputWidth
))
{
const
int
c_out
=
index
/
filterHeight
/
filterWidth
/
outputHeight
/
outputWidth
;
const
int
c_out
=
index
/
const
int
c_in
=
c_out
/
filterMultiplier
;
(
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
);
const
int
c_in
=
c_out
/
filterMultiplier
;
const
int
batch
=
num_i
;
const
int
batch
=
num_i
;
const
int
top_offset
=
((
batch
*
outputChannels
+
c_out
)
*
outputHeight
+
h_out
)
const
int
top_offset
=
((
batch
*
outputChannels
+
c_out
)
*
*
outputWidth
+
w_out
;
outputHeight
+
h_out
)
*
outputWidth
+
w_out
;
const
int
bottom_offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputHeight
+
h_in
)
const
int
bottom_offset
=
((
batch
*
inputChannels
+
c_in
)
*
inputWidth
+
w_in
;
*
input
Height
+
h_in
)
*
input
Width
+
w_in
;
buffer_data
[
index
]
=
top_diff
[
top_offset
]
*
inputData
[
bottom_offset
];
buffer_data
[
index
]
=
top_diff
[
top_offset
]
*
inputData
[
bottom_offset
];
}
else
{
}
else
{
buffer_data
[
index
]
=
0
;
buffer_data
[
index
]
=
0
;
...
@@ -160,13 +162,13 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
...
@@ -160,13 +162,13 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
template
<
class
T
>
template
<
class
T
>
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
public:
void
operator
()(
const
T
*
inputData
,
void
operator
()(
const
T
*
inputData
,
const
T
*
filterData
,
const
T
*
filterData
,
int
batchSize
,
int
batchSize
,
int
outputChannels
,
int
outputChannels
,
int
outputHeight
,
int
outputHeight
,
int
outputWidth
,
int
outputWidth
,
int
inputChannels
,
int
inputChannels
,
int
inputHeight
,
int
inputHeight
,
int
inputWidth
,
int
inputWidth
,
int
filterMultiplier
,
int
filterMultiplier
,
...
@@ -177,7 +179,6 @@ public:
...
@@ -177,7 +179,6 @@ public:
int
paddingH
,
int
paddingH
,
int
paddingW
,
int
paddingW
,
T
*
outputData
){
T
*
outputData
){
int
outputSize
=
batchSize
*
outputChannels
*
outputHeight
*
outputWidth
;
int
outputSize
=
batchSize
*
outputChannels
*
outputHeight
*
outputWidth
;
size_t
blocks
=
(
outputSize
+
1024
-
1
)
/
1024
;
size_t
blocks
=
(
outputSize
+
1024
-
1
)
/
1024
;
...
@@ -188,14 +189,14 @@ public:
...
@@ -188,14 +189,14 @@ public:
ConvolutionDepthwiseForward
<
T
>
ConvolutionDepthwiseForward
<
T
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
outputSize
,
outputSize
,
inputData
,
inputData
,
filterData
,
filterData
,
batchSize
,
batchSize
,
outputChannels
,
outputChannels
,
outputHeight
,
outputHeight
,
outputWidth
,
outputWidth
,
inputChannels
,
inputChannels
,
inputHeight
,
inputHeight
,
inputWidth
,
inputWidth
,
filterMultiplier
,
filterMultiplier
,
...
@@ -229,7 +230,6 @@ public:
...
@@ -229,7 +230,6 @@ public:
int
paddingH
,
int
paddingH
,
int
paddingW
,
int
paddingW
,
T
*
inputGrad
){
T
*
inputGrad
){
int
inputSize
=
batchSize
*
inputChannels
*
inputHeight
*
inputWidth
;
int
inputSize
=
batchSize
*
inputChannels
*
inputHeight
*
inputWidth
;
size_t
blocks
=
(
inputSize
+
1024
-
1
)
/
1024
;
size_t
blocks
=
(
inputSize
+
1024
-
1
)
/
1024
;
...
@@ -249,7 +249,7 @@ public:
...
@@ -249,7 +249,7 @@ public:
outputChannels
,
outputChannels
,
outputHeight
,
outputHeight
,
outputWidth
,
outputWidth
,
inputChannels
,
inputChannels
,
inputHeight
,
inputHeight
,
inputWidth
,
inputWidth
,
filterMultiplier
,
filterMultiplier
,
...
@@ -284,17 +284,18 @@ public:
...
@@ -284,17 +284,18 @@ public:
int
paddingW
,
int
paddingW
,
T
*
colData
,
T
*
colData
,
T
*
filterGrad
){
T
*
filterGrad
){
int
colDataSize
=
outputChannels
*
filterHeight
*
filterWidth
int
colDataSize
=
outputChannels
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
*
outputHeight
*
outputWidth
;
size_t
blocks
=
(
colDataSize
+
1024
-
1
)
/
1024
;
size_t
blocks
=
(
colDataSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
dim3
grid
(
blockX
,
blockY
);
BaseMatrix
filterGradMatrix
(
outputChannels
*
filterHeight
*
filterWidth
,
1
,
filterGrad
,
false
,
true
);
BaseMatrix
filterGradMatrix
(
outputChannels
*
filterHeight
*
filterWidth
,
1
,
filterGrad
,
false
,
true
);
for
(
int
i
=
0
;
i
<
batchSize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batchSize
;
i
++
)
{
ConvolutionDepthwiseFilterBackward
<
T
>
ConvolutionDepthwiseFilterBackward
<
T
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
i
,
i
,
...
@@ -305,24 +306,23 @@ public:
...
@@ -305,24 +306,23 @@ public:
outputChannels
,
outputChannels
,
outputHeight
,
outputHeight
,
outputWidth
,
outputWidth
,
inputChannels
,
inputChannels
,
inputHeight
,
inputHeight
,
inputWidth
,
inputWidth
,
filterMultiplier
,
filterMultiplier
,
filterHeight
,
filterHeight
,
filterWidth
,
filterWidth
,
strideH
,
strideH
,
strideW
,
strideW
,
paddingH
,
paddingH
,
paddingW
,
paddingW
,
colData
colData
);
);
int
K
=
outputHeight
*
outputWidth
;
int
K
=
outputHeight
*
outputWidth
;
int
M
=
colDataSize
/
K
;
int
M
=
colDataSize
/
K
;
BaseMatrix
colMatrix
(
M
,
K
,
colData
,
false
,
true
);
BaseMatrix
colMatrix
(
M
,
K
,
colData
,
false
,
true
);
filterGradMatrix
.
sumRows
(
colMatrix
,
(
T
)
1.0
,
(
T
)
1.0
);
filterGradMatrix
.
sumRows
(
colMatrix
,
(
T
)
1.0
,
(
T
)
1.0
);
}
}
}
}
};
};
...
@@ -330,7 +330,7 @@ public:
...
@@ -330,7 +330,7 @@ public:
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
double
>;
#else
#else
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
float
>;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录