Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
97e9dd72
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97e9dd72
编写于
11月 08, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dilation for im2col
上级
91b72482
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
395 addition
and
351 deletion
+395
-351
paddle/operators/conv_cudnn_op.cc
paddle/operators/conv_cudnn_op.cc
+0
-2
paddle/operators/conv_op.cc
paddle/operators/conv_op.cc
+12
-1
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+17
-12
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+11
-5
paddle/operators/math/context_project.h
paddle/operators/math/context_project.h
+8
-2
paddle/operators/math/im2col.cc
paddle/operators/math/im2col.cc
+140
-141
paddle/operators/math/im2col.cu
paddle/operators/math/im2col.cu
+191
-175
paddle/operators/math/im2col.h
paddle/operators/math/im2col.h
+6
-5
paddle/operators/math/im2col_test.cc
paddle/operators/math/im2col_test.cc
+10
-8
未找到文件。
paddle/operators/conv_cudnn_op.cc
浏览文件 @
97e9dd72
...
@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
...
@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker
(
framework
::
OpProto
*
proto
,
CudnnConvOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
Conv2DOpMaker
(
proto
,
op_checker
)
{
:
Conv2DOpMaker
(
proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"dilations of convolution operator."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
});
AddAttr
<
int
>
(
"workspace_size_MB"
,
AddAttr
<
int
>
(
"workspace_size_MB"
,
"workspace size for cudnn, in MB, "
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"workspace is a section of GPU memory which will be "
...
...
paddle/operators/conv_op.cc
浏览文件 @
97e9dd72
...
@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
int
input_channels
=
in_dims
[
1
];
int
input_channels
=
in_dims
[
1
];
int
output_channels
=
filter_dims
[
0
];
int
output_channels
=
filter_dims
[
0
];
...
@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
for
(
size_t
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
output_shape
.
push_back
(
OutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
output_shape
.
push_back
(
OutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
paddings
[
i
],
strides
[
i
]));
dilations
[
i
],
paddings
[
i
],
paddings
[
i
],
strides
[
i
]));
}
}
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
}
}
...
@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
...
@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels."
)
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector default:{1, 1}), the dilations of "
"convolution operator."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Convolution Operator.
Convolution Operator.
...
@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
...
@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels."
)
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector default:{1, 1, 1}), the dilations of "
"convolution operator. Currently, conv3d doesn't "
"support dilation."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
,
1
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Convolution3D Operator.
Convolution3D Operator.
...
...
paddle/operators/conv_op.h
浏览文件 @
97e9dd72
...
@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
...
@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
// like operators to reuse the implementation.
inline
int
OutputSize
(
int
input_size
,
int
filter_size
,
int
padding
,
inline
int
OutputSize
(
int
input_size
,
int
filter_size
,
int
dilation
,
int
stride
)
{
int
padding_up
,
int
padding_down
,
int
stride
)
{
int
output_size
=
(
input_size
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_size
=
(
input_size
+
padding_up
+
padding_down
-
(
dilation
*
(
filter_size
-
1
)
+
1
))
/
stride
+
1
;
return
output_size
;
return
output_size
;
}
}
...
@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
...
@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
...
@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
2
)
{
// im2col
// im2col
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
im2col
(
context
.
device_context
(),
in_slice
,
col
,
dilation
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilations
[
1
],
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
1
]
,
paddings
[
1
]
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col
// vol2col
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
...
@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
int
groups
=
context
.
Attr
<
int
>
(
"groups"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
2
)
{
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
col2im
(
context
.
device_context
(),
in_grad_slice
,
col
,
stride
s
[
0
],
col2im
(
context
.
device_context
(),
in_grad_slice
,
col
,
dilation
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilations
[
1
],
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
...
@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
2
)
{
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
im2col
(
context
.
device_context
(),
in_slice
,
col
,
dilation
s
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilations
[
1
],
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
97e9dd72
...
@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// TODO(Zhuoyuan): Paddings can be added in future.
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
// groups will alway be disabled in conv2dtranspose.
int
dilation_h
=
1
;
int
dilation_w
=
1
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {h, w} or {d, h, w}
// input_shape_vec: {h, w} or {d, h, w}
...
@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
col2im
(
context
.
device_context
(),
output_batch
,
col
,
strides
[
0
]
,
col2im
(
context
.
device_context
(),
output_batch
,
col
,
dilation_h
,
strides
[
1
],
0
,
0
,
0
,
0
);
dilation_w
,
strides
[
0
],
strides
[
1
],
0
,
0
,
0
,
0
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// col2vol: col_matrix -> dy
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
...
@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose.
// Actually, no paddings and groups allowed in conv transpose.
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
dilation_h
=
1
;
int
dilation_w
=
1
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {h, w} or {d, h, w}
// input_shape_vec: {h, w} or {d, h, w}
...
@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy -> col matrix
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
strides
[
0
]
,
im2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
dilation_h
,
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
dilation_w
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
1
]
,
paddings
[
1
]
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col: dy -> col_matrix
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
...
...
paddle/operators/math/context_project.h
浏览文件 @
97e9dd72
...
@@ -95,6 +95,9 @@ class ContextProjectFunctor {
...
@@ -95,6 +95,9 @@ class ContextProjectFunctor {
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
int
input_row_begin
,
input_row_end
;
int
input_row_begin
,
input_row_end
;
int
sequence_height
,
sequence_width
;
int
sequence_height
,
sequence_width
;
sequence_width
=
in
.
dims
()[
1
];
sequence_width
=
in
.
dims
()[
1
];
...
@@ -124,7 +127,7 @@ class ContextProjectFunctor {
...
@@ -124,7 +127,7 @@ class ContextProjectFunctor {
sequence_width
});
// input_channels, input_height, input_width
sequence_width
});
// input_channels, input_height, input_width
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
im2col_ocf
(
context
,
in_t
,
out_t
,
im2col_ocf
(
context
,
in_t
,
out_t
,
dilation_h
,
dilation_w
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
down_pad
,
0
,
0
);
down_pad
,
0
,
0
);
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
...
@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
...
@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
col2im_ocf
;
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kOCF
,
Place
,
float
>
col2im_ocf
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
int
input_row_begin
,
input_row_end
;
int
input_row_begin
,
input_row_end
;
int
sequence_height
,
sequence_width
;
int
sequence_height
,
sequence_width
;
sequence_width
=
in
.
dims
()[
1
];
sequence_width
=
in
.
dims
()[
1
];
...
@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
...
@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
sequence_width
});
// input_channels, input_height, input_width
sequence_width
});
// input_channels, input_height, input_width
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
in_t
.
Resize
(
framework
::
make_ddim
(
input_shape
));
col2im_ocf
(
context
,
in_t
,
out_t
,
col2im_ocf
(
context
,
in_t
,
out_t
,
dilation_h
,
dilation_w
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
/*stride_height*/
context_stride
,
/*stride_width*/
1
,
up_pad
,
down_pad
,
0
,
0
);
up_pad
,
down_pad
,
0
,
0
);
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
...
...
paddle/operators/math/im2col.cc
浏览文件 @
97e9dd72
...
@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
dilation_h
,
int
dilation_w
,
int
stride_height
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
2
];
int
output
_height
=
col
.
dims
()[
3
];
int
col
_height
=
col
.
dims
()[
3
];
int
output
_width
=
col
.
dims
()[
4
];
int
col
_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_height
+
padding_up
+
padding_down
-
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
((
dilation_h
*
(
filter_height
-
1
)
+
1
))
)
/
stride_height
+
stride_height
+
1
,
1
,
output
_height
,
col
_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_width
+
padding_left
+
padding_right
-
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
((
dilation_w
*
(
filter_width
-
1
)
+
1
))
)
/
stride_width
+
stride_width
+
1
,
1
,
output
_width
,
col
_width
,
"output
_width and padding(padding_left, padding_right) are "
"col
_width and padding(padding_left, padding_right) are "
"inconsistent."
);
"inconsistent."
);
int
channels_col
=
i
nput
_channels
*
filter_height
*
filter_width
;
int
channels_col
=
i
m
_channels
*
filter_height
*
filter_width
;
const
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
...
@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int
w_offset
=
c
%
filter_width
;
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
col_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
col_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
-
padding_up
;
int
im_row_idx
=
int
im_col_idx
=
w
*
stride_width
+
w_offset
-
padding_left
;
h
*
stride_height
-
padding_up
+
h_offset
*
dilation_h
;
int
im_col_idx
=
w
*
stride_width
-
padding_left
+
w_offset
*
dilation_w
;
if
(
im_row_idx
<
0
||
im_row_idx
>=
input_height
||
im_col_idx
<
0
||
col_data
[(
c
*
col_height
+
h
)
*
col_width
+
w
]
=
im_col_idx
>=
input_width
)
{
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
T
(
0
);
im_col_idx
>=
im_width
)
}
else
{
?
static_cast
<
T
>
(
0
)
im_row_idx
+=
c_im
*
input_height
;
:
im_data
[(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
im_col_idx
];
im_data
[
im_row_idx
*
input_width
+
im_col_idx
];
}
}
}
}
}
}
}
...
@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform
::
CPUPlace
,
T
>
{
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
const
framework
::
Tensor
&
col
,
int
dilation_h
,
int
dilation_w
,
int
stride_
width
,
int
padding_up
,
int
padding_down
,
int
stride_
height
,
int
stride_width
,
int
padding_up
,
int
padding_left
,
int
padding_right
)
{
int
padding_
down
,
int
padding_
left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
2
];
int
output
_height
=
col
.
dims
()[
3
];
int
col
_height
=
col
.
dims
()[
3
];
int
output
_width
=
col
.
dims
()[
4
];
int
col
_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_height
+
padding_up
+
padding_down
-
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
((
dilation_h
*
(
filter_height
-
1
)
+
1
))
)
/
stride_height
+
stride_height
+
1
,
1
,
output
_height
,
col
_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
im_width
+
padding_left
+
padding_right
-
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
((
dilation_w
*
(
filter_width
-
1
)
+
1
))
)
/
stride_width
+
stride_width
+
1
,
1
,
output
_width
,
col
_width
,
"output
_width and padding(padding_left, padding_right) are "
"col
_width and padding(padding_left, padding_right) are "
"inconsistent."
);
"inconsistent."
);
int
channels_col
=
i
nput
_channels
*
filter_height
*
filter_width
;
int
channels_col
=
i
m
_channels
*
filter_height
*
filter_width
;
T
*
im_data
=
im
.
data
<
T
>
();
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
...
@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int
w_offset
=
c
%
filter_width
;
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
int
c_im
=
c
/
filter_width
/
filter_height
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
col_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
col_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
-
padding_up
;
int
im_row_idx
=
int
im_col_idx
=
w
*
stride_width
+
w_offset
-
padding_left
;
h
*
stride_height
-
padding_up
+
h_offset
*
dilation_h
;
int
im_col_idx
=
w
*
stride_width
-
padding_left
+
w_offset
*
dilation_w
;
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
i
nput
_height
&&
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
i
m
_height
&&
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
i
nput
_width
)
{
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
i
m
_width
)
{
im_row_idx
+=
c_im
*
i
nput
_height
;
im_row_idx
+=
c_im
*
i
m
_height
;
im_data
[
im_row_idx
*
i
nput
_width
+
im_col_idx
]
+=
im_data
[
im_row_idx
*
i
m
_width
+
im_col_idx
]
+=
col_data
[(
c
*
output_height
+
h
)
*
output
_width
+
w
];
col_data
[(
c
*
col_height
+
h
)
*
col
_width
+
w
];
}
}
}
}
}
}
...
@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
dilation_h
,
int
dilation_w
,
int
stride_height
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
filter_width
=
col
.
dims
()[
4
];
int
output
_height
=
col
.
dims
()[
0
];
int
col
_height
=
col
.
dims
()[
0
];
int
output
_width
=
col
.
dims
()[
1
];
int
col
_width
=
col
.
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
((
im_height
+
padding_up
+
padding_down
-
filter_height
)
/
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
stride_height
+
1
,
1
,
col_height
,
output_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
((
im_width
+
padding_left
+
padding_right
-
filter_width
)
/
PADDLE_ENFORCE_EQ
(
stride_width
+
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
1
,
stride_width
+
col_width
,
1
,
"col_width and padding(padding_left, padding_right) are "
output_width
,
"inconsistent."
);
"output_width and padding(padding_left, padding_right) are "
"inconsistent."
);
const
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output
_height
;
++
col_row_idx
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
col
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output
_width
;
++
col_col_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
col
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
nput
_channels
;
++
channel
)
{
for
(
int
channel
=
0
;
channel
<
i
m
_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
++
filter_row_idx
)
{
++
filter_row_idx
)
{
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
...
@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
int
im_col_offset
=
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
((((
col_row_idx
)
*
output_width
+
col_col_idx
)
*
int
col_offset
=
input_channels
+
((((
col_row_idx
)
*
col_width
+
col_col_idx
)
*
im_channels
+
channel
)
*
channel
)
*
filter_height
+
filter_height
+
filter_row_idx
)
*
filter_row_idx
)
*
filter_width
+
filter_width
+
filter_col_idx
;
filter_col_idx
;
if
(
im_row_offset
<
0
||
im_row_offset
>=
input_height
||
im_col_offset
<
0
||
im_col_offset
>=
input_width
)
{
int
im_offset
=
(
channel
*
im_height
+
im_row_offset
)
*
im_width
+
col_data
[
col_offset
]
=
T
(
0
);
im_col_offset
;
}
else
{
col_data
[
col_offset
]
=
int
im_offset
=
(
im_row_offset
<
0
||
im_row_offset
>=
im_height
||
(
channel
*
input_height
+
im_row_offset
)
*
input_width
+
im_col_offset
<
0
||
im_col_offset
>=
im_width
)
im_col_offset
;
?
static_cast
<
T
>
(
0
)
col_data
[
col_offset
]
=
im_data
[
im_offset
];
:
im_data
[
im_offset
];
}
}
}
}
}
}
}
...
@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform
::
CPUPlace
,
T
>
{
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
const
framework
::
Tensor
&
col
,
int
dilation_h
,
int
dilation_w
,
int
stride_
width
,
int
padding_up
,
int
padding_down
,
int
stride_
height
,
int
stride_width
,
int
padding_up
,
int
padding_left
,
int
padding_right
)
{
int
padding_
down
,
int
padding_
left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
i
nput
_channels
=
im
.
dims
()[
0
];
int
i
m
_channels
=
im
.
dims
()[
0
];
int
i
nput
_height
=
im
.
dims
()[
1
];
int
i
m
_height
=
im
.
dims
()[
1
];
int
i
nput
_width
=
im
.
dims
()[
2
];
int
i
m
_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
filter_width
=
col
.
dims
()[
4
];
int
output
_height
=
col
.
dims
()[
0
];
int
col
_height
=
col
.
dims
()[
0
];
int
output
_width
=
col
.
dims
()[
1
];
int
col
_width
=
col
.
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
((
im_height
+
padding_up
+
padding_down
-
filter_height
)
/
(
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
stride_height
+
1
,
1
,
col_height
,
output_height
,
"Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent."
);
"inconsistent."
);
PADDLE_ENFORCE_EQ
((
im_width
+
padding_left
+
padding_right
-
filter_width
)
/
PADDLE_ENFORCE_EQ
(
stride_width
+
(
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
1
,
stride_width
+
col_width
,
1
,
"col_width and padding(padding_left, padding_right) are "
output_width
,
"inconsistent."
);
"output_width and padding(padding_left, padding_right) are "
"inconsistent."
);
T
*
im_data
=
im
.
data
<
T
>
();
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output
_height
;
++
col_row_idx
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
col
_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output
_width
;
++
col_col_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
col
_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
i
nput
_channels
;
++
channel
)
{
for
(
int
channel
=
0
;
channel
<
i
m
_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
++
filter_row_idx
)
{
++
filter_row_idx
)
{
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
...
@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_up
;
int
im_col_offset
=
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
(((
col_row_idx
*
output_width
+
col_col_idx
)
*
int
col_offset
=
input
_channels
+
(((
col_row_idx
*
col_width
+
col_col_idx
)
*
im
_channels
+
channel
)
*
channel
)
*
filter_height
+
filter_height
+
filter_row_idx
)
*
filter_row_idx
)
*
filter_width
+
filter_width
+
filter_col_idx
;
filter_col_idx
;
if
(
im_row_offset
>=
0
&&
im_row_offset
<
i
nput
_height
&&
if
(
im_row_offset
>=
0
&&
im_row_offset
<
i
m
_height
&&
im_col_offset
>=
0
&&
im_col_offset
<
i
nput
_width
)
{
im_col_offset
>=
0
&&
im_col_offset
<
i
m
_width
)
{
int
im_offset
=
int
im_offset
=
(
channel
*
i
nput_height
+
im_row_offset
)
*
input
_width
+
(
channel
*
i
m_height
+
im_row_offset
)
*
im
_width
+
im_col_offset
;
im_col_offset
;
im_data
[
im_offset
]
+=
col_data
[
col_offset
];
im_data
[
im_offset
]
+=
col_data
[
col_offset
];
}
}
...
...
paddle/operators/math/im2col.cu
浏览文件 @
97e9dd72
此差异已折叠。
点击以展开。
paddle/operators/math/im2col.h
浏览文件 @
97e9dd72
...
@@ -74,17 +74,18 @@ class Im2ColFunctor {
...
@@ -74,17 +74,18 @@ class Im2ColFunctor {
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
dilation_h
,
int
dilation_w
,
int
stride_height
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
};
};
template
<
ColFormat
Format
,
typename
Place
,
typename
T
>
template
<
ColFormat
Format
,
typename
Place
,
typename
T
>
class
Col2ImFunctor
{
class
Col2ImFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
const
framework
::
Tensor
&
col
,
int
dilation_h
,
int
dilation_w
,
int
stride_
width
,
int
padding_up
,
int
padding_down
,
int
stride_
height
,
int
stride_width
,
int
padding_up
,
int
padding_left
,
int
padding_right
);
int
padding_
down
,
int
padding_
left
,
int
padding_right
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/operators/math/im2col_test.cc
浏览文件 @
97e9dd72
...
@@ -47,6 +47,8 @@ void testIm2col() {
...
@@ -47,6 +47,8 @@ void testIm2col() {
int
filter_size
=
2
;
int
filter_size
=
2
;
int
stride
=
1
;
int
stride
=
1
;
int
padding
=
0
;
int
padding
=
0
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
float
*
input_ptr
=
input_tmp
.
mutable_data
<
float
>
(
float
*
input_ptr
=
input_tmp
.
mutable_data
<
float
>
(
...
@@ -85,10 +87,10 @@ void testIm2col() {
...
@@ -85,10 +87,10 @@ void testIm2col() {
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
Place
,
float
>
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
im2col_ocf
;
im2col
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
im2col
(
*
context
,
input
,
output_cfo
,
dilation_h
,
dilation_w
,
stride
,
stride
,
padding
);
padding
,
padding
,
padding
,
padding
);
im2col_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
im2col_ocf
(
*
context
,
input
,
output_ocf
,
dilation_h
,
dilation_w
,
stride
,
padding
,
padding
);
stride
,
padding
,
padding
,
padding
,
padding
);
float
out_cfo_data
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
};
float
out_cfo_data
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
};
float
out_ocf_data
[]
=
{
0
,
1
,
3
,
4
,
1
,
2
,
4
,
5
};
float
out_ocf_data
[]
=
{
0
,
1
,
3
,
4
,
1
,
2
,
4
,
5
};
...
@@ -131,8 +133,8 @@ void testIm2col() {
...
@@ -131,8 +133,8 @@ void testIm2col() {
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
}
}
col2im
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
col2im
(
*
context
,
input
,
output_cfo
,
dilation_h
,
dilation_w
,
stride
,
stride
,
padding
);
padding
,
padding
,
padding
,
padding
);
float
*
in_ptr
;
float
*
in_ptr
;
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
@@ -153,8 +155,8 @@ void testIm2col() {
...
@@ -153,8 +155,8 @@ void testIm2col() {
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
input
.
CopyFrom
(
input_tmp
,
*
place
,
*
context
);
}
}
col2im_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
col2im_ocf
(
*
context
,
input
,
output_ocf
,
dilation_h
,
dilation_w
,
stride
,
padding
,
padding
);
stride
,
padding
,
padding
,
padding
,
padding
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
in_ptr
=
input
.
data
<
float
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录