Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d7319c22
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d7319c22
编写于
11月 14, 2017
作者:
Z
Zhaolong Xing
提交者:
GitHub
11月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5165 from NHZlX/add_dilation
Add dilation for exconv layer
上级
3e6f7684
f3818bd3
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
299 addition
and
157 deletion
+299
-157
paddle/function/ConvOp.h
paddle/function/ConvOp.h
+6
-0
paddle/function/ConvOpTest.h
paddle/function/ConvOpTest.h
+53
-34
paddle/function/GemmConvOp.cpp
paddle/function/GemmConvOp.cpp
+9
-3
paddle/function/Im2Col.h
paddle/function/Im2Col.h
+6
-2
paddle/function/Im2ColOp.cpp
paddle/function/Im2ColOp.cpp
+24
-14
paddle/function/Im2ColOpGpu.cu
paddle/function/Im2ColOpGpu.cu
+55
-16
paddle/function/Im2ColTest.cpp
paddle/function/Im2ColTest.cpp
+92
-76
paddle/gserver/layers/ExpandConvLayer.cpp
paddle/gserver/layers/ExpandConvLayer.cpp
+10
-2
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+1
-1
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+26
-8
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+3
-1
python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
...config_helpers/tests/configs/protostr/img_layers.protostr
+2
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
..._helpers/tests/configs/protostr/img_trans_layers.protostr
+2
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/test_bilinear_interp.protostr
...pers/tests/configs/protostr/test_bilinear_interp.protostr
+2
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/test_maxout.protostr
...onfig_helpers/tests/configs/protostr/test_maxout.protostr
+4
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/test_pad.protostr
...r_config_helpers/tests/configs/protostr/test_pad.protostr
+2
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr
...lpers/tests/configs/protostr/test_roi_pool_layer.protostr
+2
-0
未找到文件。
paddle/function/ConvOp.h
浏览文件 @
d7319c22
...
...
@@ -61,6 +61,7 @@ public:
// function arguments
strides_
=
config
.
get
<
std
::
vector
<
size_t
>>
(
"strides"
);
paddings_
=
config
.
get
<
std
::
vector
<
size_t
>>
(
"paddings"
);
dilations_
=
config
.
get
<
std
::
vector
<
size_t
>>
(
"dilations"
);
groups_
=
config
.
get
<
size_t
>
(
"groups"
);
// number of inputs and outputs
...
...
@@ -118,6 +119,7 @@ protected:
std
::
vector
<
size_t
>
strides_
;
std
::
vector
<
size_t
>
paddings_
;
std
::
vector
<
size_t
>
dilations_
;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
...
...
@@ -133,6 +135,10 @@ protected:
inline
int
paddingW
()
const
{
return
paddings_
[
1
];
}
inline
int
dilationH
()
const
{
return
dilations_
[
0
];
}
inline
int
dilationW
()
const
{
return
dilations_
[
1
];
}
// A temporary memory in convolution calculation.
MemoryHandlePtr
memory_
;
...
...
paddle/function/ConvOpTest.h
浏览文件 @
d7319c22
...
...
@@ -79,45 +79,59 @@ void Convolution(const std::string& conv1,
if
(
outputChannels
<
inputChannels
)
continue
;
for
(
size_t
stride
:
{
1
,
2
})
{
for
(
size_t
padding
:
{
0
,
1
})
{
if
(
padding
>=
filterSize
)
break
;
for
(
size_t
dilation
:
{
1
,
3
})
{
if
(
padding
>=
filterSize
)
break
;
size_t
filterS
=
(
filterSize
-
1
)
*
dilation
+
1
;
// NNPACK only supports stride = 1 if batchSize > 1
if
((
conv1
==
"NNPACKConv-CPU"
||
conv2
==
"NNPACKConv-CPU"
)
&&
batchSize
>
1
&&
stride
>
1
)
break
;
if
(
inputSize
+
2
*
padding
<
filterS
)
break
;
size_t
outputSize
=
(
inputSize
-
filterSize
+
2
*
padding
+
stride
)
/
stride
;
VLOG
(
3
)
<<
" batchSize="
<<
batchSize
<<
" inputChannels="
<<
inputChannels
<<
" inputHeight="
<<
inputSize
<<
" inputWidth="
<<
inputSize
<<
" outputChannels="
<<
outputChannels
<<
" filterHeight="
<<
filterSize
<<
" filterWidth="
<<
filterSize
<<
" outputHeight="
<<
outputSize
<<
" outputWidth="
<<
outputSize
<<
" stride="
<<
stride
<<
" padding="
<<
padding
;
if
((
conv1
==
"NaiveConv-CPU"
||
conv2
==
"NaiveConv-CPU"
||
conv1
==
"NNPACKConv-CPU"
||
conv2
==
"NNPACKConv-CPU"
)
&&
dilation
>
1
)
break
;
std
::
vector
<
size_t
>
paddings
=
{
padding
,
padding
};
std
::
vector
<
size_t
>
strides
=
{
stride
,
stride
};
Compare2Function
<
DType1
,
DType2
>
test
(
conv1
,
conv2
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"algo"
,
(
std
::
string
)
"auto"
));
// NNPACK only supports stride = 1 if batchSize > 1
if
((
conv1
==
"NNPACKConv-CPU"
||
conv2
==
"NNPACKConv-CPU"
)
&&
batchSize
>
1
&&
stride
>
1
)
break
;
TensorShape
input
{
batchSize
,
inputChannels
,
inputSize
,
inputSize
};
TensorShape
filter
{
outputChannels
,
inputChannels
,
filterSize
,
filterSize
};
TensorShape
output
{
batchSize
,
outputChannels
,
outputSize
,
outputSize
};
size_t
outputSize
=
(
inputSize
-
filterS
+
2
*
padding
+
stride
)
/
stride
;
VLOG
(
3
)
<<
" batchSize="
<<
batchSize
<<
" inputChannels="
<<
inputChannels
<<
" inputHeight="
<<
inputSize
<<
" inputWidth="
<<
inputSize
<<
" outputChannels="
<<
outputChannels
<<
" filterHeight="
<<
filterSize
<<
" filterWidth="
<<
filterSize
<<
" outputHeight="
<<
outputSize
<<
" outputWidth="
<<
outputSize
<<
" stride="
<<
stride
<<
" padding="
<<
padding
;
function
(
test
,
input
,
filter
,
output
);
std
::
vector
<
size_t
>
paddings
=
{
padding
,
padding
};
std
::
vector
<
size_t
>
strides
=
{
stride
,
stride
};
std
::
vector
<
size_t
>
dilations
=
{
dilation
,
dilation
};
Compare2Function
<
DType1
,
DType2
>
test
(
conv1
,
conv2
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"dilations"
,
dilations
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"algo"
,
(
std
::
string
)
"auto"
));
TensorShape
input
{
batchSize
,
inputChannels
,
inputSize
,
inputSize
};
TensorShape
filter
{
outputChannels
,
inputChannels
,
filterSize
,
filterSize
};
TensorShape
output
{
batchSize
,
outputChannels
,
outputSize
,
outputSize
};
function
(
test
,
input
,
filter
,
output
);
}
}
}
}
...
...
@@ -144,6 +158,7 @@ void Convolution2(const std::string& conv1,
for
(
size_t
outputChannels
:
{
7
})
{
size_t
stride
=
1
;
size_t
padding
=
0
;
size_t
dilation
=
1
;
size_t
outputHeight
=
(
inputHeight
-
filterHeight
+
2
*
padding
+
stride
)
/
stride
;
...
...
@@ -162,6 +177,7 @@ void Convolution2(const std::string& conv1,
std
::
vector
<
size_t
>
paddings
=
{
padding
,
padding
};
std
::
vector
<
size_t
>
strides
=
{
stride
,
stride
};
std
::
vector
<
size_t
>
dilations
=
{
dilation
,
dilation
};
Compare2Function
<
DType1
,
DType2
>
test
(
conv1
,
conv2
,
...
...
@@ -169,6 +185,7 @@ void Convolution2(const std::string& conv1,
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"dilations"
,
dilations
)
.
set
(
"algo"
,
(
std
::
string
)
"auto"
));
TensorShape
input
{
...
...
@@ -223,6 +240,7 @@ void DepthwiseConvolution(const std::string& conv1,
std
::
vector
<
size_t
>
paddings
=
{
padding
,
padding
};
std
::
vector
<
size_t
>
strides
=
{
stride
,
stride
};
std
::
vector
<
size_t
>
dilations
=
{
1
,
1
};
size_t
groups
=
inputChannels
;
Compare2Function
<
DType1
,
DType2
>
test
(
conv1
,
...
...
@@ -231,6 +249,7 @@ void DepthwiseConvolution(const std::string& conv1,
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
groups
)
.
set
(
"dilations"
,
dilations
)
.
set
(
"algo"
,
(
std
::
string
)
"auto"
));
TensorShape
input
{
...
...
paddle/function/GemmConvOp.cpp
浏览文件 @
d7319c22
...
...
@@ -100,7 +100,9 @@ public:
strideH
(),
strideW
(),
paddingH
(),
paddingW
());
paddingW
(),
dilationH
(),
dilationW
());
}
else
{
colData
=
inputData
+
g
*
inputOffset
;
}
...
...
@@ -223,7 +225,9 @@ public:
strideH
(),
strideW
(),
paddingH
(),
paddingW
());
paddingW
(),
dilationH
(),
dilationW
());
}
}
inputGrad
+=
inputChannels
*
inputHeight
*
inputWidth
;
...
...
@@ -310,7 +314,9 @@ public:
strideH
(),
strideW
(),
paddingH
(),
paddingW
());
paddingW
(),
dilationH
(),
dilationW
());
}
else
{
colData
=
inputData
+
g
*
inputOffset
;
}
...
...
paddle/function/Im2Col.h
浏览文件 @
d7319c22
...
...
@@ -78,7 +78,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
);
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
);
};
template
<
ColFormat
Format
,
DeviceType
Device
,
class
T
>
...
...
@@ -91,7 +93,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
);
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
);
};
}
// namespace paddle
paddle/function/Im2ColOp.cpp
浏览文件 @
d7319c22
...
...
@@ -31,7 +31,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -47,8 +49,8 @@ public:
int
c_im
=
c
/
filterWidth
/
filterHeight
;
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
int
imRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
;
int
imRowIdx
=
h
*
strideHeight
+
hOffset
*
dilationHeight
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
*
dilationWidth
;
if
((
imRowIdx
-
paddingHeight
)
<
0
||
(
imRowIdx
-
paddingHeight
)
>=
inputHeight
||
(
imColIdx
-
paddingWidth
)
<
0
||
...
...
@@ -81,7 +83,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -97,8 +101,8 @@ public:
int
c_im
=
c
/
filterWidth
/
filterHeight
;
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
int
imRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
;
int
imRowIdx
=
h
*
strideHeight
+
hOffset
*
dilationHeight
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
*
dilationWidth
;
if
((
imRowIdx
-
paddingHeight
)
>=
0
&&
(
imRowIdx
-
paddingHeight
)
<
inputHeight
&&
(
imColIdx
-
paddingWidth
)
>=
0
&&
...
...
@@ -134,7 +138,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -147,9 +153,10 @@ public:
for
(
int
channel
=
0
;
channel
<
inputChannels
;
++
channel
)
{
for
(
int
filterH
=
0
;
filterH
<
filterHeight
;
++
filterH
)
{
for
(
int
filterW
=
0
;
filterW
<
filterWidth
;
++
filterW
)
{
int
imRowOffset
=
outputH
*
strideHeight
+
filterH
-
paddingHeight
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
-
paddingWidth
;
int
imRowOffset
=
outputH
*
strideHeight
+
filterH
*
dilationHeight
-
paddingHeight
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
*
dilationWidth
-
paddingWidth
;
int
colDataOffset
=
(((
outputH
*
outputWidth
+
outputW
)
*
inputChannels
+
channel
)
*
...
...
@@ -189,7 +196,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
=
1
,
int
dilationWidth
=
1
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -202,9 +211,10 @@ public:
for
(
int
channel
=
0
;
channel
<
inputChannels
;
++
channel
)
{
for
(
int
filterH
=
0
;
filterH
<
filterHeight
;
++
filterH
)
{
for
(
int
filterW
=
0
;
filterW
<
filterWidth
;
++
filterW
)
{
int
imRowOffset
=
outputH
*
strideHeight
+
filterH
-
paddingHeight
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
-
paddingWidth
;
int
imRowOffset
=
outputH
*
strideHeight
+
filterH
*
dilationHeight
-
paddingHeight
;
int
imColOffset
=
outputW
*
strideWidth
+
filterW
*
dilationWidth
-
paddingWidth
;
int
colDataOffset
=
(((
outputH
*
outputWidth
+
outputW
)
*
inputChannels
+
channel
)
*
...
...
paddle/function/Im2ColOpGpu.cu
浏览文件 @
d7319c22
...
...
@@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
int
strideW
,
int
paddingH
,
int
paddingW
,
int
dilationH
,
int
dilationW
,
int
height_col
,
int
width_col
,
T
*
data_col
)
{
...
...
@@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
data_col
+=
(
channel_out
*
height_col
+
h_out
)
*
width_col
+
w_out
;
for
(
int
i
=
0
;
i
<
blockH
;
++
i
)
{
for
(
int
j
=
0
;
j
<
blockW
;
++
j
)
{
int
rIdx
=
int
(
h_in
+
i
);
int
cIdx
=
int
(
w_in
+
j
);
int
rIdx
=
int
(
h_in
+
i
*
dilationH
);
int
cIdx
=
int
(
w_in
+
j
*
dilationW
);
if
((
rIdx
-
(
int
)
paddingH
)
>=
(
int
)
height
||
(
rIdx
-
(
int
)
paddingH
)
<
0
||
(
cIdx
-
(
int
)
paddingW
)
>=
(
int
)
width
||
...
...
@@ -77,7 +79,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -102,6 +106,8 @@ public:
strideWidth
,
paddingHeight
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputWidth
,
colData
);
...
...
@@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
size_t
strideW
,
size_t
paddingH
,
size_t
paddingW
,
size_t
dilationH
,
size_t
dilationW
,
size_t
height_col
,
size_t
width_col
,
T
*
data_im
)
{
...
...
@@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
int
w
=
int
(
index
%
width
);
int
h
=
int
((
index
/
width
)
%
height
);
int
c
=
int
(
index
/
(
width
*
height
));
int
filterH
=
(
blockH
-
1
)
*
dilationH
+
1
;
int
filterW
=
(
blockW
-
1
)
*
dilationW
+
1
;
if
((
w
-
(
int
)
paddingW
)
>=
0
&&
(
w
-
(
int
)
paddingW
)
<
(
width
-
2
*
paddingW
)
&&
(
h
-
(
int
)
paddingH
)
>=
0
&&
(
h
-
paddingH
)
<
(
height
-
2
*
paddingH
))
{
// compute the start and end of the output
int
w_col_start
=
(
w
<
(
int
)
blockW
)
?
0
:
(
w
-
int
(
block
W
))
/
(
int
)
strideW
+
1
;
(
w
<
(
int
)
filterW
)
?
0
:
(
w
-
int
(
filter
W
))
/
(
int
)
strideW
+
1
;
int
w_col_end
=
min
((
int
)(
w
/
(
int
)
strideW
+
1
),
(
int
)(
width_col
));
int
h_col_start
=
(
h
<
(
int
)
blockH
)
?
0
:
(
h
-
(
int
)
block
H
)
/
(
int
)
strideH
+
1
;
(
h
<
(
int
)
filterH
)
?
0
:
(
h
-
(
int
)
filter
H
)
/
(
int
)
strideH
+
1
;
int
h_col_end
=
min
(
int
(
h
/
strideH
+
1
),
int
(
height_col
));
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
// the col location: [c * width * height + h_out, w_out]
int
c_col
=
int
(
c
*
blockH
*
blockW
)
+
(
h
-
h_col
*
(
int
)
strideH
)
*
(
int
)
blockW
+
(
w
-
w_col
*
(
int
)
strideW
);
val
+=
data_col
[(
c_col
*
height_col
+
h_col
)
*
width_col
+
w_col
];
int
h_k
=
(
h
-
h_col
*
strideH
);
int
w_k
=
(
w
-
w_col
*
strideW
);
if
(
h_k
%
dilationH
==
0
&&
w_k
%
dilationW
==
0
)
{
h_k
/=
dilationH
;
w_k
/=
dilationW
;
int
c_col
=
(((
c
*
blockH
+
h_k
)
*
blockW
+
w_k
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
val
+=
data_col
[
c_col
];
}
}
}
h
-=
paddingH
;
...
...
@@ -173,7 +192,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -205,6 +226,8 @@ public:
strideWidth
,
paddingHeight
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputWidth
,
imData
);
...
...
@@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
,
int
outputHeight
,
int
outputWidth
)
{
int
swId
=
blockIdx
.
x
;
...
...
@@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
channelId
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
int
widthOffset
=
idx
*
dilationHeight
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
*
dilationWidth
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
channelId
*
inputHeight
*
inputWidth
;
...
...
@@ -273,7 +300,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -312,6 +341,8 @@ public:
strideWidth
,
paddingHeight
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputWidth
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
...
...
@@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
,
int
outputHeight
,
int
outputWidth
)
{
int
swId
=
blockIdx
.
x
;
...
...
@@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
channelId
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
int
widthOffset
=
idx
*
dilationWidth
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
*
dilationHeight
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
channelId
*
inputHeight
*
inputWidth
;
...
...
@@ -372,7 +407,9 @@ public:
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
,
int
dilationHeight
,
int
dilationWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
...
...
@@ -411,6 +448,8 @@ public:
strideWidth
,
paddingHeight
,
paddingWidth
,
dilationHeight
,
dilationWidth
,
outputHeight
,
outputWidth
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
...
...
paddle/function/Im2ColTest.cpp
浏览文件 @
d7319c22
...
...
@@ -29,82 +29,98 @@ void TestIm2ColFunctor() {
for
(
size_t
filterWidth
:
{
3
,
7
})
{
for
(
size_t
stride
:
{
1
,
2
})
{
for
(
size_t
padding
:
{
0
,
1
})
{
if
(
inputHeight
<=
filterHeight
||
inputWidth
<=
filterWidth
)
break
;
if
(
padding
>=
filterHeight
||
padding
>=
filterWidth
)
break
;
size_t
outputHeight
=
(
inputHeight
-
filterHeight
+
2
*
padding
+
stride
)
/
stride
;
size_t
outputWidth
=
(
inputWidth
-
filterWidth
+
2
*
padding
+
stride
)
/
stride
;
TensorShape
imShape
=
TensorShape
({
channels
,
inputHeight
,
inputWidth
});
TensorShape
colShape1
=
TensorShape
({
channels
,
filterHeight
,
filterWidth
,
outputHeight
,
outputWidth
});
TensorShape
colShape2
=
TensorShape
({
outputHeight
,
outputWidth
,
channels
,
filterHeight
,
filterWidth
});
size_t
height
=
channels
*
filterHeight
*
filterWidth
;
size_t
width
=
outputHeight
*
outputWidth
;
VectorPtr
input1
=
Vector
::
create
(
imShape
.
getElements
(),
false
);
VectorPtr
input2
=
Vector
::
create
(
imShape
.
getElements
(),
false
);
MatrixPtr
output1
=
Matrix
::
create
(
height
,
width
,
false
,
false
);
MatrixPtr
output2
=
Matrix
::
create
(
width
,
height
,
false
,
false
);
input1
->
uniform
(
0.001
,
1
);
input2
->
copyFrom
(
*
input1
);
Im2ColFunctor
<
kCFO
,
Device
,
T
>
im2Col1
;
Im2ColFunctor
<
kOCF
,
Device
,
T
>
im2Col2
;
im2Col1
(
input1
->
getData
(),
imShape
,
output1
->
getData
(),
colShape1
,
stride
,
stride
,
padding
,
padding
);
im2Col2
(
input2
->
getData
(),
imShape
,
output2
->
getData
(),
colShape2
,
stride
,
stride
,
padding
,
padding
);
// The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF.
MatrixPtr
test
;
output2
->
transpose
(
test
,
true
);
autotest
::
TensorCheckErr
(
*
output1
,
*
test
);
Col2ImFunctor
<
kCFO
,
Device
,
T
>
col2Im1
;
Col2ImFunctor
<
kOCF
,
Device
,
T
>
col2Im2
;
col2Im1
(
input1
->
getData
(),
imShape
,
output1
->
getData
(),
colShape1
,
stride
,
stride
,
padding
,
padding
);
col2Im2
(
input2
->
getData
(),
imShape
,
output2
->
getData
(),
colShape2
,
stride
,
stride
,
padding
,
padding
);
autotest
::
TensorCheckErr
(
*
input1
,
*
input2
);
for
(
size_t
dilation
:
{
1
,
3
})
{
size_t
filterSizeH
=
(
filterHeight
-
1
)
*
dilation
+
1
;
size_t
filterSizeW
=
(
filterWidth
-
1
)
*
dilation
+
1
;
if
(
inputHeight
+
2
*
padding
<
filterSizeH
||
inputWidth
+
2
*
padding
<
filterSizeW
)
break
;
if
(
padding
>=
filterSizeH
||
padding
>=
filterSizeW
)
break
;
size_t
outputHeight
=
(
inputHeight
-
filterSizeH
+
2
*
padding
)
/
stride
+
1
;
size_t
outputWidth
=
(
inputWidth
-
filterSizeW
+
2
*
padding
)
/
stride
+
1
;
TensorShape
imShape
=
TensorShape
({
channels
,
inputHeight
,
inputWidth
});
TensorShape
colShape1
=
TensorShape
({
channels
,
filterHeight
,
filterWidth
,
outputHeight
,
outputWidth
});
TensorShape
colShape2
=
TensorShape
({
outputHeight
,
outputWidth
,
channels
,
filterHeight
,
filterWidth
});
size_t
height
=
channels
*
filterHeight
*
filterWidth
;
size_t
width
=
outputHeight
*
outputWidth
;
VectorPtr
input1
=
Vector
::
create
(
imShape
.
getElements
(),
false
);
VectorPtr
input2
=
Vector
::
create
(
imShape
.
getElements
(),
false
);
MatrixPtr
output1
=
Matrix
::
create
(
height
,
width
,
false
,
false
);
MatrixPtr
output2
=
Matrix
::
create
(
width
,
height
,
false
,
false
);
input1
->
uniform
(
0.001
,
1
);
input2
->
copyFrom
(
*
input1
);
Im2ColFunctor
<
kCFO
,
Device
,
T
>
im2Col1
;
Im2ColFunctor
<
kOCF
,
Device
,
T
>
im2Col2
;
im2Col1
(
input1
->
getData
(),
imShape
,
output1
->
getData
(),
colShape1
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
);
im2Col2
(
input2
->
getData
(),
imShape
,
output2
->
getData
(),
colShape2
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
);
// The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF.
MatrixPtr
test
;
output2
->
transpose
(
test
,
true
);
autotest
::
TensorCheckErr
(
*
output1
,
*
test
);
Col2ImFunctor
<
kCFO
,
Device
,
T
>
col2Im1
;
Col2ImFunctor
<
kOCF
,
Device
,
T
>
col2Im2
;
col2Im1
(
input1
->
getData
(),
imShape
,
output1
->
getData
(),
colShape1
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
);
col2Im2
(
input2
->
getData
(),
imShape
,
output2
->
getData
(),
colShape2
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
);
autotest
::
TensorCheckErr
(
*
input1
,
*
input2
);
}
}
}
}
...
...
paddle/gserver/layers/ExpandConvLayer.cpp
浏览文件 @
d7319c22
...
...
@@ -79,6 +79,10 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for
(
int
i
=
0
;
i
<
config_
.
inputs_size
();
i
++
)
{
std
::
vector
<
size_t
>
paddings
=
{(
size_t
)
paddingY_
[
i
],
(
size_t
)
padding_
[
i
]};
std
::
vector
<
size_t
>
strides
=
{(
size_t
)
strideY_
[
i
],
(
size_t
)
stride_
[
i
]};
std
::
vector
<
size_t
>
dilations
=
{(
size_t
)
dilationY_
[
i
],
(
size_t
)
dilation_
[
i
]};
bool
useDilation
=
((
size_t
)
dilationY_
[
i
]
>
1
||
(
size_t
)
dilation_
[
i
]
>
1
);
// Convolution Layer uses the GemmConv function by default.
convType
=
"GemmConv"
;
...
...
@@ -97,13 +101,14 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if
((
filterSize_
[
i
]
==
filterSizeY_
[
i
])
&&
(
filterSize_
[
i
]
==
3
||
filterSize_
[
i
]
==
4
)
&&
(
stride_
[
i
]
==
strideY_
[
i
])
&&
(
stride_
[
i
]
==
1
||
stride_
[
i
]
==
2
))
{
(
stride_
[
i
]
==
strideY_
[
i
])
&&
(
stride_
[
i
]
==
1
||
stride_
[
i
]
==
2
)
&&
!
useDilation
)
{
convType
=
"NeonDepthwiseConv"
;
}
#endif
}
if
(
FLAGS_use_nnpack
&&
!
isDeconv_
)
{
if
(
FLAGS_use_nnpack
&&
!
isDeconv_
&&
!
useDilation
)
{
createFunction
(
forward_
,
"NNPACKConv"
,
FuncConfig
()
...
...
@@ -117,6 +122,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"dilations"
,
dilations
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
createFunction
(
backward_
,
...
...
@@ -124,6 +130,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"dilations"
,
dilations
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
createFunction
(
backward_
,
...
...
@@ -131,6 +138,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"dilations"
,
dilations
)
.
set
(
"groups"
,
(
size_t
)
groups_
[
i
]));
}
}
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
d7319c22
...
...
@@ -434,7 +434,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config
.
layerConfig
.
set_partial_sum
(
1
);
config
.
layerConfig
.
set_shared_biases
(
true
);
int
dilation
=
1
;
int
dilation
=
2
;
if
(
type
==
"cudnn_conv"
)
{
#if CUDNN_VERSION >= 6000
dilation
=
2
;
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
d7319c22
...
...
@@ -1200,8 +1200,14 @@ def TestData(data_config, async_load_data=None):
#caffe_mode: compute the output size using floor instead of ceil,
# which is consistent of caffe and CuDNN's convention.
def
cnn_output_size
(
img_size
,
filter_size
,
padding
,
stride
,
caffe_mode
):
output
=
(
2
*
padding
+
img_size
-
filter_size
)
/
float
(
stride
)
def
cnn_output_size
(
img_size
,
filter_size
,
padding
,
stride
,
caffe_mode
,
dilation
=
1
):
filter_s
=
(
filter_size
-
1
)
*
dilation
+
1
output
=
(
2
*
padding
+
img_size
-
filter_s
)
/
float
(
stride
)
if
caffe_mode
:
return
1
+
int
(
math
.
floor
(
output
))
else
:
...
...
@@ -1210,8 +1216,14 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
#calcualte image_size based on output_size for de-convolution (ConvTransLayer).
#It is the reverse function of cnn_output_size
def
cnn_image_size
(
output_size
,
filter_size
,
padding
,
stride
,
caffe_mode
):
img_size
=
(
output_size
-
1
)
*
stride
+
filter_size
-
2
*
padding
def
cnn_image_size
(
output_size
,
filter_size
,
padding
,
stride
,
caffe_mode
,
dilation
=
1
):
filter_s
=
(
filter_size
-
1
)
*
dilation
+
1
img_size
=
(
output_size
-
1
)
*
stride
+
filter_s
-
2
*
padding
if
not
caffe_mode
:
img_size
=
img_size
+
1
return
img_size
...
...
@@ -1376,6 +1388,12 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf
.
stride_y
=
conv
.
stride_y
conv_conf
.
groups
=
conv
.
groups
conv_conf
.
caffe_mode
=
conv
.
caffe_mode
if
not
conv
.
dilation
:
conv
.
dilation
=
1
conv
.
dilation_y
=
1
else
:
conv_conf
.
dilation
=
conv
.
dilation
conv_conf
.
dilation_y
=
conv
.
dilation_y
if
not
trans
:
conv_conf
.
filter_channels
=
conv
.
channels
/
conv
.
groups
...
...
@@ -1383,20 +1401,20 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
get_img_size
(
input_layer_name
,
conv
.
channels
)
conv_conf
.
output_x
=
cnn_output_size
(
conv_conf
.
img_size
,
conv_conf
.
filter_size
,
conv_conf
.
padding
,
conv_conf
.
stride
,
conv_conf
.
caffe_mode
)
conv_conf
.
stride
,
conv_conf
.
caffe_mode
,
conv
.
dilation
)
conv_conf
.
output_y
=
cnn_output_size
(
conv_conf
.
img_size_y
,
conv_conf
.
filter_size_y
,
conv_conf
.
padding_y
,
conv_conf
.
stride_y
,
conv_conf
.
caffe_mode
)
conv_conf
.
stride_y
,
conv_conf
.
caffe_mode
,
conv
.
dilation_y
)
else
:
conv_conf
.
filter_channels
=
num_filters
/
conv
.
groups
conv_conf
.
output_x
,
conv_conf
.
output_y
=
\
get_img_size
(
input_layer_name
,
conv
.
channels
)
conv_conf
.
img_size
=
cnn_image_size
(
conv_conf
.
output_x
,
conv_conf
.
filter_size
,
conv_conf
.
padding
,
conv_conf
.
stride
,
conv_conf
.
caffe_mode
)
conv_conf
.
stride
,
conv_conf
.
caffe_mode
,
conv
.
dilation
)
conv_conf
.
img_size_y
=
cnn_image_size
(
conv_conf
.
output_y
,
conv_conf
.
filter_size_y
,
conv_conf
.
padding_y
,
conv_conf
.
stride_y
,
conv_conf
.
caffe_mode
)
conv_conf
.
stride_y
,
conv_conf
.
caffe_mode
,
conv
.
dilation_y
)
#caffe_mode: compute the output size using floor instead of ceil,
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
d7319c22
...
...
@@ -2571,7 +2571,9 @@ def img_conv_layer(input,
if
layer_type
:
if
dilation
>
1
or
dilation_y
>
1
:
assert
layer_type
in
[
"cudnn_conv"
,
"cudnn_convt"
]
assert
layer_type
in
[
"cudnn_conv"
,
"cudnn_convt"
,
"exconv"
,
"exconvt"
]
if
trans
:
assert
layer_type
in
[
"exconvt"
,
"cudnn_convt"
]
else
:
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
浏览文件 @
d7319c22
...
...
@@ -28,6 +28,8 @@ layers {
stride_y: 1
output_y: 227
img_size_y: 256
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
浏览文件 @
d7319c22
...
...
@@ -28,6 +28,8 @@ layers {
stride_y: 1
output_y: 227
img_size_y: 256
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/test_bilinear_interp.protostr
浏览文件 @
d7319c22
...
...
@@ -28,6 +28,8 @@ layers {
stride_y: 1
output_y: 48
img_size_y: 48
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/test_maxout.protostr
浏览文件 @
d7319c22
...
...
@@ -30,6 +30,8 @@ layers {
stride_y: 1
output_y: 48
img_size_y: 48
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
...
...
@@ -105,6 +107,8 @@ layers {
stride_y: 1
output_y: 24
img_size_y: 24
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_1__.wbias"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/test_pad.protostr
浏览文件 @
d7319c22
...
...
@@ -30,6 +30,8 @@ layers {
stride_y: 1
output_y: 48
img_size_y: 48
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr
浏览文件 @
d7319c22
...
...
@@ -36,6 +36,8 @@ layers {
stride_y: 1
output_y: 14
img_size_y: 14
dilation: 1
dilation_y: 1
}
}
bias_parameter_name: "___conv_0__.wbias"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录