test_where_op.py 16.7 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
G
GaoWei8 已提交
17
import paddle
18 19 20 21 22 23 24
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from op_test import OpTest
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.op import Operator
from paddle.fluid.backward import append_backward
25
from paddle.fluid.framework import _test_eager_guard
26 27 28


class TestWhereOp(OpTest):
29

30
    def setUp(self):
31
        self.op_type = 'where'
H
hong 已提交
32
        self.python_api = paddle.where
33 34 35 36 37
        self.init_config()
        self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
        self.outputs = {'Out': np.where(self.cond, self.x, self.y)}

    def test_check_output(self):
38
        self.check_output(check_eager=False)
39 40

    def test_check_grad(self):
41
        self.check_grad(['X', 'Y'], 'Out', check_eager=False)
42 43

    def init_config(self):
44 45 46
        self.x = np.random.uniform((-3), 5, 100).astype('float64')
        self.y = np.random.uniform((-3), 5, 100).astype('float64')
        self.cond = np.zeros(100).astype('bool')
47 48 49


class TestWhereOp2(TestWhereOp):
50

51
    def init_config(self):
52 53 54
        self.x = np.random.uniform((-5), 5, (60, 2)).astype('float64')
        self.y = np.random.uniform((-5), 5, (60, 2)).astype('float64')
        self.cond = np.ones((60, 2)).astype('bool')
55 56 57


class TestWhereOp3(TestWhereOp):
58

59
    def init_config(self):
60 61
        self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
        self.y = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
62 63 64 65
        self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool)


class TestWhereAPI(unittest.TestCase):
66

G
GaoWei8 已提交
67 68
    def setUp(self):
        self.init_data()
69

G
GaoWei8 已提交
70 71 72
    def init_data(self):
        self.shape = [10, 15]
        self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool)
73 74
        self.x = np.random.uniform((-2), 3, self.shape).astype(np.float32)
        self.y = np.random.uniform((-2), 3, self.shape).astype(np.float32)
G
GaoWei8 已提交
75
        self.out = np.where(self.cond, self.x, self.y)
76

G
GaoWei8 已提交
77
    def ref_x_backward(self, dout):
78
        return np.where((self.cond == True), dout, 0)
G
GaoWei8 已提交
79 80

    def ref_y_backward(self, dout):
81
        return np.where((self.cond == False), dout, 0)
G
GaoWei8 已提交
82 83 84 85 86

    def test_api(self, use_cuda=False):
        for x_stop_gradient in [False, True]:
            for y_stop_gradient in [False, True]:
                with fluid.program_guard(Program(), Program()):
87 88 89 90 91 92 93 94 95
                    cond = fluid.layers.data(name='cond',
                                             shape=self.shape,
                                             dtype='bool')
                    x = fluid.layers.data(name='x',
                                          shape=self.shape,
                                          dtype='float32')
                    y = fluid.layers.data(name='y',
                                          shape=self.shape,
                                          dtype='float32')
G
GaoWei8 已提交
96 97 98
                    x.stop_gradient = x_stop_gradient
                    y.stop_gradient = y_stop_gradient
                    result = paddle.where(cond, x, y)
99
                    append_backward(paddle.mean(result))
G
GaoWei8 已提交
100
                    for use_cuda in [False, True]:
101 102
                        if (use_cuda
                                and (not fluid.core.is_compiled_with_cuda())):
G
GaoWei8 已提交
103
                            break
104 105
                        place = (fluid.CUDAPlace(0)
                                 if use_cuda else fluid.CPUPlace())
G
GaoWei8 已提交
106 107
                        exe = fluid.Executor(place)
                        fetch_list = [result, result.grad_name]
108
                        if (x_stop_gradient is False):
G
GaoWei8 已提交
109
                            fetch_list.append(x.grad_name)
110
                        if (y_stop_gradient is False):
G
GaoWei8 已提交
111
                            fetch_list.append(y.grad_name)
112 113 114 115 116 117 118
                        out = exe.run(fluid.default_main_program(),
                                      feed={
                                          'cond': self.cond,
                                          'x': self.x,
                                          'y': self.y
                                      },
                                      fetch_list=fetch_list)
G
GaoWei8 已提交
119
                        assert np.array_equal(out[0], self.out)
120
                        if (x_stop_gradient is False):
G
GaoWei8 已提交
121 122
                            assert np.array_equal(out[2],
                                                  self.ref_x_backward(out[1]))
123
                            if (y.stop_gradient is False):
