diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index d89ecd27c0954d7ec5ec8a307cf153253d348a4b..9041feb10c87dae5a39ba2601ff213165f01071a 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -36,7 +36,7 @@ ENDIF() if(NOT DEFINED XPU_BASE_URL) SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211129") + SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211226") else() SET(XPU_BASE_URL "${XPU_BASE_URL}") endif() diff --git a/paddle/fluid/operators/argsort_op_xpu.cc b/paddle/fluid/operators/argsort_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fee1e8adccf15792931d2aeddb6ef2357cfa738 --- /dev/null +++ b/paddle/fluid/operators/argsort_op_xpu.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/argsort_op.h" + +namespace paddle { +namespace operators { + +const int XPU_SORT_MAX_SIZE = 16384; + +template +static inline void xpu_argsort(xpu::Context* ctx, const T* input_data, + T* output_data, TID* indices_data, int m, int n, + bool descending) { + int ret = + xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External("XPU sort kernel return wrong value[%d %s].", + ret, XPUAPIErrorMsg[ret])); +} + +template +static inline void xpu_transpose(xpu::Context* ctx, const T* x, T* y, + const std::vector& xshape, + const std::vector& permute) { + int ret = xpu::transpose(ctx, x, y, xshape, permute); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU transpose kernel return wrong value[%d %s]", ret, + XPUAPIErrorMsg[ret])); +} + +template +static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) { + int ret = xpu::cast_v2(ctx, x, y, len); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External("XPU cast kernel return wrong value[%d %s]", + ret, XPUAPIErrorMsg[ret])); +} + +template +struct XPUArgsort { + void operator()(xpu::Context* ctx, const T* input_data, T* output_data, + int64_t* indices_data, const std::vector& data_shape, + const std::vector& permute, bool descending) { + xpu::ctx_guard RAII_GUARD(ctx); + int m = data_shape[0] * data_shape[2]; + int n = data_shape[1]; + int len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{data_shape[0], data_shape[2], + data_shape[1]}; + + T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + + xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); + xpu_argsort(ctx, input_data_trans, output_data_trans, indices_data_trans, m, + n, descending); + xpu_transpose(ctx, output_data_trans, output_data, trans_data_shape, + permute); + xpu_transpose(ctx, indices_data_trans, indices_data, trans_data_shape, + permute); + } +}; + +template +struct XPUArgsort { + void operator()(xpu::Context* ctx, const T* input_data, T* output_data, + int64_t* indices_data, const std::vector& data_shape, + const std::vector& permute, bool descending) { + xpu::ctx_guard RAII_GUARD(ctx); + int m = data_shape[0] * data_shape[2]; + int n = data_shape[1]; + int len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{data_shape[0], data_shape[2], + data_shape[1]}; + + T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); + + xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); + xpu_argsort(ctx, input_data_trans, output_data_trans, indices_data_trans, m, + n, descending); + xpu_transpose(ctx, output_data_trans, output_data, trans_data_shape, + permute); + xpu_cast(ctx, indices_data_trans, cast_data_int64, len); + xpu_transpose(ctx, cast_data_int64, indices_data, trans_data_shape, + permute); + } +}; + +template <> +struct XPUArgsort { + void operator()(xpu::Context* ctx, const int64_t* input_data, + int64_t* output_data, int64_t* indices_data, + const std::vector& data_shape, + const std::vector& permute, bool descending) { + xpu::ctx_guard RAII_GUARD(ctx); + int m = data_shape[0] * data_shape[2]; + int n = data_shape[1]; + int len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{data_shape[0], data_shape[2], + data_shape[1]}; + + int* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* cast_data_int = RAII_GUARD.alloc_l3_or_gm(len); + int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); + + xpu_cast(ctx, input_data, cast_data_int, len); + xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute); + xpu_argsort(ctx, input_data_trans, output_data_trans, indices_data_trans, m, + n, descending); + + xpu_cast(ctx, output_data_trans, cast_data_int64, len); + xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute); + xpu_cast(ctx, indices_data_trans, cast_data_int64, len); + xpu_transpose(ctx, cast_data_int64, indices_data, trans_data_shape, + permute); + } +}; + +template +class ArgsortXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = ctx.Attr("axis"); + bool descending = ctx.Attr("descending"); + + auto in_dims = input->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + int n = in_dims[axis]; + + PADDLE_ENFORCE_LT( + n, XPU_SORT_MAX_SIZE, + platform::errors::InvalidArgument( + "The axis dimension of Input should less than %d, but got %d.", + XPU_SORT_MAX_SIZE, in_dims[axis])); + + auto input_data = input->data(); + auto output_data = output->mutable_data(ctx.GetPlace()); + auto indices_data = indices->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = + ctx.template device_context(); + int len_before = + framework::product(framework::slice_ddim(in_dims, 0, axis)); + int len_after = framework::product( + framework::slice_ddim(in_dims, axis + 1, in_dims.size())); + bool int64_need_cast = + (std::is_same::value && n > (XPU_SORT_MAX_SIZE / 2)) + ? true + : false; + bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false; + std::vector permute_vec{0, 2, 1}; + std::vector data_shape{len_before, n, len_after}; + + if (int64_need_cast) { + XPUArgsort()(dev_ctx.x_context(), input_data, output_data, + indices_data, data_shape, permute_vec, + descending); + } else if (index_need_cast) { + XPUArgsort()(dev_ctx.x_context(), input_data, output_data, + indices_data, data_shape, permute_vec, + descending); + } else { + XPUArgsort()(dev_ctx.x_context(), input_data, + output_data, indices_data, data_shape, + permute_vec, descending); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(argsort, ops::ArgsortXPUKernel, + ops::ArgsortXPUKernel, + ops::ArgsortXPUKernel); + +#endif diff --git a/paddle/fluid/operators/scatter_op_xpu.cc b/paddle/fluid/operators/scatter_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..fadf063bc5bd65649bc7245a35c0d810980ce871 --- /dev/null +++ b/paddle/fluid/operators/scatter_op_xpu.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include +#include + +#include "paddle/fluid/operators/scatter_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ScatterOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Ids"); + auto *updates = ctx.Input("Updates"); + auto *out = ctx.Output("Out"); + bool overwrite = ctx.Attr("overwrite"); + + // In place output: Out = X, Out[ids] = Updates + framework::TensorCopy(*x, ctx.GetPlace(), out); + // Apply ScatterUpdate: Out[index] = Updates[:] + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + // check index of shape 1-D + PADDLE_ENFORCE_EQ( + index->dims().size() == 1 || + (index->dims().size() == 2 && index->dims()[1] == 1), + true, platform::errors::InvalidArgument( + "index's shape is error, " + "expect index'dims shape is 1 or 2 and index.dims[1] is 1" + "but got index'dims shape is %d", + index->dims().size())); + + int index_size = static_cast(index->dims()[0]); + auto x_dims = x->dims(); + auto update_dims = updates->dims(); + for (int i = 1; i < x_dims.size(); i++) + PADDLE_ENFORCE_EQ( + x_dims[i], update_dims[i], + platform::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, x_dims[i], i, update_dims[i])); + + int dim0 = static_cast(x->dims()[0]); + int dim1 = static_cast( + framework::product(framework::slice_ddim(x_dims, 1, x_dims.size()))); + T *out_data = out->data(); + const T *updates_data = updates->data(); + + auto &dev_ctx = + ctx.template device_context(); + int r = XPU_SUCCESS; + + Tensor indices_cpu(index->type()); + framework::TensorCopy(*index, platform::CPUPlace(), &indices_cpu); + + if (index_type == framework::proto::VarType::INT32) { + auto index_data = const_cast(index->data()); + xpu::VectorParam indices{indices_cpu.data(), index_size, + index_data}; + r = xpu::scatter(dev_ctx.x_context(), updates_data, out_data, indices, + dim0, dim1, overwrite); + } else { + auto index_data = const_cast(index->data()); + xpu::VectorParam indices{indices_cpu.data(), index_size, + index_data}; + r = xpu::scatter(dev_ctx.x_context(), updates_data, out_data, indices, + dim0, dim1, overwrite); + } + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU scatter kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(scatter, ops::ScatterOpXPUKernel, + ops::ScatterOpXPUKernel); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 115250b3db76a8d1adf7cb855f2aeda475d4aac4..0742c2a3968b9d412ce65d60d7149616649b65b1 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -32,6 +32,9 @@ XPUOpMap& get_kl2_ops() { {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"argsort", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"assign_value", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm_grad", @@ -263,6 +266,8 @@ XPUOpMap& get_kl2_ops() { {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, + {"scatter", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..5c77d6304302c982b11a1710c000dc5570e33f23 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py @@ -0,0 +1,237 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +from paddle.fluid import ParamAttr +from paddle.fluid.framework import Program, grad_var_name +from paddle.fluid.executor import Executor +from paddle.fluid.backward import append_backward + +paddle.enable_static() + + +class TestArgsortOp(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "argsort" + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.init_inputshape() + self.init_axis() + self.init_direction() + + self.x = np.random.random(self.input_shape).astype(self.dtype) + self.inputs = {"X": self.x} + self.attrs = {"axis": self.axis, "descending": self.descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def get_output(self): + if self.descending: + self.indices = np.flip( + np.argsort( + self.x, kind='heapsort', axis=self.axis), self.axis) + self.sorted_x = np.flip( + np.sort( + self.x, kind='heapsort', axis=self.axis), self.axis) + else: + self.indices = np.argsort(self.x, kind='heapsort', axis=self.axis) + self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis) + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = True + + def init_inputshape(self): + self.input_shape = (2, 2, 2, 3, 3) + + def init_dtype(self): + self.dtype = 'float32' + + def init_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_direction(self): + self.descending = False + + +class TestArgsortOpAxis0XPU(TestArgsortOp): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpAxis1XPU(TestArgsortOp): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxis2XPU(TestArgsortOp): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1XPU(TestArgsortOp): + def init_axis(self): + self.axis = -1 + + +class TestArgsortOpAxisNeg2XPU(TestArgsortOp): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpDescendingAxisXPU(TestArgsortOp): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0XPU(TestArgsortOpAxis0XPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1XPU(TestArgsortOpAxis1XPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2XPU(TestArgsortOpAxis2XPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1XPU(TestArgsortOpAxisNeg1XPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2XPU(TestArgsortOpAxisNeg2XPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpAxis0XPUINT64(TestArgsortOp): + def setUp(self): + self.set_xpu() + self.op_type = "argsort" + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.init_inputshape() + self.init_axis() + self.init_direction() + + self.x = np.random.randint( + low=-1000, high=1000, size=self.input_shape).astype(self.dtype) + self.inputs = {"X": self.x} + self.attrs = {"axis": self.axis, "descending": self.descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def init_axis(self): + self.axis = 0 + + def init_dtype(self): + self.dtype = 'int64' + + +class TestArgsortOpAxis1XPUINT64(TestArgsortOpAxis0XPUINT64): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxis2XPUINT64(TestArgsortOpAxis0XPUINT64): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1XPUINT64(TestArgsortOpAxis0XPUINT64): + def init_axis(self): + self.axis = -1 + + +class TestArgsortOpAxisNeg2XPUINT64(TestArgsortOpAxis0XPUINT64): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpDescendingAxisXPUINT64(TestArgsortOpAxis0XPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0XPUINT64(TestArgsortOpAxis0XPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1XPUINT64(TestArgsortOpAxis1XPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2XPUINT64(TestArgsortOpAxis2XPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1XPUINT64(TestArgsortOpAxisNeg1XPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2XPUINT64(TestArgsortOpAxisNeg2XPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpAxis0XPUINT(TestArgsortOp): + def setUp(self): + self.set_xpu() + self.op_type = "argsort" + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.init_inputshape() + self.init_axis() + self.init_direction() + + self.x = np.random.randint( + low=-1000, high=1000, size=self.input_shape).astype(self.dtype) + self.inputs = {"X": self.x} + self.attrs = {"axis": self.axis, "descending": self.descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def init_axis(self): + self.axis = 0 + + def init_dtype(self): + self.dtype = 'int' + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_scatter_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_scatter_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..16b75cd3f0145dc64938cedba0a08635dd95b72d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_scatter_op_xpu.py @@ -0,0 +1,169 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +paddle.enable_static() + + +class TestScatterOp(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 50)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 50)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def set_xpu(self): + self.__class__.use_xpu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + +class TestScatterOp0(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.attrs = {'overwrite': True} + self.outputs = {'Out': output_np} + + +class TestScatterOp1(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("float32") + zeros_np = np.zeros([2, 3]).astype('float32') + index_np = np.array([1, 1]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = zeros_np + for i in range(0, len(index_np)): + output_np[index_np[i]] += updates_np[i] + self.attrs = {'overwrite': False} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +class TestScatterOp2(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +class TestScatterOp3(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("float32") + zeros_np = np.zeros([2, 3]).astype('float32') + index_np = np.array([1, 1]).astype("int32") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = zeros_np + for i in range(0, len(index_np)): + output_np[index_np[i]] += updates_np[i] + self.attrs = {'overwrite': False} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +class TestScatterOp4(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +class TestScatterOp5(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +class TestScatterOp6(TestScatterOp): + def setUp(self): + self.set_xpu() + self.op_type = "scatter" + self.place = paddle.XPUPlace(0) + + ref_np = np.ones((3, 3)).astype("int64") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("int64") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +if __name__ == '__main__': + unittest.main()