diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index a068daf9a812f6f54084f17ef513d5d96309c182..97f31bf22d7072d89bd043045045dcb5bb5518b8 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -41,8 +41,8 @@ namespace ops = paddle::operators; REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, ops::ConvOpGrad); -REGISTER_OP_CPU_KERNEL( - conv_cudnn, ops::GemmConv2DKernel); +REGISTER_OP_CPU_KERNEL(conv_cudnn, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( conv_cudnn_grad, - ops::GemmConvGrad2DKernel); + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 54ac4f4111445d8f7120f593de84ca8dff2582cf..a6f65f10165929316f971d195f3790fd9e7ed376 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -198,12 +198,12 @@ namespace ops = paddle::operators; REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); +REGISTER_OP_CPU_KERNEL(conv2d, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv2d, ops::GemmConv2DKernel); -REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2DKernel); + conv2d_grad, ops::GemmConvGradKernel); +REGISTER_OP_CPU_KERNEL(conv3d, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv3d, ops::GemmConv3DKernel); -REGISTER_OP_CPU_KERNEL( - conv3d_grad, ops::GemmConvGrad3DKernel); + conv3d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv_op.cu index d8c0bd9326bb9cdb2fb22439dee75404422e85a5..8e6f9da455b7291049aee57189dae15b8bcc2150 100644 --- a/paddle/operators/conv_op.cu +++ b/paddle/operators/conv_op.cu @@ -16,12 +16,12 @@ namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(conv2d, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv2d, ops::GemmConv2DKernel); -REGISTER_OP_GPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2DKernel); + conv2d_grad, ops::GemmConvGradKernel); +REGISTER_OP_GPU_KERNEL(conv3d, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv3d, ops::GemmConv3DKernel); -REGISTER_OP_GPU_KERNEL( - conv3d_grad, ops::GemmConvGrad3DKernel); + conv3d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 198e51e4ad4c45b6a7fc155a9eb1d1343e263a28..7c1729213bf3f5f3987afbf2d51d5b5339ae521d 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -62,7 +62,7 @@ class ConvOpGrad : public framework::OperatorWithKernel { }; template -class GemmConv2DKernel : public framework::OpKernel { +class GemmConvKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -77,49 +77,78 @@ class GemmConv2DKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output->dims()[1]; - int output_height = output->dims()[2]; - int output_width = output->dims()[3]; + const int batch_size = static_cast(input->dims()[0]); + + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); + + // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + std::vector output_shape_vec(framework::vectorize(output->dims())); + output_shape_vec.erase(output_shape_vec.begin(), + output_shape_vec.begin() + 2); - math::Im2ColFunctor im2col; // use col_shape in the im2col calculation - framework::DDim col_shape = {input_channels / groups, filter_height, - filter_width, output_height, output_width}; + // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, + // o_h, o_w} + std::vector col_shape_vec; + col_shape_vec.push_back(input->dims()[1] / groups); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), + output_shape_vec.end()); + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_height * filter_width, - output_height * output_width}; + // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * + // o_h * o_w) + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + Tensor col; col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. - Tensor col_matrix = col; + Tensor col_matrix; + col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - framework::DDim input_shape = {input->dims()[1], input->dims()[2], - input->dims()[3]}; + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = {output_channels, - output_height * output_width}; - // convolution operator: im2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; g++) { - // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(context.device_context(), in_slice, col, strides[0], strides[1], - paddings[0], paddings[0], paddings[1], paddings[1]); + + if (filter_shape_vec.size() == 2) { + // im2col + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + } else if (filter_shape_vec.size() == 3) { + // vol2col + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -132,7 +161,7 @@ class GemmConv2DKernel : public framework::OpKernel { }; template -class GemmConvGrad2DKernel : public framework::OpKernel { +class GemmConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -142,267 +171,74 @@ class GemmConvGrad2DKernel : public framework::OpKernel { context.Output(framework::GradVarName("Input")); Tensor* filter_grad = context.Output(framework::GradVarName("Filter")); - // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); + if (!input_grad && !filter_grad) return; + std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output_grad->dims()[1]; - int output_height = output_grad->dims()[2]; - int output_width = output_grad->dims()[3]; - - math::Col2ImFunctor col2im; - math::Im2ColFunctor im2col; - // use col_shape in the im2col and col2im calculation - framework::DDim col_shape = {input_channels / groups, filter_height, - filter_width, output_height, output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_height * filter_width, - output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); - - framework::DDim input_shape = {input->dims()[1], input->dims()[2], - input->dims()[3]}; - framework::DDim output_matrix_shape = { - output_grad->dims()[1], - output_grad->dims()[2] * output_grad->dims()[3]}; - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - - // convolution backward input operator: gemm + col2im - // convolution backward weight operator: im2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - math::SetConstant set_zero; - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - set_zero(context.device_context(), input_grad, static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - - // col2im - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); - } - } - } - - if (filter_grad) { - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - set_zero(context.device_context(), filter_grad, static_cast(0)); + const int batch_size = static_cast(input->dims()[0]); - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } - } -}; + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); -template -class GemmConv3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); + // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + output_shape_vec.erase(output_shape_vec.begin(), + output_shape_vec.begin() + 2); - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); + // use col_shape in the im2col calculation + // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, + // o_h, o_w} + std::vector col_shape_vec; + col_shape_vec.push_back(input->dims()[1] / groups); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), + output_shape_vec.end()); + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output->dims()[1]; - int output_depth = output->dims()[2]; - int output_height = output->dims()[3]; - int output_width = output->dims()[4]; - - math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); + // size: (i_c/g * k_h * k_w, o_h * o_w) + // or + // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w) + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = { - output_channels, output_depth * output_height * output_width}; - - // convolution operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - // vol2col - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } -}; - -template -class GemmConvGrad3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - // The filter and filter_grad will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); + output_grad->dims()[1], + output_grad->numel() / + (output_grad->dims()[0] * output_grad->dims()[1])}; - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); + // convolution backward input operator: gemm + col2im(or col2vol) + // convolution backward weight operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output_grad->dims()[1]) / groups; - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output_grad->dims()[1]; - int output_depth = output_grad->dims()[2]; - int output_height = output_grad->dims()[3]; - int output_width = output_grad->dims()[4]; - - math::Col2VolFunctor col2vol; - math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col and col2vol calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; Tensor col; - col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. - Tensor col_matrix = col; + Tensor col_matrix; + col.mutable_data(col_shape, context.GetPlace()); + col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim output_matrix_shape = {output_grad->dims()[1], - output_grad->dims()[2] * - output_grad->dims()[3] * - output_grad->dims()[4]}; - - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width - filter.Resize(filter_matrix_shape); - - // convolution backward input operator: gemm + col2vol - // convolution backward weight operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; math::SetConstant set_zero; if (input_grad) { @@ -421,13 +257,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel { math::matmul(context.device_context(), filter_slice, true, out_grad_slice, false, T(1.0), &col_matrix, T(0.0)); - - // col2vol + // col2im Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2vol(context.device_context(), in_grad_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + + if (filter_shape_vec.size() == 2) { + math::Col2ImFunctor col2im; + col2im(context.device_context(), in_grad_slice, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + + } else if (filter_shape_vec.size() == 3) { + math::Col2VolFunctor col2vol; + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } } } } @@ -443,13 +288,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel { output_grad->Slice(i, i + 1).Resize(output_matrix_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); for (int g = 0; g < groups; g++) { - // vol2col + // im2col Tensor out_grad_slice = out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + + if (filter_shape_vec.size() == 2) { + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + } else if (filter_shape_vec.size() == 3) { + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } // gemm Tensor filter_grad_slice = @@ -462,6 +316,5 @@ class GemmConvGrad3DKernel : public framework::OpKernel { } } }; - } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 6bd4bad8e2db53907a0e1e442215ca8912d2c300..04ae7f294c27fdceaaff2e9a7ed854213e643945 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -61,25 +61,23 @@ class TestConv2dOp(OpTest): def test_check_grad(self): self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + set(['Input', 'Filter']), 'Output', max_relative_error=0.02) def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', - max_relative_error=0.05, + max_relative_error=0.02, no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): self.check_grad( ['Filter'], 'Output', - max_relative_error=0.05, + max_relative_error=0.02, no_grad_set=set(['Input'])) def init_test_case(self): - # self.groups = 1 - # self.op_type = "conv2d" self.pad = [0, 0] self.stride = [1, 1] self.dilations = [1, 1] diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index f8e07fc562602a631f6e27bbe921d65238910d9b..44c192f58d25f8ddaa38d2ba7c7c19b9a5bd7dc1 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -64,20 +64,20 @@ class TestConv3dOp(OpTest): def test_check_grad(self): self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + set(['Input', 'Filter']), 'Output', max_relative_error=0.03) def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', - max_relative_error=0.05, + max_relative_error=0.03, no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): self.check_grad( ['Filter'], 'Output', - max_relative_error=0.05, + max_relative_error=0.03, no_grad_set=set(['Input'])) def init_test_case(self):