提交 40fe0a8c 编写于 作者: H hedaoyuan

Add backward of convolution.

上级 c9d8cb4e
......@@ -28,9 +28,9 @@ class Conv2DOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *in = ctx.Input<framework::Tensor>("Input");
auto *filter = ctx.Input<framework::Tensor>("Filter");
auto *out = ctx.Output<framework::Tensor>("Output");
auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto out = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D.");
PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
"Conv2DOp filter should be 4-D.");
......@@ -46,10 +46,9 @@ class Conv2DOp : public framework::OperatorWithKernel {
}
};
class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker {
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOppMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
......@@ -62,7 +61,7 @@ class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker {
"The format of the filter tensor is MCHW, where M is the number of "
"output "
"image channels, C is the number of input image channels, H and W is "
" height and width of filter.");
"height and width of filter.");
AddOutput("Output",
"The output tensor of convolution operator."
"The format of output tensor is also NCHW.");
......@@ -80,14 +79,21 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto d_in = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto d_filter = ctx.Output<Tensor>(framework::GradVarName("Filter"));
d_in->Resize(in->dims());
d_filter->Resize(filter->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOppMaker, conv2d_grad,
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad,
ops::Conv2DOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
......@@ -31,12 +32,10 @@ class GemmConvKernel : public framework::OpKernel {
Tensor* filter = const_cast<Tensor*>(context.Input<Tensor>("Filter"));
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
paddle::framework::Tensor col;
paddle::framework::Tensor in_slice;
paddle::framework::Tensor out_slice;
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
auto filter_dims = filter->dims();
int batch_size = input->dims()[0];
int input_channels = input->dims()[1];
......@@ -50,6 +49,7 @@ class GemmConvKernel : public framework::OpKernel {
im2col;
framework::DDim col_shape = {input_channels, filter_height, filter_width,
output_height, output_width};
Tensor col;
col.mutable_data<float>(col_shape, context.GetPlace());
auto* device_context =
......@@ -67,22 +67,23 @@ class GemmConvKernel : public framework::OpKernel {
output->dims()[1], output->dims()[2] * output->dims()[3]};
filter->Resize(filter_matrix_shape);
// convolution opperator: im2col + gemm
// convolution operator: im2col + gemm
for (int i = 0; i < batch_size; i++) {
// im2col
in_slice = input->Slice<T>(i, i + 1);
Tensor in_slice = input->Slice<T>(i, i + 1);
in_slice.Resize(input_shape);
col.Resize(col_shape);
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1],
device_context);
// gemm
out_slice = output->Slice<T>(i, i + 1);
Tensor out_slice = output->Slice<T>(i, i + 1);
out_slice.Resize(output_matrix_shape);
col.Resize(col_matrix_shape);
math::matmul<Place, T>(*filter, false, col, false, T(1.0), &out_slice,
T(0.0), device_context);
}
filter->Resize(filter_dims);
}
};
......@@ -90,12 +91,92 @@ template <typename Place, typename T>
class GemmConvGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
#if 0
auto input = context.Input<Tensor>("Input");
auto filter = context.Input<Tensor>("Filter");
auto output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
#endif
const Tensor* input = context.Input<Tensor>("Input");
Tensor* filter = const_cast<Tensor*>(context.Input<Tensor>("Filter"));
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
input_grad->mutable_data<T>(context.GetPlace());
filter_grad->mutable_data<T>(context.GetPlace());
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
auto filter_dims = filter->dims();
int batch_size = input->dims()[0];
int input_channels = input->dims()[1];
int filter_height = filter->dims()[filter->dims().size() - 2];
int filter_width = filter->dims()[filter->dims().size() - 1];
int output_height = output_grad->dims()[2];
int output_width = output_grad->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
Tensor col;
framework::DDim col_shape = {input_channels, filter_height, filter_width,
output_height, output_width};
col.mutable_data<float>(col_shape, context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
framework::DDim input_shape = {input->dims()[1], input->dims()[2],
input->dims()[3]};
framework::DDim filter_matrix_shape = {
filter->dims()[0],
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]};
framework::DDim col_matrix_shape = {
input_channels * filter_height * filter_width,
output_height * output_width};
framework::DDim output_matrix_shape = {
output_grad->dims()[1],
output_grad->dims()[2] * output_grad->dims()[3]};
filter->Resize(filter_matrix_shape);
filter_grad->Resize(filter_matrix_shape);
auto t1 = framework::EigenVector<T>::Flatten(*filter_grad);
t1.device(context.GetEigenDevice<Place>()) = t1.constant(static_cast<T>(0));
auto t2 = framework::EigenVector<T>::Flatten(*input_grad);
t2.device(context.GetEigenDevice<Place>()) = t2.constant(static_cast<T>(0));
// convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm
for (int i = 0; i < batch_size; i++) {
// gemm
Tensor out_slice = output_grad->Slice<T>(i, i + 1);
out_slice.Resize(output_matrix_shape);
col.Resize(col_matrix_shape);
math::matmul<Place, T>(*filter, true, out_slice, false, T(1.0), &col,
T(0.0), device_context);
// col2im
Tensor in_grad_slice = input_grad->Slice<T>(i, i + 1);
in_grad_slice.Resize(input_shape);
col.Resize(col_shape);
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0],
paddings[1], device_context);
// im2col
Tensor in_slice = input->Slice<T>(i, i + 1);
in_slice.Resize(input_shape);
col.Resize(col_shape);
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1],
device_context);
// gemm
col.Resize(col_matrix_shape);
math::matmul<Place, T>(out_slice, false, col, true, T(1.0), filter_grad,
T(1.0), device_context);
}
filter->Resize(filter_dims);
filter_grad->Resize(filter_dims);
}
};
......
......@@ -2,6 +2,7 @@ import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
from paddle.v2.framework.op import Operator
class TestConv2dOp(unittest.TestCase):
......@@ -58,5 +59,42 @@ class TestConv2dOp(unittest.TestCase):
self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}
class TestConv2dGradOp(GradientChecker):
def setUp(self):
batch_size = 2
input_channels = 3
input_height = 5
input_width = 5
output_channels = 6
filter_height = 3
filter_width = 3
stride = 1
padding = 0
output_height = (input_height - filter_height + 2 * padding
) / stride + 1
output_width = (input_width - filter_width + 2 * padding) / stride + 1
input = np.random.random((batch_size, input_channels, input_height,
input_width)).astype("float32")
filter = np.random.random(
(output_channels, input_channels, filter_height,
filter_width)).astype("float32")
self.inputs = {'Input': input, 'Filter': filter}
self.op = Operator(
"conv2d",
Input='Input',
Filter='Filter',
Output='Output',
strides=[1, 1],
paddings=[0, 0])
def test_compare_grad(self):
self.compare_grad(self.op, self.inputs)
def test_check_grad(self):
self.check_grad(self.op, self.inputs,
set(['Input', 'Filter']), 'Output')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册