G
GaoWei8 已提交
124 125
                                assert np.array_equal(
                                    out[3], self.ref_y_backward(out[1]))
126
                        elif (y.stop_gradient is False):
G
GaoWei8 已提交
127 128
                            assert np.array_equal(out[2],
                                                  self.ref_y_backward(out[1]))
129 130 131 132 133 134

    def test_api_broadcast(self, use_cuda=False):
        main_program = Program()
        with fluid.program_guard(main_program):
            x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32')
            y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32')
135
            x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32')
136 137
            y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0,
                                                   1.0]]).astype('float32')
138
            result = paddle.where((x > 1), x=x, y=y)
139
            for use_cuda in [False, True]:
140
                if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
141
                    return
142
                place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
143 144
                exe = fluid.Executor(place)
                out = exe.run(fluid.default_main_program(),
145 146 147 148
                              feed={
                                  'x': x_i,
                                  'y': y_i
                              },
149
                              fetch_list=[result])
150
                assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i))
151

R
ronnywang 已提交
152 153 154 155 156
    def test_scalar(self):
        paddle.enable_static()
        main_program = Program()
        with fluid.program_guard(main_program):
            cond_shape = [2, 4]
157 158 159
            cond = fluid.layers.data(name='cond',
                                     shape=cond_shape,
                                     dtype='bool')
R
ronnywang 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
            x_data = 1.0
            y_data = 2.0
            cond_data = np.array([False, False, True, True]).astype('bool')
            result = paddle.where(condition=cond, x=x_data, y=y_data)
            for use_cuda in [False, True]:
                if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
                    return
                place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
                exe = fluid.Executor(place)
                out = exe.run(fluid.default_main_program(),
                              feed={'cond': cond_data},
                              fetch_list=[result])
                expect = np.where(cond_data, x_data, y_data)
                assert np.array_equal(out[0], expect)

175 176 177 178
    def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
        paddle.enable_static()
        main_program = Program()
        with fluid.program_guard(main_program):
179 180 181
            cond = fluid.layers.data(name='cond',
                                     shape=cond_shape,
                                     dtype='bool')
182 183
            x = fluid.layers.data(name='x', shape=x_shape, dtype='float32')
            y = fluid.layers.data(name='y', shape=y_shape, dtype='float32')
184 185 186 187
            cond_data_tmp = np.random.random(size=cond_shape).astype('float32')
            cond_data = (cond_data_tmp < 0.3)
            x_data = np.random.random(size=x_shape).astype('float32')
            y_data = np.random.random(size=y_shape).astype('float32')
188 189
            result = paddle.where(condition=cond, x=x, y=y)
            for use_cuda in [False, True]:
190
                if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
191
                    return
192
                place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
193
                exe = fluid.Executor(place)
