test_where_op.py 18.2 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest, convert_float_to_uint16
19

G
GaoWei8 已提交
20
import paddle
21
from paddle import fluid
22
from paddle.fluid import Program, core, program_guard
23 24 25 26 27
from paddle.fluid.backward import append_backward


class TestWhereOp(OpTest):
    def setUp(self):
28
        self.op_type = 'where'
H
hong 已提交
29
        self.python_api = paddle.where
30
        self.check_cinn = True
31 32 33 34 35
        self.init_config()
        self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
        self.outputs = {'Out': np.where(self.cond, self.x, self.y)}

    def test_check_output(self):
36
        self.check_output(check_cinn=self.check_cinn)
37 38

    def test_check_grad(self):
39
        self.check_grad(['X', 'Y'], 'Out', check_cinn=self.check_cinn)
40 41

    def init_config(self):
42 43 44
        self.x = np.random.uniform((-3), 5, 100).astype('float64')
        self.y = np.random.uniform((-3), 5, 100).astype('float64')
        self.cond = np.zeros(100).astype('bool')
45 46 47 48


class TestWhereOp2(TestWhereOp):
    def init_config(self):
49 50 51
        self.x = np.random.uniform((-5), 5, (60, 2)).astype('float64')
        self.y = np.random.uniform((-5), 5, (60, 2)).astype('float64')
        self.cond = np.ones((60, 2)).astype('bool')
52 53


54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
class TestWhereFP16OP(TestWhereOp):
    def init_config(self):
        self.dtype = np.float16
        self.x = np.random.uniform((-5), 5, (60, 2)).astype(self.dtype)
        self.y = np.random.uniform((-5), 5, (60, 2)).astype(self.dtype)
        self.cond = np.ones((60, 2)).astype('bool')


@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not complied with CUDA and not support the bfloat16",
)
class TestWhereBF16OP(OpTest):
    def setUp(self):
        self.op_type = 'where'
        self.dtype = np.uint16
        self.python_api = paddle.where
72
        self.check_cinn = True
73 74 75 76 77 78 79 80 81 82 83 84
        self.init_config()
        self.inputs = {
            'Condition': self.cond,
            'X': convert_float_to_uint16(self.x),
            'Y': convert_float_to_uint16(self.y),
        }
        self.outputs = {
            'Out': convert_float_to_uint16(np.where(self.cond, self.x, self.y))
        }

    def test_check_output(self):
        place = core.CUDAPlace(0)
85
        self.check_output_with_place(place, check_cinn=self.check_cinn)
86 87 88 89

    def test_check_grad(self):
        place = core.CUDAPlace(0)
        self.check_grad_with_place(
90 91 92 93 94
            place,
            ['X', 'Y'],
            'Out',
            numeric_grad_delta=0.05,
            check_cinn=self.check_cinn,
95 96 97 98 99 100 101 102
        )

    def init_config(self):
        self.x = np.random.uniform((-5), 5, (60, 2)).astype(np.float32)
        self.y = np.random.uniform((-5), 5, (60, 2)).astype(np.float32)
        self.cond = np.random.randint(2, size=(60, 2)).astype('bool')


103 104
class TestWhereOp3(TestWhereOp):
    def init_config(self):
105 106
        self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
        self.y = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
107 108 109 110
        self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool)


class TestWhereAPI(unittest.TestCase):
G
GaoWei8 已提交
111 112
    def setUp(self):
        self.init_data()
113

G
GaoWei8 已提交
114 115 116
    def init_data(self):
        self.shape = [10, 15]
        self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool)
117 118
        self.x = np.random.uniform((-2), 3, self.shape).astype(np.float32)
        self.y = np.random.uniform((-2), 3, self.shape).astype(np.float32)
G
GaoWei8 已提交
119
        self.out = np.where(self.cond, self.x, self.y)
120

G
GaoWei8 已提交
121
    def ref_x_backward(self, dout):
122
        return np.where(self.cond, dout, 0)
G
GaoWei8 已提交
123 124

    def ref_y_backward(self, dout):
