diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index cdd86479f44b99e88c27684b657b4d6986e56f33..d1c1c361a9b3ba4dd3f9c85e121fb44bc47d1eb4 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -131,6 +131,9 @@ XPUOpMap& get_kl2_ops() { {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"conv3d_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv3d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, @@ -282,6 +285,12 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"unfold", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"unfold_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"floor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), @@ -523,6 +532,12 @@ XPUOpMap& get_kl2_ops() { {"sgd_dense_param_sparse_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"silu_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"silu", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"sigmoid_cross_entropy_with_logits_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sigmoid_cross_entropy_with_logits", diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index 9585e2264db6762fcfb8912419312ec4d0ff50c8..4ab540a570577a1edb31332594f50fb266cf8b41 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -367,6 +367,26 @@ struct XPURelu6GradFunctor : public funcs::BaseActivationFunctor { } }; +template +struct XPUSiluGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + dev_ctx.template Alloc(dx); + const XPUType* x_data = reinterpret_cast(x->data()); + const XPUType* y_grad = reinterpret_cast(dout->data()); + XPUType* x_grad = reinterpret_cast(dx->data()); + + int r = xpu::swish_grad( + dev_ctx.x_context(), x_data, y_grad, x_grad, dx->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad"); + } +}; + template struct XPUSigmoidGradFunctor : public funcs::BaseActivationFunctor { using XPUType = typename XPUTypeTrait::Type; @@ -552,6 +572,7 @@ DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, XPUSqrtGradFunctor); DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, XPUTanhGradFunctor); DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, XPUReluGradFunctor); +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, XPUSiluGradFunctor); DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, XPULogGradFunctor); DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, XPUSquareGradFunctor); @@ -603,6 +624,12 @@ PD_REGISTER_KERNEL(relu_grad, phi::ReluGradKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(silu_grad, + XPU, + ALL_LAYOUT, + phi::SiluGradKernel, + float, + phi::dtype::float16) {} #define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 39f928eb114733f7437900c7f4dc66f6e9eb2b09..def5fbb65b84dfb35b0f264041292a45072b1293 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -322,6 +322,24 @@ struct XPURelu6Functor : public funcs::BaseActivationFunctor { } }; +template +struct XPUSiluFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + dev_ctx.template Alloc(out); + const XPUType* x_data = reinterpret_cast(x.data()); + XPUType* y_data = reinterpret_cast(out->data()); + + auto xpu_context = dev_ctx.x_context(); + int r = + xpu::swish(xpu_context, x_data, y_data, x.numel(), nullptr, nullptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish"); + } +}; + template struct XPUSigmoidFunctor : public funcs::BaseActivationFunctor { using XPUType = typename XPUTypeTrait::Type; @@ -448,6 +466,7 @@ DEFINE_XPU_ACTIVATION_KERNEL(Sigmoid, XPUSigmoidFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Square, XPUSquareFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Sqrt, XPUSqrtFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Tanh, XPUTanhFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Silu, XPUSiluFunctor) DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold) DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, @@ -486,6 +505,8 @@ void HardSwishRawKernel(const Context& dev_ctx, PD_REGISTER_KERNEL( relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL( + silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {} #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} diff --git a/paddle/phi/kernels/xpu/conv_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_grad_kernel.cc index de4c573b375f687674bb7969045c2c0672391785..8ce6103d47e2c17572109a4d9f3f6eef46481f36 100644 --- a/paddle/phi/kernels/xpu/conv_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_grad_kernel.cc @@ -168,6 +168,127 @@ void DepthwiseConvGradKernel(const Context& dev_ctx, filter_grad); } +template +void Conv3DGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const std::vector& strides, + const std::vector& paddings_t, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations_t, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad) { + using XPUT = typename XPUTypeTrait::Type; + std::vector paddings = paddings_t; + std::vector dilations = dilations_t; + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + if (!input_grad && !filter_grad) return; + + phi::DDim in_data_dims = + phi::slice_ddim(input.dims(), 2, input.dims().size()); + phi::DDim filter_data_dims = + phi::slice_ddim(filter.dims(), 2, filter.dims().size()); + std::vector ksize = phi::vectorize(filter_data_dims); + std::vector filter_shape = phi::vectorize(filter.dims()); + UpdatePaddingAndDilation( + &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); + + int batch_size = static_cast(input.dims()[0]); + int img_c = static_cast(input.dims()[1]); + int img_d = static_cast(input.dims()[2]); + int img_h = static_cast(input.dims()[3]); + int img_w = static_cast(input.dims()[4]); + int f = static_cast(filter.dims()[0]); + bool is_ncdhw = true; + if (data_format == "NDHWC") { + img_c = static_cast(input.dims()[4]); + img_d = static_cast(input.dims()[1]); + img_h = static_cast(input.dims()[2]); + img_w = static_cast(input.dims()[3]); + is_ncdhw = false; + } + + const XPUT* input_data = reinterpret_cast(input.data()); + const XPUT* filter_data = reinterpret_cast(filter.data()); + const XPUT* output_grad_data = + reinterpret_cast(out_grad.data()); + XPUT* input_grad_data = nullptr; + if (input_grad) { + dev_ctx.template Alloc(input_grad); + input_grad_data = reinterpret_cast(input_grad->data()); + } + XPUT* filter_grad_data = nullptr; + if (filter_grad) { + dev_ctx.template Alloc(filter_grad); + filter_grad_data = reinterpret_cast(filter_grad->data()); + } + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + + XPUT* filter_data_tmp; + XPUT* filter_grad_data_tmp; + const XPUT* filter_data_ptr = filter_data; + XPUT* filter_grad_data_ptr = filter_grad_data; + if (data_format == "NDHWC") { + filter_data_tmp = RAII_GUARD.alloc(filter.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); + int r = xpu::transpose(dev_ctx.x_context(), + filter_data, + filter_data_tmp, + filter_shape, + {0, 2, 3, 4, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + filter_data_ptr = reinterpret_cast(filter_data_tmp); + + if (filter_grad_data != nullptr) { + filter_grad_data_tmp = RAII_GUARD.alloc(filter.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(filter_grad_data_tmp); + filter_grad_data_ptr = filter_grad_data_tmp; + } + } + int r = xpu::conv3d_grad(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_ncdhw); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); + + if ((filter_grad_data_ptr != nullptr) && (data_format == "NDHWC")) { + std::vector filter_shape_fhwc = {filter_shape[0], + filter_shape[2], + filter_shape[3], + filter_shape[4], + filter_shape[1]}; + int r = xpu::transpose(dev_ctx.x_context(), + filter_grad_data_ptr, + filter_grad_data, + filter_shape_fhwc, + {0, 4, 1, 2, 3}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + } +} } // namespace phi PD_REGISTER_KERNEL(conv2d_grad, @@ -182,3 +303,9 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad, ALL_LAYOUT, phi::DepthwiseConvGradKernel, float) {} +PD_REGISTER_KERNEL(conv3d_grad, + XPU, + ALL_LAYOUT, + phi::Conv3DGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/unfold_grad_kernel.cc b/paddle/phi/kernels/xpu/unfold_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..298d6655331da0b65239fe058b761e6a54021b58 --- /dev/null +++ b/paddle/phi/kernels/xpu/unfold_grad_kernel.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/unfold_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/unfold_functor.h" + +namespace phi { + +template +void UnfoldGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& kernel_sizes, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + ctx.template Alloc(x_grad); + const std::string data_format = phi::DataLayoutToString(x.layout()); + bool is_nchw = data_format == "NCHW"; + PADDLE_ENFORCE_EQ(is_nchw, + true, + phi::errors::PreconditionNotMet( + "Unfold grad op only supports datalayout == NCHW")); + + auto x_dims = x_grad->dims(); + int n = static_cast(x_dims[0]); + int c = static_cast(x_dims[1]); + int h = static_cast(x_dims[2]); + int w = static_cast(x_dims[3]); + + int out_height = phi::funcs::CalcOutputSize(x_dims[2], + kernel_sizes[0], + dilations[0], + paddings[0], + paddings[2], + strides[0]); + int out_width = phi::funcs::CalcOutputSize(x_dims[3], + kernel_sizes[1], + dilations[1], + paddings[1], + paddings[3], + strides[1]); + + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + XPUType* out_grad_trans = + RAII_GUARD.alloc_l3_or_gm(out_grad.numel()); + + int r = xpu::transpose( + ctx.x_context(), + reinterpret_cast(out_grad.data()), + out_grad_trans, + {n, c, kernel_sizes[0], kernel_sizes[1], out_height, out_width}, + {0, 4, 5, 1, 2, 3}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + r = xpu::col2im(ctx.x_context(), + out_grad_trans, + reinterpret_cast(x_grad->data()), + n, + c, + h, + w, + kernel_sizes, + strides, + paddings, + dilations, + is_nchw); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "col2im"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(unfold_grad, + XPU, + ALL_LAYOUT, + phi::UnfoldGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/unfold_kernel.cc b/paddle/phi/kernels/xpu/unfold_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..64a12b2881296e218d904453709e389e7533f3db --- /dev/null +++ b/paddle/phi/kernels/xpu/unfold_kernel.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/unfold_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/unfold_functor.h" + +namespace phi { + +template +void UnfoldKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& kernel_sizes, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + ctx.template Alloc(out); + const std::string data_format = phi::DataLayoutToString(x.layout()); + bool is_nchw = data_format == "NCHW"; + PADDLE_ENFORCE_EQ(is_nchw, + true, + phi::errors::PreconditionNotMet( + "Unfold op only supports datalayout == NCHW")); + auto x_dims = x.dims(); + int n = static_cast(x_dims[0]); + int c = static_cast(x_dims[1]); + int h = static_cast(x_dims[2]); + int w = static_cast(x_dims[3]); + + int out_height = phi::funcs::CalcOutputSize(x_dims[2], + kernel_sizes[0], + dilations[0], + paddings[0], + paddings[2], + strides[0]); + int out_width = phi::funcs::CalcOutputSize(x_dims[3], + kernel_sizes[1], + dilations[1], + paddings[1], + paddings[3], + strides[1]); + + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + XPUType* out_pre_trans = RAII_GUARD.alloc_l3_or_gm(out->numel()); + int r = xpu::im2col(ctx.x_context(), + reinterpret_cast(x.data()), + out_pre_trans, + n, + c, + h, + w, + kernel_sizes, + strides, + paddings, + dilations, + is_nchw); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "im2col"); + + r = xpu::transpose( + ctx.x_context(), + out_pre_trans, + reinterpret_cast(out->data()), + {n, out_height, out_width, c, kernel_sizes[0], kernel_sizes[1]}, + {0, 3, 4, 5, 1, 2}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); +} +} // namespace phi + +PD_REGISTER_KERNEL( + unfold, XPU, ALL_LAYOUT, phi::UnfoldKernel, float, phi::dtype::float16) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py index d5afa4a1b130db0cfd11acfefa63b912ab008906..bb2cddf04b13aae4d5da95ec52f953689ab5565f 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py @@ -19,6 +19,9 @@ import numpy as np sys.path.append("..") +import paddle +import paddle.nn.functional as F + from op_test import OpTest from op_test_xpu import XPUOpTest from xpu.get_test_cover_info import ( @@ -86,6 +89,87 @@ for stype in support_types: create_test_class(globals(), XPUTestExpOP, stype) +class XPUTestSiluOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'silu' + self.use_dynamic_create_class = False + + class XPUTestSilu(TestActivationOPBase): + def set_case(self): + self.op_type = "silu" + self.dtype = self.in_type + self.init_shape() + + np.random.seed(1024) + x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + out = x / (np.exp(-x) + 1) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'use_xpu': True} + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def init_shape(self): + self.shape = [11, 17] + + class TestSilu_ZeroDim(XPUTestSilu): + def init_shape(self): + self.shape = [] + + +class TestSiluAPI(unittest.TestCase): + # test paddle.nn.Silu, paddle.nn.functional.silu + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') + self.place = paddle.XPUPlace(0) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [11, 17]) + out1 = F.silu(x) + m = paddle.nn.Silu() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = self.x_np / (1 + np.exp(-self.x_np)) + for r in res: + np.testing.assert_allclose(out_ref, r, rtol=1e-05) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.silu(x) + m = paddle.nn.Silu() + out2 = m(x) + out_ref = self.x_np / (1 + np.exp(-self.x_np)) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.silu, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[11, 17], dtype='int32' + ) + self.assertRaises(TypeError, F.silu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[11, 17], dtype='float16' + ) + F.silu(x_fp16) + + +support_types = get_xpu_op_support_types('silu') +for stype in support_types: + create_test_class(globals(), XPUTestSiluOP, stype) + + class XPUTestSigmoidOP(XPUOpTestWrapper): def __init__(self): self.op_name = 'sigmoid' diff --git a/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py index 46dcd0c1302220e444f42a12aadd539df1c87240..915fb249514a964523589c95245efed158cfa93c 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py @@ -258,6 +258,38 @@ class XPUTestConv3DOp(XPUOpTestWrapper): place = paddle.XPUPlace(0) self.check_output_with_place(place) + def test_check_grad(self): + place = paddle.XPUPlace(0) + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_grad_with_place( + place, + {'Input', 'Filter'}, + 'Output', + max_relative_error=0.03, + ) + + def test_check_grad_no_filter(self): + place = paddle.XPUPlace(0) + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_grad_with_place( + place, + ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter']), + ) + + def test_check_grad_no_input(self): + place = paddle.XPUPlace(0) + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_grad_with_place( + place, + ['Filter'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input']), + ) + def init_test_case(self): self.pad = [0, 0, 0] self.stride = [1, 1, 1] @@ -401,6 +433,32 @@ class XPUTestConv3DOp_v2(XPUOpTestWrapper): place = paddle.XPUPlace(0) self.check_output_with_place(place) + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03 + ) + + def test_check_grad_no_filter(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, + ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter']), + ) + + def test_check_grad_no_input(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, + ['Filter'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input']), + ) + def init_test_case(self): self.stride = [1, 1, 1] self.input_size = [2, 3, 4, 4, 4] # NCDHW diff --git a/python/paddle/fluid/tests/unittests/xpu/test_unfold_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_unfold_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..cce01f1aebf3be8dd3a35dc2645f40e8e430c812 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_unfold_op_xpu.py @@ -0,0 +1,180 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import sys +import unittest + +sys.path.append("..") +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) + +paddle.enable_static() + + +class XPUTestUnfoldOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'unfold' + self.use_dynamic_create_class = False + + class TestUnfoldOp(XPUOpTest): + """ + This is for test on unfold Op + """ + + def init_data(self): + self.batch_size = 3 + self.input_channels = 3 + self.input_height = 20 + self.input_width = 20 + + self.kernel_sizes = [2, 2] + self.strides = [1, 1] + self.paddings = [1, 1, 1, 1] + self.dilations = [1, 1] + input_shape = [ + self.batch_size, + self.input_channels, + self.input_height, + self.input_width, + ] + self.x = np.random.rand(*input_shape).astype(self.dtype) + + def calc_unfold(self): + output_shape = [0] * 3 + output_shape[0] = self.batch_size + output_shape[1] = ( + self.input_channels + * self.kernel_sizes[0] + * self.kernel_sizes[1] + ) + dkernel_h = self.dilations[0] * (self.kernel_sizes[0] - 1) + 1 + dkernel_w = self.dilations[1] * (self.kernel_sizes[1] - 1) + 1 + out_height = ( + int( + ( + self.input_height + + self.paddings[0] + + self.paddings[2] + - dkernel_h + ) + / self.strides[0] + ) + + 1 + ) + out_width = ( + int( + ( + self.input_width + + self.paddings[1] + + self.paddings[3] + - dkernel_w + ) + / self.strides[1] + ) + + 1 + ) + output_shape[2] = out_height * out_width + output = np.zeros(output_shape).astype(np.float64) + # ------------ calculate output -------------- # + for i in range(output_shape[0]): + for j in range(output_shape[1]): + for k in range(output_shape[2]): + h_out = int(k / out_width) + w_out = k % out_width + w_offset = j % self.kernel_sizes[1] + h_offset = ( + int(j / self.kernel_sizes[1]) % self.kernel_sizes[0] + ) + c_in = int( + j / (self.kernel_sizes[0] * self.kernel_sizes[1]) + ) + h_in = ( + h_offset * self.dilations[0] + + h_out * self.strides[0] + - self.paddings[0] + ) + w_in = ( + w_offset * self.dilations[1] + + w_out * self.strides[1] + - self.paddings[1] + ) + if (h_in >= 0 and h_in < self.input_height) and ( + w_in >= 0 and w_in < self.input_width + ): + output[i, j, k] = self.x[i, c_in, h_in, w_in] + + self.outputs = output + + def set_data(self): + self.init_data() + self.calc_unfold() + + self.inputs = {'X': self.x} + self.attrs = { + 'kernel_sizes': self.kernel_sizes, + 'paddings': self.paddings, + 'dilations': self.dilations, + 'strides': self.strides, + } + self.outputs = {'Y': self.outputs} + + def setUp(self): + self.op_type = 'unfold' + self.dtype = self.in_type + self.set_data() + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Y') + + class TestUnfoldAPI(TestUnfoldOp): + """ + This is for test on paddle.nn.Unfold + """ + + def setUp(self): + self.op_type = 'unfold' + self.set_data() + self.places = [paddle.XPUPlace(0)] + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input = fluid.dygraph.to_variable(self.inputs['X']) + m = paddle.nn.Unfold(**self.attrs) + m.eval() + result = m(input) + np.testing.assert_allclose( + result.numpy(), self.outputs['Y'], rtol=1e-05 + ) + + def test_info(self): + str(paddle.nn.Unfold(**self.attrs)) + + +support_types = get_xpu_op_support_types('unfold') +for stype in support_types: + create_test_class(globals(), XPUTestUnfoldOp, stype) + +if __name__ == "__main__": + unittest.main()