test_ops_check.py 12.6 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test checking for some ops"""
import functools
import logging
import numpy as np
import pytest
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.api import _executor
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
    import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
    import pipeline_for_verify_exception_for_case_by_case_config
logging.basicConfig(level=logging.WARNING)


class NetMissConstruct(nn.Cell):
    """ NetMissConstruct definition """
    def __init__(self):
        super(NetMissConstruct, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2)
        self.flatten = P.Flatten()

    # pylint: disable=abstract-method
    # TestCase: Mis-spelled 'construct' to 'construtc'
    def construtc(self, x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def test_net_without_construct():
    """ test_net_without_construct """
    net = NetMissConstruct()
    inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
    try:
        _executor.compile(net, inp)
    except RuntimeError as err:
高东海's avatar
高东海 已提交
65
        if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
Z
zhunaipan 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
            print(str(err))
        else:
            raise err


class NetWithRaise(nn.Cell):
    """ NetWithRaise definition """
    def __init__(self):
        super(NetWithRaise, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')

    # raise exception in method 'construct'
    def construct(self, x):
        raise 'exception in construct'


def test_net_with_raise():
    """ test_net_with_raise """
    net = NetWithRaise()
    inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
    try:
        _executor.compile(net, inp)
    except RuntimeError as err:
高东海's avatar
高东海 已提交
89
        if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
Z
zhunaipan 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
            print(str(err))
        else:
            raise err


class NetAddN(nn.Cell):
    """net for test AddN"""
    def __init__(self):
        super(NetAddN, self).__init__()
        self.net = P.AddN()

    def construct(self, x):
        return self.net(x)


class NetSplit(nn.Cell):
    "net for test Split"
    def __init__(self):
        super(NetSplit, self).__init__()
        self.net = P.Split(1, 2)

    def construct(self, x):
        return self.net(x)


class NetBatchMatMul(nn.Cell):
    """net for test BatchMatMul"""
    def __init__(self):
        super(NetBatchMatMul, self).__init__()
        self.op = P.BatchMatMul()

    def construct(self, x, y):
        return self.op(x, y)


test_case_check_ops = [
    ('Conv_Padding_1', {
        'block': nn.Conv2d(1, 6, 5, pad_mode='same', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Padding_2', {
        'block': nn.Conv2d(1, 6, 5, pad_mode='valid', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Padding_3', {
        'block': nn.Conv2d(1, 6, 5, pad_mode='pad', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Padding_4', {
        'block': nn.Conv2d(1, 6, 5, pad_mode='pad', padding=7),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Bias_1', {
        'block': nn.Conv2d(1, 6, 5, has_bias=True, bias_init=Tensor(np.ones([6]).astype(np.float32))),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Bias_2', {
        'block': nn.Conv2d(1, 6, 5, has_bias=True, bias_init='zeros'),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Bias_3', {
        'block': nn.Conv2d(1, 6, 5, has_bias=False, bias_init='zeros'),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Bias_4', {
        'block': nn.Conv2d(1, 6, 5, has_bias=False, bias_init=Tensor(np.ones([6]).astype(np.float32))),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Dense_Bias_1', {
        'block': nn.Dense(1, 6, has_bias=True, bias_init=Tensor(np.ones([6]).astype(np.float32))),
        'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}),
    ('Dense_Bias_2', {
        'block': nn.Dense(1, 6, has_bias=True, bias_init='zeros'),
        'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}),
    ('Dense_Bias_3', {
        'block': nn.Dense(1, 6, has_bias=False, bias_init='zeros'),
        'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}),
    ('Dense_Bias_4', {
        'block': nn.Dense(1, 6, has_bias=False, bias_init=Tensor(np.ones([6]).astype(np.float32))),
        'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}),
    ('MaxPool2d_1', {
        'block': nn.MaxPool2d(5, pad_mode='same', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
    ('MaxPool2d_2', {
        'block': nn.MaxPool2d(5, pad_mode='valid', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
    ('AvgPool2d_1', {
        'block': nn.AvgPool2d(5, pad_mode='same', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
    ('AvgPool2d_2', {
        'block': nn.AvgPool2d(5, pad_mode='valid', padding=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
    ('Conv2D_1', {
        'block': P.Conv2D(1, 6, pad_mode='same', pad=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32)),
                        Tensor(np.ones(shape=[1, 5, 6, 6]).astype(np.float32))]}),
    ('Conv2D_2', {
        'block': P.Conv2D(1, 6, pad_mode='valid', pad=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32)),
                        Tensor(np.ones(shape=[1, 5, 6, 6]).astype(np.float32))]}),
    ('Conv2D_3', {
        'block': P.Conv2D(1, 6, pad_mode='pad', pad=0),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32)),
                        Tensor(np.ones(shape=[1, 5, 6, 6]).astype(np.float32))]}),
    ('Conv2D_4', {
        'block': P.Conv2D(1, 6, pad_mode='pad', pad=7),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32)),
                        Tensor(np.ones(shape=[1, 5, 6, 6]).astype(np.float32))]}),
    ('MatMul_1', {
        'block': P.MatMul(),
        'desc_inputs': [Tensor(np.ones(shape=[1, 3])), Tensor(np.ones(shape=[3, 4]))]}),
    ('MatMul_2', {
        'block': P.BatchMatMul(),
        'desc_inputs': [Tensor(np.ones(shape=[5, 1, 5])), Tensor(np.ones(shape=[5, 5, 4]))]}),
    ('MatMul_Transpose_1', {
        'block': P.MatMul(transpose_a=True),
        'desc_inputs': [Tensor(np.ones(shape=[3, 1])), Tensor(np.ones(shape=[3, 4]))]}),
    ('MatMul_Transpose_2', {
        'block': P.MatMul(transpose_b=True),
        'desc_inputs': [Tensor(np.ones(shape=[3, 2])), Tensor(np.ones(shape=[5, 2]))]}),
    ('MatMul_Transpose_3', {
        'block': P.MatMul(transpose_a=True, transpose_b=True),
        'desc_inputs': [Tensor(np.ones(shape=[3, 2])), Tensor(np.ones(shape=[5, 3]))]}),
    ('BatchMatMul', {
        'block': NetBatchMatMul(),
        'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[3, 5, 4]))]}),
]

test_case_lists = [test_case_check_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm


import mindspore.context as context

@non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
    context.set_context(mode=context.GRAPH_MODE)
    return test_exec_case


raise_set = [
    ('Conv_Padding_1_Error', {
        'block': (lambda x: nn.Conv2d(1, 6, 5, pad_mode='same', padding=7), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv_Padding_2_Error', {
        'block': (lambda x: nn.Conv2d(1, 6, 5, pad_mode='same', padding=7), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[1, 1, 6, 5]).astype(np.float32))]}),
    ('Conv2D_1_Error', {
        'block': (lambda x, y: P.Conv2D(1, 6, pad_mode='same', pad=7), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32)),
                        Tensor(np.ones(shape=[1, 5, 6, 6]).astype(np.float32))]}),
    ('Conv2D_2_Error', {
        'block': (lambda x, y: P.Conv2D(1, 6, pad_mode='valid', pad=7), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32)),
                        Tensor(np.ones(shape=[1, 5, 6, 6]).astype(np.float32))]}),
    ('NetAddN_Error', {
        'block': (NetAddN(), {'exception': TypeError}),
        'desc_inputs': [(np.random.randn(1, 2, 3, 4).astype(np.float32),
                         np.random.randn(1, 2, 3, 4).astype(np.float32))]}),
    ('AddN_Error', {
        'block': (P.AddN(), {'exception': TypeError}),
        'desc_inputs': [(np.random.randn(1, 2, 3, 4).astype(np.float32),
                         np.random.randn(1, 2, 3, 4).astype(np.float32))]}),
    ('Splite_Error', {
        'block': (NetSplit(), {'exception': TypeError}),
        'desc_inputs': [None]}),
    ('MatMul_1_Error', {
        'block': (P.MatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[5])), Tensor(np.ones(shape=[4]))]}),
    ('MatMul_2_Error', {
        'block': (P.MatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[1, 5])), Tensor(np.ones(shape=[3, 4]))]}),
    ('MatMul_3_Error', {
        'block': (P.MatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[1, 5])), Tensor(np.ones(shape=[5, 5, 4]))]}),
    ('MatMul_Transpose_1_Error', {
        'block': (P.MatMul(transpose_a=True), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[1, 3])), Tensor(np.ones(shape=[3, 4]))]}),
    ('MatMul_Transpose_2_Error', {
        'block': (P.MatMul(transpose_b=True), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[3, 2])), Tensor(np.ones(shape=[2, 5]))]}),
    ('MatMul_Transpose_3_Error', {
        'block': (P.MatMul(transpose_a=True, transpose_b=True), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[3, 2])), Tensor(np.ones(shape=[3, 5]))]}),
    ('BatchMatMul_1_Error', {
        'block': (P.BatchMatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[5])), Tensor(np.ones(shape=[4]))]}),
    ('BatchMatMul_2_Error', {
        'block': (P.BatchMatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[1, 5])), Tensor(np.ones(shape=[3, 4]))]}),
    ('BatchMatMul_3_Error', {
        'block': (P.BatchMatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[3, 3, 4]))]}),
    ('BatchMatMul_4_Error', {
        'block': (P.BatchMatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[1, 3, 5, 4]))]}),
    ('BatchMatMul_5_Error', {
        'block': (P.BatchMatMul(), {'exception': ValueError}),
        'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[2, 5, 4]))]}),
]


@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
def test_check_exception():
    return raise_set