194 195 196 197 198 199 200
                out = exe.run(fluid.default_main_program(),
                              feed={
                                  'cond': cond_data,
                                  'x': x_data,
                                  'y': y_data
                              },
                              fetch_list=[result])
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
                expect = np.where(cond_data, x_data, y_data)
                assert np.array_equal(out[0], expect)

    def test_static_api_broadcast_1(self):
        cond_shape = [2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_2(self):
        cond_shape = [2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_3(self):
        cond_shape = [2, 2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_4(self):
        cond_shape = [2, 1, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_5(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_6(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_7(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 1, 4]
        b_shape = [2, 1, 4]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

    def test_static_api_broadcast_8(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)

252 253

class TestWhereDygraphAPI(unittest.TestCase):
254

255 256
    def test_api(self):
        with fluid.dygraph.guard():
257 258 259
            x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
            y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
            cond_i = np.array([False, False, True, True]).astype('bool')
260 261 262
            x = fluid.dygraph.to_variable(x_i)
            y = fluid.dygraph.to_variable(y_i)
            cond = fluid.dygraph.to_variable(cond_i)
G
GaoWei8 已提交
263
            out = paddle.where(cond, x, y)
264 265
            assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))

R
ronnywang 已提交
266 267 268 269 270 271 272 273 274
    def test_scalar(self):
        with fluid.dygraph.guard():
            cond_i = np.array([False, False, True, True]).astype('bool')
            x = 1.0
            y = 2.0
            cond = fluid.dygraph.to_variable(cond_i)
            out = paddle.where(cond, x, y)
            assert np.array_equal(out.numpy(), np.where(cond_i, x, y))

275 276 277
    def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
        with fluid.dygraph.guard():
            cond_tmp = paddle.rand(cond_shape)
278
            cond = (cond_tmp < 0.3)
279 280 281 282 283
            a = paddle.rand(a_shape)
            b = paddle.rand(b_shape)
            result = paddle.where(cond, a, b)
            result = result.numpy()
            expect = np.where(cond, a, b)
284
            np.testing.assert_array_equal(expect, result)
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333

    def test_dygraph_api_broadcast_1(self):
        cond_shape = [2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_2(self):
        cond_shape = [2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_3(self):
        cond_shape = [2, 2, 1]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_4(self):
        cond_shape = [2, 1, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_5(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 4]
        b_shape = [2, 2, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_6(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_7(self):
        cond_shape = [2, 2, 4]
        a_shape = [2, 1, 4]
        b_shape = [2, 1, 4]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

    def test_dygraph_api_broadcast_8(self):
        cond_shape = [3, 2, 2, 4]
        a_shape = [2, 2, 1]
        b_shape = [2, 2, 1]
        self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)

R
ronnywang 已提交
334 335 336
    def test_where_condition(self):
        data = np.array([[True, False], [False, True]])
        with program_guard(Program(), Program()):
337
            x = fluid.layers.data(name='x', shape=[(-1), 2])
R
ronnywang 已提交
338 339 340 341 342
            y = paddle.where(x)
            self.assertEqual(type(y), tuple)
            self.assertEqual(len(y), 2)
            z = fluid.layers.concat(list(y), axis=1)
            exe = fluid.Executor(fluid.CPUPlace())
343 344 345
            (res, ) = exe.run(feed={'x': data},
                              fetch_list=[z.name],
                              return_numpy=False)
R
ronnywang 已提交
346
        expect_out = np.array([[0, 0], [1, 1]])
347
        np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
R
ronnywang 已提交
348 349
        data = np.array([True, True, False])
        with program_guard(Program(), Program()):
350
            x = fluid.layers.data(name='x', shape=[(-1)])
R
ronnywang 已提交
351 352 353 354 355
            y = paddle.where(x)
            self.assertEqual(type(y), tuple)
            self.assertEqual(len(y), 1)
            z = fluid.layers.concat(list(y), axis=1)
            exe = fluid.Executor(fluid.CPUPlace())
356 357 358
            (res, ) = exe.run(feed={'x': data},
                              fetch_list=[z.name],
                              return_numpy=False)
R
ronnywang 已提交
359
        expect_out = np.array([[0], [1]])
360
        np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
R
ronnywang 已提交
361

362 363 364 365 366 367 368 369 370 371 372 373
    def test_eager(self):
        with _test_eager_guard():
            self.test_api()
            self.test_dygraph_api_broadcast_1()
            self.test_dygraph_api_broadcast_2()
            self.test_dygraph_api_broadcast_3()
            self.test_dygraph_api_broadcast_4()
            self.test_dygraph_api_broadcast_5()
            self.test_dygraph_api_broadcast_6()
            self.test_dygraph_api_broadcast_7()
            self.test_dygraph_api_broadcast_8()

374 375

class TestWhereOpError(unittest.TestCase):
376

377 378
    def test_errors(self):
        with program_guard(Program(), Program()):
379 380 381
            x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
            y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
            cond_i = np.array([False, False, True, True]).astype('bool')
382 383

            def test_Variable():
G
GaoWei8 已提交
384
                paddle.where(cond_i, x_i, y_i)
385 386 387 388 389 390 391

            self.assertRaises(TypeError, test_Variable)

            def test_type():
                x = fluid.layers.data(name='x', shape=[4], dtype='bool')
                y = fluid.layers.data(name='y', shape=[4], dtype='float16')
                cond = fluid.layers.data(name='cond', shape=[4], dtype='int32')
G
GaoWei8 已提交
392
                paddle.where(cond, x, y)
393 394 395

            self.assertRaises(TypeError, test_type)

R
ronnywang 已提交
396 397 398 399
    def test_value_error(self):
        with fluid.dygraph.guard():
            cond_shape = [2, 2, 4]
            cond_tmp = paddle.rand(cond_shape)
400
            cond = (cond_tmp < 0.3)
R
ronnywang 已提交
401 402 403
            a = paddle.rand(cond_shape)
            self.assertRaises(ValueError, paddle.where, cond, a)

404 405 406 407
    def test_eager(self):
        with _test_eager_guard():
            self.test_value_error()

408

H
hong 已提交
409 410
if __name__ == "__main__":
    paddle.enable_static()
411
    unittest.main()