125
        return np.where(~self.cond, dout, 0)
G
GaoWei8 已提交
126 127 128 129 130

    def test_api(self, use_cuda=False):
        for x_stop_gradient in [False, True]:
            for y_stop_gradient in [False, True]:
                with fluid.program_guard(Program(), Program()):
G
GGBond8488 已提交
131 132
                    cond = paddle.static.data(
                        name='cond', shape=[-1] + self.shape, dtype='bool'
133
                    )
G
GGBond8488 已提交
134 135 136
                    cond.desc.set_need_check_feed(False)
                    x = paddle.static.data(
                        name='x', shape=[-1] + self.shape, dtype='float32'
137
                    )
G
GGBond8488 已提交
138 139 140
                    x.desc.set_need_check_feed(False)
                    y = paddle.static.data(
                        name='y', shape=[-1] + self.shape, dtype='float32'
141
                    )
G
GGBond8488 已提交
142
                    y.desc.set_need_check_feed(False)
G
GaoWei8 已提交
143
                    x.stop_gradient = x_stop_gradient
G
GGBond8488 已提交
144
                    x.desc.set_need_check_feed(False)
G
GaoWei8 已提交
145
                    y.stop_gradient = y_stop_gradient
G
GGBond8488 已提交
146
                    y.desc.set_need_check_feed(False)
G
GaoWei8 已提交
147
                    result = paddle.where(cond, x, y)
148
                    result.stop_gradient = False
149
                    append_backward(paddle.mean(result))
G
GaoWei8 已提交
150
                    for use_cuda in [False, True]:
151 152 153
                        if use_cuda and (
                            not fluid.core.is_compiled_with_cuda()
                        ):
G
GaoWei8 已提交
154
                            break
155 156 157
                        place = (
                            fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
                        )
G
GaoWei8 已提交
158 159
                        exe = fluid.Executor(place)
                        fetch_list = [result, result.grad_name]
160
                        if x_stop_gradient is False:
G
GaoWei8 已提交
161
                            fetch_list.append(x.grad_name)
162
                        if y_stop_gradient is False:
G
GaoWei8 已提交
163
                            fetch_list.append(y.grad_name)
164 165 166 167 168
                        out = exe.run(
                            fluid.default_main_program(),
                            feed={'cond': self.cond, 'x': self.x, 'y': self.y},
                            fetch_list=fetch_list,
                        )
169
                        np.testing.assert_array_equal(out[0], self.out)
170
                        if x_stop_gradient is False:
171
                            np.testing.assert_array_equal(
172 173 174
                                out[2], self.ref_x_backward(out[1])
                            )
                            if y.stop_gradient is False:
175
                                np.testing.assert_array_equal(
176 177 178
                                    out[3], self.ref_y_backward(out[1])
                                )
                        elif y.stop_gradient is False:
179
                            np.testing.assert_array_equal(
180 181
                                out[2], self.ref_y_backward(out[1])
                            )
182 183 184 185

    def test_api_broadcast(self, use_cuda=False):
        main_program = Program()
        with fluid.program_guard(main_program):
G
GGBond8488 已提交
186 187 188 189
            x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32')
            x.desc.set_need_check_feed(False)
            y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32')
            y.desc.set_need_check_feed(False)
190
            x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32')
191 192 193
            y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype(
                'float32'
            )
194
            result = paddle.where((x > 1), x=x, y=y)
195
            for use_cuda in [False, True]:
196
                if use_cuda and (not fluid.core.is_compiled_with_cuda()):
197
                    return
198
                place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
199
                exe = fluid.Executor(place)
200 201 202 203 204
                out = exe.run(
                    fluid.default_main_program(),
                    feed={'x': x_i, 'y': y_i},
                    fetch_list=[result],
                )
205 206 207
                np.testing.assert_array_equal(
                    out[0], np.where((x_i > 1), x_i, y_i)
                )
208

R
ronnywang 已提交
209 210 211 212 213
    def test_scalar(self):
        paddle.enable_static()
        main_program = Program()
        with fluid.program_guard(main_program):
            cond_shape = [2, 4]
