# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard class TestTakeAPI(unittest.TestCase): def set_mode(self): self.mode = 'raise' def set_dtype(self): self.input_dtype = 'float64' self.index_dtype = 'int64' def set_input(self): self.input_shape = [3, 4] self.index_shape = [2, 3] self.input_np = ( np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype) ) self.index_np = ( np.arange(-4, 2).reshape(self.index_shape).astype(self.index_dtype) ) def setUp(self): self.set_mode() self.set_dtype() self.set_input() self.place = ( fluid.CUDAPlace(0) if core.is_compiled_with_cuda() else fluid.CPUPlace() ) def test_static_graph(self): paddle.enable_static() startup_program = Program() train_program = Program() with program_guard(startup_program, train_program): x = fluid.data( name='input', dtype=self.input_dtype, shape=self.input_shape ) index = fluid.data( name='index', dtype=self.index_dtype, shape=self.index_shape ) out = paddle.take(x, index, mode=self.mode) exe = fluid.Executor(self.place) st_result = exe.run( fluid.default_main_program(), feed={'input': self.input_np, 'index': self.index_np}, fetch_list=out, ) np.testing.assert_allclose( st_result[0], np.take(self.input_np, self.index_np, mode=self.mode), ) def test_dygraph(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.input_np) index = paddle.to_tensor(self.index_np) dy_result = paddle.take(x, index, mode=self.mode) np.testing.assert_allclose( np.take(self.input_np, self.index_np, mode=self.mode), dy_result.numpy(), ) class TestTakeInt32(TestTakeAPI): """Test take API with data type int32""" def set_dtype(self): self.input_dtype = 'int32' self.index_dtype = 'int64' class TestTakeInt64(TestTakeAPI): """Test take API with data type int64""" def set_dtype(self): self.input_dtype = 'int64' self.index_dtype = 'int64' class TestTakeFloat32(TestTakeAPI): """Test take API with data type float32""" def set_dtype(self): self.input_dtype = 'float32' self.index_dtype = 'int64' class TestTakeTypeError(TestTakeAPI): """Test take Type Error""" def test_static_type_error(self): """Argument 'index' must be Tensor""" paddle.enable_static() with program_guard(Program()): x = fluid.data( name='input', dtype=self.input_dtype, shape=self.input_shape ) self.assertRaises( TypeError, paddle.take, x, self.index_np, self.mode ) def test_dygraph_type_error(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.input_np) self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode) def test_static_dtype_error(self): """Data type of argument 'index' must be in [paddle.int32, paddle.int64]""" paddle.enable_static() with program_guard(Program()): x = fluid.data( name='input', dtype='float64', shape=self.input_shape ) index = fluid.data( name='index', dtype='float32', shape=self.index_shape ) self.assertRaises(TypeError, paddle.take, x, index, self.mode) def test_dygraph_dtype_error(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.input_np) index = paddle.to_tensor(self.index_np, dtype='float32') self.assertRaises(TypeError, paddle.take, x, index, self.mode) class TestTakeModeRaisePos(unittest.TestCase): """Test positive index out of range error""" def set_mode(self): self.mode = 'raise' def set_dtype(self): self.input_dtype = 'float64' self.index_dtype = 'int64' def set_input(self): self.input_shape = [3, 4] self.index_shape = [5, 6] self.input_np = ( np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype) ) self.index_np = ( np.arange(-10, 20) .reshape(self.index_shape) .astype(self.index_dtype) ) # positive indices are out of range def setUp(self): self.set_mode() self.set_dtype() self.set_input() self.place = ( fluid.CUDAPlace(0) if core.is_compiled_with_cuda() else fluid.CPUPlace() ) def test_static_index_error(self): """When the index is out of range, an error is reported directly through `paddle.index_select`""" paddle.enable_static() with program_guard(Program()): x = fluid.data( name='input', dtype=self.input_dtype, shape=self.input_shape ) index = fluid.data( name='index', dtype=self.index_dtype, shape=self.index_shape ) self.assertRaises(ValueError, paddle.index_select, x, index) def test_dygraph_index_error(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.input_np) index = paddle.to_tensor(self.index_np, dtype=self.index_dtype) self.assertRaises(ValueError, paddle.index_select, x, index) class TestTakeModeRaiseNeg(TestTakeModeRaisePos): """Test negative index out of range error""" def set_mode(self): self.mode = 'raise' def set_dtype(self): self.input_dtype = 'float64' self.index_dtype = 'int64' def set_input(self): self.input_shape = [3, 4] self.index_shape = [5, 6] self.input_np = ( np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype) ) self.index_np = ( np.arange(-20, 10) .reshape(self.index_shape) .astype(self.index_dtype) ) # negative indices are out of range def setUp(self): self.set_mode() self.set_dtype() self.set_input() self.place = ( fluid.CUDAPlace(0) if core.is_compiled_with_cuda() else fluid.CPUPlace() ) class TestTakeModeWrap(TestTakeAPI): """Test take index out of range mode""" def set_mode(self): self.mode = 'wrap' def set_input(self): self.input_shape = [3, 4] self.index_shape = [5, 8] self.input_np = ( np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype) ) self.index_np = ( np.arange(-20, 20) .reshape(self.index_shape) .astype(self.index_dtype) ) # Both ends of the index are out of bounds class TestTakeModeClip(TestTakeAPI): """Test take index out of range mode""" def set_mode(self): self.mode = 'clip' def set_input(self): self.input_shape = [3, 4] self.index_shape = [5, 8] self.input_np = ( np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype) ) self.index_np = ( np.arange(-20, 20) .reshape(self.index_shape) .astype(self.index_dtype) ) # Both ends of the index are out of bounds if __name__ == "__main__": unittest.main()