Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6a93f0f3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a93f0f3
编写于
6月 05, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the calculation implementation of GemmConvGradFilterFunction
上级
afbe556e
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
99 addition
and
37 deletion
+99
-37
paddle/function/ConvOp.h
paddle/function/ConvOp.h
+16
-0
paddle/function/GemmConvOp.cpp
paddle/function/GemmConvOp.cpp
+71
-19
paddle/function/GemmFunctor.h
paddle/function/GemmFunctor.h
+12
-18
未找到文件。
paddle/function/ConvOp.h
浏览文件 @
6a93f0f3
...
...
@@ -89,11 +89,13 @@ public:
protected:
std
::
vector
<
size_t
>
strides_
;
std
::
vector
<
size_t
>
paddings_
;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half.
size_t
groups_
;
inline
int
strideH
()
const
{
return
strides_
[
0
];
}
inline
int
strideW
()
const
{
return
strides_
[
1
];
}
...
...
@@ -101,6 +103,20 @@ protected:
inline
int
paddingH
()
const
{
return
paddings_
[
0
];
}
inline
int
paddingW
()
const
{
return
paddings_
[
1
];
}
// A temporary memory in convolution calculation.
MemoryHandlePtr
memory_
;
template
<
DeviceType
Device
>
void
resizeBuffer
(
size_t
newSize
)
{
if
(
!
memory_
||
newSize
*
sizeof
(
real
)
>
memory_
->
getAllocSize
())
{
if
(
Device
==
DEVICE_TYPE_CPU
)
{
memory_
=
std
::
make_shared
<
CpuMemoryHandle
>
(
newSize
*
sizeof
(
real
));
}
else
{
memory_
=
std
::
make_shared
<
GpuMemoryHandle
>
(
newSize
*
sizeof
(
real
));
}
}
}
};
}
// namespace paddle
paddle/function/GemmConvOp.cpp
浏览文件 @
6a93f0f3
...
...
@@ -110,7 +110,7 @@ public:
size_t
size
=
inputChannels
/
groups_
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
resizeBuffer
(
size
);
resizeBuffer
<
Device
>
(
size
);
real
*
colData
=
reinterpret_cast
<
real
*>
(
memory_
->
getBuf
());
Im2ColFunctor
<
Device
,
real
>
im2col
;
...
...
@@ -120,7 +120,7 @@ public:
(
outputChannels
/
groups_
)
*
outputHeight
*
outputWidth
;
size_t
filterOffset
=
inputs
[
1
].
shape
().
getElements
()
/
groups_
;
for
(
size_t
i
=
0
;
i
<
batchSize
;
i
++
)
{
for
(
in
t
g
=
0
;
g
<
groups_
;
g
++
)
{
for
(
size_
t
g
=
0
;
g
<
groups_
;
g
++
)
{
im2col
(
inputData
+
g
*
inputOffset
,
inputChannels
/
groups_
,
inputHeight
,
...
...
@@ -138,7 +138,9 @@ public:
int
M
=
outputChannels
/
groups_
;
int
N
=
outputHeight
*
outputWidth
;
int
K
=
inputChannels
/
groups_
*
filterHeight
*
filterWidth
;
gemm
(
M
,
gemm
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
...
...
@@ -154,19 +156,6 @@ public:
outputData
+=
outputChannels
*
outputHeight
*
outputWidth
;
}
}
void
resizeBuffer
(
size_t
newSize
)
{
if
(
!
memory_
||
newSize
*
sizeof
(
real
)
>
memory_
->
getAllocSize
())
{
if
(
Device
==
DEVICE_TYPE_CPU
)
{
memory_
=
std
::
make_shared
<
CpuMemoryHandle
>
(
newSize
*
sizeof
(
real
));
}
else
{
memory_
=
std
::
make_shared
<
GpuMemoryHandle
>
(
newSize
*
sizeof
(
real
));
}
}
}
private:
MemoryHandlePtr
memory_
;
};
/*
...
...
@@ -202,10 +191,73 @@ public:
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
numInputs_
,
inputs
.
size
());
CHECK_EQ
(
numOutputs_
,
outputs
.
size
());
const
TensorShape
&
outputGrad
=
inputs
[
0
].
shape
();
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ASSIGN_TO
);
const
TensorShape
&
output
=
inputs
[
0
].
shape
();
const
TensorShape
&
input
=
inputs
[
1
].
shape
();
const
TensorShape
&
filterGrad
=
outputs
[
0
].
shape
();
check
(
input
,
filterGrad
,
outputGrad
);
const
TensorShape
&
filter
=
outputs
[
0
].
shape
();
check
(
input
,
filter
,
output
);
size_t
batchSize
=
input
[
0
];
size_t
inputChannels
=
input
[
1
];
size_t
inputHeight
=
input
[
2
];
size_t
inputWidth
=
input
[
3
];
size_t
filterHeight
=
filter
[
2
];
size_t
filterWidth
=
filter
[
3
];
size_t
outputChannels
=
output
[
1
];
size_t
outputHeight
=
output
[
2
];
size_t
outputWidth
=
output
[
3
];
real
*
outputGrad
=
inputs
[
0
].
data
<
real
>
();
real
*
inputData
=
inputs
[
1
].
data
<
real
>
();
real
*
filterGrad
=
outputs
[
0
].
data
<
real
>
();
size_t
size
=
inputChannels
/
groups_
*
filterHeight
*
filterWidth
*
outputHeight
*
outputWidth
;
resizeBuffer
<
Device
>
(
size
);
real
*
colData
=
reinterpret_cast
<
real
*>
(
memory_
->
getBuf
());
Im2ColFunctor
<
Device
,
real
>
im2col
;
GemmFunctor
<
Device
,
real
>
gemm
;
size_t
inputOffset
=
(
inputChannels
/
groups_
)
*
inputHeight
*
inputWidth
;
size_t
outputOffset
=
(
outputChannels
/
groups_
)
*
outputHeight
*
outputWidth
;
size_t
filterOffset
=
filter
.
getElements
()
/
groups_
;
for
(
size_t
i
=
0
;
i
<
batchSize
;
i
++
)
{
for
(
size_t
g
=
0
;
g
<
groups_
;
g
++
)
{
im2col
(
inputData
+
g
*
inputOffset
,
inputChannels
/
groups_
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideH
(),
strideW
(),
paddingH
(),
paddingW
(),
outputHeight
,
outputWidth
,
colData
);
int
M
=
outputChannels
/
groups_
;
int
K
=
outputHeight
*
outputWidth
;
int
N
=
inputChannels
/
groups_
*
filterHeight
*
filterWidth
;
gemm
(
CblasNoTrans
,
CblasTrans
,
M
,
N
,
K
,
1.0
f
,
outputGrad
+
g
*
outputOffset
,
K
,
colData
,
K
,
1.0
f
,
filterGrad
+
g
*
filterOffset
,
N
);
}
}
inputData
+=
inputChannels
*
inputHeight
*
inputWidth
;
outputGrad
+=
outputChannels
*
outputHeight
*
outputWidth
;
}
};
...
...
paddle/function/GemmFunctor.h
浏览文件 @
6a93f0f3
...
...
@@ -26,7 +26,9 @@ namespace paddle {
template
<
DeviceType
Device
,
class
T
>
class
GemmFunctor
{
public:
void
operator
()(
const
int
M
,
void
operator
()(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
...
...
@@ -42,7 +44,9 @@ public:
template
<
class
T
>
class
GemmFunctor
<
DEVICE_TYPE_CPU
,
T
>
{
public:
void
operator
()(
const
int
M
,
void
operator
()(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
...
...
@@ -53,26 +57,16 @@ public:
const
T
beta
,
T
*
C
,
const
int
ldc
)
{
gemm
<
T
>
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
gemm
<
T
>
(
transA
,
TransB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
};
template
<
class
T
>
class
GemmFunctor
<
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
int
M
,
void
operator
()(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
...
...
@@ -84,9 +78,9 @@ public:
T
*
C
,
const
int
ldc
)
{
hl_matrix_mul
((
T
*
)
A
,
HPPL_OP_N
,
transA
==
CblasNoTrans
?
HPPL_OP_N
:
HPPL_OP_T
,
(
T
*
)
B
,
HPPL_OP_N
,
TransB
==
CblasNoTrans
?
HPPL_OP_N
:
HPPL_OP_T
,
C
,
M
,
N
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录