test_where_op.py 16.4 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18 19
from op_test import OpTest

G
GaoWei8 已提交
20
import paddle
21
import paddle.fluid as fluid
22
from paddle.fluid import Program, program_guard
23 24 25 26 27
from paddle.fluid.backward import append_backward


class TestWhereOp(OpTest):
    def setUp(self):
28
        self.op_type = 'where'
H
hong 已提交
29
        self.python_api = paddle.where
30 31 32 33 34
        self.init_config()
        self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
        self.outputs = {'Out': np.where(self.cond, self.x, self.y)}

    def test_check_output(self):
35
        self.check_output(check_eager=False)
36 37

    def test_check_grad(self):
38
        self.check_grad(['X', 'Y'], 'Out', check_eager=False)
39 40

    def init_config(self):
41 42 43
        self.x = np.random.uniform((-3), 5, 100).astype('float64')
        self.y = np.random.uniform((-3), 5, 100).astype('float64')
        self.cond = np.zeros(100).astype('bool')
44 45 46 47


class TestWhereOp2(TestWhereOp):
    def init_config(self):
48 49 50
        self.x = np.random.uniform((-5), 5, (60, 2)).astype('float64')
        self.y = np.random.uniform((-5), 5, (60, 2)).astype('float64')
        self.cond = np.ones((60, 2)).astype('bool')
51 52 53 54


class TestWhereOp3(TestWhereOp):
    def init_config(self):
55 56
        self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
        self.y = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
57 58 59 60
        self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool)


class TestWhereAPI(unittest.TestCase):
G
GaoWei8 已提交
61 62
    def setUp(self):
        self.init_data()
63

G
GaoWei8 已提交
64 65 66
    def init_data(self):
        self.shape = [10, 15]
        self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool)
67 68
        self.x = np.random.uniform((-2), 3, self.shape).astype(np.float32)
        self.y = np.random.uniform((-2), 3, self.shape).astype(np.float32)
G
GaoWei8 已提交
69
        self.out = np.where(self.cond, self.x, self.y)
70

G
GaoWei8 已提交
71
    def ref_x_backward(self, dout):
72
        return np.where(self.cond, dout, 0)
G
GaoWei8 已提交
73 74

    def ref_y_backward(self, dout):
75
        return np.where(~self.cond, dout, 0)
G
GaoWei8 已提交
76 77 78 79 80

    def test_api(self, use_cuda=False):
        for x_stop_gradient in [False, True]:
            for y_stop_gradient in [False, True]:
                with fluid.program_guard(Program(), Program()):
G
GGBond8488 已提交
81 82
                    cond = paddle.static.data(
                        name='cond', shape=[-1] + self.shape, dtype='bool'
83
                    )
G
GGBond8488 已提交
84 85 86
                    cond.desc.set_need_check_feed(False)
                    x = paddle.static.data(
                        name='x', shape=[-1] + self.shape, dtype='float32'
87
                    )
G
GGBond8488 已提交
88 89 90
                    x.desc.set_need_check_feed(False)
                    y = paddle.static.data(
                        name='y', shape=[-1] + self.shape, dtype='float32'
91
                    )
G
GGBond8488 已提交
92
                    y.desc.set_need_check_feed(False)
G
GaoWei8 已提交
93
                    x.stop_gradient = x_stop_gradient
G
GGBond8488 已提交
94
                    x.desc.set_need_check_feed(False)
G
GaoWei8 已提交
95
                    y.stop_gradient = y_stop_gradient
G
GGBond8488 已提交
96
                    y.desc.set_need_check_feed(False)
G
GaoWei8 已提交
97
                    result = paddle.where(cond, x, y)
98
                    append_backward(paddle.mean(result))
G
GaoWei8 已提交
99
                    for use_cuda in [False, True]:
100 101 102
                        if use_cuda and (
                            not fluid.core.is_compiled_with_cuda()
                        ):
G
GaoWei8 已提交
103
                            break
104 105 106
                        place = (
                            fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
                        )
G
GaoWei8 已提交
107 108
                        exe = fluid.Executor(place)
                        fetch_list = [result, result.grad_name]
109
                        if x_stop_gradient is False:
G
GaoWei8 已提交
110
                            fetch_list.append(x.grad_name)
111
                        if y_stop_gradient is False:
G
GaoWei8 已提交
112
                            fetch_list.append(y.grad_name)
113 114 115 116 117
                        out = exe.run(
                            fluid.default_main_program(),
                            feed={'cond': self.cond, 'x': self.x, 'y': self.y},
                            fetch_list=fetch_list,
                        )
G
GaoWei8 已提交
118
                        assert np.array_equal(out[0], self.out)
