提交 64c5ecbe 编写于 作者: Z zchen0211

deconv

上级 502e7259
......@@ -18,13 +18,13 @@
namespace paddle {
namespace operators {
void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Deconv2DOp should not be null.");
"Input(Input) of Conv2DTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of Deconv2DOp should not be null.");
"Input(Filter) of Conv2DTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of Deconv2DOp should not be null.");
"Output(Output) of Conv2DTransposeOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
......@@ -32,13 +32,14 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op.");
PADDLE_ENFORCE_EQ(paddings[i], 0,
"No Padding allowed in conv transpose op.");
}
PADDLE_ENFORCE_EQ(in_dims.size(), 4,
"Deconv2DOp input should be 4-D tensor.");
"Conv2DTransposeOp input should be 4-D tensor.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
"Deconv2DOp filter should be 4-D tensor.");
"Conv2DTransposeOp filter should be 4-D tensor.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal.");
......@@ -48,36 +49,39 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
{in_dims[0], filter_dims[1], output_height, output_width});
}
Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"The input tensor of deconvolution operator. "
"The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image.");
AddInput("Filter",
"The filter tensor of deconvolution operator."
"The format of the filter tensor is MCHW, where C is the number of "
"The filter tensor of convolution transpose operator."
"The format of the filter tensor is CMHW, where C is the number of "
"output image channels, M is the number of input image channels, "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in "
"deconvolution Scenario.");
"convolution transpose Scenario.");
AddOutput("Output",
"The output tensor of deconvolution operator."
"The output tensor of convolution transpose operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", "strides of deconvolution operator.")
AddAttr<std::vector<int>>("strides",
"strides of convolution transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of deconvolution operator.")
AddAttr<std::vector<int>>("paddings",
"paddings of convolution transpose operator.")
.SetDefault({0, 0});
AddComment(R"DOC(
The deconvolution operation calculates the output based on the input, filter
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
)DOC");
}
void Deconv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const {
void Conv2DTransposeOpGrad::InferShape(
framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
......@@ -92,11 +96,13 @@ void Deconv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad,
ops::Deconv2DOpGrad);
REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2dtranspose_grad,
ops::Conv2DTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::CPUPlace, float>);
conv2dtranspose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
deconv2d_grad,
ops::GemmDeconvGrad2DKernel<paddle::platform::CPUPlace, float>);
conv2dtranspose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
......@@ -17,7 +17,8 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::GPUPlace, float>);
conv2dtranspose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
deconv2d_grad,
ops::GemmDeconvGrad2DKernel<paddle::platform::GPUPlace, float>);
conv2dtranspose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
......@@ -26,15 +26,15 @@ namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
// Define Op classes in .h file so that other deconv
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker {
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Deconv2DOpMaker(framework::OpProto* proto,
Conv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
class Deconv2DOp : public framework::OperatorWithKernel {
class Conv2DTransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -42,7 +42,7 @@ class Deconv2DOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
};
class Deconv2DOpGrad : public framework::OperatorWithKernel {
class Conv2DTransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -51,7 +51,7 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel {
};
template <typename Place, typename T>
class GemmDeconv2DKernel : public framework::OpKernel<T> {
class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
......@@ -64,27 +64,27 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
// no paddings and groups allowed in deconv
int N = input->dims()[0];
int M = input->dims()[1];
int H = input->dims()[2];
int W = input->dims()[3];
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
int K_H = filter.dims()[2];
int K_W = filter.dims()[3];
const int k_h = filter.dims()[2];
const int k_w = filter.dims()[3];
int C = output->dims()[1]; // output channels
int O_H = output->dims()[2];
int O_W = output->dims()[3];
const int c = output->dims()[1]; // output channels
const int o_h = output->dims()[2];
const int o_w = output->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation
DDim col_shape = {C, K_H, K_W, H, W};
DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {C * K_H * K_W, H * W};
DDim col_matrix_shape = {c * k_h * k_w, h * w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -94,10 +94,10 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W};
DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w};
DDim filter_matrix_shape = {M, C * K_H * K_W};
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// deconvolution: gemm + col2im (similar to conv-backward on input)
......@@ -106,16 +106,16 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; i++) {
// batch with size (M, H * W)
Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, C * K_H * K_W)
for (int i = 0; i < batch_size; i++) {
// batch with size (M, h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, c * k_h * k_w)
// output size: (C, O_H, O_W)
Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape);
// output size: (c, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
// of shape (c * k_h * k_w, h * w)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0));
col2im(context.device_context(), output_batch, col, strides[0],
......@@ -125,7 +125,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
};
template <typename Place, typename T>
class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
......@@ -145,17 +145,17 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in deconv.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int N = input->dims()[0];
int M = input->dims()[1];
int H = input->dims()[2];
int W = input->dims()[3];
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
int K_H = filter.dims()[2];
int K_W = filter.dims()[3];
const int k_h = filter.dims()[2];
const int k_w = filter.dims()[3];
int C = output_grad->dims()[1]; // output channels
int O_H = output_grad->dims()[2];
int O_W = output_grad->dims()[3];
const int c = output_grad->dims()[1]; // output channels
const int o_h = output_grad->dims()[2];
const int o_w = output_grad->dims()[3];
// Only im2col functor required for bp to get to the right shape
paddle::operators::math::Im2ColFunctor<
......@@ -163,10 +163,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
im2col;
// use col_shape in the im2col and col2im calculation
DDim col_shape = {C, K_H, K_W, H, W};
DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -174,10 +174,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W};
DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w};
DDim filter_matrix_shape = {M, C * K_H * K_W};
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// deconvolution grad on input:
......@@ -185,29 +185,29 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// input need to compute gradient
if (input_grad) {
Tensor col_matrix = col;
DDim col_matrix_shape = {C * K_H * K_W, H * W};
DDim col_matrix_shape = {c * k_h * k_w, h * w};
col_matrix.Resize(col_matrix_shape);
input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; i++) {
// batch with size (C, O_H * O_W)
for (int i = 0; i < batch_size; i++) {
// batch with size (c, o_h * o_w)
Tensor output_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
// filter of size (M, C * K_H * K_W)
output_grad->Slice(i, i + 1).Resize(output_shape);
// filter of size (m, c * k_h * k_w)
// batch with size (M, H, W)
// batch with size (m, h, w)
Tensor input_grad_batch =
input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape);
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W)
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[1]);
// gemm: dx = filter * dy
// (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H)
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, T(1.0), &input_grad_batch,
T(0.0));
......@@ -217,7 +217,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// filter gradient required
if (filter_grad) {
Tensor col_matrix_f = col;
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
col_matrix_f.Resize(col_matrix_shape_f);
filter_grad->mutable_data<T>(context.GetPlace());
......@@ -226,19 +226,19 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; ++i) {
// batch with size (C, O_H, O_W)
for (int i = 0; i < batch_size; ++i) {
// batch with size (c, o_h, o_w)
Tensor output_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
output_grad->Slice(i, i + 1).Resize(output_shape);
// input batch
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: (C * H * W, K_H * K_W)
// im2col: (c * h * w, k_h * k_w)
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[1]);
// gemm: d_filter = x * y_grad^T
// (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H)
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix_f, true, T(1.0), &filter_grad_,
T(1.0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册