Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e967645c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e967645c
编写于
8月 30, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine the gpu code.
上级
f7be9cb9
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
236 addition
and
219 deletion
+236
-219
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-1
paddle/operators/math/im2col.cu
paddle/operators/math/im2col.cu
+234
-218
未找到文件。
paddle/operators/math/CMakeLists.txt
浏览文件 @
e967645c
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc DEPS cblas device_context
)
nv_library
(
math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context
)
endif
()
...
...
paddle/operators/math/im2col.cu
浏览文件 @
e967645c
...
...
@@ -12,86 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "
Im2C
ol.h"
#include "
hl_device_functions.cu
h"
#include "
paddle/operators/math/im2c
ol.h"
#include "
paddle/platform/cuda_helper.
h"
namespace
paddle
{
template
<
class
T
>
__global__
void
im2col
(
const
T
*
data_im
,
int
num
O
uts
,
int
height
,
int
width
,
int
blockH
,
int
blockW
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
int
height_col
,
int
width_col
,
T
*
data_col
)
{
__global__
void
im2col
(
const
T
*
data_im
,
int
num
_o
uts
,
int
height
,
int
width
,
int
filter_height
,
int
filter_width
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
,
int
output_height
,
int
output_width
,
T
*
data_col
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
num
O
uts
)
{
int
w_out
=
index
%
width_col
;
index
/=
width_col
;
int
h_out
=
index
%
height_col
;
int
channel_in
=
index
/
height_col
;
int
channel_out
=
channel_in
*
blockH
*
blockW
;
int
h_in
=
h_out
*
stride
H
;
int
w_in
=
w_out
*
stride
W
;
if
(
index
<
num
_o
uts
)
{
int
w_out
=
index
%
output_width
;
index
/=
output_width
;
int
h_out
=
index
%
output_height
;
int
channel_in
=
index
/
output_height
;
int
channel_out
=
channel_in
*
filter_height
*
filter_width
;
int
h_in
=
h_out
*
stride
_height
;
int
w_in
=
w_out
*
stride
_width
;
data_col
+=
(
channel_out
*
height_col
+
h_out
)
*
width_col
+
w_out
;
for
(
int
i
=
0
;
i
<
blockH
;
++
i
)
{
for
(
int
j
=
0
;
j
<
blockW
;
++
j
)
{
data_col
+=
(
channel_out
*
output_height
+
h_out
)
*
output_width
+
w_out
;
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
int
rIdx
=
int
(
h_in
+
i
);
int
cIdx
=
int
(
w_in
+
j
);
if
((
rIdx
-
(
int
)
padding
H
)
>=
(
int
)
height
||
(
rIdx
-
(
int
)
padding
H
)
<
0
||
(
cIdx
-
(
int
)
padding
W
)
>=
(
int
)
width
||
(
cIdx
-
(
int
)
padding
W
)
<
0
)
{
if
((
rIdx
-
(
int
)
padding
_height
)
>=
(
int
)
height
||
(
rIdx
-
(
int
)
padding
_height
)
<
0
||
(
cIdx
-
(
int
)
padding
_width
)
>=
(
int
)
width
||
(
cIdx
-
(
int
)
padding
_width
)
<
0
)
{
*
data_col
=
0
;
}
else
{
rIdx
=
rIdx
+
channel_in
*
height
-
padding
H
;
cIdx
=
cIdx
-
padding
W
;
rIdx
=
rIdx
+
channel_in
*
height
-
padding
_height
;
cIdx
=
cIdx
-
padding
_width
;
*
data_col
=
data_im
[
rIdx
*
width
+
cIdx
];
}
data_col
+=
height_col
*
width_col
;
data_col
+=
output_height
*
output_width
;
}
}
}
}
/*
* im
Shape = [inputChannels, inputHeight, inputW
idth]
* col
Shape
=
* [input
Channels, filterHeight, filterWidth, outputHeight, outputW
idth]
* im
= [input_channels, input_height, input_w
idth]
* col =
* [input
_channels, filter_height, filter_width, output_height, output_w
idth]
*/
template
<
class
T
>
class
Im2ColFunctor
<
kCFO
,
DEVICE_TYPE_GPU
,
T
>
{
class
Im2ColFunctor
<
kCFO
,
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
T
*
imData
,
const
TensorShape
&
imShape
,
T
*
colData
,
const
TensorShape
&
colShape
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
filterHeight
=
colShape
[
1
];
int
filterWidth
=
colShape
[
2
];
int
outputHeight
=
colShape
[
3
];
int
outputWidth
=
colShape
[
4
];
void
operator
()(
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
numKernels
=
inputChannels
*
outputHeight
*
outputWidth
;
int
blocks
=
(
numKernels
+
1024
-
1
)
/
1024
;
int
blockX
=
512
;
int
blockY
=
(
blocks
+
512
-
1
)
/
512
;
int
input_channels
=
im
.
dims
()[
0
];
int
input_height
=
im
.
dims
()[
1
];
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
int
num_outputs
=
input_channels
*
output_height
*
output_width
;
int
blocks
=
(
num_outputs
+
1024
-
1
)
/
1024
;
int
block_x
=
512
;
int
block_y
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
im2col
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
imData
,
numKernels
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingWidth
,
outputHeight
,
outputWidth
,
colData
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
dim3
grid
(
block_x
,
block_y
);
im2col
<
T
><<<
grid
,
threads
>>>
(
im
.
data
<
T
>
(),
num_outputs
,
input_height
,
input_width
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_height
,
output_width
,
col
.
data
<
T
>
());
}
};
template
<
class
T
>
__global__
void
col2im
(
size_t
n
,
const
T
*
data_col
,
size_t
height
,
size_t
width
,
size_t
channels
,
size_t
blockH
,
size_t
blockW
,
size_t
strideH
,
size_t
strideW
,
size_t
paddingH
,
size_t
paddingW
,
size_t
height_col
,
size_t
width_col
,
T
*
data_im
)
{
size_t
channels
,
size_t
filter_height
,
size_t
filter_width
,
size_t
stride_height
,
size_t
stride_width
,
size_t
padding_height
,
size_t
padding_width
,
size_t
output_height
,
size_t
output_width
,
T
*
data_im
)
{
size_t
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
n
)
{
...
...
@@ -99,104 +102,112 @@ __global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
int
w
=
int
(
index
%
width
);
int
h
=
int
((
index
/
width
)
%
height
);
int
c
=
int
(
index
/
(
width
*
height
));
if
((
w
-
(
int
)
paddingW
)
>=
0
&&
(
w
-
(
int
)
paddingW
)
<
(
width
-
2
*
paddingW
)
&&
(
h
-
(
int
)
paddingH
)
>=
0
&&
(
h
-
paddingH
)
<
(
height
-
2
*
paddingH
))
{
if
((
w
-
(
int
)
padding_width
)
>=
0
&&
(
w
-
(
int
)
padding_width
)
<
(
width
-
2
*
padding_width
)
&&
(
h
-
(
int
)
padding_height
)
>=
0
&&
(
h
-
padding_height
)
<
(
height
-
2
*
padding_height
))
{
// compute the start and end of the output
int
w_col_start
=
(
w
<
(
int
)
blockW
)
?
0
:
(
w
-
int
(
blockW
))
/
(
int
)
strideW
+
1
;
int
w_col_end
=
min
((
int
)(
w
/
(
int
)
strideW
+
1
),
(
int
)(
width_col
));
int
h_col_start
=
(
h
<
(
int
)
blockH
)
?
0
:
(
h
-
(
int
)
blockH
)
/
(
int
)
strideH
+
1
;
int
h_col_end
=
min
(
int
(
h
/
strideH
+
1
),
int
(
height_col
));
int
w_col_start
=
(
w
<
(
int
)
filter_width
)
?
0
:
(
w
-
int
(
filter_width
))
/
(
int
)
stride_width
+
1
;
int
w_col_end
=
min
((
int
)(
w
/
(
int
)
stride_width
+
1
),
(
int
)(
output_width
));
int
h_col_start
=
(
h
<
(
int
)
filter_height
)
?
0
:
(
h
-
(
int
)
filter_height
)
/
(
int
)
stride_height
+
1
;
int
h_col_end
=
min
(
int
(
h
/
stride_height
+
1
),
int
(
output_height
));
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
// the col location: [c * width * height + h_out, w_out]
int
c_col
=
int
(
c
*
blockH
*
blockW
)
+
(
h
-
h_col
*
(
int
)
strideH
)
*
(
int
)
blockW
+
(
w
-
w_col
*
(
int
)
strideW
);
val
+=
data_col
[(
c_col
*
height_col
+
h_col
)
*
width_col
+
w_col
];
int
c_col
=
int
(
c
*
filter_height
*
filter_width
)
+
(
h
-
h_col
*
(
int
)
stride_height
)
*
(
int
)
filter_width
+
(
w
-
w_col
*
(
int
)
stride_width
);
val
+=
data_col
[(
c_col
*
output_height
+
h_col
)
*
output_width
+
w_col
];
}
}
h
-=
paddingH
;
w
-=
paddingW
;
data_im
[
c
*
((
width
-
2
*
paddingW
)
*
(
height
-
2
*
paddingH
))
+
h
*
(
width
-
2
*
paddingW
)
+
w
]
+=
val
;
h
-=
padding_height
;
w
-=
padding_width
;
data_im
[
c
*
((
width
-
2
*
padding_width
)
*
(
height
-
2
*
padding_height
))
+
h
*
(
width
-
2
*
padding_width
)
+
w
]
+=
val
;
}
}
}
/*
* im
Shape = [inputChannels, inputHeight, inputW
idth]
* col
Shape
=
* [input
Channels, filterHeight, filterWidth, outputHeight, outputW
idth]
* im
= [input_channels, input_height, input_w
idth]
* col =
* [input
_channels, filter_height, filter_width, output_height, output_w
idth]
*/
template
<
class
T
>
class
Col2ImFunctor
<
kCFO
,
DEVICE_TYPE_GPU
,
T
>
{
class
Col2ImFunctor
<
kCFO
,
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
T
*
imData
,
const
TensorShape
&
imShape
,
const
T
*
colData
,
const
TensorShape
&
colShape
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
filterHeight
=
colShape
[
1
];
int
filterWidth
=
colShape
[
2
];
int
outputHeight
=
colShape
[
3
];
int
outputWidth
=
colShape
[
4
];
void
operator
()(
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
int
input_height
=
im
.
dims
()[
1
];
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
output_height
=
col
.
dims
()[
3
];
int
output_width
=
col
.
dims
()[
4
];
size_t
num
Kernels
=
inputChannels
*
(
inputHeight
+
2
*
paddingH
eight
)
*
(
inputWidth
+
2
*
paddingW
idth
);
size_t
num
_kernels
=
input_channels
*
(
input_height
+
2
*
padding_h
eight
)
*
(
input_width
+
2
*
padding_w
idth
);
size_t
blocks
=
(
num
K
ernels
+
1024
-
1
)
/
1024
;
size_t
block
X
=
512
;
size_t
block
Y
=
(
blocks
+
512
-
1
)
/
512
;
size_t
blocks
=
(
num
_k
ernels
+
1024
-
1
)
/
1024
;
size_t
block
_x
=
512
;
size_t
block
_y
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
block
X
,
blockY
);
dim3
grid
(
block
_x
,
block_y
);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
numKernels
,
colData
,
inputHeight
+
2
*
paddingHeight
,
inputWidth
+
2
*
paddingWidth
,
inputChannels
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingWidth
,
outputHeight
,
outputWidth
,
imData
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
col2im
<
T
><<<
grid
,
threads
>>>
(
num_kernels
,
col
.
data
<
T
>
(),
input_height
+
2
*
padding_height
,
input_width
+
2
*
padding_width
,
input_channels
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_height
,
output_width
,
im
.
data
<
T
>
());
}
};
template
class
Im2ColFunctor
<
kCFO
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
kCFO
,
DEVICE_TYPE_GPU
,
double
>;
template
class
Col2ImFunctor
<
kCFO
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Col2ImFunctor
<
kCFO
,
DEVICE_TYPE_GPU
,
double
>;
template
class
Im2ColFunctor
<
kCFO
,
platform
::
GPUPlace
,
float
>;
template
class
Im2ColFunctor
<
kCFO
,
platform
::
GPUPlace
,
double
>;
template
class
Col2ImFunctor
<
kCFO
,
platform
::
GPUPlace
,
float
>;
template
class
Col2ImFunctor
<
kCFO
,
platform
::
GPUPlace
,
double
>;
template
<
class
T
>
__global__
void
im2colOCF
(
const
T
*
im
Data
,
T
*
colData
,
int
inputC
hannels
,
int
input
Height
,
int
inputWidth
,
int
filterH
eight
,
int
filter
Width
,
int
strideHeight
,
int
strideW
idth
,
int
padding
Height
,
int
paddingWidth
,
int
outputHeight
,
int
output
W
idth
)
{
int
sw
I
d
=
blockIdx
.
x
;
int
sh
I
d
=
blockIdx
.
y
;
for
(
int
channel
Id
=
threadIdx
.
z
;
channelId
<
inputC
hannels
;
channel
I
d
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filter
H
eight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filter
W
idth
;
idx
+=
blockDim
.
x
)
{
int
width
Offset
=
idx
+
swId
*
strideWidth
-
paddingW
idth
;
int
height
Offset
=
idy
+
shId
*
strideHeight
-
paddingH
eight
;
int
im
Offset
=
widthOffset
+
heightOffset
*
inputW
idth
+
channelId
*
inputHeight
*
inputW
idth
;
__global__
void
im2colOCF
(
const
T
*
im
_data
,
T
*
col_data
,
int
input_c
hannels
,
int
input
_height
,
int
input_width
,
int
filter_h
eight
,
int
filter
_width
,
int
stride_height
,
int
stride_w
idth
,
int
padding
_height
,
int
padding_width
,
int
output
_height
,
int
output_w
idth
)
{
int
sw
i
d
=
blockIdx
.
x
;
int
sh
i
d
=
blockIdx
.
y
;
for
(
int
channel
id
=
threadIdx
.
z
;
channelid
<
input_c
hannels
;
channel
i
d
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filter
_h
eight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filter
_w
idth
;
idx
+=
blockDim
.
x
)
{
int
width
_offset
=
idx
+
swid
*
stride_width
-
padding_w
idth
;
int
height
_offset
=
idy
+
shid
*
stride_height
-
padding_h
eight
;
int
im
_offset
=
width_offset
+
height_offset
*
input_w
idth
+
channelid
*
input_height
*
input_w
idth
;
int
col
Offset
=
idx
+
idy
*
filterW
idth
+
channelId
*
filterHeight
*
filterW
idth
+
(
shId
*
outputWidth
+
swI
d
)
*
(
inputChannels
*
filterHeight
*
filterW
idth
);
int
col
_offset
=
idx
+
idy
*
filter_w
idth
+
channelid
*
filter_height
*
filter_w
idth
+
(
shid
*
output_width
+
swi
d
)
*
(
input_channels
*
filter_height
*
filter_w
idth
);
if
(
height
Offset
>=
inputHeight
||
heightO
ffset
<
0
||
width
Offset
>=
inputWidth
||
widthO
ffset
<
0
)
{
col
Data
[
colO
ffset
]
=
T
(
0
);
if
(
height
_offset
>=
input_height
||
height_o
ffset
<
0
||
width
_offset
>=
input_width
||
width_o
ffset
<
0
)
{
col
_data
[
col_o
ffset
]
=
T
(
0
);
}
else
{
col
Data
[
colOffset
]
=
imData
[
imO
ffset
];
col
_data
[
col_offset
]
=
im_data
[
im_o
ffset
];
}
}
}
...
...
@@ -204,76 +215,79 @@ __global__ void im2colOCF(const T* imData, T* colData, int inputChannels,
}
/*
* im
Shape = [inputChannels, inputHeight, inputW
idth]
* col
Shape
=
* [output
Height, outputWidth, inputChannels, filterHeight, filterW
idth]
* im
= [input_channels, input_height, input_w
idth]
* col =
* [output
_height, output_width, input_channels, filter_height, filter_w
idth]
*/
template
<
class
T
>
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
T
>
{
class
Im2ColFunctor
<
kOCF
,
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
T
*
imData
,
const
TensorShape
&
imShape
,
T
*
colData
,
const
TensorShape
&
colShape
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
filterHeight
=
colShape
[
3
];
int
filterWidth
=
colShape
[
4
];
int
outputHeight
=
colShape
[
0
];
int
outputWidth
=
colShape
[
1
];
void
operator
()(
const
framework
::
Tensor
&
im
,
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
int
input_height
=
im
.
dims
()[
1
];
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
int
block
DimX
=
0
;
int
block
DimY
=
0
;
if
(
filter
Height
<=
4
&&
filterW
idth
<=
4
)
{
block
DimX
=
4
;
block
DimY
=
4
;
}
else
if
(
filter
Height
<=
8
&&
filterW
idth
<=
8
)
{
block
DimX
=
8
;
block
DimY
=
8
;
}
else
if
(
filter
Height
<=
16
&&
filterW
idth
<=
16
)
{
block
DimX
=
16
;
block
DimY
=
16
;
int
block
_dim_x
=
0
;
int
block
_dim_y
=
0
;
if
(
filter
_height
<=
4
&&
filter_w
idth
<=
4
)
{
block
_dim_x
=
4
;
block
_dim_y
=
4
;
}
else
if
(
filter
_height
<=
8
&&
filter_w
idth
<=
8
)
{
block
_dim_x
=
8
;
block
_dim_y
=
8
;
}
else
if
(
filter
_height
<=
16
&&
filter_w
idth
<=
16
)
{
block
_dim_x
=
16
;
block
_dim_y
=
16
;
}
else
{
block
DimX
=
32
;
block
DimY
=
32
;
block
_dim_x
=
32
;
block
_dim_y
=
32
;
}
int
block
DimZ
=
1024
/
blockDimX
/
blockDimY
;
dim3
threads
(
block
DimX
,
blockDimY
,
std
::
min
(
blockDimZ
,
inputChannels
));
dim3
grid
(
outputWidth
,
outputHeight
);
im2colOCF
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
imData
,
colData
,
inputChannels
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingW
idth
,
outputHeight
,
outputWidth
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
int
block
_dim_z
=
1024
/
block_dim_x
/
block_dim_y
;
dim3
threads
(
block
_dim_x
,
block_dim_y
,
std
::
min
(
block_dim_z
,
input_channels
)
);
dim3
grid
(
output_width
,
output_height
);
im2colOCF
<
T
><<<
grid
,
threads
>>>
(
im
.
data
<
T
>
(),
col
.
data
<
T
>
(),
input_channels
,
input_height
,
input_w
idth
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_height
,
output_width
);
}
};
template
<
class
T
>
__global__
void
col2imOCF
(
T
*
im
Data
,
const
T
*
colData
,
int
inputC
hannels
,
int
input
Height
,
int
inputWidth
,
int
filterH
eight
,
int
filter
Width
,
int
strideHeight
,
int
strideW
idth
,
int
padding
Height
,
int
paddingWidth
,
int
outputHeight
,
int
output
W
idth
)
{
int
sw
I
d
=
blockIdx
.
x
;
int
sh
I
d
=
blockIdx
.
y
;
for
(
int
channel
Id
=
threadIdx
.
z
;
channelId
<
inputC
hannels
;
channel
I
d
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filter
H
eight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filter
W
idth
;
idx
+=
blockDim
.
x
)
{
int
width
Offset
=
idx
+
swId
*
strideWidth
-
paddingW
idth
;
int
height
Offset
=
idy
+
shId
*
strideHeight
-
paddingH
eight
;
int
im
Offset
=
widthOffset
+
heightOffset
*
inputW
idth
+
channelId
*
inputHeight
*
inputW
idth
;
__global__
void
col2imOCF
(
T
*
im
_data
,
const
T
*
col_data
,
int
input_c
hannels
,
int
input
_height
,
int
input_width
,
int
filter_h
eight
,
int
filter
_width
,
int
stride_height
,
int
stride_w
idth
,
int
padding
_height
,
int
padding_width
,
int
output
_height
,
int
output_w
idth
)
{
int
sw
i
d
=
blockIdx
.
x
;
int
sh
i
d
=
blockIdx
.
y
;
for
(
int
channel
id
=
threadIdx
.
z
;
channelid
<
input_c
hannels
;
channel
i
d
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filter
_h
eight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filter
_w
idth
;
idx
+=
blockDim
.
x
)
{
int
width
_offset
=
idx
+
swid
*
stride_width
-
padding_w
idth
;
int
height
_offset
=
idy
+
shid
*
stride_height
-
padding_h
eight
;
int
im
_offset
=
width_offset
+
height_offset
*
input_w
idth
+
channelid
*
input_height
*
input_w
idth
;
int
col
Offset
=
idx
+
idy
*
filterW
idth
+
channelId
*
filterHeight
*
filterW
idth
+
(
shId
*
outputWidth
+
swI
d
)
*
(
inputChannels
*
filterHeight
*
filterW
idth
);
int
col
_offset
=
idx
+
idy
*
filter_w
idth
+
channelid
*
filter_height
*
filter_w
idth
+
(
shid
*
output_width
+
swi
d
)
*
(
input_channels
*
filter_height
*
filter_w
idth
);
if
(
heightOffset
>=
0
&&
heightOffset
<
inputHeight
&&
widthOffset
>=
0
&&
widthOffset
<
inputWidth
)
{
paddle
::
paddleAtomicAdd
(
imData
+
imOffset
,
colData
[
colOffset
]);
if
(
height_offset
>=
0
&&
height_offset
<
input_height
&&
width_offset
>=
0
&&
width_offset
<
input_width
)
{
paddle
::
platform
::
CudaAtomicAdd
(
im_data
+
im_offset
,
col_data
[
col_offset
]);
}
}
}
...
...
@@ -281,54 +295,56 @@ __global__ void col2imOCF(T* imData, const T* colData, int inputChannels,
}
/*
* im
Shape = [inputChannels, inputHeight, inputW
idth]
* col
Shape
=
* [output
Height, outputWidth, inputChannels, filterHeight, filterW
idth]
* im
= [input_channels, input_height, input_w
idth]
* col =
* [output
_height, output_width, input_channels, filter_height, filter_w
idth]
*/
template
<
class
T
>
class
Col2ImFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
T
>
{
class
Col2ImFunctor
<
kOCF
,
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
T
*
imData
,
const
TensorShape
&
imShape
,
const
T
*
colData
,
const
TensorShape
&
colShape
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
filterHeight
=
colShape
[
3
];
int
filterWidth
=
colShape
[
4
];
int
outputHeight
=
colShape
[
0
];
int
outputWidth
=
colShape
[
1
];
void
operator
()(
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
col
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
)
{
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
.
dims
().
size
()
==
5
);
int
input_channels
=
im
.
dims
()[
0
];
int
input_height
=
im
.
dims
()[
1
];
int
input_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
0
];
int
output_width
=
col
.
dims
()[
1
];
int
block
DimX
=
0
;
int
block
DimY
=
0
;
if
(
filter
Height
<=
4
&&
filterW
idth
<=
4
)
{
block
DimX
=
4
;
block
DimY
=
4
;
}
else
if
(
filter
Height
<=
8
&&
filterW
idth
<=
8
)
{
block
DimX
=
8
;
block
DimY
=
8
;
}
else
if
(
filter
Height
<=
16
&&
filterW
idth
<=
16
)
{
block
DimX
=
16
;
block
DimY
=
16
;
int
block
_dim_x
=
0
;
int
block
_dim_y
=
0
;
if
(
filter
_height
<=
4
&&
filter_w
idth
<=
4
)
{
block
_dim_x
=
4
;
block
_dim_y
=
4
;
}
else
if
(
filter
_height
<=
8
&&
filter_w
idth
<=
8
)
{
block
_dim_x
=
8
;
block
_dim_y
=
8
;
}
else
if
(
filter
_height
<=
16
&&
filter_w
idth
<=
16
)
{
block
_dim_x
=
16
;
block
_dim_y
=
16
;
}
else
{
block
DimX
=
32
;
block
DimY
=
32
;
block
_dim_x
=
32
;
block
_dim_y
=
32
;
}
int
block
DimZ
=
1024
/
blockDimX
/
blockDimY
;
dim3
threads
(
block
DimX
,
blockDimY
,
std
::
min
(
blockDimZ
,
inputChannels
));
dim3
grid
(
outputWidth
,
outputHeight
);
col2imOCF
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
imData
,
colData
,
inputChannels
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingW
idth
,
outputHeight
,
outputWidth
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
int
block
_dim_z
=
1024
/
block_dim_x
/
block_dim_y
;
dim3
threads
(
block
_dim_x
,
block_dim_y
,
std
::
min
(
block_dim_z
,
input_channels
)
);
dim3
grid
(
output_width
,
output_height
);
col2imOCF
<
T
><<<
grid
,
threads
,
0
>>>
(
im
.
data
<
T
>
(),
col
.
data
<
T
>
(),
input_channels
,
input_height
,
input_w
idth
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_height
,
output_width
);
}
};
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
double
>;
template
class
Col2ImFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Col2ImFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
double
>;
template
class
Im2ColFunctor
<
kOCF
,
platform
::
GPUPlace
,
float
>;
template
class
Im2ColFunctor
<
kOCF
,
platform
::
GPUPlace
,
double
>;
template
class
Col2ImFunctor
<
kOCF
,
platform
::
GPUPlace
,
float
>;
template
class
Col2ImFunctor
<
kOCF
,
platform
::
GPUPlace
,
double
>;
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录