From d89e1449b7701d759b6e3180f12ea430320db18d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Feb 2018 22:41:54 +0800 Subject: [PATCH] optimize test --- .../tests/test_python_operator_overriding.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index b985ae3e297..94f3fc958e0 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -14,40 +14,52 @@ import unittest -import numpy +import numpy as np + +import paddle.v2.fluid.layers as layers import paddle.v2.fluid.framework as framework import paddle.v2.fluid as fluid class TestPythonOperatorOverride(unittest.TestCase): - def check_result(self, fn, place, dtype='float32'): + def check_result(self, fn, x_val, y_val, place, dtype): shape = [9, 10] - x_data = numpy.random.random(size=shape).astype(dtype) - y_data = numpy.random.random(size=shape).astype(dtype) + x_data = np.full(shape, x_val).astype(dtype) + y_data = np.full(shape, y_val).astype(dtype) python_out = fn(x_data, y_data) - x_var = fluid.layers.data(name='x', shape=shape, dtype=dtype) - y_var = fluid.layers.data(name='y', shape=shape, dtype=dtype) + x_var = layers.create_global_var( + shape=shape, value=x_val, dtype=dtype, persistable=True) + y_var = layers.create_global_var( + shape=shape, value=y_val, dtype=dtype, persistable=True) out = fn(x_var, y_var) exe = fluid.Executor(place) - feeder = fluid.DataFeeder(feed_list=[x_var, y_var], place=place) exe.run(fluid.default_startup_program()) fluid_out = exe.run(fluid.default_main_program(), - feed=feeder.feed([x_data, y_data]), + feed=[], fetch_list=[out]) - print(python_out) - self.assertAlmostEqual(python_out, fluid_out[0]) + np.testing.assert_array_equal(python_out, fluid_out[0]) def test_override(self): + cpu_place = fluid.CPUPlace() + test_data = [(lambda _a, _b: _a == _b, 0.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a == _b, 1.2, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a < _b, 0.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a < _b, 2.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a <= _b, 0.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a <= _b, 1.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a >= _b, 1.1, 1.1, cpu_place, 'float32')] + main_program = framework.Program() startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): - place = fluid.CPUPlace() - self.check_result(lambda _a, _b: _a == _b, place) + for fn, x_val, y_val, place, dtype in test_data: + self.check_result(fn, x_val, y_val, place, dtype) if __name__ == '__main__': -- GitLab