Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dc7d0735
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dc7d0735
编写于
10月 21, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add padding up, down, left, right
上级
d2c1408f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
158 addition
and
134 deletion
+158
-134
paddle/operators/conv2d_op.h
paddle/operators/conv2d_op.h
+5
-3
paddle/operators/math/im2col.cc
paddle/operators/math/im2col.cc
+78
-64
paddle/operators/math/im2col.cu
paddle/operators/math/im2col.cu
+63
-56
paddle/operators/math/im2col.h
paddle/operators/math/im2col.h
+4
-3
paddle/operators/math/im2col_test.cc
paddle/operators/math/im2col_test.cc
+8
-8
未找到文件。
paddle/operators/conv2d_op.h
浏览文件 @
dc7d0735
...
...
@@ -116,7 +116,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
// im2col
Tensor
in_slice
=
in_batch
.
Slice
<
T
>
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
im2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
// gemm
Tensor
out_slice
=
out_batch
.
Slice
<
T
>
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
...
...
@@ -217,7 +217,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
Tensor
in_grad_slice
=
in_grad_batch
.
Slice
<
T
>
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
col2im
(
context
.
device_context
(),
in_grad_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
}
}
}
...
...
@@ -239,7 +240,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
out_grad_batch
.
Slice
<
T
>
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
in_slice
=
in_batch
.
Slice
<
T
>
(
g
*
in_step
,
(
g
+
1
)
*
in_step
);
im2col
(
context
.
device_context
(),
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
strides
[
1
],
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
// gemm
Tensor
filter_grad_slice
=
...
...
paddle/operators/math/im2col.cc
浏览文件 @
dc7d0735
...
...
@@ -29,8 +29,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_
height
,
int
padding_
width
)
{
int
stride_height
,
int
stride_width
,
int
padding_
up
,
int
padding_
down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
...
...
@@ -41,6 +41,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int
filter_width
=
col
.
dims
()[
2
];
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
int
channels_col
=
input_channels
*
filter_height
*
filter_width
;
const
T
*
im_data
=
im
.
data
<
T
>
();
...
...
@@ -54,14 +64,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
;
int
im_col_idx
=
w
*
stride_width
+
w_offset
;
if
((
im_row_idx
-
padding_
height
)
<
0
||
(
im_row_idx
-
padding_
height
)
>=
input_height
||
(
im_col_idx
-
padding_
width
)
<
0
||
(
im_col_idx
-
padding_
width
)
>=
input_width
)
{
if
((
im_row_idx
-
padding_
up
)
<
0
||
(
im_row_idx
-
padding_
up
)
>=
input_height
||
(
im_col_idx
-
padding_
left
)
<
0
||
(
im_col_idx
-
padding_
left
)
>=
input_width
)
{
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
T
(
0
);
}
else
{
im_row_idx
+=
c_im
*
input_height
-
padding_
height
;
im_col_idx
-=
padding_
width
;
im_row_idx
+=
c_im
*
input_height
-
padding_
up
;
im_col_idx
-=
padding_
left
;
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
]
=
im_data
[
im_row_idx
*
input_width
+
im_col_idx
];
}
...
...
@@ -82,7 +92,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
...
...
@@ -92,6 +103,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int
filter_width
=
col
.
dims
()[
2
];
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
int
channels_col
=
input_channels
*
filter_height
*
filter_width
;
T
*
im_data
=
im
.
data
<
T
>
();
...
...
@@ -105,12 +126,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_row_idx
=
h
*
stride_height
+
h_offset
;
int
im_col_idx
=
w
*
stride_width
+
w_offset
;
if
((
im_row_idx
-
padding_
height
)
>=
0
&&
(
im_row_idx
-
padding_
height
)
<
input_height
&&
(
im_col_idx
-
padding_
width
)
>=
0
&&
(
im_col_idx
-
padding_
width
)
<
input_width
)
{
im_row_idx
+=
c_im
*
input_height
-
padding_
height
;
im_col_idx
-=
padding_
width
;
if
((
im_row_idx
-
padding_
up
)
>=
0
&&
(
im_row_idx
-
padding_
up
)
<
input_height
&&
(
im_col_idx
-
padding_
left
)
>=
0
&&
(
im_col_idx
-
padding_
left
)
<
input_width
)
{
im_row_idx
+=
c_im
*
input_height
-
padding_
up
;
im_col_idx
-=
padding_
left
;
im_data
[
im_row_idx
*
input_width
+
im_col_idx
]
+=
col_data
[(
c
*
output_height
+
h
)
*
output_width
+
w
];
}
...
...
@@ -140,8 +161,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
up_pad
,
int
down_pad
)
{
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
...
...
@@ -149,25 +170,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
//
int output_height = col.dims()[0];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
int
row_begin
,
row_end
;
int
padding_height
=
std
::
max
(
up_pad
,
down_pad
);
int
padding_width
=
0
;
if
(
up_pad
>=
down_pad
)
{
row_begin
=
0
;
}
else
{
row_begin
=
down_pad
-
up_pad
;
}
row_end
=
row_begin
+
((
input_height
+
up_pad
+
down_pad
-
filter_height
)
/
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
);
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
row_begin
;
col_row_idx
<
row_end
;
++
col_row_idx
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
input_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
...
...
@@ -175,11 +193,10 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
++
filter_col_idx
)
{
int
im_row_offset
=
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_
height
;
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_
up
;
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_width
;
int
col_offset
=
((((
col_row_idx
-
row_begin
)
*
output_width
+
col_col_idx
)
*
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
((((
col_row_idx
)
*
output_width
+
col_col_idx
)
*
input_channels
+
channel
)
*
filter_height
+
...
...
@@ -214,7 +231,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
up_pad
,
int
down_pad
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
...
...
@@ -222,25 +240,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
//
int output_height = col.dims()[0];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
int
row_begin
,
row_end
;
int
padding_height
=
std
::
max
(
up_pad
,
down_pad
);
int
padding_width
=
0
;
if
(
up_pad
>=
down_pad
)
{
row_begin
=
0
;
}
else
{
row_begin
=
down_pad
-
up_pad
;
}
row_end
=
row_begin
+
((
input_height
+
up_pad
+
down_pad
-
filter_height
)
/
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
);
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
T
*
im_data
=
im
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
for
(
int
col_row_idx
=
row_begin
;
col_row_idx
<
row_end
;
++
col_row_idx
)
{
for
(
int
col_row_idx
=
0
;
col_row_idx
<
output_height
;
++
col_row_idx
)
{
for
(
int
col_col_idx
=
0
;
col_col_idx
<
output_width
;
++
col_col_idx
)
{
for
(
int
channel
=
0
;
channel
<
input_channels
;
++
channel
)
{
for
(
int
filter_row_idx
=
0
;
filter_row_idx
<
filter_height
;
...
...
@@ -248,11 +263,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for
(
int
filter_col_idx
=
0
;
filter_col_idx
<
filter_width
;
++
filter_col_idx
)
{
int
im_row_offset
=
// change or not ???
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_
height
;
col_row_idx
*
stride_height
+
filter_row_idx
-
padding_
up
;
int
im_col_offset
=
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_width
;
int
col_offset
=
((((
col_row_idx
-
row_begin
)
*
output_width
+
col_col_idx
)
*
col_col_idx
*
stride_width
+
filter_col_idx
-
padding_left
;
int
col_offset
=
(((
col_row_idx
*
output_width
+
col_col_idx
)
*
input_channels
+
channel
)
*
filter_height
+
...
...
paddle/operators/math/im2col.cu
浏览文件 @
dc7d0735
...
...
@@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_
height
,
int
padding_
width
)
{
int
stride_height
,
int
stride_width
,
int
padding_
up
,
int
padding_
down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
...
...
@@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
int
num_outputs
=
input_channels
*
output_height
*
output_width
;
int
blocks
=
(
num_outputs
+
1024
-
1
)
/
1024
;
int
block_x
=
512
;
...
...
@@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
im
.
data
<
T
>
(),
num_outputs
,
input_height
,
input_width
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_
heigh
t
,
padding_width
,
output_height
,
output_width
,
col
.
data
<
T
>
());
filter_width
,
stride_height
,
stride_width
,
padding_
up
,
padding_lef
t
,
output_height
,
output_width
,
col
.
data
<
T
>
());
}
};
...
...
@@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
...
...
@@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
size_t
num_kernels
=
input_channels
*
(
input_height
+
2
*
padding_height
)
*
(
input_width
+
2
*
padding_width
);
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
size_t
num_kernels
=
input_channels
*
(
input_height
+
padding_up
+
padding_down
)
*
(
input_width
+
padding_left
+
padding_right
);
size_t
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
size_t
block_x
=
512
;
...
...
@@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
num_kernels
,
col
.
data
<
T
>
(),
input_height
+
2
*
padding_height
,
input_width
+
2
*
padding_width
,
input_channels
,
filter_height
,
filter_
width
,
stride_height
,
stride_width
,
padding_height
,
padding_
width
,
output_height
,
output_width
,
im
.
data
<
T
>
());
num_kernels
,
col
.
data
<
T
>
(),
input_height
+
padding_up
+
padding_down
,
input_width
+
padding_left
+
padding_left
,
input_channels
,
filter_
height
,
filter_width
,
stride_height
,
stride_width
,
padding_up
,
padding_
left
,
output_height
,
output_width
,
im
.
data
<
T
>
());
}
};
...
...
@@ -199,8 +219,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int
input_height
,
int
input_width
,
int
filter_height
,
int
filter_width
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
,
int
output_height
,
int
output_width
,
int
row_begin
,
int
row_end
)
{
int
output_height
,
int
output_width
)
{
int
swid
=
blockIdx
.
x
;
int
shid
=
blockIdx
.
y
;
for
(
int
channelid
=
threadIdx
.
z
;
channelid
<
input_channels
;
...
...
@@ -208,8 +227,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filter_height
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filter_width
;
idx
+=
blockDim
.
x
)
{
int
width_offset
=
idx
+
swid
*
stride_width
-
padding_width
;
int
height_offset
=
idy
+
(
shid
+
row_begin
)
*
stride_height
-
padding_height
;
int
height_offset
=
idy
+
shid
*
stride_height
-
padding_height
;
int
im_offset
=
width_offset
+
height_offset
*
input_width
+
channelid
*
input_height
*
input_width
;
...
...
@@ -240,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
up_pad
,
int
down_pad
)
{
int
stride_height
,
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
...
...
@@ -249,21 +267,17 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
int
row_begin
,
row_end
;
int
padding_height
=
std
::
max
(
up_pad
,
down_pad
);
int
padding_width
=
0
;
if
(
up_pad
>=
down_pad
)
{
row_begin
=
0
;
}
else
{
row_begin
=
down_pad
-
up_pad
;
}
row_end
=
row_begin
+
((
input_height
+
up_pad
+
down_pad
-
filter_height
)
/
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
);
int
output_height
=
row_end
-
row_begin
;
// col.dims()[0];
int
output_width
=
col
.
dims
()[
1
];
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
int
block_dim_x
=
0
;
int
block_dim_y
=
0
;
...
...
@@ -289,9 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
im
.
data
<
T
>
(),
col
.
data
<
T
>
(),
input_channels
,
input_height
,
input_width
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_height
,
output_width
,
row_begin
,
row_end
);
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_up
,
padding_left
,
output_height
,
output_width
);
}
};
...
...
@@ -300,8 +313,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int
input_height
,
int
input_width
,
int
filter_height
,
int
filter_width
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
,
int
output_height
,
int
output_width
,
int
row_begin
,
int
row_end
)
{
int
output_height
,
int
output_width
)
{
int
swid
=
blockIdx
.
x
;
int
shid
=
blockIdx
.
y
;
for
(
int
channelid
=
threadIdx
.
z
;
channelid
<
input_channels
;
...
...
@@ -309,8 +321,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filter_height
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filter_width
;
idx
+=
blockDim
.
x
)
{
int
width_offset
=
idx
+
swid
*
stride_width
-
padding_width
;
int
height_offset
=
idy
+
(
shid
+
row_begin
)
*
stride_height
-
padding_height
;
int
height_offset
=
idy
+
shid
*
stride_height
-
padding_height
;
int
im_offset
=
width_offset
+
height_offset
*
input_width
+
channelid
*
input_height
*
input_width
;
...
...
@@ -340,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
up_pad
,
int
down_pad
)
{
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
...
...
@@ -348,21 +360,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
int
row_begin
,
row_end
;
int
padding_height
=
std
::
max
(
up_pad
,
down_pad
);
int
padding_width
=
0
;
if
(
up_pad
>=
down_pad
)
{
row_begin
=
0
;
}
else
{
row_begin
=
down_pad
-
up_pad
;
}
row_end
=
row_begin
+
((
input_height
+
up_pad
+
down_pad
-
filter_height
)
/
PADDLE_ENFORCE
((
input_height
+
padding_up
+
padding_down
-
filter_height
)
/
stride_height
+
1
);
int
output_height
=
row_end
-
row_begin
;
// col.dims()[0];
int
output_width
=
col
.
dims
()[
1
];
1
==
output_height
);
PADDLE_ENFORCE
((
input_width
+
padding_left
+
padding_right
-
filter_width
)
/
stride_width
+
1
==
output_width
);
int
block_dim_x
=
0
;
int
block_dim_y
=
0
;
...
...
@@ -388,9 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
im
.
data
<
T
>
(),
col
.
data
<
T
>
(),
input_channels
,
input_height
,
input_width
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_height
,
output_width
,
row_begin
,
row_end
);
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_up
,
padding_left
,
output_height
,
output_width
);
}
};
...
...
paddle/operators/math/im2col.h
浏览文件 @
dc7d0735
...
...
@@ -74,8 +74,8 @@ class Im2ColFunctor {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_
height
,
int
padding_
width
);
int
stride_height
,
int
stride_width
,
int
padding_
up
,
int
padding_
down
,
int
padding_left
,
int
padding_right
);
};
template
<
ColFormat
Format
,
typename
Place
,
typename
T
>
...
...
@@ -83,7 +83,8 @@ class Col2ImFunctor {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
);
int
stride_width
,
int
padding_up
,
int
padding_down
,
int
padding_left
,
int
padding_right
);
};
}
// namespace math
...
...
paddle/operators/math/im2col_test.cc
浏览文件 @
dc7d0735
...
...
@@ -85,10 +85,10 @@ void testIm2col() {
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
Place
,
float
>
im2col_ocf
;
im2col
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
);
im2col_ocf
(
*
context
,
input
,
output_ocf
,
/*stride_height*/
stride
,
/*stride_width*/
stride
,
/*up_pad*/
padding
,
/*down_pad*/
padding
);
im2col
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
im2col_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
float
out_cfo_data
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
};
float
out_ocf_data
[]
=
{
0
,
1
,
3
,
4
,
1
,
2
,
4
,
5
};
...
...
@@ -133,7 +133,8 @@ void testIm2col() {
input
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
,
*
context
);
}
col2im
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
);
col2im
(
*
context
,
input
,
output_cfo
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
float
*
in_ptr
;
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
...
@@ -154,9 +155,8 @@ void testIm2col() {
input
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
,
*
context
);
}
col2im_ocf
(
*
context
,
input
,
output_ocf
,
/*stride_height*/
stride
,
/*stride_width*/
stride
,
/*up_pad*/
padding
,
/*down_pad*/
padding
);
col2im_ocf
(
*
context
,
input
,
output_ocf
,
stride
,
stride
,
padding
,
padding
,
padding
,
padding
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录