test_linspace.py 6.5 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

Z
zhoukunsheng 已提交
17 18
import numpy as np
from op_test import OpTest
19

20 21
import paddle
import paddle.fluid as fluid
22
from paddle.fluid import Program, core, program_guard
23
from paddle.fluid.framework import _test_eager_guard
Z
zhoukunsheng 已提交
24 25 26 27 28


class TestLinspaceOpCommonCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
29
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
30 31 32 33
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([0]).astype(dtype),
            'Stop': np.array([10]).astype(dtype),
34
            'Num': np.array([11]).astype('int32'),
Z
zhoukunsheng 已提交
35
        }
36
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
37 38 39 40

        self.outputs = {'Out': np.arange(0, 11).astype(dtype)}

    def test_check_output(self):
41
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
42 43 44 45 46


class TestLinspaceOpReverseCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
47
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
48 49 50 51
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
52
            'Num': np.array([11]).astype('int32'),
Z
zhoukunsheng 已提交
53
        }
54
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
55 56 57 58

        self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}

    def test_check_output(self):
59
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
60 61 62 63 64


class TestLinspaceOpNumOneCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
65
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
66 67 68 69
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
70
            'Num': np.array([1]).astype('int32'),
Z
zhoukunsheng 已提交
71
        }
72
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
73 74 75 76

        self.outputs = {'Out': np.array(10, dtype=dtype)}

    def test_check_output(self):
77
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
78 79


80
class TestLinspaceAPI(unittest.TestCase):
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    def test_variable_input1(self):
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        exe = fluid.Executor(place=fluid.CPUPlace())
        res = exe.run(fluid.default_main_program(), fetch_list=[out])
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((res == np_res).all(), True)

    def test_variable_input2(self):
        paddle.disable_static()
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((out.numpy() == np_res).all(), True)
        paddle.enable_static()

101 102 103 104 105
    def test_dtype(self):
        out_1 = paddle.linspace(0, 10, 5, dtype='float32')
        out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
        out_3 = paddle.linspace(0, 10, 5, dtype=core.VarDesc.VarType.FP32)
        exe = fluid.Executor(place=fluid.CPUPlace())
106 107 108
        res_1, res_2, res_3 = exe.run(
            fluid.default_main_program(), fetch_list=[out_1, out_2, out_3]
        )
109
        assert np.array_equal(res_1, res_2)
110 111

    def test_name(self):
112
        with paddle.static.program_guard(paddle.static.Program()):
113 114 115
            out = paddle.linspace(
                0, 10, 5, dtype='float32', name='linspace_res'
            )
116 117
            assert 'linspace_res' in out.name

118
    def test_imperative(self):
119
        paddle.disable_static()
120 121 122 123 124 125
        out1 = paddle.linspace(0, 10, 5, dtype='float32')
        np_out1 = np.linspace(0, 10, 5, dtype='float32')
        out2 = paddle.linspace(0, 10, 5, dtype='int32')
        np_out2 = np.linspace(0, 10, 5, dtype='int32')
        out3 = paddle.linspace(0, 10, 200, dtype='int32')
        np_out3 = np.linspace(0, 10, 200, dtype='int32')
126
        paddle.enable_static()
127 128 129
        self.assertEqual((out1.numpy() == np_out1).all(), True)
        self.assertEqual((out2.numpy() == np_out2).all(), True)
        self.assertEqual((out3.numpy() == np_out3).all(), True)
130

131 132 133 134 135
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_variable_input2()
            self.test_imperative()

136 137 138 139 140

class TestLinspaceOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):

141
            def test_dtype():
142 143 144 145
                fluid.layers.linspace(0, 10, 1, dtype="int8")

            self.assertRaises(TypeError, test_dtype)

Z
zhangchunle 已提交
146
            def test_dtype1():
147
                fluid.layers.linspace(0, 10, 1.33, dtype="int32")
148

Z
zhangchunle 已提交
149
            self.assertRaises(TypeError, test_dtype1)
150

151 152 153 154 155
            def test_start_type():
                fluid.layers.linspace([0], 10, 1, dtype="float32")

            self.assertRaises(TypeError, test_start_type)

Z
zhangchunle 已提交
156
            def test_end_type():
157 158
                fluid.layers.linspace(0, [10], 1, dtype="float32")

Z
zhangchunle 已提交
159
            self.assertRaises(TypeError, test_end_type)
160 161 162 163 164 165 166

            def test_step_dtype():
                fluid.layers.linspace(0, 10, [0], dtype="float32")

            self.assertRaises(TypeError, test_step_dtype)

            def test_start_dtype():
167
                start = fluid.data(shape=[1], dtype="float64", name="start")
168 169
                fluid.layers.linspace(start, 10, 1, dtype="float32")

170
            self.assertRaises(ValueError, test_start_dtype)
171 172

            def test_end_dtype():
173
                end = fluid.data(shape=[1], dtype="float64", name="end")
174 175
                fluid.layers.linspace(0, end, 1, dtype="float32")

176
            self.assertRaises(ValueError, test_end_dtype)
177

178 179 180
            def test_num_dtype():
                num = fluid.data(shape=[1], dtype="int32", name="step")
                fluid.layers.linspace(0, 10, num, dtype="float32")
181 182 183

            self.assertRaises(TypeError, test_step_dtype)

184

Z
zhoukunsheng 已提交
185 186
if __name__ == "__main__":
    unittest.main()