Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
40fe0a8c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
40fe0a8c
编写于
9月 12, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add backward of convolution.
上级
c9d8cb4e
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
146 addition
and
21 deletion
+146
-21
paddle/operators/conv_op.cc
paddle/operators/conv_op.cc
+15
-9
paddle/operators/gemm_conv_op.h
paddle/operators/gemm_conv_op.h
+93
-12
python/paddle/v2/framework/tests/test_conv2d_op.py
python/paddle/v2/framework/tests/test_conv2d_op.py
+38
-0
未找到文件。
paddle/operators/conv_op.cc
浏览文件 @
40fe0a8c
...
@@ -28,9 +28,9 @@ class Conv2DOp : public framework::OperatorWithKernel {
...
@@ -28,9 +28,9 @@ class Conv2DOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
);
auto
in
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
filter
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Filter"
);
auto
filter
=
ctx
.
Input
<
Tensor
>
(
"Filter"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Output"
);
auto
out
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
PADDLE_ENFORCE_EQ
(
in
->
dims
().
size
(),
4
,
"Conv2DOp intput should be 4-D."
);
PADDLE_ENFORCE_EQ
(
in
->
dims
().
size
(),
4
,
"Conv2DOp intput should be 4-D."
);
PADDLE_ENFORCE_EQ
(
filter
->
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
filter
->
dims
().
size
(),
4
,
"Conv2DOp filter should be 4-D."
);
"Conv2DOp filter should be 4-D."
);
...
@@ -46,10 +46,9 @@ class Conv2DOp : public framework::OperatorWithKernel {
...
@@ -46,10 +46,9 @@ class Conv2DOp : public framework::OperatorWithKernel {
}
}
};
};
class
Conv2DOp
p
Maker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Conv2DOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
Conv2DOppMaker
(
framework
::
OpProto
*
proto
,
Conv2DOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
AddInput
(
"Input"
,
"Input"
,
...
@@ -62,7 +61,7 @@ class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -62,7 +61,7 @@ class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker {
"The format of the filter tensor is MCHW, where M is the number of "
"The format of the filter tensor is MCHW, where M is the number of "
"output "
"output "
"image channels, C is the number of input image channels, H and W is "
"image channels, C is the number of input image channels, H and W is "
"
height and width of filter."
);
"height and width of filter."
);
AddOutput
(
"Output"
,
AddOutput
(
"Output"
,
"The output tensor of convolution operator."
"The output tensor of convolution operator."
"The format of output tensor is also NCHW."
);
"The format of output tensor is also NCHW."
);
...
@@ -80,14 +79,21 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
...
@@ -80,14 +79,21 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
in
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
filter
=
ctx
.
Input
<
Tensor
>
(
"Filter"
);
auto
d_in
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
d_filter
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter"
));
d_in
->
Resize
(
in
->
dims
());
d_filter
->
Resize
(
filter
->
dims
());
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
conv2d
,
ops
::
Conv2DOp
,
ops
::
Conv2DOp
p
Maker
,
conv2d_grad
,
REGISTER_OP
(
conv2d
,
ops
::
Conv2DOp
,
ops
::
Conv2DOpMaker
,
conv2d_grad
,
ops
::
Conv2DOpGrad
);
ops
::
Conv2DOpGrad
);
REGISTER_OP_CPU_KERNEL
(
conv2d
,
REGISTER_OP_CPU_KERNEL
(
conv2d
,
...
...
paddle/operators/gemm_conv_op.h
浏览文件 @
40fe0a8c
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/math_function.h"
...
@@ -31,12 +32,10 @@ class GemmConvKernel : public framework::OpKernel {
...
@@ -31,12 +32,10 @@ class GemmConvKernel : public framework::OpKernel {
Tensor
*
filter
=
const_cast
<
Tensor
*>
(
context
.
Input
<
Tensor
>
(
"Filter"
));
Tensor
*
filter
=
const_cast
<
Tensor
*>
(
context
.
Input
<
Tensor
>
(
"Filter"
));
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
paddle
::
framework
::
Tensor
col
;
paddle
::
framework
::
Tensor
in_slice
;
paddle
::
framework
::
Tensor
out_slice
;
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
auto
filter_dims
=
filter
->
dims
();
int
batch_size
=
input
->
dims
()[
0
];
int
batch_size
=
input
->
dims
()[
0
];
int
input_channels
=
input
->
dims
()[
1
];
int
input_channels
=
input
->
dims
()[
1
];
...
@@ -50,6 +49,7 @@ class GemmConvKernel : public framework::OpKernel {
...
@@ -50,6 +49,7 @@ class GemmConvKernel : public framework::OpKernel {
im2col
;
im2col
;
framework
::
DDim
col_shape
=
{
input_channels
,
filter_height
,
filter_width
,
framework
::
DDim
col_shape
=
{
input_channels
,
filter_height
,
filter_width
,
output_height
,
output_width
};
output_height
,
output_width
};
Tensor
col
;
col
.
mutable_data
<
float
>
(
col_shape
,
context
.
GetPlace
());
col
.
mutable_data
<
float
>
(
col_shape
,
context
.
GetPlace
());
auto
*
device_context
=
auto
*
device_context
=
...
@@ -67,22 +67,23 @@ class GemmConvKernel : public framework::OpKernel {
...
@@ -67,22 +67,23 @@ class GemmConvKernel : public framework::OpKernel {
output
->
dims
()[
1
],
output
->
dims
()[
2
]
*
output
->
dims
()[
3
]};
output
->
dims
()[
1
],
output
->
dims
()[
2
]
*
output
->
dims
()[
3
]};
filter
->
Resize
(
filter_matrix_shape
);
filter
->
Resize
(
filter_matrix_shape
);
// convolution op
p
erator: im2col + gemm
// convolution operator: im2col + gemm
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// im2col
// im2col
in_slice
=
input
->
Slice
<
T
>
(
i
,
i
+
1
);
Tensor
in_slice
=
input
->
Slice
<
T
>
(
i
,
i
+
1
);
in_slice
.
Resize
(
input_shape
);
in_slice
.
Resize
(
input_shape
);
col
.
Resize
(
col_shape
);
col
.
Resize
(
col_shape
);
im2col
(
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
],
im2col
(
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
],
device_context
);
device_context
);
// gemm
// gemm
out_slice
=
output
->
Slice
<
T
>
(
i
,
i
+
1
);
Tensor
out_slice
=
output
->
Slice
<
T
>
(
i
,
i
+
1
);
out_slice
.
Resize
(
output_matrix_shape
);
out_slice
.
Resize
(
output_matrix_shape
);
col
.
Resize
(
col_matrix_shape
);
col
.
Resize
(
col_matrix_shape
);
math
::
matmul
<
Place
,
T
>
(
*
filter
,
false
,
col
,
false
,
T
(
1.0
),
&
out_slice
,
math
::
matmul
<
Place
,
T
>
(
*
filter
,
false
,
col
,
false
,
T
(
1.0
),
&
out_slice
,
T
(
0.0
),
device_context
);
T
(
0.0
),
device_context
);
}
}
filter
->
Resize
(
filter_dims
);
}
}
};
};
...
@@ -90,12 +91,92 @@ template <typename Place, typename T>
...
@@ -90,12 +91,92 @@ template <typename Place, typename T>
class
GemmConvGradKernel
:
public
framework
::
OpKernel
{
class
GemmConvGradKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
#if 0
const
Tensor
*
input
=
context
.
Input
<
Tensor
>
(
"Input"
);
auto input = context.Input<Tensor>("Input");
Tensor
*
filter
=
const_cast
<
Tensor
*>
(
context
.
Input
<
Tensor
>
(
"Filter"
));
auto filter = context.Input<Tensor>("Filter");
const
Tensor
*
output_grad
=
auto output = context.Output<Tensor>("Output");
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
output->mutable_data<T>(context.GetPlace());
Tensor
*
input_grad
=
#endif
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
Tensor
*
filter_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter"
));
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
filter_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
auto
filter_dims
=
filter
->
dims
();
int
batch_size
=
input
->
dims
()[
0
];
int
input_channels
=
input
->
dims
()[
1
];
int
filter_height
=
filter
->
dims
()[
filter
->
dims
().
size
()
-
2
];
int
filter_width
=
filter
->
dims
()[
filter
->
dims
().
size
()
-
1
];
int
output_height
=
output_grad
->
dims
()[
2
];
int
output_width
=
output_grad
->
dims
()[
3
];
paddle
::
operators
::
math
::
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
Tensor
col
;
framework
::
DDim
col_shape
=
{
input_channels
,
filter_height
,
filter_width
,
output_height
,
output_width
};
col
.
mutable_data
<
float
>
(
col_shape
,
context
.
GetPlace
());
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
framework
::
DDim
input_shape
=
{
input
->
dims
()[
1
],
input
->
dims
()[
2
],
input
->
dims
()[
3
]};
framework
::
DDim
filter_matrix_shape
=
{
filter
->
dims
()[
0
],
filter
->
dims
()[
1
]
*
filter
->
dims
()[
2
]
*
filter
->
dims
()[
3
]};
framework
::
DDim
col_matrix_shape
=
{
input_channels
*
filter_height
*
filter_width
,
output_height
*
output_width
};
framework
::
DDim
output_matrix_shape
=
{
output_grad
->
dims
()[
1
],
output_grad
->
dims
()[
2
]
*
output_grad
->
dims
()[
3
]};
filter
->
Resize
(
filter_matrix_shape
);
filter_grad
->
Resize
(
filter_matrix_shape
);
auto
t1
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
filter_grad
);
t1
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t1
.
constant
(
static_cast
<
T
>
(
0
));
auto
t2
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
input_grad
);
t2
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t2
.
constant
(
static_cast
<
T
>
(
0
));
// convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// gemm
Tensor
out_slice
=
output_grad
->
Slice
<
T
>
(
i
,
i
+
1
);
out_slice
.
Resize
(
output_matrix_shape
);
col
.
Resize
(
col_matrix_shape
);
math
::
matmul
<
Place
,
T
>
(
*
filter
,
true
,
out_slice
,
false
,
T
(
1.0
),
&
col
,
T
(
0.0
),
device_context
);
// col2im
Tensor
in_grad_slice
=
input_grad
->
Slice
<
T
>
(
i
,
i
+
1
);
in_grad_slice
.
Resize
(
input_shape
);
col
.
Resize
(
col_shape
);
col2im
(
in_grad_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
],
device_context
);
// im2col
Tensor
in_slice
=
input
->
Slice
<
T
>
(
i
,
i
+
1
);
in_slice
.
Resize
(
input_shape
);
col
.
Resize
(
col_shape
);
im2col
(
in_slice
,
col
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
],
device_context
);
// gemm
col
.
Resize
(
col_matrix_shape
);
math
::
matmul
<
Place
,
T
>
(
out_slice
,
false
,
col
,
true
,
T
(
1.0
),
filter_grad
,
T
(
1.0
),
device_context
);
}
filter
->
Resize
(
filter_dims
);
filter_grad
->
Resize
(
filter_dims
);
}
}
};
};
...
...
python/paddle/v2/framework/tests/test_conv2d_op.py
浏览文件 @
40fe0a8c
...
@@ -2,6 +2,7 @@ import unittest
...
@@ -2,6 +2,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
from
gradient_checker
import
GradientChecker
,
create_op
from
gradient_checker
import
GradientChecker
,
create_op
from
op_test_util
import
OpTestMeta
from
op_test_util
import
OpTestMeta
from
paddle.v2.framework.op
import
Operator
class
TestConv2dOp
(
unittest
.
TestCase
):
class
TestConv2dOp
(
unittest
.
TestCase
):
...
@@ -58,5 +59,42 @@ class TestConv2dOp(unittest.TestCase):
...
@@ -58,5 +59,42 @@ class TestConv2dOp(unittest.TestCase):
self
.
attrs
=
{
'strides'
:
[
1
,
1
],
'paddings'
:
[
0
,
0
]}
self
.
attrs
=
{
'strides'
:
[
1
,
1
],
'paddings'
:
[
0
,
0
]}
class
TestConv2dGradOp
(
GradientChecker
):
def
setUp
(
self
):
batch_size
=
2
input_channels
=
3
input_height
=
5
input_width
=
5
output_channels
=
6
filter_height
=
3
filter_width
=
3
stride
=
1
padding
=
0
output_height
=
(
input_height
-
filter_height
+
2
*
padding
)
/
stride
+
1
output_width
=
(
input_width
-
filter_width
+
2
*
padding
)
/
stride
+
1
input
=
np
.
random
.
random
((
batch_size
,
input_channels
,
input_height
,
input_width
)).
astype
(
"float32"
)
filter
=
np
.
random
.
random
(
(
output_channels
,
input_channels
,
filter_height
,
filter_width
)).
astype
(
"float32"
)
self
.
inputs
=
{
'Input'
:
input
,
'Filter'
:
filter
}
self
.
op
=
Operator
(
"conv2d"
,
Input
=
'Input'
,
Filter
=
'Filter'
,
Output
=
'Output'
,
strides
=
[
1
,
1
],
paddings
=
[
0
,
0
])
def
test_compare_grad
(
self
):
self
.
compare_grad
(
self
.
op
,
self
.
inputs
)
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
op
,
self
.
inputs
,
set
([
'Input'
,
'Filter'
]),
'Output'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录