119 120 121 122 123
                        if x_stop_gradient is False:
                            assert np.array_equal(
                                out[2], self.ref_x_backward(out[1])
                            )
                            if y.stop_gradient is False:
G
GaoWei8 已提交
124
                                assert np.array_equal(
125 126 127 128 129 130
                                    out[3], self.ref_y_backward(out[1])
                                )
                        elif y.stop_gradient is False:
                            assert np.array_equal(
                                out[2], self.ref_y_backward(out[1])
                            )
131 132 133 134

    def test_api_broadcast(self, use_cuda=False):
        main_program = Program()
        with fluid.program_guard(main_program):
G
GGBond8488 已提交
135 136 137 138
            x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32')
            x.desc.set_need_check_feed(False)
            y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32')
            y.desc.set_need_check_feed(False)
139
            x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32')
140 141 142
            y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype(
                'float32'
            )
143
            result = paddle.where((x > 1), x=x, y=y)
144
            for use_cuda in [False, True]:
145
                if use_cuda and (not fluid.core.is_compiled_with_cuda()):
146
                    return
147
                place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
148
                exe = fluid.Executor(place)
149 150 151 152 153
                out = exe.run(
                    fluid.default_main_program(),
                    feed={'x': x_i, 'y': y_i},
                    fetch_list=[result],
                )
154
                assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i))
155

R
ronnywang 已提交
156 157 158 159 160
    def test_scalar(self):
        paddle.enable_static()
        main_program = Program()
        with fluid.program_guard(main_program):
            cond_shape = [2, 4]
G
GGBond8488 已提交
161 162
            cond = paddle.static.data(
                name='cond', shape=[-1] + cond_shape, dtype='bool'
163
            )
G
GGBond8488 已提交
164
            cond.desc.set_need_check_feed(False)
R
ronnywang 已提交
165 166 167 168 169
            x_data = 1.0
            y_data = 2.0
            cond_data = np.array([False, False, True, True]).astype('bool')
            result = paddle.where(condition=cond, x=x_data, y=y_data)
            for use_cuda in [False, True]:
170
                if use_cuda and (not fluid.core.is_compiled_with_cuda()):
R
ronnywang 已提交
171
                    return
172
                place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
R
ronnywang 已提交
173
                exe = fluid.Executor(place)
174 175 176 177 178
                out = exe.run(
                    fluid.default_main_program(),
                    feed={'cond': cond_data},
                    fetch_list=[result],
                )
R
ronnywang 已提交
179 180 181
                expect = np.where(cond_data, x_data, y_data)
                assert np.array_equal(out[0], expect)

182 183 184 185
    def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
        paddle.enable_static()
        main_program = Program()
        with fluid.program_guard(main_program):
G
GGBond8488 已提交
186 187
            cond = paddle.static.data(
                name='cond', shape=[-1] + cond_shape, dtype='bool'
188
            )
G
GGBond8488 已提交
189 190 191 192 193 194 195 196 197
            x = paddle.static.data(
                name='x', shape=[-1] + x_shape, dtype='float32'
            )
            y = paddle.static.data(
                name='y', shape=[-1] + y_shape, dtype='float32'
            )
            x.desc.set_need_check_feed(False)
            y.desc.set_need_check_feed(False)
            cond.desc.set_need_check_feed(False)
198
            cond_data_tmp = np.random.random(size=cond_shape).astype('float32')
199
            cond_data = cond_data_tmp < 0.3
200 201
            x_data = np.random.random(size=x_shape).astype('float32')
            y_data = np.random.random(size=y_shape).astype('float32')
202 203
            result = paddle.where(condition=cond, x=x, y=y)
            for use_cuda in [False, True]:
204
                if use_cuda and (not fluid.core.is_compiled_with_cuda()):
205
                    return
206
                place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
207
                exe = fluid.Executor(place)
