test_linspace.py 6.6 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
18 19
import paddle
import paddle.fluid as fluid
20
from paddle.fluid import Program, program_guard
21
from paddle.fluid import core
22
from paddle.fluid.framework import _test_eager_guard
Z
zhoukunsheng 已提交
23 24 25 26 27


class TestLinspaceOpCommonCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
28
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
29 30 31 32
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([0]).astype(dtype),
            'Stop': np.array([10]).astype(dtype),
33
            'Num': np.array([11]).astype('int32'),
Z
zhoukunsheng 已提交
34
        }
35
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
36 37 38 39

        self.outputs = {'Out': np.arange(0, 11).astype(dtype)}

    def test_check_output(self):
40
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
41 42 43 44 45


class TestLinspaceOpReverseCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
46
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
47 48 49 50
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
51
            'Num': np.array([11]).astype('int32'),
Z
zhoukunsheng 已提交
52
        }
53
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
54 55 56 57

        self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}

    def test_check_output(self):
58
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
59 60 61 62 63


class TestLinspaceOpNumOneCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
64
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
65 66 67 68
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
69
            'Num': np.array([1]).astype('int32'),
Z
zhoukunsheng 已提交
70
        }
71
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
72 73 74 75

        self.outputs = {'Out': np.array(10, dtype=dtype)}

    def test_check_output(self):
76
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
77 78


79
class TestLinspaceAPI(unittest.TestCase):
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    def test_variable_input1(self):
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        exe = fluid.Executor(place=fluid.CPUPlace())
        res = exe.run(fluid.default_main_program(), fetch_list=[out])
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((res == np_res).all(), True)

    def test_variable_input2(self):
        paddle.disable_static()
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((out.numpy() == np_res).all(), True)
        paddle.enable_static()

100 101 102 103 104
    def test_dtype(self):
        out_1 = paddle.linspace(0, 10, 5, dtype='float32')
        out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
        out_3 = paddle.linspace(0, 10, 5, dtype=core.VarDesc.VarType.FP32)
        exe = fluid.Executor(place=fluid.CPUPlace())
105 106 107
        res_1, res_2, res_3 = exe.run(
            fluid.default_main_program(), fetch_list=[out_1, out_2, out_3]
        )
108
        assert np.array_equal(res_1, res_2)
109 110

    def test_name(self):
111
        with paddle.static.program_guard(paddle.static.Program()):
112 113 114
            out = paddle.linspace(
                0, 10, 5, dtype='float32', name='linspace_res'
            )
115 116
            assert 'linspace_res' in out.name

117
    def test_imperative(self):
118
        paddle.disable_static()
119 120 121 122 123 124
        out1 = paddle.linspace(0, 10, 5, dtype='float32')
        np_out1 = np.linspace(0, 10, 5, dtype='float32')
        out2 = paddle.linspace(0, 10, 5, dtype='int32')
        np_out2 = np.linspace(0, 10, 5, dtype='int32')
        out3 = paddle.linspace(0, 10, 200, dtype='int32')
        np_out3 = np.linspace(0, 10, 200, dtype='int32')
125
        paddle.enable_static()
126 127 128
        self.assertEqual((out1.numpy() == np_out1).all(), True)
        self.assertEqual((out2.numpy() == np_out2).all(), True)
        self.assertEqual((out3.numpy() == np_out3).all(), True)
129

130 131 132 133 134
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_variable_input2()
            self.test_imperative()

135 136 137 138 139

class TestLinspaceOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):

140
            def test_dtype():
141 142 143 144
                fluid.layers.linspace(0, 10, 1, dtype="int8")

            self.assertRaises(TypeError, test_dtype)

Z
zhangchunle 已提交
145
            def test_dtype1():
146
                fluid.layers.linspace(0, 10, 1.33, dtype="int32")
147

Z
zhangchunle 已提交
148
            self.assertRaises(TypeError, test_dtype1)
149

150 151 152 153 154
            def test_start_type():
                fluid.layers.linspace([0], 10, 1, dtype="float32")

            self.assertRaises(TypeError, test_start_type)

Z
zhangchunle 已提交
155
            def test_end_type():
156 157
                fluid.layers.linspace(0, [10], 1, dtype="float32")

Z
zhangchunle 已提交
158
            self.assertRaises(TypeError, test_end_type)
159 160 161 162 163 164 165

            def test_step_dtype():
                fluid.layers.linspace(0, 10, [0], dtype="float32")

            self.assertRaises(TypeError, test_step_dtype)

            def test_start_dtype():
166
                start = fluid.data(shape=[1], dtype="float64", name="start")
167 168
                fluid.layers.linspace(start, 10, 1, dtype="float32")

169
            self.assertRaises(ValueError, test_start_dtype)
170 171

            def test_end_dtype():
172
                end = fluid.data(shape=[1], dtype="float64", name="end")
173 174
                fluid.layers.linspace(0, end, 1, dtype="float32")

175
            self.assertRaises(ValueError, test_end_dtype)
176

177 178 179
            def test_num_dtype():
                num = fluid.data(shape=[1], dtype="int32", name="step")
                fluid.layers.linspace(0, 10, num, dtype="float32")
180 181 182

            self.assertRaises(TypeError, test_step_dtype)

183

Z
zhoukunsheng 已提交
184 185
if __name__ == "__main__":
    unittest.main()