Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4fc9f55e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4fc9f55e
编写于
11月 15, 2017
作者:
C
chengduo
提交者:
GitHub
11月 15, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5472 from chengduoZH/refine_im2col
Add dilations for conv2d and optimize conv2d code
上级
09866fb7
00e0881b
变更
17
展开全部
隐藏空白更改
内联
并排
Showing
17 changed file
with
944 addition
and
635 deletion
+944
-635
paddle/operators/conv_cudnn_op.cc
paddle/operators/conv_cudnn_op.cc
+0
-2
paddle/operators/conv_op.cc
paddle/operators/conv_op.cc
+40
-14
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+79
-46
paddle/operators/conv_transpose_op.cc
paddle/operators/conv_transpose_op.cc
+12
-9
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+24
-19
paddle/operators/math/context_project.h
paddle/operators/math/context_project.h
+29
-24
paddle/operators/math/im2col.cc
paddle/operators/math/im2col.cc
+142
-151
paddle/operators/math/im2col.cu
paddle/operators/math/im2col.cu
+210
-193
paddle/operators/math/im2col.h
paddle/operators/math/im2col.h
+17
-7
paddle/operators/math/im2col_test.cc
paddle/operators/math/im2col_test.cc
+12
-12
paddle/operators/math/vol2col.cc
paddle/operators/math/vol2col.cc
+83
-38
paddle/operators/math/vol2col.cu
paddle/operators/math/vol2col.cu
+120
-58
paddle/operators/math/vol2col.h
paddle/operators/math/vol2col.h
+19
-8
paddle/operators/math/vol2col_test.cc
paddle/operators/math/vol2col_test.cc
+11
-9
paddle/operators/sequence_conv_op.h
paddle/operators/sequence_conv_op.h
+11
-11
python/paddle/v2/fluid/tests/test_conv2d_op.py
python/paddle/v2/fluid/tests/test_conv2d_op.py
+72
-13
python/paddle/v2/fluid/tests/test_conv3d_op.py
python/paddle/v2/fluid/tests/test_conv3d_op.py
+63
-21
未找到文件。
paddle/operators/conv_cudnn_op.cc
浏览文件 @
4fc9f55e
...
...
@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
Conv2DOpMaker
(
proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"dilations of convolution operator."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
});
AddAttr
<
int
>
(
"workspace_size_MB"
,
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
...
...
paddle/operators/conv_op.cc
浏览文件 @
4fc9f55e
...
...
@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
int
input_channels
=
in_dims
[
1
];
int
output_channels
=
filter_dims
[
0
];
...
...
@@ -52,9 +53,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"The number of output channels should be divided by groups."
);
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
for
(
size_t
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
in_dims
[
i
+
2
]
+
2
*
paddings
[
i
]
-
(
dilations
[
i
]
*
(
filter_dims
[
i
+
2
]
-
1
)
+
1
)
>
0
,
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again."
);
output_shape
.
push_back
(
OutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
paddings
[
i
],
strides
[
i
]));
dilations
[
i
],
paddings
[
i
],
strides
[
i
]));
}
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
}
...
...
@@ -78,9 +85,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"strides of convolution operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator."
)
.
SetDefault
({
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"paddings of convolution operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector<int> default:{0, 0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator."
)
.
SetDefault
({
0
,
0
});
AddAttr
<
int
>
(
"groups"
,
...
...
@@ -90,15 +103,20 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator."
)
.
SetDefault
({
1
,
1
});
AddComment
(
R"DOC(
Convolution Operator.
The convolution operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
and strides, paddings, groups
, dilations
parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is
the width of the feature. Parameters(ksize, strides, paddings) are two elements.
the width of the feature. Parameters(ksize, strides, paddings
, dilations
) are two elements.
These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different.
...
...
@@ -109,8 +127,8 @@ Example:
Output:
Output shape: (N, C_out, H_out, W_out)
where
H_out = (H_in
- filter_size[0] + 2 * paddings[0]
) / strides[0] + 1;
W_out = (W_in
- filter_size[1] + 2 * paddings[1]
) / strides[1] + 1;
H_out = (H_in
+ 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)
) / strides[0] + 1;
W_out = (W_in
+ 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)
) / strides[1] + 1;
)DOC"
);
}
...
...
@@ -135,13 +153,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector, default:{0, 0, 0}), the strides of convolution operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
"convolution operator."
)
.
SetDefault
({
1
,
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector, default:{0, 0, 0}), the paddings of convolution operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector<int>, default:{0, 0, 0}), the "
"paddings(d_pad, h_pad, w_pad) of convolution "
"operator."
)
.
SetDefault
({
0
,
0
,
0
});
AddAttr
<
int
>
(
"groups"
,
...
...
@@ -151,6 +171,12 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation, h_dilation, w_dilation) of "
"convolution operator. Currently, conv3d doesn't "
"support dilation."
)
.
SetDefault
({
1
,
1
,
1
});
AddComment
(
R"DOC(
Convolution3D Operator.
...
...
paddle/operators/conv_op.h
浏览文件 @
4fc9f55e
...
...
@@ -27,11 +27,24 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline
int
OutputSize
(
int
input_size
,
int
filter_size
,
int
padding
,
int
stride
)
{
int
output_size
=
(
input_size
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
inline
int
OutputSize
(
int
input_size
,
int
filter_size
,
int
dilation
,
int
padding
,
int
stride
)
{
const
int
dkernel
=
dilation
*
(
filter_size
-
1
)
+
1
;
const
int
output_size
=
(
input_size
+
2
*
padding
-
dkernel
)
/
stride
+
1
;
return
output_size
;
}
inline
bool
IsExpand
(
std
::
vector
<
int64_t
>&
filter_dim
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
std
::
vector
<
int
>&
dilations
)
{
bool
filter_1
=
true
,
strides_1
=
true
,
padding_0
=
true
,
dilation_1
=
true
;
for
(
size_t
j
=
0
;
j
<
strides
.
size
();
++
j
)
{
filter_1
=
filter_1
&&
(
static_cast
<
int
>
(
filter_dim
[
j
])
==
1
);
strides_1
=
strides_1
&&
(
strides
[
j
]
==
1
);
padding_0
=
padding_0
&&
(
paddings
[
j
]
==
0
);
dilation_1
=
dilation_1
&&
(
dilations
[
j
]
==
1
);
}
return
!
(
filter_1
&&
strides_1
&&
padding_0
&&
dilation_1
);
}
// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
...
...
@@ -50,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class
ConvOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
class
ConvOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
...
...
@@ -73,9 +84,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
context
.
Attr
<
int
>
(
"group
s"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilation
s"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
...
@@ -106,14 +118,17 @@ class GemmConvKernel : public framework::OpKernel<T> {
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor
col_matrix
;
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
if
(
is_expand
)
{
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
input
->
dims
(),
1
,
static_cast
<
int
>
(
input
->
dims
().
size
()));
...
...
@@ -130,24 +145,30 @@ class GemmConvKernel : public framework::OpKernel<T> {
int
in_step
=
static_cast
<
int
>
(
input
->
dims
()[
1
])
/
groups
;
int
out_step
=
static_cast
<
int
>
(
output
->
dims
()[
1
])
/
groups
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
Tensor
in_slice
=
in_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
!
is_expand
)
{
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
2
)
{
// im2col
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
]
,
paddings
[
1
]
);
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
padding
s
[
0
],
paddings
[
1
]}
,
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
// gemm
...
...
@@ -178,9 +199,10 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if
(
!
input_grad
&&
!
filter_grad
)
return
;
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
context
.
Attr
<
int
>
(
"group
s"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilation
s"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
...
@@ -230,14 +252,17 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
int
in_step
=
static_cast
<
int
>
(
input
->
dims
()[
1
])
/
groups
;
int
out_step
=
static_cast
<
int
>
(
output_grad
->
dims
()[
1
])
/
groups
;
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
Tensor
col
;
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor
col_matrix
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
if
(
is_expand
)
{
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
math
::
SetConstant
<
Place
,
T
>
set_zero
;
...
...
@@ -245,6 +270,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
set_zero
(
context
.
device_context
(),
input_grad
,
static_cast
<
T
>
(
0
));
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
out_grad_batch
=
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
...
@@ -254,24 +282,26 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
Tensor
out_grad_slice
=
out_grad_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
filter_slice
=
filter
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter_slice
,
true
,
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
// col2im
Tensor
in_grad_slice
=
in_grad_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
if
(
filter_shape_vec
.
size
()
==
2
)
{
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
col2im
(
context
.
device_context
(),
in_grad_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
if
(
!
is_expand
)
{
col_matrix
.
ShareDataWith
(
in_grad_slice
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter_slice
,
true
,
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
col2vol
(
context
.
device_context
(),
in_grad_slice
,
col
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
2
)
{
col2im
(
context
.
device_context
(),
col
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
in_grad_slice
);
}
else
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
3
)
{
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
&
in_grad_slice
);
}
}
}
...
...
@@ -282,7 +312,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
Tensor
filter_grad_
=
*
filter_grad
;
filter_grad_
.
Resize
(
filter_matrix_shape
);
set_zero
(
context
.
device_context
(),
filter_grad
,
static_cast
<
T
>
(
0
));
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
out_grad_batch
=
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
...
@@ -293,16 +324,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
in_slice
=
in_batch
.
Slice
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
if
(
filter_shape_vec
.
size
()
==
2
)
{
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
if
(
!
is_expand
)
{
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
2
)
{
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
// gemm
...
...
paddle/operators/conv_transpose_op.cc
浏览文件 @
4fc9f55e
...
...
@@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"as the number of filters."
);
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
1
]});
for
(
size_t
i
=
0
;
i
<
padding
s
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
stride
s
.
size
();
++
i
)
{
output_shape
.
push_back
((
in_dims
[
i
+
2
]
-
1
)
*
strides
[
i
]
+
filter_dims
[
i
+
2
]);
}
...
...
@@ -79,11 +79,13 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
"The format of output tensor is also NCHW."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector defalut:{1, 1}), strides of convolution transpose operator."
)
"(vector<int> defalut:{1, 1}), the strides(h_stride, w_stride) of "
"convolution transpose operator."
)
.
SetDefault
({
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector defalut:{0, 0}), paddings of convolution transpose operator."
)
"(vector<int> defalut:{0, 0}), the paddings(h_pad, w_pad) of convolution "
"transpose operator."
)
.
SetDefault
({
0
,
0
});
AddComment
(
R"DOC(
Convolution2D Transpose Operator.
...
...
@@ -132,13 +134,14 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector defalut:{1, 1, 1}), strides of convolution transpose operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int> defalut:{1, 1, 1}), the "
"strides{d_stride, h_stride, w_stride} of "
"convolution transpose operator."
)
.
SetDefault
({
1
,
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector defalut:{0, 0, 0}), paddings
of convolution transpose operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector<int> defalut:{0, 0, 0}), paddings(d_pad, "
"h_pad, w_pad)
of convolution transpose operator."
)
.
SetDefault
({
0
,
0
,
0
});
AddComment
(
R"DOC(
Convolution3D Transpose Operator.
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
4fc9f55e
...
...
@@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
class
ConvTransposeOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
class
ConvTransposeOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
...
...
@@ -66,6 +62,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
// Actually, no paddings and groups allowed in conv transpose.
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
...
...
@@ -120,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
context
.
device_context
(),
output
,
static_cast
<
T
>
(
0
));
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
std
::
vector
<
int
>
dilations
({
1
,
1
,
1
});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
...
...
@@ -138,16 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
col2im
(
context
.
device_context
(),
output_batch
,
col
,
strides
[
0
],
strides
[
1
],
0
,
0
,
0
,
0
);
col2im
(
context
.
device_context
(),
col
,
std
::
vector
<
int
>
{
dilations
[
0
],
dilations
[
1
]},
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
output_batch
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
col2vol
(
context
.
device_context
(),
output_batch
,
col
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
0
,
0
,
0
);
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
std
::
vector
<
int
>
{
0
,
0
,
0
},
&
output_batch
);
}
}
}
...
...
@@ -228,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor
filter_grad_
;
math
::
SetConstant
<
Place
,
T
>
set_zero
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
std
::
vector
<
int
>
dilations
({
1
,
1
,
1
});
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
set_zero
(
context
.
device_context
(),
input_grad
,
static_cast
<
T
>
(
0
));
...
...
@@ -247,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
im2col
(
context
.
device_context
(),
output_grad_batch
,
std
::
vector
<
int
>
{
dilations
[
0
],
dilations
[
1
]},
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
vol2col
(
context
.
device_context
(),
output_grad_batch
,
dilations
,
strides
,
paddings
,
&
col
);
}
if
(
input_grad
)
{
...
...
paddle/operators/math/context_project.h
浏览文件 @
4fc9f55e
...
...
@@ -88,13 +88,18 @@ template <typename Place, typename T>
class
ContextProjectFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
LoDTensor
&
in
,
const
Tensor
&
padding_data
,
Tensor
&
col
,
bool
padding_trainable
,
int
context_start
,
int
context_length
,
int
context_stride
,
int
up_pad
,
int
down_pad
)
{
const
Tensor
&
padding_data
,
bool
padding_trainable
,
const
int
context_start
,
const
int
context_length
,
const
int
context_stride
,
const
int
up_pad
,
const
int
down_pad
,
Tensor
*
col
)
{
auto
lod_level_0
=
in
.
lod
()[
0
];
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
std
::
vector
<
int
>
dilation
({
1
,
1
});
std
::
vector
<
int
>
padding
({
up_pad
,
0
,
down_pad
,
0
});
std
::
vector
<
int
>
stride
({
context_stride
,
1
});
int
input_row_begin
,
input_row_end
;
int
sequence_height
,
sequence_width
;
sequence_width
=
in
.
dims
()[
1
];
...
...
@@ -105,8 +110,8 @@ class ContextProjectFunctor {
:
static_cast
<
int
>
(
lod_level_0
[
i
]);
input_row_end
=
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]);
Tensor
out_t
=
col
.
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
Tensor
out_t
=
col
->
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
sequence_height
=
static_cast
<
int
>
(
out_t
.
dims
()[
0
]);
...
...
@@ -123,17 +128,14 @@ class ContextProjectFunctor {
{
1
,
input_row_end
-
input_row_begin
,
sequence_width
});
// input_channels, input_height, input_width
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
im2col_ocf
(
context
,
in_t
,
out_t
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
down_pad
,
0
,
0
);
im2col_ocf
(
context
,
in_t
,
dilation
,
stride
,
padding
,
&
out_t
);
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
}
}
if
(
padding_trainable
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod_level_0
.
size
())
-
1
;
++
i
)
{
Tensor
out_t
=
col
.
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
Tensor
out_t
=
col
->
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
sequence_height
=
static_cast
<
int
>
(
out_t
.
dims
()[
0
]);
...
...
@@ -196,14 +198,19 @@ class ContextProjectFunctor {
template
<
typename
Place
,
typename
T
>
class
ContextProjectGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
LoDTensor
&
in
,
Tensor
&
padding_data
,
Tensor
&
col
,
bool
padding_trainable
,
int
context_start
,
int
context_length
,
int
context_stride
,
int
up_pad
,
int
down_pad
,
bool
input_grad
,
bool
pad_grad
)
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
LoDTensor
&
in
,
bool
padding_trainable
,
const
int
context_start
,
const
int
context_length
,
const
int
context_stride
,
const
int
up_pad
,
const
int
down_pad
,
bool
pad_grad
,
bool
input_grad
,
Tensor
*
padding_data
,
Tensor
*
col
)
{
auto
lod_level_0
=
in
.
lod
()[
0
];
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
col2im_ocf
;
std
::
vector
<
int
>
dilation
({
1
,
1
});
std
::
vector
<
int
>
padding
({
up_pad
,
0
,
down_pad
,
0
});
std
::
vector
<
int
>
stride
({
context_stride
,
1
});
int
input_row_begin
,
input_row_end
;
int
sequence_height
,
sequence_width
;
sequence_width
=
in
.
dims
()[
1
];
...
...
@@ -215,8 +222,8 @@ class ContextProjectGradFunctor {
:
static_cast
<
int
>
(
lod_level_0
[
i
]);
input_row_end
=
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]);
Tensor
out_t
=
col
.
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
Tensor
out_t
=
col
->
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
sequence_height
=
static_cast
<
int
>
(
out_t
.
dims
()[
0
]);
...
...
@@ -234,9 +241,7 @@ class ContextProjectGradFunctor {
sequence_width
});
// input_channels, input_height, input_width
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
col2im_ocf
(
context
,
in_t
,
out_t
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
down_pad
,
0
,
0
);
col2im_ocf
(
context
,
out_t
,
dilation
,
stride
,
padding
,
&
in_t
);
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
}
}
...
...
@@ -244,8 +249,8 @@ class ContextProjectGradFunctor {
if
(
pad_grad
)
{
if
(
padding_trainable
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod_level_0
.
size
())
-
1
;
++
i
)
{
Tensor
out_t
=
col
.
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
Tensor
out_t
=
col
->
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
static_cast
<
int
>
(
lod_level_0
[
i
+
1
]));
sequence_height
=
static_cast
<
int
>
(
out_t
.
dims
()[
0
]);
out_t
.
Resize
({
sequence_height
*
context_length
,
sequence_width
});
...
...
@@ -259,7 +264,7 @@ class ContextProjectGradFunctor {
k
+
context_length
<
up_pad
?
context_length
:
up_pad
-
k
;
Tensor
out_t_sub
=
out_t
.
Slice
(
k
*
context_length
,
k
*
context_length
+
padding_size
);
Tensor
w_sub
=
padding_data
.
Slice
(
k
,
k
+
padding_size
);
Tensor
w_sub
=
padding_data
->
Slice
(
k
,
k
+
padding_size
);
auto
out_t_sub_e
=
EigenMatrix
<
T
>::
From
(
out_t_sub
);
auto
w_sub_e
=
EigenMatrix
<
T
>::
From
(
w_sub
);
w_sub_e
.
device
(
*
context
.
GetEigenDevice
<
Place
>
())
=
...
...
@@ -292,7 +297,7 @@ class ContextProjectGradFunctor {
Tensor
out_t_sub
=
out_t
.
Slice
(
(
down_pad_begin_row
+
t
)
*
context_length
-
padding_size
,
(
down_pad_begin_row
+
t
)
*
context_length
);
Tensor
w_sub
=
padding_data
.
Slice
(
Tensor
w_sub
=
padding_data
->
Slice
(
up_pad
+
padding_idx
,
up_pad
+
padding_idx
+
padding_size
);
auto
out_t_sub_e
=
EigenMatrix
<
T
>::
From
(
out_t_sub
);
auto
w_sub_e
=
EigenMatrix
<
T
>::
From
(
w_sub
);
...
...
paddle/operators/math/im2col.cc
浏览文件 @
4fc9f55e
...
...
@@ -28,57 +28,55 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
->
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
col_height
=
col
->
dims
()[
3
];
int
col_width
=
col
->
dims
()[
4
];
PADDLE_ENFORCE_EQ
(
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
,
output
_height
,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
,
output
_width
,
"output_width and padding(padding_left, padding_right
) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
im_height
+
padding
[
0
]
+
padding
[
2
]
-
((
dilation
[
0
]
*
(
filter_height
-
1
)
+
1
))
)
/
stride
[
0
]
+
1
,
col
_height
,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
im_width
+
padding
[
1
]
+
padding
[
3
]
-
((
dilation
[
1
]
*
(
filter_width
-
1
)
+
1
))
)
/
stride
[
1
]
+
1
,
col
_width
,
"Output_height and padding(padding_up, padding_down
) are "
"inconsistent."
);
int
channels_col
=
i
nput
_channels
*
filter_height
*
filter_width
;
int
channels_col
=
i
m
_channels
*
filter_height
*
filter_width
;
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
-
padding_up
;
int
im_col_idx
=
w
*
stride_width
+
w_offset
-
padding_left
;
for
(
int
h
=
0
;
h
<
col_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
col_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
col_idx
=
(
c
*
col_height
+
h
)
*
col_width
+
w
;
int
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
if
(
im_row_idx
<
0
||
im_row_idx
>=
input_height
||
im_col_idx
<
0
||
im_col_idx
>=
input_width
)
{
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
T
(
0
);
}
else
{
im_row_idx
+=
c_im
*
input_height
;
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
im_data
[
im_row_idx
*
input_width
+
im_col_idx
];
}
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
?
static_cast
<
T
>
(
0
)
:
im_data
[
im_idx
];
}
}
}
...
...
@@ -94,54 +92,55 @@ template <class T>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
)
{
PADDLE_ENFORCE
(
im
->
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput_channels
=
im
.
dims
()[
0
];
int
i
nput_height
=
im
.
dims
()[
1
];
int
i
nput_width
=
im
.
dims
()[
2
];
int
i
m_channels
=
im
->
dims
()[
0
];
int
i
m_height
=
im
->
dims
()[
1
];
int
i
m_width
=
im
->
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
output
_height
=
col
.
dims
()[
3
];
int
output
_width
=
col
.
dims
()[
4
];
int
col
_height
=
col
.
dims
()[
3
];
int
col
_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE_EQ
(
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
,
output
_height
,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
,
output
_width
,
"output_width and padding(padding_left, padding_right
) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
im_height
+
padding
[
0
]
+
padding
[
2
]
-
((
dilation
[
0
]
*
(
filter_height
-
1
)
+
1
))
)
/
stride
[
0
]
+
1
,
col
_height
,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
im_width
+
padding
[
1
]
+
padding
[
3
]
-
((
dilation
[
1
]
*
(
filter_width
-
1
)
+
1
))
)
/
stride
[
1
]
+
1
,
col
_width
,
"Output_height and padding(padding_up, padding_down
) are "
"inconsistent."
);
int
channels_col
=
i
nput
_channels
*
filter_height
*
filter_width
;
int
channels_col
=
i
m
_channels
*
filter_height
*
filter_width
;
T
*
im_data
=
im
.
data
<
T
>
();
T
*
im_data
=
im
->
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
for
(
int
h
=
0
;
h
<
output
_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
output
_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride
_height
+
h_offset
-
padding_up
;
int
im_col_idx
=
w
*
stride
_width
+
w_offset
-
padding_left
;
for
(
int
h
=
0
;
h
<
col
_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
col
_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
]
;
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
]
;
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
i
nput
_height
&&
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
i
nput
_width
)
{
im_row_idx
+=
c_im
*
i
nput
_height
;
im_data
[
im_row_idx
*
i
nput
_width
+
im_col_idx
]
+=
col_data
[(
c
*
output_height
+
h
)
*
output
_width
+
w
];
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
i
m
_height
&&
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
i
m
_width
)
{
im_row_idx
+=
c_im
*
i
m
_height
;
im_data
[
im_row_idx
*
i
m
_width
+
im_col_idx
]
+=
col_data
[(
c
*
col_height
+
h
)
*
col
_width
+
w
];
}
}
}
...
...
@@ -168,64 +167,59 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
PADDLE_ENFORCE
(
col
->
dims
().
size
()
==
5
);
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
3
];
int
filter_width
=
col
->
dims
()[
4
];
int
col_height
=
col
->
dims
()[
0
];
int
col_width
=
col
->
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
,
output_height
,
(
im_height
+
padding
[
0
]
+
padding
[
2
]
-
filter_height
)
/
stride
[
0
]
+
1
,
col_height
,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
,
output_width
,
"output_width and padding(padding_left, padding_right) are "
(
im_width
+
padding
[
1
]
+
padding
[
3
]
-
filter_width
)
/
stride
[
1
]
+
1
,
col_width
,
"col_width and padding(padding_left, padding_right) are "
"inconsistent."
);
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
nput
_channels
;
++
channel
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
col
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
col
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
m
_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
++
filter_row_idx
)
{
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
++
filter_col_idx
)
{
int
im_row_offset
=
col_row_idx
*
stride
_height
+
filter_row_idx
-
padding_up
;
col_row_idx
*
stride
[
0
]
+
filter_row_idx
-
padding
[
0
]
;
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
((((
col_row_idx
)
*
output_width
+
col_col_idx
)
*
input_channels
+
channel
)
*
filter_height
+
filter_row_idx
)
*
filter_width
+
filter_col_idx
;
if
(
im_row_offset
<
0
||
im_row_offset
>=
input_height
||
im_col_offset
<
0
||
im_col_offset
>=
input_width
)
{
col_data
[
col_offset
]
=
T
(
0
);
}
else
{
int
im_offset
=
(
channel
*
input_height
+
im_row_offset
)
*
input_width
+
im_col_offset
;
col_data
[
col_offset
]
=
im_data
[
im_offset
];
}
col_col_idx
*
stride
[
1
]
+
filter_col_idx
-
padding
[
1
];
int
col_offset
=
((((
col_row_idx
)
*
col_width
+
col_col_idx
)
*
im_channels
+
channel
)
*
filter_height
+
filter_row_idx
)
*
filter_width
+
filter_col_idx
;
int
im_offset
=
(
channel
*
im_height
+
im_row_offset
)
*
im_width
+
im_col_offset
;
col_data
[
col_offset
]
=
(
im_row_offset
<
0
||
im_row_offset
>=
im_height
||
im_col_offset
<
0
||
im_col_offset
>=
im_width
)
?
static_cast
<
T
>
(
0
)
:
im_data
[
im_offset
];
}
}
}
...
...
@@ -243,60 +237,57 @@ template <class T>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
)
{
PADDLE_ENFORCE
(
im
->
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput_channels
=
im
.
dims
()[
0
];
int
i
nput_height
=
im
.
dims
()[
1
];
int
i
nput_width
=
im
.
dims
()[
2
];
int
i
m_channels
=
im
->
dims
()[
0
];
int
i
m_height
=
im
->
dims
()[
1
];
int
i
m_width
=
im
->
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
output
_height
=
col
.
dims
()[
0
];
int
output
_width
=
col
.
dims
()[
1
];
int
col
_height
=
col
.
dims
()[
0
];
int
col
_width
=
col
.
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
,
output_height
,
(
im_height
+
padding
[
0
]
+
padding
[
2
]
-
filter_height
)
/
stride
[
0
]
+
1
,
col_height
,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
,
output_width
,
"output_width and padding(padding_left, padding_right) are "
(
im_width
+
padding
[
1
]
+
padding
[
3
]
-
filter_width
)
/
stride
[
1
]
+
1
,
col_width
,
"col_width and padding(padding_left, padding_right) are "
"inconsistent."
);
T
*
im_data
=
im
.
data
<
T
>
();
T
*
im_data
=
im
->
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
nput
_channels
;
++
channel
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
col
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
col
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
m
_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
++
filter_row_idx
)
{
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
++
filter_col_idx
)
{
int
im_row_offset
=
col_row_idx
*
stride
_height
+
filter_row_idx
-
padding_up
;
col_row_idx
*
stride
[
0
]
+
filter_row_idx
-
padding
[
0
]
;
int
im_col_offset
=
col_col_idx
*
stride
_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
(((
col_row_idx
*
output_width
+
col_col_idx
)
*
input
_channels
+
channel
)
*
filter_height
+
filter_row_idx
)
*
filter_width
+
filter_col_idx
;
if
(
im_row_offset
>=
0
&&
im_row_offset
<
i
nput
_height
&&
im_col_offset
>=
0
&&
im_col_offset
<
i
nput
_width
)
{
col_col_idx
*
stride
[
1
]
+
filter_col_idx
-
padding
[
1
]
;
int
col_offset
=
(((
col_row_idx
*
col_width
+
col_col_idx
)
*
im
_channels
+
channel
)
*
filter_height
+
filter_row_idx
)
*
filter_width
+
filter_col_idx
;
if
(
im_row_offset
>=
0
&&
im_row_offset
<
i
m
_height
&&
im_col_offset
>=
0
&&
im_col_offset
<
i
m
_width
)
{
int
im_offset
=
(
channel
*
i
nput_height
+
im_row_offset
)
*
input
_width
+
(
channel
*
i
m_height
+
im_row_offset
)
*
im
_width
+
im_col_offset
;
im_data
[
im_offset
]
+=
col_data
[
col_offset
];
}
...
...
paddle/operators/math/im2col.cu
浏览文件 @
4fc9f55e
此差异已折叠。
点击以展开。
paddle/operators/math/im2col.h
浏览文件 @
4fc9f55e
...
...
@@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
* \param colData Column data.
* \param colShape The shape of colData.
*
* \param dilations dilation data.
* \param 2-dimension [dilation_height, dilation_width].
*
* \param strides stride data.
* \param 2-dimension [stride_height, stride_width].
*
* \param paddings padding data.
* \param 4-dimension [up_pad, left_pad, down_pad, right_pad].
*
* If the template argument Format is kCFO, the shape of colData is:
* [input_channels, filter_height, filter_width, output_height, output_width]
* So, it is easy to reshape into a convolution matrix for convolution
...
...
@@ -73,18 +82,19 @@ template <ColFormat Format, typename Place, typename T>
class
Im2ColFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
);
};
template
<
ColFormat
Format
,
typename
Place
,
typename
T
>
class
Col2ImFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
);
};
}
// namespace math
...
...
paddle/operators/math/im2col_test.cc
浏览文件 @
4fc9f55e
...
...
@@ -45,10 +45,14 @@ void testIm2col() {
int
input_height
=
2
;
int
input_width
=
3
;
int
filter_size
=
2
;
int
stride
=
1
;
int
padding
=
0
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
std
::
vector
<
int
>
stride
({
1
,
1
});
// stride_y, stride_x
std
::
vector
<
int
>
padding
(
{
0
,
0
,
0
,
0
});
// up_pad, left_pad, down_pad, right_pad
std
::
vector
<
int
>
dilation
({
1
,
1
});
// dilation_y, dilation_x
int
output_height
=
(
input_height
-
filter_size
+
padding
[
0
]
+
padding
[
1
])
/
stride
[
0
]
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
padding
[
2
]
+
padding
[
3
])
/
stride
[
1
]
+
1
;
float
*
input_ptr
=
input_tmp
.
mutable_data
<
float
>
(
{
1
,
input_height
,
input_width
},
paddle
::
platform
::
CPUPlace
());
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
...
...
@@ -85,10 +89,8 @@ void testIm2col() {
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
im2col
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
im2col_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
im2col
(
*
context
,
input
,
dilation
,
stride
,
padding
,
&
output_cfo
);
im2col_ocf
(
*
context
,
input
,
dilation
,
stride
,
padding
,
&
output_ocf
);
float
out_cfo_data
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
};
float
out_ocf_data
[]
=
{
0
,
1
,
3
,
4
,
1
,
2
,
4
,
5
};
...
...
@@ -131,8 +133,7 @@ void testIm2col() {
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
}
col2im
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
col2im
(
*
context
,
output_cfo
,
dilation
,
stride
,
padding
,
&
input
);
float
*
in_ptr
;
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
...
@@ -153,8 +154,7 @@ void testIm2col() {
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
}
col2im_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
col2im_ocf
(
*
context
,
output_ocf
,
dilation
,
stride
,
padding
,
&
input
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
...
...
paddle/operators/math/vol2col.cc
浏览文件 @
4fc9f55e
...
...
@@ -28,28 +28,51 @@ template <class T>
class
Vol2ColFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
framework
::
Tensor
&
col
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
)
const
{
PADDLE_ENFORCE
(
vol
.
dims
().
size
()
==
4
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
7
);
PADDLE_ENFORCE
(
col
->
dims
().
size
()
==
7
);
int
input_channels
=
vol
.
dims
()[
0
];
int
input_depth
=
vol
.
dims
()[
1
];
int
input_height
=
vol
.
dims
()[
2
];
int
input_width
=
vol
.
dims
()[
3
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
int
output_depth
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
5
];
int
output_width
=
col
.
dims
()[
6
];
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
3
];
int
output_depth
=
col
->
dims
()[
4
];
int
output_height
=
col
->
dims
()[
5
];
int
output_width
=
col
->
dims
()[
6
];
int
channels_col
=
input_channels
*
filter_depth
*
filter_height
*
filter_width
;
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
paddings
[
0
]
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
1
,
output_depth
,
"input_depth and output_depth are "
"mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
paddings
[
1
]
-
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
strides
[
1
]
+
1
,
output_height
,
"input_height and output_height are "
"mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
paddings
[
2
]
-
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
1
,
output_width
,
"input_width and output_width are "
"mismatching."
);
const
T
*
vol_data
=
vol
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
...
...
@@ -57,24 +80,23 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
int
d_offset
=
(
c
/
filter_width
/
filter_height
)
%
filter_depth
;
int
c_in
=
c
/
filter_width
/
filter_height
/
filter_depth
;
for
(
int
d
=
0
;
d
<
output_depth
;
++
d
)
{
int
d_pad
=
d
*
stride
_depth
-
padding_depth
+
d_offset
;
int
d_pad
=
d
*
stride
s
[
0
]
-
paddings
[
0
]
+
d_offset
*
dilations
[
0
]
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
h_pad
=
h
*
stride
_height
-
padding_height
+
h_offset
;
int
h_pad
=
h
*
stride
s
[
1
]
-
paddings
[
1
]
+
h_offset
*
dilations
[
1
]
;
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
w_pad
=
w
*
stride
_width
-
padding_width
+
w_offset
;
int
w_pad
=
w
*
stride
s
[
2
]
-
paddings
[
2
]
+
w_offset
*
dilations
[
2
]
;
int
col_idx
=
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
w
;
if
(
h_pad
<
0
||
h_pad
>=
input_height
||
w_pad
<
0
||
w_pad
>=
input_width
||
d_pad
<
0
||
d_pad
>=
input_depth
)
{
col_data
[
col_idx
]
=
static_cast
<
T
>
(
0
);
}
else
{
int
vol_idx
=
((
c_in
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
input_width
+
w_pad
;
col_data
[
col_idx
]
=
vol_data
[
vol_idx
];
}
int
vol_idx
=
((
c_in
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
input_width
+
w_pad
;
col_data
[
col_idx
]
=
(
h_pad
<
0
||
h_pad
>=
input_height
||
w_pad
<
0
||
w_pad
>=
input_width
||
d_pad
<
0
||
d_pad
>=
input_depth
)
?
static_cast
<
T
>
(
0
)
:
vol_data
[
vol_idx
];
}
}
}
...
...
@@ -92,17 +114,18 @@ template <class T>
class
Col2VolFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
col
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
PADDLE_ENFORCE
(
vol
.
dims
().
size
()
==
4
);
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
)
const
{
PADDLE_ENFORCE
(
vol
->
dims
().
size
()
==
4
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
7
);
int
input_channels
=
vol
.
dims
()[
0
];
int
input_depth
=
vol
.
dims
()[
1
];
int
input_height
=
vol
.
dims
()[
2
];
int
input_width
=
vol
.
dims
()[
3
];
int
input_channels
=
vol
->
dims
()[
0
];
int
input_depth
=
vol
->
dims
()[
1
];
int
input_height
=
vol
->
dims
()[
2
];
int
input_width
=
vol
->
dims
()[
3
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
...
...
@@ -112,7 +135,28 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int
channels_col
=
input_channels
*
filter_depth
*
filter_height
*
filter_width
;
T
*
vol_data
=
vol
.
data
<
T
>
();
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
paddings
[
0
]
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
1
,
output_depth
,
"input_depth and output_depth are "
"mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
paddings
[
1
]
-
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
strides
[
1
]
+
1
,
output_height
,
"input_height and output_height are "
"mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
paddings
[
2
]
-
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
1
,
output_width
,
"input_width and output_width are "
"mismatching."
);
T
*
vol_data
=
vol
->
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
...
...
@@ -121,11 +165,11 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int
d_offset
=
(
c
/
filter_width
/
filter_height
)
%
filter_depth
;
int
cIm
=
c
/
filter_width
/
filter_height
/
filter_depth
;
for
(
int
d
=
0
;
d
<
output_depth
;
++
d
)
{
int
d_pad
=
d
*
stride
_depth
-
padding_depth
+
d_offset
;
int
d_pad
=
d
*
stride
s
[
0
]
-
paddings
[
0
]
+
d_offset
*
dilations
[
0
]
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
h_pad
=
h
*
stride
_height
-
padding_height
+
h_offset
;
int
h_pad
=
h
*
stride
s
[
1
]
-
paddings
[
1
]
+
h_offset
*
dilations
[
1
]
;
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
w_pad
=
w
*
stride
_width
-
padding_width
+
w_offset
;
int
w_pad
=
w
*
stride
s
[
2
]
-
paddings
[
2
]
+
w_offset
*
dilations
[
2
]
;
if
(
h_pad
>=
0
&&
h_pad
<
input_height
&&
w_pad
>=
0
&&
w_pad
<
input_width
&&
d_pad
>=
0
&&
d_pad
<
input_depth
)
{
...
...
@@ -133,6 +177,7 @@ class Col2VolFunctor<platform::CPUPlace, T> {
((
cIm
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
input_width
+
w_pad
;
int
col_idx
=
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
w
;
...
...
paddle/operators/math/vol2col.cu
浏览文件 @
4fc9f55e
...
...
@@ -21,11 +21,12 @@ namespace math {
template
<
class
T
>
__global__
void
vol2col
(
int
num_kernels
,
const
T
*
data_vol
,
int
depth
,
int
height
,
int
width
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_col
)
{
int
height
,
int
width
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_col
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_out
=
index
%
output_width
;
...
...
@@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
for
(
int
k
=
0
;
k
<
filter_depth
;
++
k
)
{
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
int
d
=
d_in
+
k
;
int
h
=
h_in
+
i
;
int
w
=
w_in
+
j
;
int
d
=
d_in
+
k
*
dilation_d
;
int
h
=
h_in
+
i
*
dilation_h
;
int
w
=
w_in
+
j
*
dilation_w
;
int
col_idx
=
(
k
*
dilation_d
*
height
+
i
*
dilation_h
)
*
width
+
j
*
dilation_w
;
*
data_col
=
(
d
>=
0
&&
d
<
depth
&&
h
>=
0
&&
h
<
height
&&
w
>=
0
&&
w
<
width
)
?
data_vol
[
(
k
*
height
+
i
)
*
width
+
j
]
?
data_vol
[
col_idx
]
:
0
;
data_col
+=
output_detph
*
output_height
*
output_width
;
}
...
...
@@ -68,23 +71,46 @@ template <class T>
class
Vol2ColFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
framework
::
Tensor
&
col
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
)
const
{
PADDLE_ENFORCE
(
vol
.
dims
().
size
()
==
4
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
7
);
PADDLE_ENFORCE
(
col
->
dims
().
size
()
==
7
);
int
input_channels
=
vol
.
dims
()[
0
];
int
input_depth
=
vol
.
dims
()[
1
];
int
input_height
=
vol
.
dims
()[
2
];
int
input_width
=
vol
.
dims
()[
3
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
int
output_depth
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
5
];
int
output_width
=
col
.
dims
()[
6
];
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
3
];
int
output_depth
=
col
->
dims
()[
4
];
int
output_height
=
col
->
dims
()[
5
];
int
output_width
=
col
->
dims
()[
6
];
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
paddings
[
0
]
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
1
,
output_depth
,
"input_depth and output_depth are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
paddings
[
1
]
-
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
strides
[
1
]
+
1
,
output_height
,
"input_height and output_height are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
paddings
[
2
]
-
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
1
,
output_width
,
"input_width and output_width are "
"Mismatching."
);
int
num_outputs
=
input_channels
*
output_depth
*
output_height
*
output_width
;
...
...
@@ -95,19 +121,25 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
num_outputs
,
vol
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
filter_depth
,
filter_height
,
filter_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_depth
,
output_height
,
output_width
,
col
.
data
<
T
>
());
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
],
output_depth
,
output_height
,
output_width
,
col
->
data
<
T
>
());
}
};
template
<
class
T
>
__global__
void
col2vol
(
int
num_kernels
,
const
T
*
data_col
,
int
depth
,
int
height
,
int
width
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_vol
)
{
int
height
,
int
width
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_vol
)
{
const
int
d_filter_depth
=
dilation_d
*
(
filter_depth
-
1
)
+
1
;
const
int
d_filter_height
=
dilation_h
*
(
filter_height
-
1
)
+
1
;
const
int
d_filter_width
=
dilation_w
*
(
filter_width
-
1
)
+
1
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
src_val
=
0
;
...
...
@@ -115,35 +147,41 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int
h
=
(
index
/
width
)
%
height
+
padding_height
;
int
d
=
(
index
/
width
/
height
)
%
depth
+
padding_depth
;
int
c
=
index
/
width
/
height
/
depth
;
// compute the start and end of the output
int
w_col_start
=
(
w
<
filter_width
)
?
0
:
(
w
-
filter_width
)
/
stride_width
+
1
;
(
w
<
d_filter_width
)
?
0
:
(
w
-
d_
filter_width
)
/
stride_width
+
1
;
int
w_col_end
=
min
(
w
/
stride_width
+
1
,
output_width
);
int
h_col_start
=
(
h
<
filter_height
)
?
0
:
(
h
-
filter_height
)
/
stride_height
+
1
;
(
h
<
d_filter_height
)
?
0
:
(
h
-
d_
filter_height
)
/
stride_height
+
1
;
int
h_col_end
=
min
(
h
/
stride_height
+
1
,
output_height
);
int
d_col_start
=
(
d
<
filter_depth
)
?
0
:
(
d
-
filter_depth
)
/
stride_depth
+
1
;
(
d
<
d_filter_depth
)
?
0
:
(
d
-
d_
filter_depth
)
/
stride_depth
+
1
;
int
d_col_end
=
min
(
d
/
stride_depth
+
1
,
output_detph
);
int
offset
=
(
c
*
filter_depth
*
filter_height
*
filter_width
+
d
*
filter_width
*
filter_height
+
h
*
filter_width
+
w
)
*
output_detph
*
output_height
*
output_width
;
int
coeff_d_col
=
(
1
-
stride_depth
*
filter_width
*
filter_height
*
output_detph
)
*
output_height
*
output_width
;
int
coeff_h_col
=
(
1
-
stride_height
*
filter_width
*
output_detph
*
output_height
)
*
output_width
;
int
coeff_w_col
=
(
1
-
stride_width
*
output_detph
*
output_height
*
output_width
);
for
(
int
d_col
=
d_col_start
;
d_col
<
d_col_end
;
++
d_col
)
{
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
src_val
+=
data_col
[
offset
+
d_col
*
coeff_d_col
+
h_col
*
coeff_h_col
+
w_col
*
coeff_w_col
];
int
d_off
=
(
d
-
d_col
*
stride_depth
);
int
h_off
=
(
h
-
h_col
*
stride_height
);
int
w_off
=
(
w
-
w_col
*
stride_width
);
if
(
d_off
%
dilation_d
==
0
&&
h_off
%
dilation_h
==
0
&&
w_off
%
dilation_w
==
0
)
{
d_off
/=
dilation_d
;
h_off
/=
dilation_h
;
w_off
/=
dilation_w
;
int
data_col_index
=
(((((
c
*
filter_depth
+
d_off
)
*
filter_height
+
h_off
)
*
filter_width
+
w_off
)));
data_col_index
=
((
data_col_index
*
output_detph
+
d_col
)
*
output_height
+
h_col
)
*
output_width
+
w_col
;
src_val
+=
data_col
[
data_col_index
];
}
}
}
}
...
...
@@ -161,17 +199,18 @@ template <class T>
class
Col2VolFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
col
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
PADDLE_ENFORCE
(
vol
.
dims
().
size
()
==
4
);
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
)
const
{
PADDLE_ENFORCE
(
vol
->
dims
().
size
()
==
4
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
7
);
int
input_channels
=
vol
.
dims
()[
0
];
int
input_depth
=
vol
.
dims
()[
1
];
int
input_height
=
vol
.
dims
()[
2
];
int
input_width
=
vol
.
dims
()[
3
];
int
input_channels
=
vol
->
dims
()[
0
];
int
input_depth
=
vol
->
dims
()[
1
];
int
input_height
=
vol
->
dims
()[
2
];
int
input_width
=
vol
->
dims
()[
3
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
...
...
@@ -179,6 +218,28 @@ class Col2VolFunctor<platform::GPUPlace, T> {
int
output_height
=
col
.
dims
()[
5
];
int
output_width
=
col
.
dims
()[
6
];
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
paddings
[
0
]
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
1
,
output_depth
,
"input_depth and output_depth are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
paddings
[
1
]
-
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
strides
[
1
]
+
1
,
output_height
,
"input_height and output_height are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
paddings
[
2
]
-
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
1
,
output_width
,
"input_width and output_width are "
"Mismatching."
);
int
num_kernels
=
input_channels
*
input_depth
*
input_height
*
input_width
;
const
int
threads
=
1024
;
...
...
@@ -188,9 +249,10 @@ class Col2VolFunctor<platform::GPUPlace, T> {
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
num_kernels
,
col
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
filter_depth
,
filter_height
,
filter_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_depth
,
output_height
,
output_width
,
vol
.
data
<
T
>
());
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
],
output_depth
,
output_height
,
output_width
,
vol
->
data
<
T
>
());
}
};
...
...
paddle/operators/math/vol2col.h
浏览文件 @
4fc9f55e
...
...
@@ -31,6 +31,15 @@ namespace math {
* \param colData Column data.
* \param colShape The shape of colData.
*
* \param dilations dilation data.
* \param 3-dimension [dilation_depth, dilation_height, dilation_width].
*
* \param strides stride data.
* \param 3-dimension [stride_depth, stride_height, stride_width].
*
* \param paddings padding data.
* \param 3-dimension [d_pad, h_pad, w_pad].
*
* The shape of colData is:
* [input_channels, filter_depth, filter_height, filter_width, output_depth,
* output_height, output_width]
...
...
@@ -57,20 +66,22 @@ template <typename Place, typename T>
class
Vol2ColFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
framework
::
Tensor
&
col
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
;
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
)
const
;
};
template
<
typename
Place
,
typename
T
>
class
Col2VolFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
col
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
;
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
)
const
;
};
}
// namespace math
...
...
paddle/operators/math/vol2col_test.cc
浏览文件 @
4fc9f55e
...
...
@@ -62,11 +62,15 @@ void testVol2col() {
int
input_height
=
2
;
int
input_width
=
3
;
int
filter_size
=
2
;
int
stride
=
1
;
int
padding
=
0
;
int
output_depth
=
(
input_depth
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
std
::
vector
<
int
>
strides
({
1
,
1
,
1
});
std
::
vector
<
int
>
paddings
({
0
,
0
,
0
});
std
::
vector
<
int
>
dilations
({
1
,
1
,
1
});
int
output_depth
=
(
input_depth
-
filter_size
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
paddings
[
2
])
/
strides
[
2
]
+
1
;
// Vol2Col test
float
*
input_ptr
=
...
...
@@ -85,8 +89,7 @@ void testVol2col() {
*
place
);
paddle
::
operators
::
math
::
Vol2ColFunctor
<
Place
,
float
>
vol2col
;
vol2col
(
*
context
,
input
,
output
,
stride
,
stride
,
stride
,
padding
,
padding
,
padding
);
vol2col
(
*
context
,
input
,
dilations
,
strides
,
paddings
,
&
output
);
float
vol_2_col
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
10
,
10
,
11
};
float
*
out_cfo_ptr
;
...
...
@@ -111,8 +114,7 @@ void testVol2col() {
}
paddle
::
operators
::
math
::
Col2VolFunctor
<
Place
,
float
>
col2vol
;
col2vol
(
*
context
,
input
,
output
,
stride
,
stride
,
stride
,
padding
,
padding
,
padding
);
col2vol
(
*
context
,
output
,
dilations
,
strides
,
paddings
,
&
input
);
float
*
in_ptr
;
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
...
paddle/operators/sequence_conv_op.h
浏览文件 @
4fc9f55e
...
...
@@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel<T> {
math
::
ContextProjectFunctor
<
Place
,
T
>
seq_project_functor
;
seq_project_functor
(
context
.
device_context
(),
*
in
,
*
padding_data
,
col
,
seq_project_functor
(
context
.
device_context
(),
*
in
,
*
padding_data
,
padding_trainable
,
context_start
,
context_length
,
context_stride
,
up_pad
,
down_pad
);
context_stride
,
up_pad
,
down_pad
,
&
col
);
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
col
,
false
,
filter
,
false
,
static_cast
<
T
>
(
1.0
),
out
,
static_cast
<
T
>
(
0.0
));
...
...
@@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
in_g
->
set_lod
(
in
->
lod
());
set_zero
(
context
.
device_context
(),
in_g
,
static_cast
<
T
>
(
0
));
seq_project_grad_functor
(
context
.
device_context
(),
*
in_g
,
*
padding_data_g
,
col
,
padding_trainable
,
context_start
,
context_
length
,
context_stride
,
up_pad
,
down_pad
,
true
,
false
);
seq_project_grad_functor
(
context
.
device_context
(),
*
in_g
,
padding_trainable
,
context_start
,
context_length
,
context_
stride
,
up_pad
,
down_pad
,
false
,
true
,
padding_data_g
,
&
col
);
}
if
(
padding_trainable
&&
padding_data_g
)
{
...
...
@@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
LoDTensor
*
input
=
const_cast
<
LoDTensor
*>
(
in
);
seq_project_grad_functor
(
context
.
device_context
(),
*
input
,
*
padding_data_g
,
col
,
padding_trainable
,
context_st
art
,
context_length
,
context_strid
e
,
up_pad
,
down_pad
,
false
,
true
);
padding_trainable
,
context_start
,
context_length
,
context_st
ride
,
up_pad
,
down_pad
,
true
,
fals
e
,
padding_data_g
,
&
col
);
}
if
(
filter_g
)
{
...
...
@@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
padding_data
=
context
.
Input
<
Tensor
>
(
"PaddingData"
);
}
seq_project_functor
(
context
.
device_context
(),
*
in
,
*
padding_data
,
col
,
seq_project_functor
(
context
.
device_context
(),
*
in
,
*
padding_data
,
padding_trainable
,
context_start
,
context_length
,
context_stride
,
up_pad
,
down_pad
);
context_stride
,
up_pad
,
down_pad
,
&
col
);
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
col
,
true
,
out_grad
,
false
,
T
(
1.0
),
&
filter_grad
,
T
(
1.0
));
...
...
python/paddle/v2/fluid/tests/test_conv2d_op.py
浏览文件 @
4fc9f55e
...
...
@@ -10,23 +10,33 @@ def conv2d_forward_naive(input, filter, group, conv_param):
assert
np
.
mod
(
out_c
,
group
)
==
0
sub_out_c
=
out_c
/
group
stride
,
pad
=
conv_param
[
'stride'
],
conv_param
[
'pad'
]
out_h
=
1
+
(
in_h
+
2
*
pad
[
0
]
-
f_h
)
/
stride
[
0
]
out_w
=
1
+
(
in_w
+
2
*
pad
[
1
]
-
f_w
)
/
stride
[
1
]
stride
,
pad
,
dilation
=
conv_param
[
'stride'
],
conv_param
[
'pad'
],
conv_param
[
'dilation'
]
out_h
=
1
+
(
in_h
+
2
*
pad
[
0
]
-
(
dilation
[
0
]
*
(
f_h
-
1
)
+
1
))
/
stride
[
0
]
out_w
=
1
+
(
in_w
+
2
*
pad
[
1
]
-
(
dilation
[
1
]
*
(
f_w
-
1
)
+
1
))
/
stride
[
1
]
out
=
np
.
zeros
((
in_n
,
out_c
,
out_h
,
out_w
))
d_bolck_w
=
(
dilation
[
0
]
*
(
f_h
-
1
)
+
1
)
d_bolck_h
=
(
dilation
[
1
]
*
(
f_w
-
1
)
+
1
)
input_pad
=
np
.
pad
(
input
,
((
0
,
),
(
0
,
),
(
pad
[
0
],
),
(
pad
[
1
],
)),
mode
=
'constant'
,
constant_values
=
0
)
filter_dilation
=
np
.
zeros
((
out_c
,
f_c
,
d_bolck_h
,
d_bolck_w
))
filter_dilation
[:,
:,
0
:
d_bolck_h
:
dilation
[
0
],
0
:
d_bolck_w
:
dilation
[
1
]]
=
filter
for
i
in
range
(
out_h
):
for
j
in
range
(
out_w
):
for
g
in
range
(
group
):
input_pad_masked
=
\
input_pad
[:,
g
*
f_c
:(
g
+
1
)
*
f_c
,
i
*
stride
[
0
]:
i
*
stride
[
0
]
+
f
_h
,
j
*
stride
[
1
]:
j
*
stride
[
1
]
+
f
_w
]
i
*
stride
[
0
]:
i
*
stride
[
0
]
+
d_bolck
_h
,
j
*
stride
[
1
]:
j
*
stride
[
1
]
+
d_bolck
_w
]
f_sub
=
filter
[
g
*
sub_out_c
:(
g
+
1
)
*
sub_out_c
,
:,
:,
:]
f_sub
=
filter_dilation
[
g
*
sub_out_c
:(
g
+
1
)
*
sub_out_c
,
:,
:,
:]
for
k
in
range
(
sub_out_c
):
out
[:,
g
*
sub_out_c
+
k
,
i
,
j
]
=
\
np
.
sum
(
input_pad_masked
*
f_sub
[
k
,
:,
:,
:],
...
...
@@ -39,9 +49,14 @@ class TestConv2dOp(OpTest):
def
setUp
(
self
):
self
.
init_op_type
()
self
.
init_group
()
self
.
init_dilation
()
self
.
init_test_case
()
conv2d_param
=
{
'stride'
:
self
.
stride
,
'pad'
:
self
.
pad
}
conv2d_param
=
{
'stride'
:
self
.
stride
,
'pad'
:
self
.
pad
,
'dilation'
:
self
.
dilations
}
input
=
np
.
random
.
random
(
self
.
input_size
).
astype
(
"float32"
)
filter
=
np
.
random
.
random
(
self
.
filter_size
).
astype
(
"float32"
)
output
=
conv2d_forward_naive
(
input
,
filter
,
self
.
groups
,
...
...
@@ -80,12 +95,14 @@ class TestConv2dOp(OpTest):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
]
self
.
stride
=
[
1
,
1
]
self
.
dilations
=
[
1
,
1
]
self
.
input_size
=
[
2
,
3
,
5
,
5
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
]
def
init_dilation
(
self
):
self
.
dilations
=
[
1
,
1
]
def
init_group
(
self
):
self
.
groups
=
1
...
...
@@ -101,24 +118,66 @@ class TestWithGroup(TestConv2dOp):
self
.
op_type
=
"conv2d"
#----------------Conv2dCudnn----------------
class
TestWith1x1
(
TestConv2dOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
]
self
.
stride
=
[
1
,
1
]
self
.
input_size
=
[
2
,
3
,
5
,
5
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
1
,
1
]
def
init_dilation
(
self
):
self
.
dilations
=
[
1
,
1
]
class
TestCudnn
(
TestConv2dOp
):
def
init_group
(
self
):
self
.
groups
=
1
self
.
groups
=
3
def
init_op_type
(
self
):
self
.
op_type
=
"conv
_cudnn
"
self
.
op_type
=
"conv
2d
"
class
TestCudnnWithGroup
(
TestConv2dOp
):
class
TestWithDilation
(
TestConv2dOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
]
self
.
stride
=
[
1
,
1
]
self
.
input_size
=
[
2
,
3
,
10
,
10
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
]
def
init_dilation
(
self
):
self
.
dilations
=
[
2
,
2
]
def
init_group
(
self
):
self
.
groups
=
3
def
init_op_type
(
self
):
self
.
op_type
=
"conv2d"
#----------------Conv2dCudnn----------------
class
TestCudnn
(
TestConv2dOp
):
def
init_op_type
(
self
):
self
.
op_type
=
"conv_cudnn"
class
TestCudnnWithGroup
(
TestWithGroup
):
def
init_op_type
(
self
):
self
.
op_type
=
"conv_cudnn"
class
TestCudnnWith1x1
(
TestWith1x1
):
def
init_op_type
(
self
):
self
.
op_type
=
"conv_cudnn"
# cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/fluid/tests/test_conv3d_op.py
浏览文件 @
4fc9f55e
...
...
@@ -10,27 +10,40 @@ def conv3d_forward_naive(input, filter, group, conv_param):
assert
np
.
mod
(
out_c
,
group
)
==
0
sub_out_c
=
out_c
/
group
stride
,
pad
=
conv_param
[
'stride'
],
conv_param
[
'pad'
]
out_d
=
1
+
(
in_d
+
2
*
pad
[
0
]
-
f_h
)
/
stride
[
0
]
out_h
=
1
+
(
in_h
+
2
*
pad
[
1
]
-
f_h
)
/
stride
[
1
]
out_w
=
1
+
(
in_w
+
2
*
pad
[
2
]
-
f_w
)
/
stride
[
2
]
stride
,
pad
,
dilation
=
conv_param
[
'stride'
],
conv_param
[
'pad'
],
conv_param
[
'dilations'
]
out_d
=
1
+
(
in_d
+
2
*
pad
[
0
]
-
(
dilation
[
0
]
*
(
f_d
-
1
)
+
1
))
/
stride
[
0
]
out_h
=
1
+
(
in_h
+
2
*
pad
[
1
]
-
(
dilation
[
1
]
*
(
f_h
-
1
)
+
1
))
/
stride
[
1
]
out_w
=
1
+
(
in_w
+
2
*
pad
[
2
]
-
(
dilation
[
2
]
*
(
f_w
-
1
)
+
1
))
/
stride
[
2
]
out
=
np
.
zeros
((
in_n
,
out_c
,
out_d
,
out_h
,
out_w
))
d_bolck_d
=
(
dilation
[
0
]
*
(
f_d
-
1
)
+
1
)
d_bolck_h
=
(
dilation
[
1
]
*
(
f_h
-
1
)
+
1
)
d_bolck_w
=
(
dilation
[
2
]
*
(
f_w
-
1
)
+
1
)
input_pad
=
np
.
pad
(
input
,
((
0
,
),
(
0
,
),
(
pad
[
0
],
),
(
pad
[
1
],
),
(
pad
[
2
],
)),
mode
=
'constant'
,
constant_values
=
0
)
filter_dilation
=
np
.
zeros
((
out_c
,
f_c
,
d_bolck_d
,
d_bolck_h
,
d_bolck_w
))
filter_dilation
[:,
:,
0
:
d_bolck_d
:
dilation
[
0
],
0
:
d_bolck_h
:
dilation
[
1
],
0
:
d_bolck_w
:
dilation
[
2
]]
=
filter
for
d
in
range
(
out_d
):
for
i
in
range
(
out_h
):
for
j
in
range
(
out_w
):
for
g
in
range
(
group
):
input_pad_masked
=
\
input_pad
[:,
g
*
f_c
:(
g
+
1
)
*
f_c
,
d
*
stride
[
0
]:
d
*
stride
[
0
]
+
f_d
,
i
*
stride
[
1
]:
i
*
stride
[
1
]
+
f_h
,
j
*
stride
[
2
]:
j
*
stride
[
2
]
+
f_w
]
f_sub
=
filter
[
g
*
sub_out_c
:(
g
+
1
)
*
sub_out_c
,
:,
:,
:,
:]
d
*
stride
[
0
]:
d
*
stride
[
0
]
+
d_bolck_d
,
i
*
stride
[
1
]:
i
*
stride
[
1
]
+
d_bolck_h
,
j
*
stride
[
2
]:
j
*
stride
[
2
]
+
d_bolck_w
]
f_sub
=
filter_dilation
[
g
*
sub_out_c
:(
g
+
1
)
*
sub_out_c
,
:,
:,
:,
:]
for
k
in
range
(
sub_out_c
):
out
[:,
g
*
sub_out_c
+
k
,
d
,
i
,
j
]
=
\
np
.
sum
(
input_pad_masked
*
f_sub
[
k
,
:,
:,
:,
:],
...
...
@@ -43,9 +56,14 @@ class TestConv3dOp(OpTest):
def
setUp
(
self
):
self
.
init_group
()
self
.
init_op_type
()
self
.
init_dilation
()
self
.
init_test_case
()
conv3d_param
=
{
'stride'
:
self
.
stride
,
'pad'
:
self
.
pad
}
conv3d_param
=
{
'stride'
:
self
.
stride
,
'pad'
:
self
.
pad
,
'dilations'
:
self
.
dilations
}
input
=
np
.
random
.
random
(
self
.
input_size
).
astype
(
"float32"
)
filter
=
np
.
random
.
random
(
self
.
filter_size
).
astype
(
"float32"
)
output
=
conv3d_forward_naive
(
input
,
filter
,
self
.
groups
,
...
...
@@ -55,7 +73,8 @@ class TestConv3dOp(OpTest):
self
.
attrs
=
{
'strides'
:
self
.
stride
,
'paddings'
:
self
.
pad
,
'groups'
:
self
.
groups
'groups'
:
self
.
groups
,
'dilations'
:
self
.
dilations
}
self
.
outputs
=
{
'Output'
:
output
}
...
...
@@ -88,6 +107,9 @@ class TestConv3dOp(OpTest):
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
,
3
]
def
init_dilation
(
self
):
self
.
dilations
=
[
1
,
1
,
1
]
def
init_group
(
self
):
self
.
groups
=
1
...
...
@@ -104,27 +126,47 @@ class TestCase1(TestConv3dOp):
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
,
3
]
def
init_group
(
self
):
self
.
groups
=
1
def
init_op_type
(
self
):
self
.
op_type
=
"conv3d"
class
TestWithGroup1
(
TestConv3dOp
):
def
init_group
(
self
):
self
.
groups
=
3
class
TestWithGroup
1
(
TestConv3dOp
):
class
TestWithGroup
2
(
TestCase1
):
def
init_group
(
self
):
self
.
groups
=
3
def
init_op_type
(
self
):
self
.
op_type
=
"conv3d"
class
TestWith1x1
(
TestConv3dOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
,
0
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
input_size
=
[
2
,
3
,
4
,
4
,
4
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
1
,
1
,
1
]
def
init_dilation
(
self
):
self
.
dilations
=
[
1
,
1
,
1
]
class
TestWithGroup2
(
TestCase1
):
def
init_group
(
self
):
self
.
groups
=
3
def
init_op_type
(
self
):
self
.
op_type
=
"conv3d"
class
TestWithDilation
(
TestConv3dOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
,
0
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
input_size
=
[
2
,
3
,
6
,
6
,
6
]
# NCDHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
2
,
2
,
2
]
def
init_dilation
(
self
):
self
.
dilations
=
[
2
,
2
,
2
]
def
init_group
(
self
):
self
.
groups
=
3
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录