G
GGBond8488 已提交
214 215
            cond = paddle.static.data(
                name='cond', shape=[-1] + cond_shape, dtype='bool'
216
            )
G
GGBond8488 已提交
217
            cond.desc.set_need_check_feed(False)
R
ronnywang 已提交
218 219 220 221 222
            x_data = 1.0
            y_data = 2.0
            cond_data = np.array([False, False, True, True]).astype('bool')
            result = paddle.where(condition=cond, x=x_data, y=y_data)
            for use_cuda in [False, True]:
223
                if use_cuda and (not fluid.core.is_compiled_with_cuda()):
R
ronnywang 已提交
224
                    return
225
                place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
R
ronnywang 已提交
226
                exe = fluid.Executor(place)
227 228 229 230 231
                out = exe.run(
                    fluid.default_main_program(),
                    feed={'cond': cond_data},
                    fetch_list=[result],
                )
R
ronnywang 已提交
232
                expect = np.where(cond_data, x_data, y_data)
233
                np.testing.assert_array_equal(out[0], expect)
R
ronnywang 已提交
234

235 236 237 238
    def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
        paddle.enable_static()
        main_program = Program()
        with fluid.program_guard(main_program):
G
GGBond8488 已提交
239 240
            cond = paddle.static.data(
                name='cond', shape=[-1] + cond_shape, dtype='bool'
241
            )
G
GGBond8488 已提交
242 243 244 245 246 247 248 249 250
            x = paddle.static.data(
                name='x', shape=[-1] + x_shape, dtype='float32'
            )
            y = paddle.static.data(
                name='y', shape=[-1] + y_shape, dtype='float32'
            )
            x.desc.set_need_check_feed(False)
            y.desc.set_need_check_feed(False)
            cond.desc.set_need_check_feed(False)
251
            cond_data_tmp = np.random.random(size=cond_shape).astype('float32')
252
            cond_data = cond_data_tmp < 0.3
253 254
            x_data = np.random.random(size=x_shape).astype('float32')
            y_data = np.random.random(size=y_shape).astype('float32')
255 256
            result = paddle.where(condition=cond, x=x, y=y)
            for use_cuda in [False, True]:
257
                if use_cuda and (not fluid.core.is_compiled_with_cuda()):
258
                    return
259
                place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
260
                exe = fluid.Executor(place)
261 262 263 264 265
                out = exe.run(
                    fluid.default_main_program(),
                    feed={'cond': cond_data, 'x': x_data, 'y': y_data},
                    fetch_list=[result],
                )
266
                expect = np.where(cond_data, x_data, y_data)
