test_linspace.py 6.6 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
20 21 22
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
23
from paddle.fluid import core
24
from paddle.fluid.framework import _test_eager_guard
Z
zhoukunsheng 已提交
25 26 27 28 29


class TestLinspaceOpCommonCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
30
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
31 32 33 34 35 36
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([0]).astype(dtype),
            'Stop': np.array([10]).astype(dtype),
            'Num': np.array([11]).astype('int32')
        }
37
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
38 39 40 41

        self.outputs = {'Out': np.arange(0, 11).astype(dtype)}

    def test_check_output(self):
42
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
43 44 45 46 47


class TestLinspaceOpReverseCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
48
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
49 50 51 52 53 54
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
            'Num': np.array([11]).astype('int32')
        }
55
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
56 57 58 59

        self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}

    def test_check_output(self):
60
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
61 62 63 64 65


class TestLinspaceOpNumOneCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
66
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
67 68 69 70 71 72
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
            'Num': np.array([1]).astype('int32')
        }
73
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
74 75 76 77

        self.outputs = {'Out': np.array(10, dtype=dtype)}

    def test_check_output(self):
78
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
79 80


81
class TestLinspaceAPI(unittest.TestCase):
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    def test_variable_input1(self):
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        exe = fluid.Executor(place=fluid.CPUPlace())
        res = exe.run(fluid.default_main_program(), fetch_list=[out])
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((res == np_res).all(), True)

    def test_variable_input2(self):
        paddle.disable_static()
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((out.numpy() == np_res).all(), True)
        paddle.enable_static()

102 103 104 105 106 107 108 109
    def test_dtype(self):
        out_1 = paddle.linspace(0, 10, 5, dtype='float32')
        out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
        out_3 = paddle.linspace(0, 10, 5, dtype=core.VarDesc.VarType.FP32)
        exe = fluid.Executor(place=fluid.CPUPlace())
        res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
                                      fetch_list=[out_1, out_2, out_3])
        assert np.array_equal(res_1, res_2)
110 111

    def test_name(self):
112
        with paddle.static.program_guard(paddle.static.Program()):
113 114 115 116
            out = paddle.linspace(
                0, 10, 5, dtype='float32', name='linspace_res')
            assert 'linspace_res' in out.name

117
    def test_imperative(self):
118
        paddle.disable_static()
119 120 121 122 123 124
        out1 = paddle.linspace(0, 10, 5, dtype='float32')
        np_out1 = np.linspace(0, 10, 5, dtype='float32')
        out2 = paddle.linspace(0, 10, 5, dtype='int32')
        np_out2 = np.linspace(0, 10, 5, dtype='int32')
        out3 = paddle.linspace(0, 10, 200, dtype='int32')
        np_out3 = np.linspace(0, 10, 200, dtype='int32')
125
        paddle.enable_static()
126 127 128
        self.assertEqual((out1.numpy() == np_out1).all(), True)
        self.assertEqual((out2.numpy() == np_out2).all(), True)
        self.assertEqual((out3.numpy() == np_out3).all(), True)
129

130 131 132 133 134
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_variable_input2()
            self.test_imperative()

135 136 137 138 139

class TestLinspaceOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):

140
            def test_dtype():
141 142 143 144
                fluid.layers.linspace(0, 10, 1, dtype="int8")

            self.assertRaises(TypeError, test_dtype)

Z
zhangchunle 已提交
145
            def test_dtype1():
146
                fluid.layers.linspace(0, 10, 1.33, dtype="int32")
147

Z
zhangchunle 已提交
148
            self.assertRaises(TypeError, test_dtype1)
149

150 151 152 153 154
            def test_start_type():
                fluid.layers.linspace([0], 10, 1, dtype="float32")

            self.assertRaises(TypeError, test_start_type)

Z
zhangchunle 已提交
155
            def test_end_type():
156 157
                fluid.layers.linspace(0, [10], 1, dtype="float32")

Z
zhangchunle 已提交
158
            self.assertRaises(TypeError, test_end_type)
159 160 161 162 163 164 165

            def test_step_dtype():
                fluid.layers.linspace(0, 10, [0], dtype="float32")

            self.assertRaises(TypeError, test_step_dtype)

            def test_start_dtype():
166
                start = fluid.data(shape=[1], dtype="float64", name="start")
167 168
                fluid.layers.linspace(start, 10, 1, dtype="float32")

169
            self.assertRaises(ValueError, test_start_dtype)
170 171

            def test_end_dtype():
172
                end = fluid.data(shape=[1], dtype="float64", name="end")
173 174
                fluid.layers.linspace(0, end, 1, dtype="float32")

175
            self.assertRaises(ValueError, test_end_dtype)
176

177 178 179
            def test_num_dtype():
                num = fluid.data(shape=[1], dtype="int32", name="step")
                fluid.layers.linspace(0, 10, num, dtype="float32")
180 181 182

            self.assertRaises(TypeError, test_step_dtype)

183

Z
zhoukunsheng 已提交
184 185
if __name__ == "__main__":
    unittest.main()