test_python_operator_overriding.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

Q
qiaolongfei 已提交
17 18 19
import numpy as np

import paddle.v2.fluid.layers as layers
20 21 22 23 24
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid as fluid


class TestPythonOperatorOverride(unittest.TestCase):
Q
qiaolongfei 已提交
25
    def check_result(self, fn, place, dtype):
26 27
        shape = [9, 10]

Q
qiaolongfei 已提交
28 29
        x_data = np.random.random(size=shape).astype(dtype)
        y_data = np.random.random(size=shape).astype(dtype)
30 31
        python_out = fn(x_data, y_data)

Q
qiaolongfei 已提交
32
        x_var = layers.create_global_var(
Q
qiaolongfei 已提交
33
            name='x', shape=shape, value=0.0, dtype=dtype, persistable=True)
Q
qiaolongfei 已提交
34
        y_var = layers.create_global_var(
Q
qiaolongfei 已提交
35
            name='y', shape=shape, value=0.0, dtype=dtype, persistable=True)
36 37 38 39 40 41
        out = fn(x_var, y_var)

        exe = fluid.Executor(place)

        exe.run(fluid.default_startup_program())
        fluid_out = exe.run(fluid.default_main_program(),
Q
qiaolongfei 已提交
42 43
                            feed={'x': x_data,
                                  'y': y_data},
44 45
                            fetch_list=[out])

Q
qiaolongfei 已提交
46
        np.testing.assert_array_equal(python_out, fluid_out[0])
47 48

    def test_override(self):
Q
qiaolongfei 已提交
49 50 51
        # compare func to check
        compare_fns = [
            lambda _a, _b: _a == _b,
Q
qiaolongfei 已提交
52
            lambda _a, _b: _a != _b,
Q
qiaolongfei 已提交
53 54
            lambda _a, _b: _a < _b,
            lambda _a, _b: _a <= _b,
Q
qiaolongfei 已提交
55
            lambda _a, _b: _a > _b,
Q
qiaolongfei 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
            lambda _a, _b: _a >= _b,
        ]

        # places to check
        places = [fluid.CPUPlace()]
        if fluid.core.is_compiled_with_cuda():
            places.append(fluid.CUDAPlace(0))

        # dtypes to check
        dtypes = ['int32', 'float32']

        for place in places:
            for dtype in dtypes:
                for compare_fn in compare_fns:
                    with framework.program_guard(framework.Program(),
Q
qiaolongfei 已提交
71
                                                 framework.Program()):
Q
qiaolongfei 已提交
72
                        self.check_result(compare_fn, place, dtype)
73 74 75 76


if __name__ == '__main__':
    unittest.main()