267
                np.testing.assert_array_equal(out[0], expect)
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316

    def test_static_api_broadcast_1(self):
        cond_shape = [2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_2(self):
        cond_shape = [2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_3(self):
        cond_shape = [2, 2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_4(self):
        cond_shape = [2, 1, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_5(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_6(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_7(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 1, 4]
        b_shape = [2, 1, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_8(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

317 318 319 320

class TestWhereDygraphAPI(unittest.TestCase):
    def test_api(self):
        with fluid.dygraph.guard():
321 322 323
            x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
            y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
            cond_i = np.array([False, False, True, True]).astype('bool')
324 325 326
            x = fluid.dygraph.to_variable(x_i)
            y = fluid.dygraph.to_variable(y_i)
            cond = fluid.dygraph.to_variable(cond_i)
G
GaoWei8 已提交
327
            out = paddle.where(cond, x, y)
328 329 330
            np.testing.assert_array_equal(
                out.numpy(), np.where(cond_i, x_i, y_i)
            )
331

R
ronnywang 已提交
332 333 334 335 336 337 338
    def test_scalar(self):
        with fluid.dygraph.guard():
            cond_i = np.array([False, False, True, True]).astype('bool')
            x = 1.0
            y = 2.0
            cond = fluid.dygraph.to_variable(cond_i)
            out = paddle.where(cond, x, y)
339
            np.testing.assert_array_equal(out.numpy(), np.where(cond_i, x, y))
R
ronnywang 已提交
340

341 342 343
    def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
        with fluid.dygraph.guard():
            cond_tmp = paddle.rand(cond_shape)
344
            cond = cond_tmp < 0.3
345 346 347 348 349
            a = paddle.rand(a_shape)
            b = paddle.rand(b_shape)
            result = paddle.where(cond, a, b)
            result = result.numpy()
            expect = np.where(cond, a, b)
350
            np.testing.assert_array_equal(expect, result)
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399

    def test_dygraph_api_broadcast_1(self):
        cond_shape = [2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_2(self):
        cond_shape = [2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_3(self):
        cond_shape = [2, 2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_4(self):
        cond_shape = [2, 1, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_5(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_6(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_7(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 1, 4]
        b_shape = [2, 1, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_8(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

R
ronnywang 已提交
400 401 402
    def test_where_condition(self):
        data = np.array([[True, False], [False, True]])
        with program_guard(Program(), Program()):
G
GGBond8488 已提交
403 404
            x = paddle.static.data(name='x', shape=[(-1), 2], dtype='float32')
            x.desc.set_need_check_feed(False)
R
ronnywang 已提交
405 406 407
            y = paddle.where(x)
            self.assertEqual(type(y), tuple)
            self.assertEqual(len(y), 2)
408
            z = paddle.concat(list(y), axis=1)
R
ronnywang 已提交
409
            exe = fluid.Executor(fluid.CPUPlace())
410 411 412
            (res,) = exe.run(
                feed={'x': data}, fetch_list=[z.name], return_numpy=False
            )
R
ronnywang 已提交
413
        expect_out = np.array([[0, 0], [1, 1]])
414
        np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
R
ronnywang 已提交
415 416
        data = np.array([True, True, False])
        with program_guard(Program(), Program()):
G
GGBond8488 已提交
417 418
            x = paddle.static.data(name='x', shape=[(-1)], dtype='float32')
            x.desc.set_need_check_feed(False)
R
ronnywang 已提交
419 420 421
            y = paddle.where(x)
            self.assertEqual(type(y), tuple)
            self.assertEqual(len(y), 1)
422
            z = paddle.concat(list(y), axis=1)
R
ronnywang 已提交
423
            exe = fluid.Executor(fluid.CPUPlace())
424 425 426
            (res,) = exe.run(
                feed={'x': data}, fetch_list=[z.name], return_numpy=False
            )
R
ronnywang 已提交
427
        expect_out = np.array([[0], [1]])
428
        np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
R
ronnywang 已提交
429

430 431 432 433

class TestWhereOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):
434 435 436
            x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
            y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
            cond_i = np.array([False, False, True, True]).astype('bool')
437 438

            def test_Variable():
G
GaoWei8 已提交
439
                paddle.where(cond_i, x_i, y_i)
440 441 442 443

            self.assertRaises(TypeError, test_Variable)

            def test_type():
G
GGBond8488 已提交
444 445 446 447 448 449 450 451
                x = paddle.static.data(name='x', shape=[-1, 4], dtype='bool')
                x.desc.set_need_check_feed(False)
                y = paddle.static.data(name='y', shape=[-1, 4], dtype='float16')
                y.desc.set_need_check_feed(False)
                cond = paddle.static.data(
                    name='cond', shape=[-1, 4], dtype='int32'
                )
                cond.desc.set_need_check_feed(False)
G
GaoWei8 已提交
452
                paddle.where(cond, x, y)
453 454 455

            self.assertRaises(TypeError, test_type)

R
ronnywang 已提交
456 457 458 459
    def test_value_error(self):
        with fluid.dygraph.guard():
            cond_shape = [2, 2, 4]
            cond_tmp = paddle.rand(cond_shape)
460
            cond = cond_tmp < 0.3
R
ronnywang 已提交
461 462 463
            a = paddle.rand(cond_shape)
            self.assertRaises(ValueError, paddle.where, cond, a)

464

H
hong 已提交
465 466
if __name__ == "__main__":
    paddle.enable_static()
467
    unittest.main()