test_linspace.py 6.7 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
18 19
import paddle
import paddle.fluid as fluid
20
from paddle.fluid import Program, program_guard
21
from paddle.fluid import core
22
from paddle.fluid.framework import _test_eager_guard
Z
zhoukunsheng 已提交
23 24 25


class TestLinspaceOpCommonCase(OpTest):
26

Z
zhoukunsheng 已提交
27 28
    def setUp(self):
        self.op_type = "linspace"
29
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
30 31 32 33 34 35
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([0]).astype(dtype),
            'Stop': np.array([10]).astype(dtype),
            'Num': np.array([11]).astype('int32')
        }
36
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
37 38 39 40

        self.outputs = {'Out': np.arange(0, 11).astype(dtype)}

    def test_check_output(self):
41
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
42 43 44


class TestLinspaceOpReverseCase(OpTest):
45

Z
zhoukunsheng 已提交
46 47
    def setUp(self):
        self.op_type = "linspace"
48
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
49 50 51 52 53 54
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
            'Num': np.array([11]).astype('int32')
        }
55
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
56 57 58 59

        self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}

    def test_check_output(self):
60
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
61 62 63


class TestLinspaceOpNumOneCase(OpTest):
64

Z
zhoukunsheng 已提交
65 66
    def setUp(self):
        self.op_type = "linspace"
67
        self.python_api = paddle.linspace
Z
zhoukunsheng 已提交
68 69 70 71 72 73
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
            'Num': np.array([1]).astype('int32')
        }
74
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
75 76 77 78

        self.outputs = {'Out': np.array(10, dtype=dtype)}

    def test_check_output(self):
79
        self.check_output(check_eager=True)
Z
zhoukunsheng 已提交
80 81


82
class TestLinspaceAPI(unittest.TestCase):
83

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    def test_variable_input1(self):
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        exe = fluid.Executor(place=fluid.CPUPlace())
        res = exe.run(fluid.default_main_program(), fetch_list=[out])
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((res == np_res).all(), True)

    def test_variable_input2(self):
        paddle.disable_static()
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((out.numpy() == np_res).all(), True)
        paddle.enable_static()

104 105 106 107 108 109 110 111
    def test_dtype(self):
        out_1 = paddle.linspace(0, 10, 5, dtype='float32')
        out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
        out_3 = paddle.linspace(0, 10, 5, dtype=core.VarDesc.VarType.FP32)
        exe = fluid.Executor(place=fluid.CPUPlace())
        res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
                                      fetch_list=[out_1, out_2, out_3])
        assert np.array_equal(res_1, res_2)
112 113

    def test_name(self):
114
        with paddle.static.program_guard(paddle.static.Program()):
115 116 117 118 119
            out = paddle.linspace(0,
                                  10,
                                  5,
                                  dtype='float32',
                                  name='linspace_res')
120 121
            assert 'linspace_res' in out.name

122
    def test_imperative(self):
123
        paddle.disable_static()
124 125 126 127 128 129
        out1 = paddle.linspace(0, 10, 5, dtype='float32')
        np_out1 = np.linspace(0, 10, 5, dtype='float32')
        out2 = paddle.linspace(0, 10, 5, dtype='int32')
        np_out2 = np.linspace(0, 10, 5, dtype='int32')
        out3 = paddle.linspace(0, 10, 200, dtype='int32')
        np_out3 = np.linspace(0, 10, 200, dtype='int32')
130
        paddle.enable_static()
131 132 133
        self.assertEqual((out1.numpy() == np_out1).all(), True)
        self.assertEqual((out2.numpy() == np_out2).all(), True)
        self.assertEqual((out3.numpy() == np_out3).all(), True)
134

135 136 137 138 139
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_variable_input2()
            self.test_imperative()

140 141

class TestLinspaceOpError(unittest.TestCase):
142

143 144 145
    def test_errors(self):
        with program_guard(Program(), Program()):

146
            def test_dtype():
147 148 149 150
                fluid.layers.linspace(0, 10, 1, dtype="int8")

            self.assertRaises(TypeError, test_dtype)

Z
zhangchunle 已提交
151
            def test_dtype1():
152
                fluid.layers.linspace(0, 10, 1.33, dtype="int32")
153

Z
zhangchunle 已提交
154
            self.assertRaises(TypeError, test_dtype1)
155

156 157 158 159 160
            def test_start_type():
                fluid.layers.linspace([0], 10, 1, dtype="float32")

            self.assertRaises(TypeError, test_start_type)

Z
zhangchunle 已提交
161
            def test_end_type():
162 163
                fluid.layers.linspace(0, [10], 1, dtype="float32")

Z
zhangchunle 已提交
164
            self.assertRaises(TypeError, test_end_type)
165 166 167 168 169 170 171

            def test_step_dtype():
                fluid.layers.linspace(0, 10, [0], dtype="float32")

            self.assertRaises(TypeError, test_step_dtype)

            def test_start_dtype():
172
                start = fluid.data(shape=[1], dtype="float64", name="start")
173 174
                fluid.layers.linspace(start, 10, 1, dtype="float32")

175
            self.assertRaises(ValueError, test_start_dtype)
176 177

            def test_end_dtype():
178
                end = fluid.data(shape=[1], dtype="float64", name="end")
179 180
                fluid.layers.linspace(0, end, 1, dtype="float32")

181
            self.assertRaises(ValueError, test_end_dtype)
182

183 184 185
            def test_num_dtype():
                num = fluid.data(shape=[1], dtype="int32", name="step")
                fluid.layers.linspace(0, 10, num, dtype="float32")
186 187 188

            self.assertRaises(TypeError, test_step_dtype)

189

Z
zhoukunsheng 已提交
190 191
if __name__ == "__main__":
    unittest.main()