test_parameter.py 8.6 KB
Newer Older
J
jinyaohui 已提交
1

Z
zhunaipan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test parameter """
import numpy as np
import pytest
J
jinyaohui 已提交
19

20
from mindspore import context, Tensor, Parameter, ParameterTuple
J
jinyaohui 已提交
21
from mindspore._checkparam import _check_str_by_regular
Z
zhunaipan 已提交
22 23 24 25 26 27 28 29 30 31 32 33
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer

def test_parameter_init():
    dat = np.array([[1, 2, 3], [2, 3, 4]])
    tensor = Tensor(dat)
    Parameter(tensor, name="testParameter", requires_grad=True, layerwise_parallel=False)


def test_parameter_tuple_illegal():
    p1 = Parameter(initializer(0, [1], mstype.int32), name="global_step1")
    p2 = Parameter(initializer(0, [1], mstype.int32), name="global_step2")
J
jinyaohui 已提交
34
    plist = [p1, p2]
Z
zhunaipan 已提交
35 36 37 38 39 40 41 42 43 44
    plist2 = [p1, "str"]
    ptuple = (p1, p2)
    ptuple_str = ("2", "1")
    pstr = "[2,3]"
    pnum = 3

    ParameterTuple(plist)
    ParameterTuple(ptuple)
    with pytest.raises(TypeError):
        ParameterTuple(p1)
45
    with pytest.raises(TypeError):
Z
zhunaipan 已提交
46
        ParameterTuple(plist2)
47
    with pytest.raises(TypeError):
Z
zhunaipan 已提交
48
        ParameterTuple(ptuple_str)
49
    with pytest.raises(TypeError):
Z
zhunaipan 已提交
50 51 52 53 54 55
        ParameterTuple(pstr)
    with pytest.raises(TypeError):
        ParameterTuple(pnum)


def test_parameter_init_illegal():
56 57 58
    dat = np.array([[1, 2, 3], [2, 3, 4]])
    tensor = Tensor(dat)
    data_none = None
Z
zhunaipan 已提交
59 60
    data_bool = True
    data_str = "nicai"
61 62 63 64 65 66 67 68
    data_int = 3
    data_list = [1, "2", True]
    data_tuple = (1, 2, 3)

    # test data
    Parameter(tensor, name=data_str)
    Parameter(data_int, name=data_str)
    Parameter(dat, name=data_str)
Z
zhunaipan 已提交
69 70 71
    with pytest.raises(ValueError):
        Parameter(data_bool, name=data_str)

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    # test name
    Parameter(tensor, name=data_none)
    with pytest.raises(ValueError):
        Parameter(tensor, name=dat)
    with pytest.raises(ValueError):
        Parameter(tensor, name=tensor)
    with pytest.raises(ValueError):
        Parameter(tensor, name=data_bool)
    with pytest.raises(ValueError):
        Parameter(tensor, name=data_int)
    with pytest.raises(ValueError):
        Parameter(tensor, name=data_list)
    with pytest.raises(ValueError):
        Parameter(tensor, name=data_tuple)

    Parameter(tensor, name=data_str, requires_grad=data_bool)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=data_none)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=dat)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=tensor)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=data_str)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=data_int)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=data_list)
    with pytest.raises(TypeError):
        Parameter(tensor, name=data_str, requires_grad=data_tuple)
Z
zhunaipan 已提交
102

J
jinyaohui 已提交
103
    Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_bool)
104
    with pytest.raises(TypeError):
J
jinyaohui 已提交
105
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=dat)
106
    with pytest.raises(TypeError):
J
jinyaohui 已提交
107
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=tensor)
108
    with pytest.raises(TypeError):
J
jinyaohui 已提交
109
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_none)
110
    with pytest.raises(TypeError):
J
jinyaohui 已提交
111
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_str)
112
    with pytest.raises(TypeError):
J
jinyaohui 已提交
113
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_int)
114
    with pytest.raises(TypeError):
J
jinyaohui 已提交
115
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_list)
116
    with pytest.raises(TypeError):
J
jinyaohui 已提交
117
        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_tuple)
Z
zhunaipan 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135


def test_check_str_by_regular():
    str1 = "12_sf.asdf_"
    str2 = "x12_sf.asdf."
    str3 = "_x12_sf.asdf"
    str4 = ".12_sf.asdf"
    str5 = "12_sf.a$sdf."
    str6 = "12+sf.asdf"
    _check_str_by_regular(str1)
    _check_str_by_regular(str2)
    _check_str_by_regular(str3)
    with pytest.raises(ValueError):
        _check_str_by_regular(str4)
    with pytest.raises(ValueError):
        _check_str_by_regular(str5)
    with pytest.raises(ValueError):
        _check_str_by_regular(str6)
H
He Wei 已提交
136

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
def test_parameter_compute():
    para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1')
    para_2 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test2')

    t3 = Tensor(np.ones((1, 2, 3)))

    out = para_1 + para_2
    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2)

    out = para_1 * para_2
    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)))

    out = para_1 + t3
    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2)

    out = para_1 * t3
    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)))

    assert isinstance(para_1, Tensor)


def test_scalar_parameter_update():
159
    # float
160 161 162 163 164 165 166 167 168 169
    fp = Parameter(0.5, 'fp')
    fp.default_input = 0.8
    assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
    fp.default_input = 1
    assert np.array_equal(fp.default_input.asnumpy(), np.array(1.0, np.float32))
    int_ = Parameter(1, 'fp')
    int_.default_input = 2
    assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
    with pytest.raises(TypeError):
        int_.default_input = 1.2
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
    # Tensor
    fp32 = Tensor(0.5, mstype.float32)
    int32 = Tensor(2, mstype.int32)
    fp16 = Tensor(0.6, mstype.float16)
    int16 = Tensor(3, mstype.int16)
    bool_ = Tensor(np.array(True, dtype=np.bool_))
    # updata_by_tensor
    fp32_p = Parameter(fp32, 'fp32')
    fp32_p.default_input = 0.8
    fp32_p.default_input = 1
    fp32_p.default_input = int32
    fp32_p.default_input = fp32
    fp32_p.default_input = int16
    fp32_p.default_input = fp16
    fp32_p.default_input = bool_

    # updata_by_tensor
    fp16_p = Parameter(fp16, 'fp16')
    with pytest.raises(TypeError):
        fp16_p.default_input = fp32
190 191


H
He Wei 已提交
192
def test_parameter_lazy_init():
193 194 195
    # support lazy init in SEMI_AUTO_PARALLEL mode
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
H
He Wei 已提交
196 197 198
    # Call init_data() without set default_input.
    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
    assert not isinstance(para.default_input, Tensor)
199
    para = para.init_data()
H
He Wei 已提交
200 201 202 203 204 205
    assert isinstance(para.default_input, Tensor)
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))

    # Call init_data() after default_input is set.
    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
    assert not isinstance(para.default_input, Tensor)
206 207 208 209 210 211
    # expect type error when not init
    with pytest.raises(TypeError):
        para.default_input = Tensor(np.zeros((1, 2, 3)))
    # init then assign
    para = para.init_data()
    # check the type
212
    with pytest.raises(TypeError):
213 214 215 216 217 218
        para.default_input = Tensor(np.zeros((1, 2, 3)))
    # check the shape
    with pytest.raises(ValueError):
        para.default_input = Tensor(np.zeros((1, 2)))
    # expect change ok
    para.default_input = Tensor(np.zeros((1, 2, 3)).astype(np.float32))
H
He Wei 已提交
219
    assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
220 221 222 223 224 225 226
    para.default_input = initializer('ones', [1, 2, 3], mstype.float32)
    assert isinstance(para.default_input, Tensor)
    # same object and has inited
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
    # expect no effect.
    para.init_data()
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
227 228 229 230
    para.set_parameter_data(Tensor(np.zeros((1, 2)).astype(np.float32)), slice_shape=True)
    assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2)))
    para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True)
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2)))
231
    context.reset_auto_parallel_context()