diff --git a/paddle/fluid/operators/conv_op_xpu.cc b/paddle/fluid/operators/conv_op_xpu.cc index ddfc6fe862c27a61c81e2bce694377ea1348a8b5..e4751f1f26008c3d443fc0126d3e6d68995a44e0 100644 --- a/paddle/fluid/operators/conv_op_xpu.cc +++ b/paddle/fluid/operators/conv_op_xpu.cc @@ -19,14 +19,16 @@ namespace operators { template class GemmConvXPUKernel : public framework::OpKernel { + using XPUT = typename XPUTypeTrait::Type; + public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *input = context.Input("Input"); // The filter will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); + Tensor *output = context.Output("Output"); output->mutable_data(context.GetPlace()); int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); @@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel { const int img_h = static_cast(input->dims()[2]); const int img_w = static_cast(input->dims()[3]); const int f = static_cast(filter.dims()[0]); - auto& dev_ctx = context.template device_context(); - int r = xpu::conv2d( - dev_ctx.x_context(), input->data(), filter.data(), - output->data(), batch_size, img_c, img_h, img_w, f, ksize, - strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true); + + const XPUT *input_data = reinterpret_cast(input->data()); + const XPUT *filter_data = reinterpret_cast(filter.data()); + XPUT *output_data = reinterpret_cast(output->data()); + + auto &dev_ctx = context.template device_context(); + int r = xpu::conv2d( + dev_ctx.x_context(), input_data, filter_data, output_data, batch_size, + img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups, + nullptr, nullptr, nullptr, true); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU conv kernel return wrong value[%d %s]", @@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel { template class GemmConvGradXPUKernel : public framework::OpKernel { + using XPUT = typename XPUTypeTrait::Type; + public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *input = context.Input("Input"); + const Tensor *output_grad = context.Input(framework::GradVarName("Output")); - Tensor* input_grad = + Tensor *input_grad = context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = + Tensor *filter_grad = context.Output(framework::GradVarName("Filter")); // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, @@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel { const int img_h = static_cast(input->dims()[2]); const int img_w = static_cast(input->dims()[3]); const int f = static_cast(filter.dims()[0]); + + const XPUT *input_data = reinterpret_cast(input->data()); + const XPUT *filter_data = reinterpret_cast(filter.data()); + const XPUT *output_grad_data = + reinterpret_cast(output_grad->data()); + XPUT *input_grad_data = nullptr; if (input_grad) { input_grad->mutable_data(context.GetPlace()); + input_grad_data = reinterpret_cast(input_grad->data()); } + XPUT *filter_grad_data = nullptr; if (filter_grad) { filter_grad->mutable_data(context.GetPlace()); + filter_grad_data = reinterpret_cast(filter_grad->data()); } - auto& dev_ctx = context.template device_context(); - int r = xpu::conv2d_grad( - dev_ctx.x_context(), input->data(), filter.data(), - output_grad->data(), input_grad ? input_grad->data() : nullptr, - filter_grad ? filter_grad->data() : nullptr, batch_size, img_c, - img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, - nullptr, nullptr, nullptr, nullptr, true); + auto &dev_ctx = context.template device_context(); + int r = xpu::conv2d_grad( + dev_ctx.x_context(), input_data, filter_data, output_grad_data, + input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f, + ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, + nullptr, nullptr, true); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU conv kernel return wrong value[%d %s]", @@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( - depthwise_conv2d, - ops::GemmConvXPUKernel); -REGISTER_OP_XPU_KERNEL( - conv2d, ops::GemmConvXPUKernel); + conv2d, ops::GemmConvXPUKernel, + ops::GemmConvXPUKernel); REGISTER_OP_XPU_KERNEL( conv2d_grad, - ops::GemmConvGradXPUKernel); + ops::GemmConvGradXPUKernel, + ops::GemmConvGradXPUKernel); +REGISTER_OP_XPU_KERNEL( + depthwise_conv2d, + ops::GemmConvXPUKernel, + ops::GemmConvXPUKernel); REGISTER_OP_XPU_KERNEL( depthwise_conv2d_grad, - ops::GemmConvGradXPUKernel); + ops::GemmConvGradXPUKernel, + ops::GemmConvGradXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 14f516235a720c1fb8f46fe6606ac8f0bdb149f9..57d6c5e119ccfa51a40d9f34d47c070c347d8546 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -51,16 +51,20 @@ XPUOpMap& get_kl2_ops() { {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d_transpose_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"depthwise_conv2d_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"depthwise_conv2d", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"dropout_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py index 78089d703891edac663cfd5a43c12c513cab7e92..5f954659c2d9a3ad7d5c2fbb69a0797afc6cc760 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py @@ -23,6 +23,7 @@ import paddle.fluid as fluid from op_test_xpu import XPUOpTest import paddle from paddle.fluid import Program, program_guard +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper def conv2d_forward_naive(input, @@ -159,320 +160,334 @@ def create_test_padding_VALID_class(parent): globals()[cls_name] = TestPaddingVALIDCase -class TestConv2DOp(XPUOpTest): - def setUp(self): - self.op_type = "conv2d" - self.use_cudnn = False - self.exhaustive_search = False - self.use_cuda = False - self.use_mkldnn = False - self.fuse_relu_before_depthwise_conv = False - self.data_format = "AnyLayout" - self.dtype = np.float32 - self.init_kernel_type() - self.init_group() - self.init_dilation() - self.init_test_case() - - conv2d_param = { - 'stride': self.stride, - 'pad': self.pad, - 'dilation': self.dilations - } - - input = np.random.random(self.input_size).astype(self.dtype) - if not self.has_cuda(): +class XPUTestConv2DOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'conv2d' + self.use_dynamic_create_class = False + + class TestConv2DOp(XPUOpTest): + def setUp(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.op_type = "conv2d" + self.use_cudnn = False + self.exhaustive_search = False + self.use_cuda = False + self.use_mkldnn = False self.fuse_relu_before_depthwise_conv = False - if self.fuse_relu_before_depthwise_conv: - input = input - 0.5 - input -= (input < 0) * 0.1 - input += (input >= 0) * 0.1 - input2 = np.maximum(input, 0.0) - else: - input2 = input - filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) - - output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups, - conv2d_param) - output = output.astype(self.dtype) - - self.inputs = { - 'Input': XPUOpTest.np_dtype_to_fluid_dtype(input), - 'Filter': XPUOpTest.np_dtype_to_fluid_dtype(filter) - } - self.attrs = { - 'strides': self.stride, - 'paddings': self.pad, - 'groups': self.groups, - 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format, - 'fuse_relu_before_depthwise_conv': - self.fuse_relu_before_depthwise_conv, - 'exhaustive_search': self.exhaustive_search - } - self.outputs = {'Output': output} - - def has_cuda(self): - return core.is_compiled_with_cuda() and (self.use_cudnn or - self.use_cuda) - - def test_check_output(self): - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad(self): - if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and - self.no_need_check_grad == True): - return - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, {'Input', 'Filter'}, 'Output') - - def test_check_grad_no_filter(self): - if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and - self.no_need_check_grad == True): - return - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place( - place, ['Input'], 'Output', no_grad_set=set(['Filter'])) - - def test_check_grad_no_input(self): - if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and - self.no_need_check_grad == True): - return - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place( - place, ['Filter'], 'Output', no_grad_set=set(['Input'])) - - def init_test_case(self): - self.pad = [0, 0] - self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 3, 3] - - def init_test_case_2(self): - pass - - def init_dilation(self): - self.dilations = [1, 1] - - def init_group(self): - self.groups = 1 - - def init_kernel_type(self): - pass - - -class TestWithPad(TestConv2DOp): - def init_test_case(self): - self.pad = [1, 1] - self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 3, 3] - - -class TestWithStride(TestConv2DOp): - def init_test_case(self): - self.pad = [1, 1] - self.stride = [2, 2] - self.input_size = [2, 3, 6, 6] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 3, 3] - - -class TestWith1x1(TestConv2DOp): - def init_test_case(self): - self.pad = [0, 0] - self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [120, f_c, 1, 1] - - def init_group(self): - self.groups = 1 - - -# Please Don't remove the following code. -# Currently, CI use cudnn V5.0 which not support dilation conv. -# class TestCUDNNWithDilation(TestWithDilation): -# def init_op_type(self): -# self.op_type = "conv_cudnn" + self.data_format = "AnyLayout" + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_test_case() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + np.random.seed(100) + input = np.random.random(self.input_size).astype(self.dtype) + if not self.has_cuda(): + self.fuse_relu_before_depthwise_conv = False + if self.fuse_relu_before_depthwise_conv: + input = input - 0.5 + input -= (input < 0) * 0.1 + input += (input >= 0) * 0.1 + input2 = np.maximum(input, 0.0) + else: + input2 = input + np.random.seed(1) + filter = np.random.uniform(-1, 1, + self.filter_size).astype(self.dtype) + + output, _, _, _, _ = conv2d_forward_naive(input2, filter, + self.groups, conv2d_param) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': XPUOpTest.np_dtype_to_fluid_dtype(input), + 'Filter': XPUOpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + 'fuse_relu_before_depthwise_conv': + self.fuse_relu_before_depthwise_conv, + 'exhaustive_search': self.exhaustive_search + } + self.outputs = {'Output': output} + + def has_cuda(self): + return core.is_compiled_with_cuda() and (self.use_cudnn or + self.use_cuda) + + def test_check_output(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_output_with_place(self.place) + + def test_check_grad(self): + if (hasattr(self, "no_need_check_grad") and + self.no_need_check_grad == True): + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place(self.place, {'Input', 'Filter'}, + 'Output') + + def test_check_grad_no_filter(self): + if (hasattr(self, "no_need_check_grad") and + self.no_need_check_grad == True): + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place( + self.place, ['Input'], + 'Output', + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + if (hasattr(self, "no_need_check_grad") and + self.no_need_check_grad == True): + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place( + self.place, ['Filter'], + 'Output', + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] -# ---- test asymmetric padding ---- + def init_test_case_2(self): + pass + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + class TestWithPad(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + class TestWithStride(TestConv2DOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + class TestWith1x1(TestConv2DOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1] + def init_group(self): + self.groups = 1 -class TestConv2DOp_v2(XPUOpTest): - def setUp(self): - self.op_type = "conv2d" - self.use_cudnn = False - self.exhaustive_search = False - self.use_cuda = False - self.use_mkldnn = False - self.fuse_relu_before_depthwise_conv = False - self.dtype = np.float32 - self.init_kernel_type() - self.init_group() - self.init_dilation() - self.init_data_format() - self.init_test_case() - self.init_paddings() - self.init_test_case_2() - - conv2d_param = { - 'stride': self.stride, - 'pad': self.pad, - 'dilation': self.dilations - } - - input = np.random.random(self.input_size).astype(self.dtype) - if not self.has_cuda(): + +# ---- test asymmetric padding ---- +class XPUTestConv2DOp_v2(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'conv2d' + self.use_dynamic_create_class = False + + class TestConv2DOp_v2(XPUOpTest): + def setUp(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.op_type = "conv2d" + self.use_cudnn = False + self.exhaustive_search = False + self.use_cuda = False + self.use_mkldnn = False self.fuse_relu_before_depthwise_conv = False - if self.fuse_relu_before_depthwise_conv: - input = input - 0.5 - input -= (input < 0) * 0.1 - input += (input >= 0) * 0.1 - input2 = np.maximum(input, 0.0) - else: - input2 = input - filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) - output, _, _, _, _ = conv2d_forward_naive( - input2, filter, self.groups, conv2d_param, self.padding_algorithm, - self.data_format) - output = output.astype(self.dtype) - - self.inputs = { - 'Input': XPUOpTest.np_dtype_to_fluid_dtype(input), - 'Filter': XPUOpTest.np_dtype_to_fluid_dtype(filter) - } - self.attrs = { - 'strides': self.stride, - 'paddings': self.pad, - 'padding_algorithm': self.padding_algorithm, - 'groups': self.groups, - 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format, - 'fuse_relu_before_depthwise_conv': - self.fuse_relu_before_depthwise_conv, - 'exhaustive_search': self.exhaustive_search - } - self.outputs = {'Output': output} - - def has_cuda(self): - return core.is_compiled_with_cuda() and (self.use_cudnn or - self.use_cuda) - - def test_check_output(self): - # TODO(wangzhongpu): support mkldnn op in dygraph mode - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad(self): - # TODO(wangzhongpu): support mkldnn op in dygraph mode - if self.dtype == np.float16: - return - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, {'Input', 'Filter'}, 'Output') - - def test_check_grad_no_filter(self): - # TODO(wangzhongpu): support mkldnn op in dygraph mode - if self.dtype == np.float16: - return - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place( - place, ['Input'], 'Output', no_grad_set=set(['Filter'])) - - def test_check_grad_no_input(self): - # TODO(wangzhongpu): support mkldnn op in dygraph mode - if self.dtype == np.float16: - return - if core.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place( - place, ['Filter'], 'Output', no_grad_set=set(['Input'])) - - def init_test_case(self): - self.pad = [0, 0] - self.stride = [1, 2] - self.input_size = [2, 3, 5, 5] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 4, 3] - - def init_dilation(self): - self.dilations = [1, 1] - - def init_group(self): - self.groups = 1 - - def init_kernel_type(self): - pass - - def init_paddings(self): - self.pad = [0, 0] - self.padding_algorithm = "EXPLICIT" - - def init_data_format(self): - self.data_format = "NCHW" - - def init_test_case_2(self): - pass - - -class TestConv2DOp_AsyPadding(TestConv2DOp_v2): - def init_paddings(self): - self.pad = [0, 0, 0, 0] - self.padding_algorithm = "EXPLICIT" - - -class TestWithPad_AsyPadding(TestConv2DOp_v2): - def init_test_case(self): - self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 3, 3] - - def init_paddings(self): - self.pad = [1, 1, 1, 1] - self.padding_algorithm = "EXPLICIT" - - -class TestWithStride_AsyPadding(TestConv2DOp_v2): - def init_test_case(self): - self.stride = [2, 2] - self.input_size = [2, 3, 6, 6] # NCHW - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 3, 3] - - def init_paddings(self): - self.pad = [1, 1, 1, 1] - self.padding_algorithm = "EXPLICIT" + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_data_format() + self.init_test_case() + self.init_paddings() + self.init_test_case_2() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + np.random.seed(100) + input = np.random.random(self.input_size).astype(self.dtype) + if not self.has_cuda(): + self.fuse_relu_before_depthwise_conv = False + if self.fuse_relu_before_depthwise_conv: + input = input - 0.5 + input -= (input < 0) * 0.1 + input += (input >= 0) * 0.1 + input2 = np.maximum(input, 0.0) + else: + input2 = input + np.random.seed(8) + filter = np.random.uniform(-1, 1, + self.filter_size).astype(self.dtype) + output, _, _, _, _ = conv2d_forward_naive( + input2, filter, self.groups, conv2d_param, + self.padding_algorithm, self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': XPUOpTest.np_dtype_to_fluid_dtype(input), + 'Filter': XPUOpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + 'fuse_relu_before_depthwise_conv': + self.fuse_relu_before_depthwise_conv, + 'exhaustive_search': self.exhaustive_search + } + self.outputs = {'Output': output} + + def has_cuda(self): + return core.is_compiled_with_cuda() and (self.use_cudnn or + self.use_cuda) + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_output_with_place(place=self.place) + + def test_check_grad(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if (hasattr(self, "no_need_check_grad") and + self.no_need_check_grad == True): + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place(self.place, {'Input', 'Filter'}, + 'Output') + + def test_check_grad_no_filter(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if (hasattr(self, "no_need_check_grad") and + self.no_need_check_grad == True): + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place( + self.place, ['Input'], + 'Output', + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if (hasattr(self, "no_need_check_grad") and + self.no_need_check_grad == True): + return + if core.is_compiled_with_xpu(): + paddle.enable_static() + self.check_grad_with_place( + self.place, ['Filter'], + 'Output', + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 2] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 4, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_data_format(self): + self.data_format = "NCHW" + + def init_test_case_2(self): + pass + + class TestConv2DOp_AsyPadding(TestConv2DOp_v2): + def init_paddings(self): + self.pad = [0, 0, 0, 0] + self.padding_algorithm = "EXPLICIT" + + class TestWithPad_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [1, 1, 1, 1] + self.padding_algorithm = "EXPLICIT" + + class TestWithStride_AsyPadding(TestConv2DOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [1, 1, 1, 1] + self.padding_algorithm = "EXPLICIT" + +support_types = get_xpu_op_support_types('conv2d') +for stype in support_types: + create_test_class(globals(), XPUTestConv2DOp, stype) + create_test_class(globals(), XPUTestConv2DOp_v2, stype) #---------- test SAME VALID ----------- #create_test_padding_SAME_class(TestConv2DOp_AsyPadding)