Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
eeb17c26
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eeb17c26
编写于
7月 04, 2017
作者:
Z
zlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add depthwise operation and depthwise conv layer
上级
211f83fa
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
911 addition
and
0 deletion
+911
-0
paddle/function/DepthwiseConvOp.cpp
paddle/function/DepthwiseConvOp.cpp
+308
-0
paddle/function/DepthwiseConvOp.h
paddle/function/DepthwiseConvOp.h
+91
-0
paddle/function/DepthwiseConvOpGpu.cu
paddle/function/DepthwiseConvOpGpu.cu
+295
-0
paddle/gserver/layers/DepthwiseConvLayer.cpp
paddle/gserver/layers/DepthwiseConvLayer.cpp
+165
-0
paddle/gserver/layers/DepthwiseConvLayer.h
paddle/gserver/layers/DepthwiseConvLayer.h
+52
-0
未找到文件。
paddle/function/DepthwiseConvOp.cpp
0 → 100644
浏览文件 @
eeb17c26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DepthwiseConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
namespace
paddle
{
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template
<
class
T
>
class
DepthwiseConvFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
int
outputSize
,
const
T
*
inputData
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
outputData
)
{
// NO_IMPLEMENTATION
}
};
template
<
class
T
>
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
int
inputSize
,
const
T
*
outputGrad
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
inputGrad
)
{}
};
template
<
class
T
>
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
int
num_i
,
int
colDataSize
,
const
T
*
outputGrad
,
const
T
*
inputData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
colData
,
T
*
multiplierData
,
T
*
filterGrad
)
{}
};
/*
* \brief Forward calculation of convolution.
*/
template
<
DeviceType
Device
>
class
DepthwiseConvFunction
:
public
ConvFunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
ConvFunctionBase
::
init
(
config
);
}
virtual
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
input
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
output
=
outputs
[
0
].
shape
();
checkShape
(
input
,
filter
,
output
);
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
check
(
inputs
,
outputs
);
const
TensorShape
&
input
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
output
=
outputs
[
0
].
shape
();
size_t
batchSize
=
input
[
0
];
// size_t inputChannels = input[1];
// size_t inputHeight = input[2];
// size_t inputWidth = input[3];
size_t
filterHeight
=
getFilterHeight
(
filter
);
size_t
filterWidth
=
getFilterWidth
(
filter
);
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
real
*
inputData
=
inputs
[
0
].
data
<
real
>
();
real
*
filterData
=
inputs
[
1
].
data
<
real
>
();
real
*
outputData
=
outputs
[
0
].
data
<
real
>
();
size_t
outputSize
=
batchSize
*
outputChannels
*
outputHeight
*
outputWidth
;
DepthwiseConvFunctor
<
Device
,
real
>
depthwiseConv
;
depthwiseConv
(
outputSize
,
inputData
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
outputData
);
}
};
/*
* \brief Backward input calculation of convolution.
*/
template
<
DeviceType
Device
>
class
DepthwiseConvGradInputFunction
:
public
ConvFunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
ConvFunctionBase
::
init
(
config
);
}
virtual
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
checkShape
(
input
,
filter
,
output
);
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
check
(
inputs
,
outputs
);
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
getFilterHeight
(
filter
);
size_t
filterWidth
=
getFilterWidth
(
filter
);
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
real
*
outputGrad
=
inputs
[
0
].
data
<
real
>
();
real
*
filterData
=
inputs
[
1
].
data
<
real
>
();
real
*
inputGrad
=
outputs
[
0
].
data
<
real
>
();
size_t
inputSize
=
batchSize
*
inputChannels
*
inputHeight
*
inputWidth
;
DepthwiseConvGradInputFunctor
<
Device
,
real
>
depthwiseConvGradInput
;
depthwiseConvGradInput
(
inputSize
,
outputGrad
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
inputGrad
);
}
};
/*
* \brief Backward filter calculation of convolution.
*/
template
<
DeviceType
Device
>
class
DepthwiseConvGradFilterFunction
:
public
ConvFunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
ConvFunctionBase
::
init
(
config
);
}
virtual
void
check
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
checkShape
(
input
,
filter
,
output
);
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
check
(
inputs
,
outputs
);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
// const TensorShape& multiplier = inputs[2].shape();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
getFilterHeight
(
filter
);
size_t
filterWidth
=
getFilterWidth
(
filter
);
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
real
*
outputGrad
=
inputs
[
0
].
data
<
real
>
();
real
*
inputData
=
inputs
[
1
].
data
<
real
>
();
real
*
multiplierData
=
inputs
[
2
].
data
<
real
>
();
real
*
filterGrad
=
outputs
[
0
].
data
<
real
>
();
size_t
size
=
inputChannels
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
resizeBuffer
<
Device
>
(
size
);
real
*
colData
=
reinterpret_cast
<
real
*>
(
memory_
->
getBuf
());
DepthwiseConvGradFilterFunctor
<
Device
,
real
>
depthwiseConvGradFilter
;
for
(
size_t
i
=
0
;
i
<
batchSize
;
i
++
)
{
depthwiseConvGradFilter
(
i
,
size
,
outputGrad
,
inputData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
colData
,
multiplierData
,
filterGrad
);
}
}
};
REGISTER_TYPED_FUNC
(
DepthwiseConv
,
CPU
,
DepthwiseConvFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradInput
,
CPU
,
DepthwiseConvGradInputFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradFilter
,
CPU
,
DepthwiseConvGradFilterFunction
);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC
(
DepthwiseConv
,
GPU
,
DepthwiseConvFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradInput
,
GPU
,
DepthwiseConvGradInputFunction
);
REGISTER_TYPED_FUNC
(
DepthwiseConvGradFilter
,
GPU
,
DepthwiseConvGradFilterFunction
);
#endif
}
// namespace paddle
paddle/function/DepthwiseConvOp.h
0 → 100644
浏览文件 @
eeb17c26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvOp.h"
namespace
paddle
{
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template
<
DeviceType
Device
,
class
T
>
class
DepthwiseConvFunctor
{
public:
void
operator
()(
int
outputSize
,
const
T
*
inputData
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
outputData
);
};
template
<
DeviceType
Device
,
class
T
>
class
DepthwiseConvGradInputFunctor
{
public:
void
operator
()(
int
inputSize
,
const
T
*
outputGrad
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
inputGrad
);
};
template
<
DeviceType
Device
,
class
T
>
class
DepthwiseConvGradFilterFunctor
{
public:
void
operator
()(
int
num_i
,
int
colDataSize
,
const
T
*
outputGrad
,
const
T
*
inputData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
colData
,
T
*
multiplierData
,
T
*
filterGrad
);
};
// namespace paddle
}
// namespace paddle
paddle/function/DepthwiseConvOpGpu.cu
0 → 100644
浏览文件 @
eeb17c26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "DepthwiseConvOp.h"
namespace
paddle
{
template
<
class
T
>
__global__
void
ConvolutionDepthwiseWeightForward
(
const
int
nthreads
,
const
T
*
const
bottom_data
,
const
T
*
const
weight_data
,
const
int
num
,
const
int
channels
,
const
int
top_height
,
const
int
top_width
,
const
int
bottom_height
,
const
int
bottom_width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
T
*
const
top_data
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
n
=
index
/
channels
/
top_height
/
top_width
;
const
int
c
=
(
index
/
top_height
/
top_width
)
%
channels
;
const
int
h
=
(
index
/
top_width
)
%
top_height
;
const
int
w
=
index
%
top_width
;
const
T
*
weight
=
weight_data
+
c
*
kernel_h
*
kernel_w
;
T
value
=
0
;
for
(
int
kh
=
0
;
kh
<
kernel_h
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
kernel_w
;
++
kw
)
{
const
int
h_in
=
-
pad_h
+
h
*
stride_h
+
kh
*
dilation_h
;
const
int
w_in
=
-
pad_w
+
w
*
stride_w
+
kw
*
dilation_w
;
if
((
h_in
>=
0
)
&&
(
h_in
<
bottom_height
)
&&
(
w_in
>=
0
)
&&
(
w_in
<
bottom_width
))
{
const
int
offset
=
((
n
*
channels
+
c
)
*
bottom_height
+
h_in
)
*
bottom_width
+
w_in
;
value
+=
(
*
weight
)
*
bottom_data
[
offset
];
}
++
weight
;
}
}
top_data
[
index
]
=
value
;
}
}
template
<
class
T
>
__global__
void
ConvolutionDepthwiseBottomBackward
(
const
int
nthreads
,
const
T
*
const
top_diff
,
const
T
*
const
weight_data
,
const
int
num
,
const
int
channels
,
const
int
top_height
,
const
int
top_width
,
const
int
bottom_height
,
const
int
bottom_width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
T
*
const
bottom_diff
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
n
=
index
/
channels
/
bottom_height
/
bottom_width
;
const
int
c
=
(
index
/
bottom_height
/
bottom_width
)
%
channels
;
const
int
h
=
(
index
/
bottom_width
)
%
bottom_height
;
const
int
w
=
index
%
bottom_width
;
const
T
*
weight
=
weight_data
+
c
*
kernel_h
*
kernel_w
;
T
value
=
0
;
for
(
int
kh
=
0
;
kh
<
kernel_h
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
kernel_w
;
++
kw
)
{
const
int
h_out_s
=
h
+
pad_h
-
kh
*
dilation_h
;
const
int
w_out_s
=
w
+
pad_w
-
kw
*
dilation_w
;
if
(((
h_out_s
%
stride_h
)
==
0
)
&&
((
w_out_s
%
stride_w
)
==
0
))
{
const
int
h_out
=
h_out_s
/
stride_h
;
const
int
w_out
=
w_out_s
/
stride_w
;
//it affect the effectives
if
((
h_out
>=
0
)
&&
(
h_out
<
top_height
)
&&
(
w_out
>=
0
)
&&
(
w_out
<
top_width
))
{
const
int
offset
=
((
n
*
channels
+
c
)
*
top_height
+
h_out
)
*
top_width
+
w_out
;
value
+=
(
*
weight
)
*
top_diff
[
offset
];
}
}
++
weight
;
}
}
bottom_diff
[
index
]
+=
value
;
}
}
template
<
class
T
>
__global__
void
ConvolutionDepthwiseWeightBackward
(
const
int
num_i
,
const
int
nthreads
,
const
T
*
const
top_diff
,
const
T
*
const
bottom_data
,
const
int
num
,
const
int
channels
,
const
int
top_height
,
const
int
top_width
,
const
int
bottom_height
,
const
int
bottom_width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
T
*
const
buffer_data
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
h
=
(
index
/
top_width
)
%
top_height
;
const
int
w
=
index
%
top_width
;
const
int
kh
=
(
index
/
kernel_w
/
top_height
/
top_width
)
%
kernel_h
;
const
int
kw
=
(
index
/
top_height
/
top_width
)
%
kernel_w
;
const
int
h_in
=
-
pad_h
+
h
*
stride_h
+
kh
*
dilation_h
;
const
int
w_in
=
-
pad_w
+
w
*
stride_w
+
kw
*
dilation_w
;
if
((
h_in
>=
0
)
&&
(
h_in
<
bottom_height
)
&&
(
w_in
>=
0
)
&&
(
w_in
<
bottom_width
))
{
const
int
c
=
index
/
kernel_h
/
kernel_w
/
top_height
/
top_width
;
const
int
n
=
num_i
;
const
int
top_offset
=
((
n
*
channels
+
c
)
*
top_height
+
h
)
*
top_width
+
w
;
const
int
bottom_offset
=
((
n
*
channels
+
c
)
*
bottom_height
+
h_in
)
*
bottom_width
+
w_in
;
buffer_data
[
index
]
=
top_diff
[
top_offset
]
*
bottom_data
[
bottom_offset
];
}
else
{
buffer_data
[
index
]
=
0
;
}
}
}
template
<
class
T
>
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
int
outputSize
,
const
T
*
inputData
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
outputData
){
size_t
blocks
=
(
outputSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
ConvolutionDepthwiseWeightForward
<
T
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
outputSize
,
inputData
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
filterHeight
,
filterWidth
,
strideH
,
strideW
,
paddingH
,
paddingW
,
outputData
);
}
};
template
<
class
T
>
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
int
inputSize
,
const
T
*
outputGrad
,
const
T
*
filterData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
inputGrad
){
size_t
blocks
=
(
inputSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
ConvolutionDepthwiseBottomBackward
<
T
>
// NOLINT_NEXT_LINE(whitespace/operators)
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
inputSize
,
outputGrad
,
filterData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideH
,
strideW
,
paddingH
,
paddingW
,
inputGrad
);
}
};
template
<
class
T
>
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
int
num_i
,
int
colDataSize
,
const
T
*
outputGrad
,
const
T
*
inputData
,
int
batchSize
,
int
outputChannels
,
int
outputHeight
,
int
outputWidth
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
T
*
colData
,
T
*
multiplierData
,
T
*
filterGrad
){
size_t
blocks
=
(
colDataSize
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
ConvolutionDepthwiseWeightBackward
<
T
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
i
,
size
,
outputGrad
,
inputData
,
batchSize
,
outputChannels
,
outputHeight
,
outputWidth
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideH
,
strideW
,
paddingH
,
paddingW
,
colData
);
GemmFunctor
<
Device
,
real
>
gemm
;
int
M
=
size
/
outputHeight
/
outputWidth
;
int
N
=
1
;
int
K
=
outputHeight
*
outputWidth
;
gemm
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
colData
,
K
,
multiplierData
,
N
,
1.0
f
,
filterGrad
,
N
);
//gemv
}
};
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvGradInputFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
DepthwiseConvGradFilterFunctor
<
DEVICE_TYPE_GPU
,
double
>;
}
// namespace paddle
paddle/gserver/layers/DepthwiseConvLayer.cpp
0 → 100644
浏览文件 @
eeb17c26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DepthwiseConvLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
/*
* The calculation of the exconvt(convolution transpose (deconv) operation)
* is a swap of forward and backward of the calculation of exconv.
* */
REGISTER_LAYER
(
depthwise_conv
,
DepthwiseConvLayer
);
bool
DepthwiseConvLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
/* Initialize the basic convolutional parent class */
ExpandConvBaseLayer
::
init
(
layerMap
,
parameterMap
);
size_t
numInputs
=
config_
.
inputs_size
();
inputShape_
.
resize
(
numInputs
);
filterShape_
.
resize
(
numInputs
);
outputShape_
.
resize
(
numInputs
);
multiplierShape_
.
resize
(
numInputs
);
weightMultiplier_
.
resize
(
numInputs
);
for
(
int
i
=
0
;
i
<
config_
.
inputs_size
();
i
++
)
{
std
::
vector
<
size_t
>
paddings
=
{(
size_t
)
paddingY_
[
i
],
(
size_t
)
padding_
[
i
]};
std
::
vector
<
size_t
>
strides
=
{(
size_t
)
strideY_
[
i
],
(
size_t
)
stride_
[
i
]};
Matrix
::
resizeOrCreate
(
weightMultiplier_
[
i
],
(
size_t
)
outputH_
[
i
]
*
(
size_t
)
outputW_
[
i
],
(
size_t
)
1
,
false
,
useGpu_
);
weightMultiplier_
[
i
]
->
one
();
createFunction
(
forward_
,
"DepthwiseConv"
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
createFunction
(
backward_
,
"DepthwiseConvGradInput"
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
createFunction
(
backward_
,
"DepthwiseConvGradFilter"
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
}
return
true
;
}
// i is the index of input layers
#define BACKWARD_INPUT(i, inputs, outputs) \
backward_[2 * i]->calc(inputs, outputs)
#define BACKWARD_FILTER(i, inputs, outputs) \
backward_[2 * i + 1]->calc(inputs, outputs)
void
DepthwiseConvLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
size_t
batchSize
=
inputLayers_
[
0
]
->
getOutputValue
()
->
getHeight
();
resetOutput
(
batchSize
,
getOutputSize
());
// Calculate the shape of the input, output, and filter.
for
(
size_t
i
=
0
;
i
<
inputLayers_
.
size
();
++
i
)
{
inputShape_
[
i
]
=
TensorShape
({(
size_t
)
batchSize
,
(
size_t
)
channels_
[
i
],
(
size_t
)
imgSizeH_
[
i
],
(
size_t
)
imgSizeW_
[
i
]});
multiplierShape_
[
i
]
=
TensorShape
({(
size_t
)
outputH_
[
i
]
*
(
size_t
)
outputW_
[
i
],
(
size_t
)
1
});
filterShape_
[
i
]
=
TensorShape
({(
size_t
)
groups_
[
i
],
(
size_t
)
numFilters_
/
groups_
[
i
],
(
size_t
)
channels_
[
i
]
/
groups_
[
i
],
(
size_t
)
filterSizeY_
[
i
],
(
size_t
)
filterSize_
[
i
]});
outputShape_
[
i
]
=
TensorShape
({(
size_t
)
batchSize
,
(
size_t
)
numFilters_
,
(
size_t
)
outputH_
[
i
],
(
size_t
)
outputW_
[
i
]});
}
// Calculate the output value.
for
(
size_t
i
=
0
;
i
<
inputLayers_
.
size
();
++
i
)
{
BufferArgs
inputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getInputValue
(
i
),
inputShape_
[
i
]);
inputs
.
addArg
(
*
weights_
[
i
]
->
getW
(),
filterShape_
[
i
]);
outputs
.
addArg
(
*
getOutputValue
(),
outputShape_
[
i
],
i
==
0
?
ASSIGN_TO
:
ADD_TO
);
forward_
[
i
]
->
calc
(
inputs
,
outputs
);
}
/* add the bias-vector */
if
(
biases_
.
get
())
{
if
(
sharedBiases_
)
{
addSharedBias
();
}
else
{
addUnsharedBias
();
}
}
/* activation */
forwardActivation
();
}
void
DepthwiseConvLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
backwardActivation
();
MatrixPtr
outGrad
=
getOutputGrad
();
if
(
biases_
&&
biases_
->
getWGrad
())
{
bpropBiases
(
outGrad
);
/* Increasing the number of gradient */
biases_
->
getParameterPtr
()
->
incUpdate
(
callback
);
}
// Calculate the input grad and filter grad.
for
(
size_t
i
=
0
;
i
<
inputLayers_
.
size
();
++
i
)
{
if
(
getInputGrad
(
i
))
{
BufferArgs
inputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getOutputGrad
(),
outputShape_
[
i
]);
inputs
.
addArg
(
*
weights_
[
i
]
->
getW
(),
filterShape_
[
i
]);
outputs
.
addArg
(
*
getInputGrad
(
i
),
inputShape_
[
i
],
ADD_TO
);
BACKWARD_INPUT
(
i
,
inputs
,
outputs
);
}
if
(
weights_
[
i
]
->
getWGrad
())
{
BufferArgs
inputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getOutputGrad
(),
outputShape_
[
i
]);
inputs
.
addArg
(
*
getInputValue
(
i
),
inputShape_
[
i
]);
inputs
.
addArg
(
*
weightMultiplier_
[
i
],
multiplierShape_
[
i
]);
// weight_multiplier
outputs
.
addArg
(
*
weights_
[
i
]
->
getWGrad
(),
filterShape_
[
i
],
ADD_TO
);
BACKWARD_FILTER
(
i
,
inputs
,
outputs
);
/* Increasing the number of gradient */
weights_
[
i
]
->
getParameterPtr
()
->
incUpdate
(
callback
);
}
}
}
}
// namespace paddle
paddle/gserver/layers/DepthwiseConvLayer.h
0 → 100644
浏览文件 @
eeb17c26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ExpandConvBaseLayer.h"
#include "paddle/math/Matrix.h"
namespace
paddle
{
/**
* @brief A subclass of convolution layer.
* This layer expands input and use matrix multiplication to
* calculate convolution operation.
*
* The config file api is img_conv_layer.
*/
class
DepthwiseConvLayer
:
public
ExpandConvBaseLayer
{
public:
explicit
DepthwiseConvLayer
(
const
LayerConfig
&
config
)
:
ExpandConvBaseLayer
(
config
)
{}
~
DepthwiseConvLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
override
;
void
forward
(
PassType
passType
)
override
;
void
backward
(
const
UpdateCallback
&
callback
)
override
;
protected:
std
::
vector
<
TensorShape
>
inputShape_
;
std
::
vector
<
TensorShape
>
filterShape_
;
std
::
vector
<
TensorShape
>
outputShape_
;
std
::
vector
<
TensorShape
>
multiplierShape_
;
std
::
vector
<
MatrixPtr
>
weightMultiplier_
;
};
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录