Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
271fc9c1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
271fc9c1
编写于
11月 10, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dilation for vol2col
上级
93551bd2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
189 addition
and
70 deletion
+189
-70
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+8
-7
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+8
-5
paddle/operators/math/im2col.cu
paddle/operators/math/im2col.cu
+1
-0
paddle/operators/math/vol2col.cc
paddle/operators/math/vol2col.cc
+64
-16
paddle/operators/math/vol2col.cu
paddle/operators/math/vol2col.cu
+101
-38
paddle/operators/math/vol2col.h
paddle/operators/math/vol2col.h
+2
-0
paddle/operators/math/vol2col_test.cc
paddle/operators/math/vol2col_test.cc
+5
-4
未找到文件。
paddle/operators/conv_op.h
浏览文件 @
271fc9c1
...
...
@@ -165,9 +165,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
// vol2col
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
padding
s
[
1
],
paddings
[
2
]);
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
dilation
s
[
0
],
dilations
[
1
],
dilations
[
2
],
strides
[
0
],
stride
s
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
}
// gemm
...
...
@@ -314,7 +314,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
col2vol
(
context
.
device_context
(),
in_grad_slice
,
col
,
strides
[
0
],
col2vol
(
context
.
device_context
(),
in_grad_slice
,
col
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
}
...
...
@@ -371,9 +372,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
paddings
[
0
],
paddings
[
1
],
paddings
[
1
]);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
stride
s
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
padding
s
[
1
],
paddings
[
2
]);
vol2col
(
context
.
device_context
(),
in_slice
,
col
,
dilation
s
[
0
],
dilations
[
1
],
dilations
[
2
],
strides
[
0
],
stride
s
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
}
// gemm
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
271fc9c1
...
...
@@ -69,6 +69,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
int
dilaiton_d
=
1
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
...
...
@@ -149,8 +150,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
math
::
Col2VolFunctor
<
Place
,
T
>
col2vol
;
col2vol
(
context
.
device_context
(),
output_batch
,
col
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
0
,
0
,
0
);
col2vol
(
context
.
device_context
(),
output_batch
,
col
,
dilaiton_d
,
dilation_h
,
dilation_w
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
0
,
0
,
0
);
}
}
}
...
...
@@ -177,6 +179,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose.
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
int
dilaiton_d
=
1
;
int
dilation_h
=
1
;
int
dilation_w
=
1
;
...
...
@@ -261,9 +264,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
math
::
Vol2ColFunctor
<
Place
,
T
>
vol2col
;
vol2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
strides
[
0
]
,
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
vol2col
(
context
.
device_context
(),
output_grad_batch
,
col
,
dilaiton_d
,
dilation_h
,
dilation_w
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
}
if
(
input_grad
)
{
...
...
paddle/operators/math/im2col.cu
浏览文件 @
271fc9c1
...
...
@@ -145,6 +145,7 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
h_col
)
*
col_width
+
w_col
;
val
+=
data_col
[
data_col_index
];
}
}
...
...
paddle/operators/math/vol2col.cc
浏览文件 @
271fc9c1
...
...
@@ -29,6 +29,7 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
framework
::
Tensor
&
col
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
...
...
@@ -48,6 +49,28 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
int
channels_col
=
input_channels
*
filter_depth
*
filter_height
*
filter_width
;
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
padding_depth
-
((
dilation_d
*
(
filter_depth
-
1
)
+
1
)))
/
stride_depth
+
1
,
output_depth
,
"input_depth and output_depth are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
padding_height
-
((
dilation_h
*
(
filter_height
-
1
)
+
1
)))
/
stride_height
+
1
,
output_height
,
"input_height and output_height are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
padding_width
-
((
dilation_w
*
(
filter_width
-
1
)
+
1
)))
/
stride_width
+
1
,
output_width
,
"input_width and output_width are "
"Mismatching."
);
const
T
*
vol_data
=
vol
.
data
<
T
>
();
T
*
col_data
=
col
.
data
<
T
>
();
...
...
@@ -57,24 +80,25 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
int
d_offset
=
(
c
/
filter_width
/
filter_height
)
%
filter_depth
;
int
c_in
=
c
/
filter_width
/
filter_height
/
filter_depth
;
for
(
int
d
=
0
;
d
<
output_depth
;
++
d
)
{
int
d_pad
=
d
*
stride_depth
-
padding_depth
+
d_offset
;
int
d_pad
=
d
*
stride_depth
-
padding_depth
+
d_offset
*
dilation_d
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
h_pad
=
h
*
stride_height
-
padding_height
+
h_offset
;
int
h_pad
=
h
*
stride_height
-
padding_height
+
h_offset
*
dilation_h
;
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
w_pad
=
w
*
stride_width
-
padding_width
+
w_offset
;
int
w_pad
=
w
*
stride_width
-
padding_width
+
w_offset
*
dilation_w
;
int
col_idx
=
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
w
;
if
(
h_pad
<
0
||
h_pad
>=
input_height
||
w_pad
<
0
||
w_pad
>=
input_width
||
d_pad
<
0
||
d_pad
>=
input_depth
)
{
col_data
[
col_idx
]
=
static_cast
<
T
>
(
0
);
}
else
{
int
vol_idx
=
((
c_in
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
input_width
+
w_pad
;
col_data
[
col_idx
]
=
vol_data
[
vol_idx
];
}
int
vol_idx
=
((
c_in
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
input_width
+
w_pad
;
col_data
[
col_idx
]
=
(
h_pad
<
0
||
h_pad
>=
input_height
||
w_pad
<
0
||
w_pad
>=
input_width
||
d_pad
<
0
||
d_pad
>=
input_depth
)
?
static_cast
<
T
>
(
0
)
:
vol_data
[
vol_idx
];
}
}
}
...
...
@@ -93,6 +117,7 @@ class Col2VolFunctor<platform::CPUPlace, T> {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
col
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
...
...
@@ -112,6 +137,27 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int
channels_col
=
input_channels
*
filter_depth
*
filter_height
*
filter_width
;
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
padding_depth
-
((
dilation_d
*
(
filter_depth
-
1
)
+
1
)))
/
stride_depth
+
1
,
output_depth
,
"input_depth and output_depth are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
padding_height
-
((
dilation_h
*
(
filter_height
-
1
)
+
1
)))
/
stride_height
+
1
,
output_height
,
"input_height and output_height are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
padding_width
-
((
dilation_w
*
(
filter_width
-
1
)
+
1
)))
/
stride_width
+
1
,
output_width
,
"input_width and output_width are "
"Mismatching."
);
T
*
vol_data
=
vol
.
data
<
T
>
();
const
T
*
col_data
=
col
.
data
<
T
>
();
...
...
@@ -121,11 +167,13 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int
d_offset
=
(
c
/
filter_width
/
filter_height
)
%
filter_depth
;
int
cIm
=
c
/
filter_width
/
filter_height
/
filter_depth
;
for
(
int
d
=
0
;
d
<
output_depth
;
++
d
)
{
int
d_pad
=
d
*
stride_depth
-
padding_depth
+
d_offset
;
int
d_pad
=
d
*
stride_depth
-
padding_depth
+
d_offset
*
dilation_d
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
h_pad
=
h
*
stride_height
-
padding_height
+
h_offset
;
int
h_pad
=
h
*
stride_height
-
padding_height
+
h_offset
*
dilation_h
;
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
w_pad
=
w
*
stride_width
-
padding_width
+
w_offset
;
int
w_pad
=
w
*
stride_width
-
padding_width
+
w_offset
*
dilation_w
;
if
(
h_pad
>=
0
&&
h_pad
<
input_height
&&
w_pad
>=
0
&&
w_pad
<
input_width
&&
d_pad
>=
0
&&
d_pad
<
input_depth
)
{
...
...
paddle/operators/math/vol2col.cu
浏览文件 @
271fc9c1
...
...
@@ -21,11 +21,12 @@ namespace math {
template
<
class
T
>
__global__
void
vol2col
(
int
num_kernels
,
const
T
*
data_vol
,
int
depth
,
int
height
,
int
width
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_col
)
{
int
height
,
int
width
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_col
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_out
=
index
%
output_width
;
...
...
@@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
for
(
int
k
=
0
;
k
<
filter_depth
;
++
k
)
{
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
int
d
=
d_in
+
k
;
int
h
=
h_in
+
i
;
int
w
=
w_in
+
j
;
int
d
=
d_in
+
k
*
dilation_d
;
int
h
=
h_in
+
i
*
dilation_h
;
int
w
=
w_in
+
j
*
dilation_w
;
int
col_idx
=
(
k
*
dilation_d
*
height
+
i
*
dilation_h
)
*
width
+
j
*
dilation_w
;
*
data_col
=
(
d
>=
0
&&
d
<
depth
&&
h
>=
0
&&
h
<
height
&&
w
>=
0
&&
w
<
width
)
?
data_vol
[
(
k
*
height
+
i
)
*
width
+
j
]
?
data_vol
[
col_idx
]
:
0
;
data_col
+=
output_detph
*
output_height
*
output_width
;
}
...
...
@@ -69,6 +72,7 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
framework
::
Tensor
&
col
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
...
...
@@ -86,6 +90,28 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
int
output_height
=
col
.
dims
()[
5
];
int
output_width
=
col
.
dims
()[
6
];
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
padding_depth
-
((
dilation_d
*
(
filter_depth
-
1
)
+
1
)))
/
stride_depth
+
1
,
output_depth
,
"input_depth and output_depth are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
padding_height
-
((
dilation_h
*
(
filter_height
-
1
)
+
1
)))
/
stride_height
+
1
,
output_height
,
"input_height and output_height are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
padding_width
-
((
dilation_w
*
(
filter_width
-
1
)
+
1
)))
/
stride_width
+
1
,
output_width
,
"input_width and output_width are "
"Mismatching."
);
int
num_outputs
=
input_channels
*
output_depth
*
output_height
*
output_width
;
...
...
@@ -95,19 +121,25 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
num_outputs
,
vol
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
filter_depth
,
filter_height
,
filter_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_depth
,
output_height
,
output_width
,
col
.
data
<
T
>
());
dilation_d
,
dilation_h
,
dilation_w
,
filter_depth
,
filter_height
,
filter_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_depth
,
output_height
,
output_width
,
col
.
data
<
T
>
());
}
};
template
<
class
T
>
__global__
void
col2vol
(
int
num_kernels
,
const
T
*
data_col
,
int
depth
,
int
height
,
int
width
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_vol
)
{
int
height
,
int
width
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
filter_depth
,
int
filter_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_vol
)
{
const
int
d_filter_depth
=
dilation_d
*
(
filter_depth
-
1
)
+
1
;
const
int
d_filter_height
=
dilation_h
*
(
filter_height
-
1
)
+
1
;
const
int
d_filter_width
=
dilation_w
*
(
filter_width
-
1
)
+
1
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
src_val
=
0
;
...
...
@@ -115,35 +147,42 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int
h
=
(
index
/
width
)
%
height
+
padding_height
;
int
d
=
(
index
/
width
/
height
)
%
depth
+
padding_depth
;
int
c
=
index
/
width
/
height
/
depth
;
// compute the start and end of the output
int
w_col_start
=
(
w
<
filter_width
)
?
0
:
(
w
-
filter_width
)
/
stride_width
+
1
;
(
w
<
d_filter_width
)
?
0
:
(
w
-
d_
filter_width
)
/
stride_width
+
1
;
int
w_col_end
=
min
(
w
/
stride_width
+
1
,
output_width
);
int
h_col_start
=
(
h
<
filter_height
)
?
0
:
(
h
-
filter_height
)
/
stride_height
+
1
;
(
h
<
d_filter_height
)
?
0
:
(
h
-
d_
filter_height
)
/
stride_height
+
1
;
int
h_col_end
=
min
(
h
/
stride_height
+
1
,
output_height
);
int
d_col_start
=
(
d
<
filter_depth
)
?
0
:
(
d
-
filter_depth
)
/
stride_depth
+
1
;
(
d
<
d_filter_depth
)
?
0
:
(
d
-
d_
filter_depth
)
/
stride_depth
+
1
;
int
d_col_end
=
min
(
d
/
stride_depth
+
1
,
output_detph
);
int
offset
=
(
c
*
filter_depth
*
filter_height
*
filter_width
+
d
*
filter_width
*
filter_height
+
h
*
filter_width
+
w
)
*
output_detph
*
output_height
*
output_width
;
int
coeff_d_col
=
(
1
-
stride_depth
*
filter_width
*
filter_height
*
output_detph
)
*
output_height
*
output_width
;
int
coeff_h_col
=
(
1
-
stride_height
*
filter_width
*
output_detph
*
output_height
)
*
output_width
;
int
coeff_w_col
=
(
1
-
stride_width
*
output_detph
*
output_height
*
output_width
);
for
(
int
d_col
=
d_col_start
;
d_col
<
d_col_end
;
++
d_col
)
{
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
src_val
+=
data_col
[
offset
+
d_col
*
coeff_d_col
+
h_col
*
coeff_h_col
+
w_col
*
coeff_w_col
];
int
d_off
=
(
d
-
d_col
*
stride_depth
);
int
h_off
=
(
h
-
h_col
*
stride_height
);
int
w_off
=
(
w
-
w_col
*
stride_width
);
if
(
d_off
%
dilation_d
==
0
&&
h_off
%
dilation_h
==
0
&&
w_off
%
dilation_w
==
0
)
{
d_off
/=
dilation_d
;
h_off
/=
dilation_h
;
w_off
/=
dilation_w
;
int
data_col_index
=
(((((
c
*
filter_depth
+
d_off
)
*
filter_height
+
h_off
)
*
filter_width
+
w_off
)
*
output_detph
+
d_col
)
*
output_height
+
h_col
)
*
output_width
+
w_col
;
src_val
+=
data_col
[
data_col_index
];
}
}
}
}
...
...
@@ -162,6 +201,7 @@ class Col2VolFunctor<platform::GPUPlace, T> {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
col
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
{
...
...
@@ -179,6 +219,28 @@ class Col2VolFunctor<platform::GPUPlace, T> {
int
output_height
=
col
.
dims
()[
5
];
int
output_width
=
col
.
dims
()[
6
];
PADDLE_ENFORCE_EQ
((
input_depth
+
2
*
padding_depth
-
((
dilation_d
*
(
filter_depth
-
1
)
+
1
)))
/
stride_depth
+
1
,
output_depth
,
"input_depth and output_depth are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_height
+
2
*
padding_height
-
((
dilation_h
*
(
filter_height
-
1
)
+
1
)))
/
stride_height
+
1
,
output_height
,
"input_height and output_height are "
"Mismatching."
);
PADDLE_ENFORCE_EQ
((
input_width
+
2
*
padding_width
-
((
dilation_w
*
(
filter_width
-
1
)
+
1
)))
/
stride_width
+
1
,
output_width
,
"input_width and output_width are "
"Mismatching."
);
int
num_kernels
=
input_channels
*
input_depth
*
input_height
*
input_width
;
const
int
threads
=
1024
;
...
...
@@ -188,9 +250,10 @@ class Col2VolFunctor<platform::GPUPlace, T> {
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
num_kernels
,
col
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
filter_depth
,
filter_height
,
filter_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_depth
,
output_height
,
output_width
,
vol
.
data
<
T
>
());
dilation_d
,
dilation_h
,
dilation_w
,
filter_depth
,
filter_height
,
filter_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_depth
,
output_height
,
output_width
,
vol
.
data
<
T
>
());
}
};
...
...
paddle/operators/math/vol2col.h
浏览文件 @
271fc9c1
...
...
@@ -58,6 +58,7 @@ class Vol2ColFunctor {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
framework
::
Tensor
&
col
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
;
...
...
@@ -68,6 +69,7 @@ class Col2VolFunctor {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
col
,
int
dilation_d
,
int
dilation_h
,
int
dilation_w
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
)
const
;
...
...
paddle/operators/math/vol2col_test.cc
浏览文件 @
271fc9c1
...
...
@@ -64,6 +64,7 @@ void testVol2col() {
int
filter_size
=
2
;
int
stride
=
1
;
int
padding
=
0
;
int
dilation
=
1
;
int
output_depth
=
(
input_depth
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_height
=
(
input_height
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
int
output_width
=
(
input_width
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
...
...
@@ -85,8 +86,8 @@ void testVol2col() {
*
place
);
paddle
::
operators
::
math
::
Vol2ColFunctor
<
Place
,
float
>
vol2col
;
vol2col
(
*
context
,
input
,
output
,
stride
,
stride
,
stride
,
padding
,
padding
,
padding
);
vol2col
(
*
context
,
input
,
output
,
dilation
,
dilation
,
dilation
,
stride
,
stride
,
stride
,
padding
,
padding
,
padding
);
float
vol_2_col
[]
=
{
0
,
1
,
1
,
2
,
3
,
4
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
10
,
10
,
11
};
float
*
out_cfo_ptr
;
...
...
@@ -111,8 +112,8 @@ void testVol2col() {
}
paddle
::
operators
::
math
::
Col2VolFunctor
<
Place
,
float
>
col2vol
;
col2vol
(
*
context
,
input
,
output
,
stride
,
stride
,
stride
,
padding
,
padding
,
padding
);
col2vol
(
*
context
,
input
,
output
,
dilation
,
dilation
,
dilation
,
stride
,
stride
,
stride
,
padding
,
padding
,
padding
);
float
*
in_ptr
;
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录