diff --git a/paddle/fluid/operators/masked_select_op_xpu.cc b/paddle/fluid/operators/masked_select_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..665ac937fdc05c61ab003ac3cce8fd53aa1b9f66 --- /dev/null +++ b/paddle/fluid/operators/masked_select_op_xpu.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/masked_select_op.h" + +namespace paddle { +namespace operators { + +template +class MaskedSelectXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto input = context.Input("X"); + auto mask = context.Input("Mask"); + auto out = context.Output("Y"); + auto* mask_data = mask->data(); + auto* input_data = input->data(); + auto input_dim = input->dims(); + auto mask_dim = mask->dims(); + PADDLE_ENFORCE_EQ( + input_dim, mask_dim, + platform::errors::InvalidArgument( + "The dim size of input and mask in OP(masked_selected) " + "must be equal, but got input dim:(%ld), mask dim: " + "(%ld). Please check input " + "value.", + input_dim, mask_dim)); + auto& dev_ctx = + context.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* out_size = RAII_GUARD.alloc_l3_or_gm(1); + int out_size_cpu; + + int ret = xpu::nonzero_count(dev_ctx.x_context(), mask_data, out_size, + mask->numel()); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU nonzero_count kernel return wrong value[%d %s]", + ret, XPUAPIErrorMsg[ret])); + + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + ret = xpu_memcpy(static_cast(&out_size_cpu), + static_cast(out_size), sizeof(int32_t), + XPU_DEVICE_TO_HOST); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External("XPU xpu_memcpy return wrong " + "value[%d %s]", + ret, XPUAPIErrorMsg[ret])); + + framework::DDim out_dim{out_size_cpu}; + out->Resize(out_dim); + auto out_data = out->mutable_data(context.GetPlace()); + + auto input_shape = framework::vectorize(input_dim); + auto mask_shape = framework::vectorize(mask_dim); + + ret = xpu::masked_select(dev_ctx.x_context(), input_data, mask_data, + out_data, input_shape, mask_shape); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU masked_select kernel return wrong value[%d %s]", + ret, XPUAPIErrorMsg[ret])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL(masked_select, ops::MaskedSelectXPUKernel, + ops::MaskedSelectXPUKernel, + ops::MaskedSelectXPUKernel); +#endif diff --git a/paddle/fluid/operators/where_index_op_xpu.cc b/paddle/fluid/operators/where_index_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..58f09e7381ed0d587d4f5c8ecf8a890b724eeb6e --- /dev/null +++ b/paddle/fluid/operators/where_index_op_xpu.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/where_index_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class WhereIndexXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + auto* out = context.Output("Out"); + + const T* cond_data = condition->data(); + auto numel = condition->numel(); + auto dims = condition->dims(); + const int rank = dims.size(); + + auto& dev_ctx = + context.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* true_num = RAII_GUARD.alloc_l3_or_gm(1); + int true_num_cpu; + int ret = + xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External( + "XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex", + ret, XPUAPIErrorMsg[ret])); + + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + ret = xpu_memcpy(static_cast(&true_num_cpu), + static_cast(true_num), sizeof(int32_t), + XPU_DEVICE_TO_HOST); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External("XPU xpu_memcpy return wrong " + "value[%d %s]", + ret, XPUAPIErrorMsg[ret])); + + out->Resize( + framework::make_ddim({static_cast(true_num_cpu), rank})); + auto out_data = out->mutable_data(context.GetPlace()); + if (true_num_cpu == 0) { + return; + } + + auto condition_shape = framework::vectorize(dims); + ret = xpu::where(dev_ctx.x_context(), cond_data, out_data, condition_shape, + true_num_cpu); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU masked_select kernel return wrong value[%d %s]", + ret, XPUAPIErrorMsg[ret])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(where_index, ops::WhereIndexXPUKernel, + ops::WhereIndexXPUKernel, + ops::WhereIndexXPUKernel); +#endif diff --git a/paddle/fluid/operators/where_op_xpu.cc b/paddle/fluid/operators/where_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c12bb55146f1168a40811bbcdb35bdfc355255d --- /dev/null +++ b/paddle/fluid/operators/where_op_xpu.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/where_op.h" + +namespace paddle { +namespace operators { + +template +class WhereXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* out = context.Output("Out"); + + const bool* cond_data = condition->data(); + const T* x_data = X->data(); + const T* y_data = Y->data(); + T* out_data = out->mutable_data(context.GetPlace()); + + auto cond_dims = framework::vectorize(condition->dims()); + auto input_dims = framework::vectorize(X->dims()); + + auto& dev_ctx = context.template device_context(); + int ret = xpu::select(dev_ctx.x_context(), cond_data, x_data, y_data, + out_data, cond_dims, input_dims); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU select kernel return wrong value[%d %s]", ret, + XPUAPIErrorMsg[ret])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + where, ops::WhereXPUKernel, + ops::WhereXPUKernel, + ops::WhereXPUKernel); +#endif diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index 5eb86a36f5167d0e799bd9b42a83b75c4ff4f371..58109092fbda686da7090817c3de83dbbf47b415 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -261,8 +261,17 @@ XPUOpMap& get_kl2_ops() { {"tile", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})} - + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"where", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"where_index", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"masked_select", + XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})} // AddMore }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..8c5b3f3d8a9afa1ac7e812c54b9f05bf33ec07ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +def np_masked_select(x, mask): + result = np.empty(shape=(0), dtype=x.dtype) + for ele, ma in zip(np.nditer(x), np.nditer(mask)): + if ma: + result = np.append(result, ele) + return result.flatten() + + +class TestMaskedSelectOp(XPUOpTest): + def set_xpu(self): + self.__class__.use_xpu = True + + def setUp(self): + self.set_xpu() + self.init() + self.init_dtype() + self.place = paddle.XPUPlace(0) + self.op_type = "masked_select" + x = np.random.random(self.shape).astype(self.dtype) + mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) + out = np_masked_select(x, mask) + self.inputs = {'X': x, 'Mask': mask} + self.outputs = {'Y': out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + def init(self): + self.shape = (50, 3) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestMaskedSelectOp1(TestMaskedSelectOp): + def init(self): + self.shape = (6, 8, 9, 18) + + +class TestMaskedSelectOp2(TestMaskedSelectOp): + def init(self): + self.shape = (168, ) + + +class TestMaskedSelectOpInt32(TestMaskedSelectOp): + def init_dtype(self): + self.dtype = np.int32 + + # skip_check_grad_ci(reason="get_numeric_gradient not support int32") + def test_check_grad(self): + pass + + +class TestMaskedSelectOpInt64(TestMaskedSelectOp): + def init_dtype(self): + self.dtype = np.int64 + + # skip_check_grad_ci(reason="get_numeric_gradient not support int64") + def test_check_grad(self): + pass + + +class TestMaskedSelectAPI(unittest.TestCase): + def test_imperative_mode(self): + paddle.disable_static(paddle.XPUPlace(0)) + shape = (88, 6, 8) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + x = paddle.to_tensor(np_x) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + np_out = np_masked_select(np_x, np_mask) + self.assertEqual(np.allclose(out.numpy(), np_out), True) + paddle.enable_static() + + def test_static_mode(self): + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='float32', name='x') + mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask') + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + + out = paddle.masked_select(x, mask) + np_out = np_masked_select(np_x, np_mask) + + exe = paddle.static.Executor(place=paddle.XPUPlace(0)) + + res = exe.run(paddle.static.default_main_program(), + feed={"x": np_x, + "mask": np_mask}, + fetch_list=[out]) + self.assertEqual(np.allclose(res, np_out), True) + + +class TestMaskedSelectError(unittest.TestCase): + def test_error(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='float32', name='x') + mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask') + mask_float = paddle.fluid.data( + shape=shape, dtype='float32', name='mask_float') + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + + def test_x_type(): + paddle.masked_select(np_x, mask) + + self.assertRaises(TypeError, test_x_type) + + def test_mask_type(): + paddle.masked_select(x, np_mask) + + self.assertRaises(TypeError, test_mask_type) + + def test_mask_dtype(): + paddle.masked_select(x, mask_float) + + self.assertRaises(TypeError, test_mask_dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_where_index_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_where_index_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..69b4f5a03ed18f92dce682a12ca79c98c16c01c5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_where_index_xpu.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import paddle +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +class TestWhereIndexOp(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "where_index" + self.place = paddle.XPUPlace(0) + self.init_config() + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + def init_config(self): + self.inputs = {'Condition': np.array([True, False, True]), } + self.outputs = {'Out': np.array([[0], [2]], dtype='int64')} + + def set_xpu(self): + self.__class__.use_xpu = True + + +class TestNotBool(TestWhereIndexOp): + def init_config(self): + self.inputs = {'Condition': np.array([1, 0, 8]), } + + self.outputs = {'Out': np.array([[0], [2]], dtype='int64')} + + +class TestAllFalse(TestWhereIndexOp): + def init_config(self): + self.inputs = {'Condition': np.array([False, False, False]), } + self.outputs = {'Out': np.array([], dtype='int64')} + + +class TestRank2(TestWhereIndexOp): + def init_config(self): + self.inputs = {'Condition': np.array([[True, False], [False, True]]), } + self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')} + + +class TestRank3(TestWhereIndexOp): + def init_config(self): + self.inputs = { + 'Condition': np.array([[[True, False], [False, True]], + [[False, True], [True, False]], + [[False, False], [False, True]]]), + } + + self.outputs = { + 'Out': np.array( + [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]], + dtype='int64') + } + + +class TestWhereOpError(unittest.TestCase): + def test_api(self): + with program_guard(Program(), Program()): + cond = fluid.layers.data(name='cond', shape=[4], dtype='bool') + result = fluid.layers.where(cond) + + exe = fluid.Executor(paddle.XPUPlace(0)) + exe.run(fluid.default_startup_program()) + cond_i = np.array([True, False, False, False]).astype("bool") + out = exe.run(fluid.default_main_program(), feed={'cond': cond_i}) + + +class TestWhereRaiseError(unittest.TestCase): + def test_errors(self): + def test_type(): + fluid.layers.where([10]) + + self.assertRaises(TypeError, test_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_where_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_where_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..2161ec24dbf87349c6846afe4c3380fc5b288e07 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_where_op_xpu.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program +from paddle.fluid.backward import append_backward + +paddle.enable_static() + + +class TestXPUWhereOp(XPUOpTest): + def setUp(self): + self.op_type = "where" + self.set_xpu() + self.init_config() + self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y} + self.outputs = {'Out': np.where(self.cond, self.x, self.y)} + + def init_config(self): + self.x = np.random.uniform(-3, 5, (100)).astype("float32") + self.y = np.random.uniform(-3, 5, (100)).astype("float32") + self.cond = np.zeros((100)).astype("bool") + + def set_xpu(self): + self.__class__.use_xpu = True + self.place = paddle.XPUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + +class TestXPUWhereOp2(TestXPUWhereOp): + def init_config(self): + self.x = np.random.uniform(-5, 5, (60, 2)).astype("float32") + self.y = np.random.uniform(-5, 5, (60, 2)).astype("float32") + self.cond = np.ones((60, 2)).astype("bool") + + +class TestXPUWhereOp3(TestXPUWhereOp): + def init_config(self): + self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype("float32") + self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype("float32") + self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool) + + +class TestXPUWhereAPI(unittest.TestCase): + def setUp(self): + self.__class__.use_xpu = True + self.place = paddle.XPUPlace(0) + self.init_data() + + def init_data(self): + self.shape = [10, 15] + self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool) + self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.out = np.where(self.cond, self.x, self.y) + + def ref_x_backward(self, dout): + return np.where(self.cond == True, dout, 0) + + def ref_y_backward(self, dout): + return np.where(self.cond == False, dout, 0) + + def test_api(self): + for x_stop_gradient in [False, True]: + for y_stop_gradient in [False, True]: + train_prog = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(train_prog, startup): + cond = fluid.data( + name='cond', shape=self.shape, dtype='bool') + x = fluid.data(name='x', shape=self.shape, dtype='float32') + y = fluid.data(name='y', shape=self.shape, dtype='float32') + + x.stop_gradient = x_stop_gradient + y.stop_gradient = y_stop_gradient + + result = paddle.where(cond, x, y) + append_backward(fluid.layers.mean(result)) + + exe = fluid.Executor(self.place) + exe.run(startup) + + fetch_list = [result, result.grad_name] + if x_stop_gradient is False: + fetch_list.append(x.grad_name) + if y_stop_gradient is False: + fetch_list.append(y.grad_name) + out = exe.run( + train_prog, + feed={'cond': self.cond, + 'x': self.x, + 'y': self.y}, + fetch_list=fetch_list) + assert np.array_equal(out[0], self.out) + + if x_stop_gradient is False: + assert np.array_equal(out[2], + self.ref_x_backward(out[1])) + if y.stop_gradient is False: + assert np.array_equal(out[3], + self.ref_y_backward(out[1])) + elif y.stop_gradient is False: + assert np.array_equal(out[2], + self.ref_y_backward(out[1])) + + def test_api_broadcast(self, use_cuda=False): + train_prog = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(train_prog, startup): + x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32') + y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32') + x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") + y_i = np.array([[1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0]]).astype("float32") + result = paddle.where(x > 1, x=x, y=y) + + exe = fluid.Executor(self.place) + exe.run(startup) + + out = exe.run(train_prog, + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result]) + assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i)) + + +class TestWhereDygraphAPI(unittest.TestCase): + def test_api(self): + with fluid.dygraph.guard(paddle.XPUPlace(0)): + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") + cond_i = np.array([False, False, True, True]).astype("bool") + x = fluid.dygraph.to_variable(x_i) + y = fluid.dygraph.to_variable(y_i) + cond = fluid.dygraph.to_variable(cond_i) + out = paddle.where(cond, x, y) + assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) + + +if __name__ == '__main__': + unittest.main()