Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d99faf31
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d99faf31
编写于
6月 06, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the calculation implementation of GemmConvGradInputFunction.
上级
90326198
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
242 addition
and
18 deletion
+242
-18
paddle/function/ConvOpTest.cpp
paddle/function/ConvOpTest.cpp
+5
-2
paddle/function/GemmConvOp.cpp
paddle/function/GemmConvOp.cpp
+126
-16
paddle/function/GemmConvOp.h
paddle/function/GemmConvOp.h
+18
-0
paddle/function/GemmConvOpGpu.cu
paddle/function/GemmConvOpGpu.cu
+93
-0
未找到文件。
paddle/function/ConvOpTest.cpp
浏览文件 @
d99faf31
...
...
@@ -78,12 +78,10 @@ public:
test
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
output
));
test
.
run
();
}
else
if
(
type
==
BACKWARD_INPUT_TEST
)
{
#if 0
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
output
));
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
filter
));
test
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
input
));
test
.
run
();
#endif
}
else
if
(
type
==
BACKWARD_FILTER_TEST
)
{
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
output
));
test
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
input
));
...
...
@@ -111,6 +109,11 @@ TEST(Forward, GEMM2) {
"GemmConv-CPU"
,
"GemmConv-GPU"
,
FORWARD_TEST
);
}
TEST
(
BackwardInput
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConvGradInput-CPU"
,
"GemmConvGradInput-GPU"
,
BACKWARD_INPUT_TEST
);
}
TEST
(
BackwardFilter
,
GEMM
)
{
ConvolutionTest
<
DEVICE_TYPE_CPU
,
DEVICE_TYPE_GPU
>
test
(
"GemmConvGradFilter-CPU"
,
"GemmConvGradFilter-GPU"
,
BACKWARD_FILTER_TEST
);
...
...
paddle/function/GemmConvOp.cpp
浏览文件 @
d99faf31
...
...
@@ -44,22 +44,62 @@ public:
for
(
int
c
=
0
;
c
<
channelsCol
;
++
c
)
{
int
wOffset
=
c
%
filterWidth
;
int
hOffset
=
(
c
/
filterWidth
)
%
filterHeight
;
int
c_im
=
c
/
filter
Height
/
filterWidth
;
int
c_im
=
c
/
filter
Width
/
filterHeight
;
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
// no c_im*height to Exclude the channel number
int
imgRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imgColIdx
=
w
*
strideWidth
+
wOffset
;
if
((
imgRowIdx
-
paddingHeight
)
<
0
||
(
imgRowIdx
-
paddingHeight
)
>=
inputHeight
||
(
imgColIdx
-
paddingWidth
)
<
0
||
(
imgColIdx
-
paddingWidth
)
>=
inputWidth
)
{
int
imRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
;
if
((
imRowIdx
-
paddingHeight
)
<
0
||
(
imRowIdx
-
paddingHeight
)
>=
inputHeight
||
(
imColIdx
-
paddingWidth
)
<
0
||
(
imColIdx
-
paddingWidth
)
>=
inputWidth
)
{
colData
[(
c
*
outputHeight
+
h
)
*
outputWidth
+
w
]
=
T
(
0
);
}
else
{
im
g
RowIdx
+=
c_im
*
inputHeight
-
paddingHeight
;
im
g
ColIdx
-=
paddingWidth
;
imRowIdx
+=
c_im
*
inputHeight
-
paddingHeight
;
imColIdx
-=
paddingWidth
;
colData
[(
c
*
outputHeight
+
h
)
*
outputWidth
+
w
]
=
imData
[
imgRowIdx
*
inputWidth
+
imgColIdx
];
imData
[
imRowIdx
*
inputWidth
+
imColIdx
];
}
}
}
}
}
};
template
<
class
T
>
class
Col2ImFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
T
*
colData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
,
T
*
imData
)
{
int
channelsCol
=
inputChannels
*
filterHeight
*
filterWidth
;
for
(
int
c
=
0
;
c
<
channelsCol
;
++
c
)
{
int
wOffset
=
c
%
filterWidth
;
int
hOffset
=
(
c
/
filterWidth
)
%
filterHeight
;
int
c_im
=
c
/
filterWidth
/
filterHeight
;
for
(
int
h
=
0
;
h
<
outputHeight
;
++
h
)
{
for
(
int
w
=
0
;
w
<
outputWidth
;
++
w
)
{
int
imRowIdx
=
h
*
strideHeight
+
hOffset
;
int
imColIdx
=
w
*
strideWidth
+
wOffset
;
if
((
imRowIdx
-
paddingHeight
)
>=
0
&&
(
imRowIdx
-
paddingHeight
)
<
inputHeight
&&
(
imColIdx
-
paddingWidth
)
>=
0
&&
(
imColIdx
-
paddingWidth
)
<
inputWidth
)
{
imRowIdx
+=
c_im
*
inputHeight
-
paddingHeight
;
imColIdx
-=
paddingWidth
;
imData
[
imRowIdx
*
inputWidth
+
imColIdx
]
+=
colData
[(
c
*
outputHeight
+
h
)
*
outputWidth
+
w
];
}
}
}
...
...
@@ -171,10 +211,74 @@ public:
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
const
TensorShape
&
outputGrad
=
inputs
[
0
].
shape
();
// CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
filter
=
inputs
[
1
].
shape
();
const
TensorShape
&
inputGrad
=
outputs
[
0
].
shape
();
check
(
inputGrad
,
filter
,
outputGrad
);
const
TensorShape
&
input
=
outputs
[
0
].
shape
();
check
(
input
,
filter
,
output
);
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
filter
[
2
];
size_t
filterWidth
=
filter
[
3
];
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
real
*
outputGrad
=
inputs
[
0
].
data
<
real
>
();
real
*
filterData
=
inputs
[
1
].
data
<
real
>
();
real
*
inputGrad
=
outputs
[
0
].
data
<
real
>
();
size_t
size
=
inputChannels
/
groups_
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
resizeBuffer
<
Device
>
(
size
);
real
*
colData
=
reinterpret_cast
<
real
*>
(
memory_
->
getBuf
());
Col2ImFunctor
<
Device
,
real
>
col2im
;
GemmFunctor
<
Device
,
real
>
gemm
;
size_t
inputOffset
=
(
inputChannels
/
groups_
)
*
inputHeight
*
inputWidth
;
size_t
outputOffset
=
(
outputChannels
/
groups_
)
*
outputHeight
*
outputWidth
;
size_t
filterOffset
=
filter
.
getElements
()
/
groups_
;
for
(
size_t
i
=
0
;
i
<
batchSize
;
i
++
)
{
for
(
size_t
g
=
0
;
g
<
groups_
;
g
++
)
{
int
K
=
outputChannels
/
groups_
;
int
N
=
outputHeight
*
outputWidth
;
int
M
=
inputChannels
/
groups_
*
filterHeight
*
filterWidth
;
gemm
(
CblasTrans
,
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
filterData
+
g
*
filterOffset
,
M
,
outputGrad
+
g
*
outputOffset
,
N
,
0.0
f
,
colData
,
N
);
col2im
(
colData
,
inputChannels
/
groups_
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
outputHeight
,
outputWidth
,
inputGrad
+
g
*
inputOffset
);
}
inputGrad
+=
inputChannels
*
inputHeight
*
inputWidth
;
outputGrad
+=
outputChannels
*
outputHeight
*
outputWidth
;
}
}
};
...
...
@@ -191,12 +295,18 @@ public:
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ASSIGN_TO
);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
check
(
input
,
filter
,
output
);
real
beta
;
if
(
outputs
[
0
].
getArgType
()
==
ADD_TO
)
{
beta
=
1.0
;
}
else
{
beta
=
0.0
;
}
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
...
...
@@ -251,7 +361,7 @@ public:
K
,
colData
,
K
,
1.0
f
,
i
==
0
?
beta
:
1.0
f
,
filterGrad
+
g
*
filterOffset
,
N
);
}
...
...
paddle/function/GemmConvOp.h
浏览文件 @
d99faf31
...
...
@@ -41,4 +41,22 @@ public:
T
*
colData
);
};
template
<
DeviceType
Device
,
class
T
>
class
Col2ImFunctor
{
public:
void
operator
()(
const
T
*
colData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
,
T
*
imData
);
};
}
// namespace paddle
paddle/function/GemmConvOpGpu.cu
浏览文件 @
d99faf31
...
...
@@ -87,7 +87,100 @@ public:
}
};
template
<
class
T
>
__global__
void
col2im
(
size_t
n
,
const
T
*
data_col
,
size_t
height
,
size_t
width
,
size_t
channels
,
size_t
blockH
,
size_t
blockW
,
size_t
strideH
,
size_t
strideW
,
size_t
paddingH
,
size_t
paddingW
,
size_t
height_col
,
size_t
width_col
,
T
*
data_im
)
{
size_t
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
n
)
{
T
val
=
0
;
int
w
=
int
(
index
%
width
);
int
h
=
int
((
index
/
width
)
%
height
);
int
c
=
int
(
index
/
(
width
*
height
));
if
((
w
-
(
int
)
paddingW
)
>=
0
&&
(
w
-
(
int
)
paddingW
)
<
(
width
-
2
*
paddingW
)
&&
(
h
-
(
int
)
paddingH
)
>=
0
&&
(
h
-
paddingH
)
<
(
height
-
2
*
paddingH
))
{
// compute the start and end of the output
int
w_col_start
=
(
w
<
(
int
)
blockW
)
?
0
:
(
w
-
int
(
blockW
))
/
(
int
)
strideW
+
1
;
int
w_col_end
=
min
((
int
)(
w
/
(
int
)
strideW
+
1
),
(
int
)(
width_col
));
int
h_col_start
=
(
h
<
(
int
)
blockH
)
?
0
:
(
h
-
(
int
)
blockH
)
/
(
int
)
strideH
+
1
;
int
h_col_end
=
min
(
int
(
h
/
strideH
+
1
),
int
(
height_col
));
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
// the col location: [c * width * height + h_out, w_out]
int
c_col
=
int
(
c
*
blockH
*
blockW
)
+
\
(
h
-
h_col
*
(
int
)
strideH
)
*
(
int
)
blockW
+
(
w
-
w_col
*
(
int
)
strideW
);
val
+=
data_col
[(
c_col
*
height_col
+
h_col
)
*
width_col
+
w_col
];
}
}
h
-=
paddingH
;
w
-=
paddingW
;
data_im
[
c
*
((
width
-
2
*
paddingW
)
*
(
height
-
2
*
paddingH
))
+
h
*
(
width
-
2
*
paddingW
)
+
w
]
+=
val
;
}
}
}
template
<
class
T
>
class
Col2ImFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
T
*
colData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
,
T
*
imData
)
{
size_t
numKernels
=
inputChannels
*
(
inputHeight
+
2
*
paddingHeight
)
*
(
inputWidth
+
2
*
paddingWidth
);
size_t
blocks
=
(
numKernels
+
1024
-
1
)
/
1024
;
size_t
blockX
=
512
;
size_t
blockY
=
(
blocks
+
512
-
1
)
/
512
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blockX
,
blockY
);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
numKernels
,
colData
,
inputHeight
+
2
*
paddingHeight
,
inputWidth
+
2
*
paddingWidth
,
inputChannels
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingWidth
,
outputHeight
,
outputWidth
,
imData
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
}
};
template
class
Im2ColFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
DEVICE_TYPE_GPU
,
double
>;
template
class
Col2ImFunctor
<
DEVICE_TYPE_GPU
,
float
>;
template
class
Col2ImFunctor
<
DEVICE_TYPE_GPU
,
double
>;
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录