test_linspace.py 6.2 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
20 21 22
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
23
from paddle.fluid import core
Z
zhoukunsheng 已提交
24 25 26 27 28 29 30 31 32 33 34


class TestLinspaceOpCommonCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([0]).astype(dtype),
            'Stop': np.array([10]).astype(dtype),
            'Num': np.array([11]).astype('int32')
        }
35
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

        self.outputs = {'Out': np.arange(0, 11).astype(dtype)}

    def test_check_output(self):
        self.check_output()


class TestLinspaceOpReverseCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
            'Num': np.array([11]).astype('int32')
        }
52
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

        self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}

    def test_check_output(self):
        self.check_output()


class TestLinspaceOpNumOneCase(OpTest):
    def setUp(self):
        self.op_type = "linspace"
        dtype = 'float32'
        self.inputs = {
            'Start': np.array([10]).astype(dtype),
            'Stop': np.array([0]).astype(dtype),
            'Num': np.array([1]).astype('int32')
        }
69
        self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
Z
zhoukunsheng 已提交
70 71 72 73 74 75 76

        self.outputs = {'Out': np.array(10, dtype=dtype)}

    def test_check_output(self):
        self.check_output()


77
class TestLinspaceAPI(unittest.TestCase):
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    def test_variable_input1(self):
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        exe = fluid.Executor(place=fluid.CPUPlace())
        res = exe.run(fluid.default_main_program(), fetch_list=[out])
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((res == np_res).all(), True)

    def test_variable_input2(self):
        paddle.disable_static()
        start = paddle.full(shape=[1], fill_value=0, dtype='float32')
        stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
        num = paddle.full(shape=[1], fill_value=5, dtype='int32')
        out = paddle.linspace(start, stop, num, dtype='float32')
        np_res = np.linspace(0, 10, 5, dtype='float32')
        self.assertEqual((out.numpy() == np_res).all(), True)
        paddle.enable_static()

98 99 100 101 102 103 104 105
    def test_dtype(self):
        out_1 = paddle.linspace(0, 10, 5, dtype='float32')
        out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
        out_3 = paddle.linspace(0, 10, 5, dtype=core.VarDesc.VarType.FP32)
        exe = fluid.Executor(place=fluid.CPUPlace())
        res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
                                      fetch_list=[out_1, out_2, out_3])
        assert np.array_equal(res_1, res_2)
106 107

    def test_name(self):
108
        with paddle.static.program_guard(paddle.static.Program()):
109 110 111 112
            out = paddle.linspace(
                0, 10, 5, dtype='float32', name='linspace_res')
            assert 'linspace_res' in out.name

113
    def test_imperative(self):
114
        paddle.disable_static()
115 116 117 118 119 120
        out1 = paddle.linspace(0, 10, 5, dtype='float32')
        np_out1 = np.linspace(0, 10, 5, dtype='float32')
        out2 = paddle.linspace(0, 10, 5, dtype='int32')
        np_out2 = np.linspace(0, 10, 5, dtype='int32')
        out3 = paddle.linspace(0, 10, 200, dtype='int32')
        np_out3 = np.linspace(0, 10, 200, dtype='int32')
121
        paddle.enable_static()
122 123 124
        self.assertEqual((out1.numpy() == np_out1).all(), True)
        self.assertEqual((out2.numpy() == np_out2).all(), True)
        self.assertEqual((out3.numpy() == np_out3).all(), True)
125

126 127 128 129 130

class TestLinspaceOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):

131
            def test_dtype():
132 133 134 135
                fluid.layers.linspace(0, 10, 1, dtype="int8")

            self.assertRaises(TypeError, test_dtype)

Z
zhangchunle 已提交
136
            def test_dtype1():
137
                fluid.layers.linspace(0, 10, 1.33, dtype="int32")
138

Z
zhangchunle 已提交
139
            self.assertRaises(TypeError, test_dtype1)
140

141 142 143 144 145
            def test_start_type():
                fluid.layers.linspace([0], 10, 1, dtype="float32")

            self.assertRaises(TypeError, test_start_type)

Z
zhangchunle 已提交
146
            def test_end_type():
147 148
                fluid.layers.linspace(0, [10], 1, dtype="float32")

Z
zhangchunle 已提交
149
            self.assertRaises(TypeError, test_end_type)
150 151 152 153 154 155 156

            def test_step_dtype():
                fluid.layers.linspace(0, 10, [0], dtype="float32")

            self.assertRaises(TypeError, test_step_dtype)

            def test_start_dtype():
157
                start = fluid.data(shape=[1], dtype="float64", name="start")
158 159
                fluid.layers.linspace(start, 10, 1, dtype="float32")

160
            self.assertRaises(ValueError, test_start_dtype)
161 162

            def test_end_dtype():
163
                end = fluid.data(shape=[1], dtype="float64", name="end")
164 165
                fluid.layers.linspace(0, end, 1, dtype="float32")

166
            self.assertRaises(ValueError, test_end_dtype)
167

168 169 170
            def test_num_dtype():
                num = fluid.data(shape=[1], dtype="int32", name="step")
                fluid.layers.linspace(0, 10, num, dtype="float32")
171 172 173

            self.assertRaises(TypeError, test_step_dtype)

174

Z
zhoukunsheng 已提交
175 176
if __name__ == "__main__":
    unittest.main()