208 209 210 211 212
                out = exe.run(
                    fluid.default_main_program(),
                    feed={'cond': cond_data, 'x': x_data, 'y': y_data},
                    fetch_list=[result],
                )
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
                expect = np.where(cond_data, x_data, y_data)
                assert np.array_equal(out[0], expect)

    def test_static_api_broadcast_1(self):
        cond_shape = [2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_2(self):
        cond_shape = [2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_3(self):
        cond_shape = [2, 2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_4(self):
        cond_shape = [2, 1, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_5(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_6(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_7(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 1, 4]
        b_shape = [2, 1, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_8(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

264 265 266 267

class TestWhereDygraphAPI(unittest.TestCase):
    def test_api(self):
        with fluid.dygraph.guard():
268 269 270
            x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
            y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
            cond_i = np.array([False, False, True, True]).astype('bool')
271 272 273
            x = fluid.dygraph.to_variable(x_i)
            y = fluid.dygraph.to_variable(y_i)
            cond = fluid.dygraph.to_variable(cond_i)
G
GaoWei8 已提交
274
            out = paddle.where(cond, x, y)
275 276
            assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))

R
ronnywang 已提交
277 278 279 280 281 282 283 284 285
    def test_scalar(self):
        with fluid.dygraph.guard():
            cond_i = np.array([False, False, True, True]).astype('bool')
            x = 1.0
            y = 2.0
            cond = fluid.dygraph.to_variable(cond_i)
            out = paddle.where(cond, x, y)
            assert np.array_equal(out.numpy(), np.where(cond_i, x, y))

286 287 288
    def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
        with fluid.dygraph.guard():
            cond_tmp = paddle.rand(cond_shape)
289
            cond = cond_tmp < 0.3
290 291 292 293 294
            a = paddle.rand(a_shape)
            b = paddle.rand(b_shape)
            result = paddle.where(cond, a, b)
            result = result.numpy()
            expect = np.where(cond, a, b)
295
            np.testing.assert_array_equal(expect, result)
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344

    def test_dygraph_api_broadcast_1(self):
        cond_shape = [2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_2(self):
        cond_shape = [2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_3(self):
        cond_shape = [2, 2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_4(self):
        cond_shape = [2, 1, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_5(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_6(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_7(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 1, 4]
        b_shape = [2, 1, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_8(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

R
ronnywang 已提交
345 346 347
    def test_where_condition(self):
        data = np.array([[True, False], [False, True]])
        with program_guard(Program(), Program()):
G
GGBond8488 已提交
348 349
            x = paddle.static.data(name='x', shape=[(-1), 2], dtype='float32')
            x.desc.set_need_check_feed(False)
R
ronnywang 已提交
350 351 352 353 354
            y = paddle.where(x)
            self.assertEqual(type(y), tuple)
            self.assertEqual(len(y), 2)
            z = fluid.layers.concat(list(y), axis=1)
            exe = fluid.Executor(fluid.CPUPlace())
355 356 357
            (res,) = exe.run(
                feed={'x': data}, fetch_list=[z.name], return_numpy=False
            )
R
ronnywang 已提交
358
        expect_out = np.array([[0, 0], [1, 1]])
359
        np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
R
ronnywang 已提交
360 361
        data = np.array([True, True, False])
        with program_guard(Program(), Program()):
G
GGBond8488 已提交
362 363
            x = paddle.static.data(name='x', shape=[(-1)], dtype='float32')
            x.desc.set_need_check_feed(False)
R
ronnywang 已提交
364 365 366 367 368
            y = paddle.where(x)
            self.assertEqual(type(y), tuple)
            self.assertEqual(len(y), 1)
            z = fluid.layers.concat(list(y), axis=1)
            exe = fluid.Executor(fluid.CPUPlace())
369 370 371
            (res,) = exe.run(
                feed={'x': data}, fetch_list=[z.name], return_numpy=False
            )
R
ronnywang 已提交
372
        expect_out = np.array([[0], [1]])
373
        np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
R
ronnywang 已提交
374

375 376 377 378

class TestWhereOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):
379 380 381
            x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
            y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
            cond_i = np.array([False, False, True, True]).astype('bool')
382 383

            def test_Variable():
G
GaoWei8 已提交
384
                paddle.where(cond_i, x_i, y_i)
385 386 387 388

            self.assertRaises(TypeError, test_Variable)

            def test_type():
G
GGBond8488 已提交
389 390 391 392 393 394 395 396
                x = paddle.static.data(name='x', shape=[-1, 4], dtype='bool')
                x.desc.set_need_check_feed(False)
                y = paddle.static.data(name='y', shape=[-1, 4], dtype='float16')
                y.desc.set_need_check_feed(False)
                cond = paddle.static.data(
                    name='cond', shape=[-1, 4], dtype='int32'
                )
                cond.desc.set_need_check_feed(False)
G
GaoWei8 已提交
397
                paddle.where(cond, x, y)
398 399 400

            self.assertRaises(TypeError, test_type)

R
ronnywang 已提交
401 402 403 404
    def test_value_error(self):
        with fluid.dygraph.guard():
            cond_shape = [2, 2, 4]
            cond_tmp = paddle.rand(cond_shape)
405
            cond = cond_tmp < 0.3
R
ronnywang 已提交
406 407 408
            a = paddle.rand(cond_shape)
            self.assertRaises(ValueError, paddle.where, cond, a)

409

H
hong 已提交
410 411
if __name__ == "__main__":
    paddle.enable_static()
412
    unittest.main()