From c86a514074e486b1df5b81aa069b9945a866cdb4 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 16 Feb 2023 11:04:20 +0800 Subject: [PATCH] [XPU] add group_norm, sin, cos, linspace, randint kernels (#50465) * [XPU] add group_norm kernel * update * add xpu sin, cos, randint, linspace kernels * update * update --- paddle/phi/backends/xpu/xpu2_op_list.cc | 6 + paddle/phi/kernels/xpu/activation_kernel.cc | 30 +++++ paddle/phi/kernels/xpu/group_norm_kernel.cc | 93 +++++++++++++++ paddle/phi/kernels/xpu/linspace_kernel.cc | 84 ++++++++++++++ paddle/phi/kernels/xpu/randint_kernel.cc | 73 ++++++++++++ .../unittests/xpu/test_activation_op_xpu.py | 99 ++++++++++++++++ .../unittests/xpu/test_group_norm_op_xpu.py | 109 ++++++++++++++++++ .../unittests/xpu/test_linspace_op_xpu.py | 87 ++++++++++++++ .../unittests/xpu/test_randint_op_xpu.py | 82 +++++++++++++ 9 files changed, 663 insertions(+) create mode 100644 paddle/phi/kernels/xpu/group_norm_kernel.cc create mode 100644 paddle/phi/kernels/xpu/linspace_kernel.cc create mode 100644 paddle/phi/kernels/xpu/randint_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_linspace_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_randint_op_xpu.py diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 8caa484b7bb..a5935c07b94 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -711,6 +711,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, phi::DataType::FLOAT32})}, + {"sin", XPUKernelSet({phi::DataType::FLOAT32})}, + {"cos", XPUKernelSet({phi::DataType::FLOAT32})}, + {"linspace", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, + {"group_norm", XPUKernelSet({phi::DataType::FLOAT32})}, // AddMore {"sequence_conv", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index def5fbb65b8..2f0e90e7d3f 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -457,6 +457,32 @@ struct XPUFloorFunctor : public funcs::BaseActivationFunctor { } }; +template +struct XPUSinFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::sin); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sin"); + } +}; + +template +struct XPUCosFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::cos); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cos"); + } +}; + DEFINE_XPU_ACTIVATION_KERNEL(Exp, XPUExpFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Floor, XPUFloorFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Log, XPULogFunctor) @@ -467,6 +493,8 @@ DEFINE_XPU_ACTIVATION_KERNEL(Square, XPUSquareFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Sqrt, XPUSqrtFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Tanh, XPUTanhFunctor) DEFINE_XPU_ACTIVATION_KERNEL(Silu, XPUSiluFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Sin, XPUSinFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Cos, XPUCosFunctor) DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold) DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, @@ -531,3 +559,5 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel) +PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel) diff --git a/paddle/phi/kernels/xpu/group_norm_kernel.cc b/paddle/phi/kernels/xpu/group_norm_kernel.cc new file mode 100644 index 00000000000..7d82a5d18fe --- /dev/null +++ b/paddle/phi/kernels/xpu/group_norm_kernel.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/group_norm_kernel.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GroupNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + using XPUType = typename XPUTypeTrait::Type; + + const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + + const auto x_dims = phi::vectorize(x.dims()); + const int N = x_dims[0]; + const bool channel_first = + data_layout == DataLayout::kNCHW || data_layout == DataLayout::kNCDHW; + const int C = (channel_first ? x_dims[1] : x_dims[x_dims.size() - 1]); + const int L = + (channel_first + ? std::accumulate( + x_dims.begin() + 2, x_dims.end(), 1, std::multiplies()) + : std::accumulate(x_dims.begin() + 1, + x_dims.end() - 1, + 1, + std::multiplies())); + + dev_ctx.template Alloc(y); + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(var); + + auto* x_data = x.data(); + auto* y_data = y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + + auto r = + xpu::group_norm(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + N, + C, + L, + 1, + groups, + static_cast(epsilon), + reinterpret_cast(scale_data), + reinterpret_cast(bias_data), + reinterpret_cast(mean_data), + reinterpret_cast(var_data), + channel_first); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(group_norm, XPU, ALL_LAYOUT, phi::GroupNormKernel, float) {} diff --git a/paddle/phi/kernels/xpu/linspace_kernel.cc b/paddle/phi/kernels/xpu/linspace_kernel.cc new file mode 100644 index 00000000000..e33a6d73f1c --- /dev/null +++ b/paddle/phi/kernels/xpu/linspace_kernel.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/linspace_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" + +namespace phi { + +template +T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) { + switch (x.dtype()) { + case DataType::FLOAT32: + return static_cast(GetValue(ctx, x)); + case DataType::FLOAT64: + return static_cast(GetValue(ctx, x)); + case DataType::INT32: + return static_cast(GetValue(ctx, x)); + case DataType::INT64: + return static_cast(GetValue(ctx, x)); + case DataType::FLOAT16: + return static_cast(GetValue(ctx, x)); + case DataType::BFLOAT16: + return static_cast(GetValue(ctx, x)); + case DataType::BOOL: + return static_cast(GetValue(ctx, x)); + case DataType::INT16: + return static_cast(GetValue(ctx, x)); + case DataType::UINT8: + return static_cast(GetValue(ctx, x)); + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Data type (%s) is not supported when casting data type.", + x.dtype())); + } +} + +template +void LinspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + DataType dtype, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + T start_value = GetValueOfExpectedType(ctx, start); + T stop_value = GetValueOfExpectedType(ctx, stop); + int32_t num = GetValueOfExpectedType(ctx, number); + + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of linspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + int r = xpu::linspace(ctx.x_context(), + reinterpret_cast(out_data), + static_cast(start_value), + static_cast(stop_value), + num); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "linspace"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + linspace, XPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t) {} diff --git a/paddle/phi/kernels/xpu/randint_kernel.cc b/paddle/phi/kernels/xpu/randint_kernel.cc new file mode 100644 index 00000000000..b6b43c17148 --- /dev/null +++ b/paddle/phi/kernels/xpu/randint_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/randint_kernel.h" + +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/generator.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RandintRawKernel(const Context& dev_ctx, + int low, + int high, + const IntArray& shape, + DataType dtype, + int seed, + DenseTensor* out) { + int64_t size = out->numel(); + out->Resize(phi::make_ddim(shape.GetData())); + T* data = dev_ctx.template Alloc(out); + auto numel = out->numel(); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } + std::unique_ptr data_cpu(new T[size]); + std::uniform_int_distribution dist(low, high - 1); + for (int64_t i = 0; i < numel; ++i) { + data_cpu[i] = dist(*engine); + } + paddle::memory::Copy(dev_ctx.GetPlace(), + data, + phi::CPUPlace(), + reinterpret_cast(data_cpu.get()), + size * sizeof(T)); +} + +template +void RandintKernel(const Context& dev_ctx, + int low, + int high, + const IntArray& shape, + DataType dtype, + DenseTensor* out) { + RandintRawKernel(dev_ctx, low, high, shape, dtype, 0, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + randint_raw, XPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {} + +PD_REGISTER_KERNEL(randint, XPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) { +} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py index 12302c582f3..04931e5fb53 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py @@ -1222,5 +1222,104 @@ def ref_mish(x, threshold=20): return out +class XPUTestSinOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'sin' + self.use_dynamic_create_class = False + + class XPUTestSinBase(TestActivationOPBase): + def set_case(self): + self.op_type = "sin" + self.dtype = self.in_type + + self.init_config() + out = np.sin(self.x) + + self.inputs = {'X': self.x} + self.outputs = {'Out': out} + self.attrs = {'use_xpu': True} + + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype( + self.dtype + ) + + class XPUTestSin_ZeroDim(XPUTestSinBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype) + + class XPUTestSin2(XPUTestSinBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype( + self.dtype + ) + + class XPUTestSin3(XPUTestSinBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype( + self.dtype + ) + + class XPUTestSin4(XPUTestSinBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype( + self.dtype + ) + + +support_types = get_xpu_op_support_types('sin') +for stype in support_types: + create_test_class(globals(), XPUTestSinOP, stype) + + +class XPUTestCosOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'cos' + self.use_dynamic_create_class = False + + class XPUTestCosBase(TestActivationOPBase): + def set_case(self): + self.op_type = "cos" + self.dtype = self.in_type + + self.init_config() + out = np.cos(self.x) + + self.inputs = {'X': self.x} + self.outputs = {'Out': out} + self.attrs = {'use_xpu': True} + + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype( + self.dtype + ) + + class XPUTestCos_ZeroDim(XPUTestCosBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype) + + class XPUTestCos2(XPUTestCosBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype( + self.dtype + ) + + class XPUTestCos3(XPUTestCosBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype( + self.dtype + ) + + class XPUTestCos4(XPUTestCosBase): + def init_config(self): + self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype( + self.dtype + ) + + +support_types = get_xpu_op_support_types('cos') +for stype in support_types: + create_test_class(globals(), XPUTestCosOP, stype) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py new file mode 100644 index 00000000000..7ff43265890 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +sys.path.append("..") + +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle + +paddle.enable_static() + + +def group_norm_naive(x, scale, bias, epsilon, groups, data_layout): + if data_layout == "NHWC": + x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW + N, C, H, W = x.shape + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + output = (x - mean) / np.sqrt(var + epsilon) + output = output.reshape((N, C, H, W)) * scale.reshape( + (-1, 1, 1) + ) + bias.reshape((-1, 1, 1)) + if data_layout == "NHWC": + output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC + return output, mean.reshape((N, G)), var.reshape((N, G)) + + +class XPUTestGroupNormOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'group_norm' + self.use_dynamic_create_class = False + + class TestGroupNormOp(XPUOpTest): + def init_test_case(self): + self.data_format = "NCHW" + self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"} + + def setUp(self): + '''Test GroupNorm Op with supplied attributes''' + self.__class__.op_type = 'group_norm' + self.dtype = self.in_type + self.shape = (2, 100, 3, 5) + self.init_test_case() + input = np.random.random(self.shape).astype(self.dtype) + if self.data_format == "NHWC": + input = np.transpose(input, (0, 2, 3, 1)) + scale = np.random.random([self.shape[1]]).astype(self.dtype) + bias = np.random.random([self.shape[1]]).astype(self.dtype) + + output, mean, var = group_norm_naive( + input, + scale, + bias, + self.attrs['epsilon'], + self.attrs['groups'], + self.data_format, + ) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(input), + 'Scale': OpTest.np_dtype_to_fluid_dtype(scale), + 'Bias': OpTest.np_dtype_to_fluid_dtype(bias), + } + self.outputs = {'Y': output, 'Mean': mean, 'Variance': var} + self.attrs['data_layout'] = self.data_format + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + def test_check_grad(self): + pass + + class TestGroupNormOp2(TestGroupNormOp): + def init_test_case(self): + self.data_format = "NHWC" + self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NHWC"} + + +support_types = get_xpu_op_support_types('group_norm') +for stype in support_types: + create_test_class(globals(), XPUTestGroupNormOp, stype) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_linspace_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_linspace_op_xpu.py new file mode 100644 index 00000000000..65247c5bec5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_linspace_op_xpu.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +sys.path.append("..") + +from op_test_xpu import XPUOpTest, convert_np_dtype_to_dtype_ +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle + +paddle.enable_static() + + +class XPUTestLinspaceOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'linspace' + self.use_dynamic_create_class = False + + class TestXPULinespaceOp(XPUOpTest): + def setUp(self): + self.op_type = "linspace" + self.dtype = self.in_type + self.set_attrs() + + self.atol = 1e-4 + np.random.seed(10) + self.inputs = { + 'Start': np.array([0]).astype(self.dtype), + 'Stop': np.array([10]).astype(self.dtype), + 'Num': np.array([11]).astype('int32'), + } + self.outputs = {'Out': np.arange(0, 11).astype(self.dtype)} + self.attrs = {'dtype': int(convert_np_dtype_to_dtype_(self.dtype))} + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0), atol=self.atol) + + class TestXPULinespace2(TestXPULinespaceOp): + def set_attrs(self): + self.inputs = { + 'Start': np.array([10]).astype(self.dtype), + 'Stop': np.array([0]).astype(self.dtype), + 'Num': np.array([11]).astype('int32'), + } + + self.outputs = {'Out': np.arange(10, -1, -1).astype(self.dtype)} + + class TestXPULinespace3(TestXPULinespaceOp): + def set_attrs(self): + self.inputs = { + 'Start': np.array([10]).astype(self.dtype), + 'Stop': np.array([0]).astype(self.dtype), + 'Num': np.array([1]).astype('int32'), + } + self.outputs = {'Out': np.array(10, dtype=self.dtype)} + + +support_types = get_xpu_op_support_types('linspace') +for stype in support_types: + create_test_class(globals(), XPUTestLinspaceOp, stype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_randint_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_randint_op_xpu.py new file mode 100644 index 00000000000..6e74b437e2d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_randint_op_xpu.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +sys.path.append("..") + +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle + +paddle.enable_static() + + +def output_hist(out): + hist, _ = np.histogram(out, range=(-10, 10)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones((10)) + return hist, prob + + +class XPUTestRandIntOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'randint' + self.use_dynamic_create_class = False + + class TestXPURandIntOp(XPUOpTest): + def setUp(self): + self.op_type = "randint" + self.dtype = self.in_type + self.set_attrs() + + self.atol = 1e-4 + np.random.seed(10) + self.inputs = {} + self.outputs = {"Out": np.zeros((10000, 784)).astype("float32")} + self.attrs = { + "shape": [10000, 784], + "low": -10, + "high": 10, + "seed": 10, + } + self.output_hist = output_hist + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + np.testing.assert_allclose(hist, prob, rtol=0, atol=0.001) + + +support_types = get_xpu_op_support_types('randint') +for stype in support_types: + create_test_class(globals(), XPUTestRandIntOp, stype) + + +if __name__ == "__main__": + unittest.main() -- GitLab