From f3669ca3f18eee7c817f4b72f163734f0daaa001 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 18 Sep 2017 23:48:49 +0800 Subject: [PATCH] Support input_grad = null or filter_grad = null. --- paddle/operators/conv2d_op.cc | 11 ++- paddle/operators/gemm_conv2d_op.h | 84 ++++++++++++------- .../v2/framework/tests/test_conv2d_op.py | 6 ++ 3 files changed, 68 insertions(+), 33 deletions(-) diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index 10091ec6a5..12db65b5cb 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -28,6 +28,13 @@ class Conv2DOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), + "Input(Input) of Conv2DOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), + "Input(Filter) of Conv2DOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), + "Output(Output) of Conv2DOp should not be null."); + auto in = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto out = ctx.Output("Output"); @@ -108,8 +115,8 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { ctx.Output(framework::GradVarName("Input")); auto d_filter = ctx.Output(framework::GradVarName("Filter")); - d_in->Resize(in->dims()); - d_filter->Resize(filter->dims()); + if (d_in) d_in->Resize(in->dims()); + if (d_filter) d_filter->Resize(filter->dims()); } }; diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index a4df7b9cb9..96f4c06005 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -111,14 +111,16 @@ class GemmConvGrad2DKernel : public framework::OpKernel { context.Output(framework::GradVarName("Input")); Tensor* filter_grad_ = context.Output(framework::GradVarName("Filter")); - input_grad->mutable_data(context.GetPlace()); - filter_grad_->mutable_data(context.GetPlace()); // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); - Tensor filter_grad = *filter_grad_; + Tensor filter_grad; + if (filter_grad_) { + filter_grad_->mutable_data(context.GetPlace()); + filter_grad = *filter_grad_; + } std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); @@ -162,12 +164,20 @@ class GemmConvGrad2DKernel : public framework::OpKernel { framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - filter_grad.Resize(filter_matrix_shape); - auto t1 = framework::EigenVector::Flatten(filter_grad); - t1.device(context.GetEigenDevice()) = t1.constant(static_cast(0)); - auto t2 = framework::EigenVector::Flatten(*input_grad); - t2.device(context.GetEigenDevice()) = t2.constant(static_cast(0)); + if (filter_grad_) { + filter_grad.Resize(filter_matrix_shape); + auto t1 = framework::EigenVector::Flatten(filter_grad); + t1.device(context.GetEigenDevice()) = + t1.constant(static_cast(0)); + } + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t2 = framework::EigenVector::Flatten(*input_grad); + t2.device(context.GetEigenDevice()) = + t2.constant(static_cast(0)); + } auto* device_context = const_cast(context.device_context_); @@ -176,35 +186,47 @@ class GemmConvGrad2DKernel : public framework::OpKernel { // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; + Tensor in_grad_batch; + Tensor in_batch; for (int i = 0; i < batch_size; i++) { Tensor out_grad_batch = output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + if (input_grad) { + in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + } + if (filter_grad_) { + in_batch = input->Slice(i, i + 1).Resize(input_shape); + } for (int g = 0; g < groups; g++) { - // gemm Tensor out_grad_slice = out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, true, out_grad_slice, false, - T(1.0), &col_matrix, T(0.0), device_context); - - // col2im - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); - - // im2col - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); - - // gemm - Tensor filter_grad_slice = - filter_grad.Slice(g * out_step, (g + 1) * out_step); - math::matmul(out_grad_slice, false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0), device_context); + if (input_grad) { + // gemm + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, true, out_grad_slice, false, + T(1.0), &col_matrix, T(0.0), device_context); + + // col2im + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + } + + if (filter_grad_) { + // im2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // gemm + Tensor filter_grad_slice = + filter_grad.Slice(g * out_step, (g + 1) * out_step); + math::matmul(out_grad_slice, false, col_matrix, true, + T(1.0), &filter_grad_slice, T(1.0), + device_context); + } } } } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 64aeb6e8a9..3142a60a1a 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -75,6 +75,12 @@ class TestConv2dOp(OpTest): def test_check_grad(self): self.check_grad(set(['Input', 'Filter']), 'Output') + def test_check_grad_no_filter(self): + self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input'])) + def init_groups(self): self.groups = 1 -- GitLab