Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
97e9dd72
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97e9dd72
编写于
11月 08, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dilation for im2col
上级
91b72482
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
395 addition
and
351 deletion
+395
-351
paddle/operators/conv_cudnn_op.cc
paddle/operators/conv_cudnn_op.cc
+0
-2
paddle/operators/conv_op.cc
paddle/operators/conv_op.cc
+12
-1
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+17
-12
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+11
-5
paddle/operators/math/context_project.h
paddle/operators/math/context_project.h
+8
-2
paddle/operators/math/im2col.cc
paddle/operators/math/im2col.cc
+140
-141
paddle/operators/math/im2col.cu
paddle/operators/math/im2col.cu
+191
-175
paddle/operators/math/im2col.h
paddle/operators/math/im2col.h
+6
-5
paddle/operators/math/im2col_test.cc
paddle/operators/math/im2col_test.cc
+10
-8
未找到文件。
paddle/operators/conv_cudnn_op.cc
浏览文件 @
97e9dd72
...
@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
...
@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker
(
framework
::
OpProto
*
proto
,
CudnnConvOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
Conv2DOpMaker
(
proto
,
op_checker
)
{
:
Conv2DOpMaker
(
proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"dilations of convolution operator."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
});
AddAttr
<
int
>
(
"workspace_size_MB"
,
AddAttr
<
int
>
(
"workspace_size_MB"
,
"workspace size for cudnn, in MB, "
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"workspace is a section of GPU memory which will be "
...
...
paddle/operators/conv_op.cc
浏览文件 @
97e9dd72
...
@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
int
input_channels
=
in_dims
[
1
];
int
input_channels
=
in_dims
[
1
];
int
output_channels
=
filter_dims
[
0
];
int
output_channels
=
filter_dims
[
0
];
...
@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
for
(
size_t
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
output_shape
.
push_back
(
OutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
output_shape
.
push_back
(
OutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
paddings
[
i
],
strides
[
i
]));
dilations
[
i
],
paddings
[
i
],
paddings
[
i
],
strides
[
i
]));
}
}
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
}
}
...
@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
...
@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels."
)
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector default:{1, 1}), the dilations of "
"convolution operator."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Convolution Operator.
Convolution Operator.
...
@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
...
@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels."
)
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector default:{1, 1, 1}), the dilations of "
"convolution operator. Currently, conv3d doesn't "
"support dilation."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
,
1
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Convolution3D Operator.
Convolution3D Operator.
...
...
paddle/operators/conv_op.h
浏览文件 @
97e9dd72
...
@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
...
@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
// like operators to reuse the implementation.
inline
int
OutputSize
(
int
input_size
,
int
filter_size
,
int
padding
,
inline
int
OutputSize
(
int
input_size
,
int
filter_size
,
int
dilation
,
int
stride
)
{
int
padding_up
,
int
padding_down
,
int
stride
)
{
int
output_size
=
(
input_size
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_size
=
(
input_size
+
padding_up
+
padding_down
-
(
dilation
*
(
filter_size
-
1
)
+
1
))
/
stride
+
1
;
return
output_size
;
return
output_size
;
}
}
...
@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
...
@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
...
@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
2
)
{
// im2col
// im2col
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
im2col
(
context
.
device_context
(),
in_slice
,
col
,
dilation
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilations
[
1
],
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
1
]
,
paddings
[
1
]
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col
// vol2col
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
...
@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
2
)
{
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
col2im
(
context
.
device_context
(),
in_grad_slice
,
col
,
stride
s
[
0
],
col2im
(
context
.
device_context
(),
in_grad_slice
,
col
,
dilation
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilations
[
1
],
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
...
@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
2
)
{
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
im2col
(
context
.
device_context
(),
in_slice
,
col
,
dilation
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilations
[
1
],
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
97e9dd72
...
@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// TODO(Zhuoyuan): Paddings can be added in future.
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
// groups will alway be disabled in conv2dtranspose.
int
dilation_h
=
1
;
int
dilation_w
=
1
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {h, w} or {d, h, w}
// input_shape_vec: {h, w} or {d, h, w}
...
@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
col2im
(
context
.
device_context
(),
output_batch
,
col
,
strides
[
0
]
,
col2im
(
context
.
device_context
(),
output_batch
,
col
,
dilation_h
,
strides
[
1
],
0
,
0
,
0
,
0
);
dilation_w
,
strides
[
0
],
strides
[
1
],
0
,
0
,
0
,
0
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// col2vol: col_matrix -> dy
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
...
@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose.
// Actually, no paddings and groups allowed in conv transpose.
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
dilation_h
=
1
;
int
dilation_w
=
1
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {h, w} or {d, h, w}
// input_shape_vec: {h, w} or {d, h, w}
...
@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy -> col matrix
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
strides
[
0
]
,
im2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
dilation_h
,
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilation_w
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
1
]
,
paddings
[
1
]
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col: dy -> col_matrix
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
...
...
paddle/operators/math/context_project.h
浏览文件 @
97e9dd72
...
@@ -95,6 +95,9 @@ class ContextProjectFunctor {
...
@@ -95,6 +95,9 @@ class ContextProjectFunctor {
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
int
input_row_begin
,
input_row_end
;
int
input_row_begin
,
input_row_end
;
int
sequence_height
,
sequence_width
;
int
sequence_height
,
sequence_width
;
sequence_width
=
in
.
dims
()[
1
];
sequence_width
=
in
.
dims
()[
1
];
...
@@ -124,7 +127,7 @@ class ContextProjectFunctor {
...
@@ -124,7 +127,7 @@ class ContextProjectFunctor {
sequence_width
});
// input_channels, input_height, input_width
sequence_width
});
// input_channels, input_height, input_width
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
im2col_ocf
(
context
,
in_t
,
out_t
,
im2col_ocf
(
context
,
in_t
,
out_t
,
dilation_h
,
dilation_w
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
down_pad
,
0
,
0
);
down_pad
,
0
,
0
);
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
...
@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
...
@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
col2im_ocf
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
col2im_ocf
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
int
input_row_begin
,
input_row_end
;
int
input_row_begin
,
input_row_end
;
int
sequence_height
,
sequence_width
;
int
sequence_height
,
sequence_width
;
sequence_width
=
in
.
dims
()[
1
];
sequence_width
=
in
.
dims
()[
1
];
...
@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
...
@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
sequence_width
});
// input_channels, input_height, input_width
sequence_width
});
// input_channels, input_height, input_width
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
col2im_ocf
(
context
,
in_t
,
out_t
,
col2im_ocf
(
context
,
in_t
,
out_t
,
dilation_h
,
dilation_w
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
down_pad
,
0
,
0
);
up_pad
,
down_pad
,
0
,
0
);
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
...
...
paddle/operators/math/im2col.cc
浏览文件 @
97e9dd72
...
@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
dilation_h
,
int
dilation_w
,
int
stride_height
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
2
];
int
output
_height
=
col
.
dims
()[
3
];
int
col
_height
=
col
.
dims
()[
3
];
int
output
_width
=
col
.
dims
()[
4
];
int
col
_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_height
+
padding_up
+
padding_down
-
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
((
dilation_h
*
(
filter_height
-
1
)
+
1
))
)
/
stride_height
+
stride_height
+
1
,
1
,
output
_height
,
col
_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_width
+
padding_left
+
padding_right
-
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
((
dilation_w
*
(
filter_width
-
1
)
+
1
))
)
/
stride_width
+
stride_width
+
1
,
1
,
output
_width
,
col
_width
,
"output
_width and padding(padding_left, padding_right) are "
"col
_width and padding(padding_left, padding_right) are "
"inconsistent."
);
"inconsistent."
);
int
channels_col
=
i
nput
_channels
*
filter_height
*
filter_width
;
int
channels_col
=
i
m
_channels
*
filter_height
*
filter_width
;
const
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
...
@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int
w_offset
=
c
%
filter_width
;
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
col_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
col_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
-
padding_up
;
int
im_row_idx
=
int
im_col_idx
=
w
*
stride_width
+
w_offset
-
padding_left
;
h
*
stride_height
-
padding_up
+
h_offset
*
dilation_h
;
int
im_col_idx
=
w
*
stride_width
-
padding_left
+
w_offset
*
dilation_w
;
if
(
im_row_idx
<
0
||
im_row_idx
>=
input_height
||
im_col_idx
<
0
||
col_data
[(
c
*
col_height
+
h
)
*
col_width
+
w
]
=
im_col_idx
>=
input_width
)
{
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
T
(
0
);
im_col_idx
>=
im_width
)
}
else
{
?
static_cast
<
T
>
(
0
)
im_row_idx
+=
c_im
*
input_height
;
:
im_data
[(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
im_col_idx
];
im_data
[
im_row_idx
*
input_width
+
im_col_idx
];
}
}
}
}
}
}
}
...
@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform
::
CPUPlace
,
T
>
{
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
const
framework
::
Tensor
&
col
,
int
dilation_h
,
int
dilation_w
,
int
stride_
width
,
int
padding_up
,
int
padding_down
,
int
stride_
height
,
int
stride_width
,
int
padding_up
,
int
padding_left
,
int
padding_right
)
{
int
padding_
down
,
int
padding_
left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
2
];
int
output
_height
=
col
.
dims
()[
3
];
int
col
_height
=
col
.
dims
()[
3
];
int
output
_width
=
col
.
dims
()[
4
];
int
col
_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_height
+
padding_up
+
padding_down
-
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
((
dilation_h
*
(
filter_height
-
1
)
+
1
))
)
/
stride_height
+
stride_height
+
1
,
1
,
output
_height
,
col
_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_width
+
padding_left
+
padding_right
-
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
((
dilation_w
*
(
filter_width
-
1
)
+
1
))
)
/
stride_width
+
stride_width
+
1
,
1
,
output
_width
,
col
_width
,
"output
_width and padding(padding_left, padding_right) are "
"col
_width and padding(padding_left, padding_right) are "
"inconsistent."
);
"inconsistent."
);
int
channels_col
=
i
nput
_channels
*
filter_height
*
filter_width
;
int
channels_col
=
i
m
_channels
*
filter_height
*
filter_width
;
T
*
im_data
=
im
.
data
<
T
>
();
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
...
@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int
w_offset
=
c
%
filter_width
;
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
col_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
col_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
-
padding_up
;
int
im_row_idx
=
int
im_col_idx
=
w
*
stride_width
+
w_offset
-
padding_left
;
h
*
stride_height
-
padding_up
+
h_offset
*
dilation_h
;
int
im_col_idx
=
w
*
stride_width
-
padding_left
+
w_offset
*
dilation_w
;
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
i
nput
_height
&&
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
i
m
_height
&&
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
i
nput
_width
)
{
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
i
m
_width
)
{
im_row_idx
+=
c_im
*
i
nput
_height
;
im_row_idx
+=
c_im
*
i
m
_height
;
im_data
[
im_row_idx
*
i
nput
_width
+
im_col_idx
]
+=
im_data
[
im_row_idx
*
i
m
_width
+
im_col_idx
]
+=
col_data
[(
c
*
output_height
+
h
)
*
output
_width
+
w
];
col_data
[(
c
*
col_height
+
h
)
*
col
_width
+
w
];
}
}
}
}
}
}
...
@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
dilation_h
,
int
dilation_w
,
int
stride_height
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
filter_width
=
col
.
dims
()[
4
];
int
output
_height
=
col
.
dims
()[
0
];
int
col
_height
=
col
.
dims
()[
0
];
int
output
_width
=
col
.
dims
()[
1
];
int
col
_width
=
col
.
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
((
im_height
+
padding_up
+
padding_down
-
filter_height
)
/
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
stride_height
+
1
,
1
,
col_height
,
output_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
((
im_width
+
padding_left
+
padding_right
-
filter_width
)
/
PADDLE_ENFORCE_EQ
(
stride_width
+
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
1
,
stride_width
+
col_width
,
1
,
"col_width and padding(padding_left, padding_right) are "
output_width
,
"inconsistent."
);
"output_width and padding(padding_left, padding_right) are "
"inconsistent."
);
const
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output
_height
;
++
col_row_idx
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
col
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output
_width
;
++
col_col_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
col
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
nput
_channels
;
++
channel
)
{
for
(
int
channel
=
0
;
channel
<
i
m
_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
++
filter_row_idx
)
{
++
filter_row_idx
)
{
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
...
@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
int
im_col_offset
=
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
((((
col_row_idx
)
*
output_width
+
col_col_idx
)
*
int
col_offset
=
input_channels
+
((((
col_row_idx
)
*
col_width
+
col_col_idx
)
*
im_channels
+
channel
)
*
channel
)
*
filter_height
+
filter_height
+
filter_row_idx
)
*
filter_row_idx
)
*
filter_width
+
filter_width
+
filter_col_idx
;
filter_col_idx
;
if
(
im_row_offset
<
0
||
im_row_offset
>=
input_height
||
im_col_offset
<
0
||
im_col_offset
>=
input_width
)
{
int
im_offset
=
(
channel
*
im_height
+
im_row_offset
)
*
im_width
+
col_data
[
col_offset
]
=
T
(
0
);
im_col_offset
;
}
else
{
col_data
[
col_offset
]
=
int
im_offset
=
(
im_row_offset
<
0
||
im_row_offset
>=
im_height
||
(
channel
*
input_height
+
im_row_offset
)
*
input_width
+
im_col_offset
<
0
||
im_col_offset
>=
im_width
)
im_col_offset
;
?
static_cast
<
T
>
(
0
)
col_data
[
col_offset
]
=
im_data
[
im_offset
];
:
im_data
[
im_offset
];
}
}
}
}
}
}
}
...
@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform
::
CPUPlace
,
T
>
{
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
const
framework
::
Tensor
&
col
,
int
dilation_h
,
int
dilation_w
,
int
stride_
width
,
int
padding_up
,
int
padding_down
,
int
stride_
height
,
int
stride_width
,
int
padding_up
,
int
padding_left
,
int
padding_right
)
{
int
padding_
down
,
int
padding_
left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
filter_width
=
col
.
dims
()[
4
];
int
output
_height
=
col
.
dims
()[
0
];
int
col
_height
=
col
.
dims
()[
0
];
int
output
_width
=
col
.
dims
()[
1
];
int
col
_width
=
col
.
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
((
im_height
+
padding_up
+
padding_down
-
filter_height
)
/
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
stride_height
+
1
,
1
,
col_height
,
output_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
((
im_width
+
padding_left
+
padding_right
-
filter_width
)
/
PADDLE_ENFORCE_EQ
(
stride_width
+
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
1
,
stride_width
+
col_width
,
1
,
"col_width and padding(padding_left, padding_right) are "
output_width
,
"inconsistent."
);
"output_width and padding(padding_left, padding_right) are "
"inconsistent."
);
T
*
im_data
=
im
.
data
<
T
>
();
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output
_height
;
++
col_row_idx
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
col
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output
_width
;
++
col_col_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
col
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
nput
_channels
;
++
channel
)
{
for
(
int
channel
=
0
;
channel
<
i
m
_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
++
filter_row_idx
)
{
++
filter_row_idx
)
{
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
...
@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
int
im_col_offset
=
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
(((
col_row_idx
*
output_width
+
col_col_idx
)
*
int
col_offset
=
input
_channels
+
(((
col_row_idx
*
col_width
+
col_col_idx
)
*
im
_channels
+
channel
)
*
channel
)
*
filter_height
+
filter_height
+
filter_row_idx
)
*
filter_row_idx
)
*
filter_width
+
filter_width
+
filter_col_idx
;
filter_col_idx
;
if
(
im_row_offset
>=
0
&&
im_row_offset
<
i
nput
_height
&&
if
(
im_row_offset
>=
0
&&
im_row_offset
<
i
m
_height
&&
im_col_offset
>=
0
&&
im_col_offset
<
i
nput
_width
)
{
im_col_offset
>=
0
&&
im_col_offset
<
i
m
_width
)
{
int
im_offset
=
int
im_offset
=
(
channel
*
i
nput_height
+
im_row_offset
)
*
input
_width
+
(
channel
*
i
m_height
+
im_row_offset
)
*
im
_width
+
im_col_offset
;
im_col_offset
;
im_data
[
im_offset
]
+=
col_data
[
col_offset
];
im_data
[
im_offset
]
+=
col_data
[
col_offset
];
}
}
...
...
paddle/operators/math/im2col.cu
浏览文件 @
97e9dd72
此差异已折叠。
点击以展开。
paddle/operators/math/im2col.h
浏览文件 @
97e9dd72
...
@@ -74,17 +74,18 @@ class Im2ColFunctor {
...
@@ -74,17 +74,18 @@ class Im2ColFunctor {
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
dilation_h
,
int
dilation_w
,
int
stride_height
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
};
};
template
<
ColFormat
Format
,
typename
Place
,
typename
T
>
template
<
ColFormat
Format
,
typename
Place
,
typename
T
>
class
Col2ImFunctor
{
class
Col2ImFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
const
framework
::
Tensor
&
col
,
int
dilation_h
,
int
dilation_w
,
int
stride_
width
,
int
padding_up
,
int
padding_down
,
int
stride_
height
,
int
stride_width
,
int
padding_up
,
int
padding_left
,
int
padding_right
);
int
padding_
down
,
int
padding_
left
,
int
padding_right
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/operators/math/im2col_test.cc
浏览文件 @
97e9dd72
...
@@ -47,6 +47,8 @@ void testIm2col() {
...
@@ -47,6 +47,8 @@ void testIm2col() {
int
filter_size
=
2
;
int
filter_size
=
2
;
int
stride
=
1
;
int
stride
=
1
;
int
padding
=
0
;
int
padding
=
0
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
float
*
input_ptr
=
input_tmp
.
mutable_data
<
float
>
(
float
*
input_ptr
=
input_tmp
.
mutable_data
<
float
>
(
...
@@ -85,10 +87,10 @@ void testIm2col() {
...
@@ -85,10 +87,10 @@ void testIm2col() {
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
Place
,
float
>
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
im2col_ocf
;
im2col
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
im2col
(
*
context
,
input
,
output_cfo
,
dilation_h
,
dilation_w
,
stride
,
stride
,
padding
);
padding
,
padding
,
padding
,
padding
);
im2col_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
im2col_ocf
(
*
context
,
input
,
output_ocf
,
dilation_h
,
dilation_w
,
stride
,
padding
,
padding
);
stride
,
padding
,
padding
,
padding
,
padding
);
float
out_cfo_data
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
};
float
out_cfo_data
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
};
float
out_ocf_data
[]
=
{
0
,
1
,
3
,
4
,
1
,
2
,
4
,
5
};
float
out_ocf_data
[]
=
{
0
,
1
,
3
,
4
,
1
,
2
,
4
,
5
};
...
@@ -131,8 +133,8 @@ void testIm2col() {
...
@@ -131,8 +133,8 @@ void testIm2col() {
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
}
}
col2im
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
col2im
(
*
context
,
input
,
output_cfo
,
dilation_h
,
dilation_w
,
stride
,
stride
,
padding
);
padding
,
padding
,
padding
,
padding
);
float
*
in_ptr
;
float
*
in_ptr
;
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
@@ -153,8 +155,8 @@ void testIm2col() {
...
@@ -153,8 +155,8 @@ void testIm2col() {
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
}
}
col2im_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
col2im_ocf
(
*
context
,
input
,
output_ocf
,
dilation_h
,
dilation_w
,
stride
,
padding
,
padding
);
stride
,
padding
,
padding
,
padding
,
padding
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
in_ptr
=
input
.
data
<
float
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录