Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3c0aa0cc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3c0aa0cc
编写于
6月 02, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add GPU GemmConvFunction implementation
上级
3ce974b9
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
274 addition
and
27 deletion
+274
-27
paddle/function/ConvOp.h
paddle/function/ConvOp.h
+2
-0
paddle/function/ConvOpTest.cpp
paddle/function/ConvOpTest.cpp
+16
-10
paddle/function/GemmConvOp.cpp
paddle/function/GemmConvOp.cpp
+17
-17
paddle/function/GemmConvOp.h
paddle/function/GemmConvOp.h
+44
-0
paddle/function/GemmConvOpGpu.cu
paddle/function/GemmConvOpGpu.cu
+93
-0
paddle/function/GemmFunctor.h
paddle/function/GemmFunctor.h
+102
-0
未找到文件。
paddle/function/ConvOp.h
浏览文件 @
3c0aa0cc
...
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace
paddle
{
...
...
paddle/function/ConvOpTest.cpp
浏览文件 @
3c0aa0cc
...
...
@@ -19,8 +19,7 @@ limitations under the License. */
namespace
paddle
{
typedef
Compare2Function
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_CPU
>
Compare2CpuFunction
;
template
<
DeviceType
DType1
,
DeviceType
DType2
>
class
ConvolutionTest
{
public:
ConvolutionTest
(
const
std
::
string
&
conv1
,
...
...
@@ -50,13 +49,14 @@ public:
std
::
vector
<
size_t
>
paddings
=
{
padding
,
padding
};
std
::
vector
<
size_t
>
strides
=
{
stride
,
stride
};
Compare2CpuFunction
test
(
conv1
,
conv2
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"algo"
,
algo
));
Compare2Function
<
DType1
,
DType2
>
test
(
conv1
,
conv2
,
FuncConfig
()
.
set
(
"paddings"
,
paddings
)
.
set
(
"strides"
,
strides
)
.
set
(
"groups"
,
(
size_t
)
1
)
.
set
(
"algo"
,
algo
));
TensorShape
shape0
{
batchSize
,
inputChannels
,
inputSize
,
inputSize
};
...
...
@@ -79,7 +79,13 @@ public:
};
TEST
(
Convolution
,
GEMM
)
{
ConvolutionTest
test
(
"NaiveConv-CPU"
,
"GemmConv-CPU"
);
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_CPU
>
test
(
"NaiveConv-CPU"
,
"GemmConv-CPU"
);
}
TEST
(
Convolution
,
GEMM2
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConv-CPU"
,
"GemmConv-GPU"
);
}
}
// namespace paddle
paddle/function/GemmConvOp.cpp
浏览文件 @
3c0aa0cc
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "
paddle/math/MathFunctions
.h"
#include "
Gemm
ConvOp.h"
#include "
GemmFunctor
.h"
#include "paddle/math/MemoryHandle.h"
namespace
paddle
{
...
...
@@ -24,7 +24,7 @@ namespace paddle {
* output_height, output_width]
*/
template
<
class
T
>
class
Im2ColFunctor
{
class
Im2ColFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
T
*
imData
,
int
inputChannels
,
...
...
@@ -112,7 +112,8 @@ public:
resizeBuffer
(
size
);
real
*
colData
=
reinterpret_cast
<
real
*>
(
memory_
->
getBuf
());
Im2ColFunctor
<
real
>
im2col
;
Im2ColFunctor
<
Device
,
real
>
im2col
;
GemmFunctor
<
Device
,
real
>
gemm
;
size_t
inputOffset
=
(
inputChannels
/
groups_
)
*
inputHeight
*
inputWidth
;
size_t
outputOffset
=
(
outputChannels
/
groups_
)
*
outputHeight
*
outputWidth
;
...
...
@@ -136,19 +137,17 @@ public:
int
M
=
outputChannels
;
int
N
=
outputHeight
*
outputWidth
;
int
K
=
inputChannels
*
filterHeight
*
filterWidth
;
gemm
<
real
>
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
filterData
+
g
*
filterOffset
,
K
,
colData
,
N
,
0.0
f
,
outputData
+
g
*
outputOffset
,
N
);
gemm
(
M
,
N
,
K
,
1.0
f
,
filterData
+
g
*
filterOffset
,
K
,
colData
,
N
,
0.0
f
,
outputData
+
g
*
outputOffset
,
N
);
inputData
+=
inputChannels
*
inputHeight
*
inputWidth
;
outputData
+=
outputChannels
*
outputHeight
*
outputWidth
;
}
...
...
@@ -166,5 +165,6 @@ private:
};
REGISTER_TYPED_FUNC
(
GemmConv
,
CPU
,
GemmConvFunction
);
REGISTER_TYPED_FUNC
(
GemmConv
,
GPU
,
GemmConvFunction
);
}
// namespace paddle
paddle/function/GemmConvOp.h
0 → 100644
浏览文件 @
3c0aa0cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvOp.h"
namespace
paddle
{
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template
<
DeviceType
Device
,
class
T
>
class
Im2ColFunctor
{
public:
void
operator
()(
const
T
*
imData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
,
T
*
colData
);
};
}
// namespace paddle
paddle/function/GemmConvOpGpu.cu
0 → 100644
浏览文件 @
3c0aa0cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "GemmConvOp.h"
namespace
paddle
{
template
<
class
T
>
__global__
void
im2col
(
const
T
*
data_im
,
int
numOuts
,
int
height
,
int
width
,
int
blockH
,
int
blockW
,
int
strideH
,
int
strideW
,
int
paddingH
,
int
paddingW
,
int
height_col
,
int
width_col
,
T
*
data_col
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
numOuts
)
{
int
w_out
=
index
%
width_col
;
index
/=
width_col
;
int
h_out
=
index
%
height_col
;
int
channel_in
=
index
/
height_col
;
int
channel_out
=
channel_in
*
blockH
*
blockW
;
int
h_in
=
h_out
*
strideH
;
int
w_in
=
w_out
*
strideW
;
data_col
+=
(
channel_out
*
height_col
+
h_out
)
*
width_col
+
w_out
;
for
(
int
i
=
0
;
i
<
blockH
;
++
i
)
{
for
(
int
j
=
0
;
j
<
blockW
;
++
j
)
{
int
rIdx
=
int
(
h_in
+
i
);
int
cIdx
=
int
(
w_in
+
j
);
if
((
rIdx
-
(
int
)
paddingH
)
>=
(
int
)
height
||
(
rIdx
-
(
int
)
paddingH
)
<
0
||
(
cIdx
-
(
int
)
paddingW
)
>=
(
int
)
width
||
(
cIdx
-
(
int
)
paddingW
)
<
0
)
{
*
data_col
=
0
;
}
else
{
rIdx
=
rIdx
+
channel_in
*
height
-
paddingH
;
cIdx
=
cIdx
-
paddingW
;
*
data_col
=
data_im
[
rIdx
*
width
+
cIdx
];
}
data_col
+=
height_col
*
width_col
;
}
}
}
}
template
<
class
T
>
class
Im2ColFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
T
*
imData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
,
T
*
colData
)
{
int
numKernels
=
inputChannels
*
outputHeight
*
outputWidth
;
int
blocks
=
(
numKernels
+
1024
-
1
)
/
1024
;
int
blockX
=
512
;
int
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
im2col
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
imData
,
numKernels
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingWidth
,
outputHeight
,
outputWidth
,
colData
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
}
};
template
class
Im2ColFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
DEVICE_TYPE_GPU
,
double
>;
}
// namespace paddle
paddle/function/GemmFunctor.h
0 → 100644
浏览文件 @
3c0aa0cc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/math/MathFunctions.h"
namespace
paddle
{
// TODO(hedaoyuan): Since the hl_matrix_mul interface does not conform to the
// cblas_dgemm interface's parameter format, it is necessary to introduce
// GemmFunctor as a new interface. Later, when considering the implementation
// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul
// interface.
template
<
DeviceType
Device
,
class
T
>
class
GemmFunctor
{
public:
void
operator
()(
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
);
};
template
<
class
T
>
class
GemmFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
{
gemm
<
T
>
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
};
template
<
class
T
>
class
GemmFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
{
hl_matrix_mul
((
T
*
)
A
,
HPPL_OP_N
,
(
T
*
)
B
,
HPPL_OP_N
,
C
,
M
,
N
,
K
,
alpha
,
beta
,
lda
,
ldb
,
ldc
);
}
};
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录