test_framstruct.py 21.3 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_framstruct """
import numpy as np
J
jinyaohui 已提交
17 18
import pytest

Z
zhunaipan 已提交
19 20
import mindspore.nn as nn
from mindspore import context
J
jinyaohui 已提交
21 22 23
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
Z
zhunaipan 已提交
24 25
from mindspore.ops import composite as C
from mindspore.ops import operations as P
J
jinyaohui 已提交
26 27
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
Z
zhunaipan 已提交
28 29 30 31 32 33 34 35 36
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.utils.check_gradient import (
    ms_function, check_jacobian, Tensor, NNGradChecker,
    OperationGradChecker, check_gradient, ScalarGradChecker)


def setup_module(module):
    context.set_context(mode=context.PYNATIVE_MODE)

J
jinyaohui 已提交
37

Z
zhunaipan 已提交
38 39 40 41 42 43 44
@ms_function
def while_upper_bound(upper):
    rval = 2
    while rval < upper:
        rval = rval * rval
    return rval

J
jinyaohui 已提交
45

Z
zhunaipan 已提交
46 47 48 49
def test_while_upper_bound():
    res = while_upper_bound(10)
    assert res == 16

J
jinyaohui 已提交
50

Z
zhunaipan 已提交
51 52 53 54 55 56 57 58
@ms_function
def while_lower_bound(lower):
    """ t_while """
    rval = lower
    while rval < 100:
        rval = rval * rval
    return rval

J
jinyaohui 已提交
59

Z
zhunaipan 已提交
60 61 62 63
def test_while_lower_bound():
    res = while_lower_bound(2)
    assert res == 256

J
jinyaohui 已提交
64

Z
zhunaipan 已提交
65 66 67 68 69 70 71 72 73
@ms_function
def dynamic_make_tuple(x, lower, upper):
    out = ()
    i = lower
    while i < upper:
        out = out + (x,)
        i = i + 1
    return out

J
jinyaohui 已提交
74

Z
zhunaipan 已提交
75 76 77 78 79
def test_dynamic_make_tuple():
    # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language.
    with pytest.raises(RuntimeError):
        dynamic_make_tuple(2, 1, 5)

J
jinyaohui 已提交
80

Z
zhunaipan 已提交
81 82 83 84 85 86 87 88
def test_make_tuple():
    # Staticly recursively creating static type is valid in mindspore.
    @ms_function
    def make_tuple(x):
        out = ()
        for i in range(3):
            out = out + (x,)
        return out
J
jinyaohui 已提交
89

Z
zhunaipan 已提交
90 91 92
    res = make_tuple(5)
    assert res == (5, 5, 5)

J
jinyaohui 已提交
93

Z
zhunaipan 已提交
94 95 96 97 98
@ms_function
def add(x, y):
    """ add """
    return x + y

J
jinyaohui 已提交
99

Z
zhunaipan 已提交
100 101 102 103
def mul(x, y):
    """ mul """
    return x * y

J
jinyaohui 已提交
104

Z
zhunaipan 已提交
105 106 107 108
def add_mul(x, y):
    """ add_mul """
    return (x + y) * y

J
jinyaohui 已提交
109

Z
zhunaipan 已提交
110 111 112 113
def mainf(x, y):
    """ mainf """
    return C.grad_all(mul)(x, y)

J
jinyaohui 已提交
114

Z
zhunaipan 已提交
115 116 117 118
def grad_add_mul(x, y):
    """ grad_add_mul """
    return C.grad_all(add_mul)(x, y)

J
jinyaohui 已提交
119

Z
zhunaipan 已提交
120 121 122 123 124
@ms_function
def sub(x, y):
    """ sub """
    return x - y

J
jinyaohui 已提交
125

J
jinyaohui 已提交
126
# pylint: disable=using-constant-test
Z
zhunaipan 已提交
127 128 129 130 131 132 133 134
@ms_function
def if_always_true(x):
    """ if_always_true """
    if True:
        return x
    else:
        return 0

J
jinyaohui 已提交
135

Z
zhunaipan 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
def test_add():
    """ test_add """
    res = add(2.5, 3)
    assert res == 5.5


def test_sub():
    """ test_sub """
    res = sub(3.5, 3)
    assert res == 0.5


@non_graph_engine
def test_if_always_true():
    """ test_if_always_true """
    res = if_always_true(1)
    assert res == 1


@non_graph_engine
def test_f():
    """ test_f """
    res = mainf(3, 2)
    assert res == (2, 3)

J
jinyaohui 已提交
161

Z
zhunaipan 已提交
162 163 164 165 166 167
@non_graph_engine
def test_grad_add_mul():
    """ test_grad_add_mul """
    res = grad_add_mul(3, 2)
    assert res == (2, 7)

J
jinyaohui 已提交
168

Z
zhunaipan 已提交
169 170
def f(x):
    if x > 0:
J
jinyaohui 已提交
171
        return f(x - 1)
Z
zhunaipan 已提交
172 173
    return x

J
jinyaohui 已提交
174

Z
zhunaipan 已提交
175 176 177
@ms_function
def list_subscript():
    """ list_subscript """
J
jinyaohui 已提交
178
    x = [1, 2, 3]
Z
zhunaipan 已提交
179 180
    return x[0] * x[1]

J
jinyaohui 已提交
181

Z
zhunaipan 已提交
182 183 184 185 186
def test_list_subscript():
    """ test_list_subscript """
    res = list_subscript()
    assert res == 2

J
jinyaohui 已提交
187

Z
zhunaipan 已提交
188 189 190 191 192 193 194 195
@ms_function
def ms_infer_for(xs, y):
    """ ms_infer_for """
    rval = y
    for x in xs:
        rval = rval + x
    return rval

J
jinyaohui 已提交
196

Z
zhunaipan 已提交
197 198 199 200 201 202 203
def test_infer_for():
    """ test_infer_for """
    t = (1, 2, 3)
    y = 4
    res = ms_infer_for(t, y)
    assert res == 10

J
jinyaohui 已提交
204

Z
zhunaipan 已提交
205 206 207 208
@ms_function
def if_construct(a, b):
    z = a
    if a > b:
J
jinyaohui 已提交
209
        z = a + b
Z
zhunaipan 已提交
210
    else:
J
jinyaohui 已提交
211
        z = a * b
Z
zhunaipan 已提交
212
    if z > b:
J
jinyaohui 已提交
213
        return z - a
Z
zhunaipan 已提交
214
    else:
J
jinyaohui 已提交
215 216
        return a - b

Z
zhunaipan 已提交
217 218 219 220 221 222

def test_if_construct():
    """ test_if_construct """
    res = if_construct(3, 6)
    assert res == 15

J
jinyaohui 已提交
223

Z
zhunaipan 已提交
224 225 226 227 228 229 230
@ms_function
def if_scalar(a, b):
    """ if_abstract """
    if a:
        return a
    return b

J
jinyaohui 已提交
231

Z
zhunaipan 已提交
232 233 234 235 236
def test_if_scalar1():
    """ test_if_abstract """
    res = if_scalar(3, 6)
    assert res == 3

J
jinyaohui 已提交
237

Z
zhunaipan 已提交
238 239 240 241 242
def test_if_scalar2():
    """ test_if_abstract """
    res = if_scalar(0, 6)
    assert res == 6

J
jinyaohui 已提交
243

Z
zhunaipan 已提交
244 245 246 247
@ms_function
def if_tensor(a, b):
    c = a
    if a < b:
J
jinyaohui 已提交
248
        c = a + a
Z
zhunaipan 已提交
249
        if c < b:
J
jinyaohui 已提交
250
            c = a + c
Z
zhunaipan 已提交
251
        else:
J
jinyaohui 已提交
252
            c = a + b
Z
zhunaipan 已提交
253
    else:
J
jinyaohui 已提交
254
        c = b + b
Z
zhunaipan 已提交
255 256 257
    out = c + c
    return out

J
jinyaohui 已提交
258

Z
zhunaipan 已提交
259 260 261 262
def test_if_tensor():
    res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
    assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)

J
jinyaohui 已提交
263

Z
zhunaipan 已提交
264 265 266 267
@ms_function
def rec(x):
    """ rec """
    if x > 0:
J
jinyaohui 已提交
268
        return rec(x - 1)
Z
zhunaipan 已提交
269 270
    return x

J
jinyaohui 已提交
271

Z
zhunaipan 已提交
272 273 274 275 276
def test_grad_rec():
    """ test_grad_rec """
    res = C.grad(rec)(10)
    assert res == 1

J
jinyaohui 已提交
277

Z
zhunaipan 已提交
278 279 280 281 282
def test_me_rec():
    """ test_me_rec """
    res = rec(10)
    assert res == 0

J
jinyaohui 已提交
283

Z
zhunaipan 已提交
284 285 286 287 288 289 290 291 292
@ms_function
def t2_while(x, y):
    out = y - x
    i = 0
    while i < 10:
        out = mul(x, y)
        i = i + 1
    return out

J
jinyaohui 已提交
293

Z
zhunaipan 已提交
294 295 296 297
def test_while2():
    res = t2_while(2, 3)
    assert res == 6

J
jinyaohui 已提交
298

Z
zhunaipan 已提交
299 300 301 302
def test_grad_while2():
    res = C.grad(t2_while)(2, 3)
    assert res == 3

J
jinyaohui 已提交
303

Z
zhunaipan 已提交
304 305 306 307 308 309
def if_test(a, b):
    """ if_test """
    if a > b:
        return 3 * a
    return 2 * b

J
jinyaohui 已提交
310

Z
zhunaipan 已提交
311 312 313 314
def grad_if(x, y):
    """ grad_if """
    return C.grad_all(if_test)(x, y)

J
jinyaohui 已提交
315

Z
zhunaipan 已提交
316 317 318 319
def test_grad_if():
    """ test_grad_if """
    assert grad_if(5, 4) == (3, 0)

J
jinyaohui 已提交
320

Z
zhunaipan 已提交
321 322 323 324 325 326 327 328 329
# While loop is not unrolled in forward and backward graphs.
def test_dont_unroll_while():
    def dont_unroll_while(x, y):
        i = 2
        out = y - x
        while i < 10:
            out = mul(x, y)
            i = i + 1
        return out
J
jinyaohui 已提交
330

Z
zhunaipan 已提交
331 332 333
    @ms_function()
    def invoke_while(x, y):
        return C.grad(dont_unroll_while)(x, y)
J
jinyaohui 已提交
334

Z
zhunaipan 已提交
335 336 337
    res = invoke_while(2, 3)
    assert res == 3

J
jinyaohui 已提交
338

Z
zhunaipan 已提交
339
class ConvNet(nn.Cell):
P
panyifeng 已提交
340 341 342 343 344
    def __init__(self):
        super(ConvNet, self).__init__()
        out_channel = 16
        kernel_size = 3
        self.conv = P.Conv2D(out_channel,
J
jinyaohui 已提交
345 346 347 348 349 350 351
                             kernel_size,
                             mode=1,
                             pad_mode="pad",
                             pad=0,
                             stride=1,
                             dilation=2,
                             group=1)
P
panyifeng 已提交
352 353 354 355
        self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')

    def construct(self, x):
        return self.conv(x, self.w)
Z
zhunaipan 已提交
356

J
jinyaohui 已提交
357

Z
zhunaipan 已提交
358 359 360 361 362
conv = ConvNet()
c1 = Tensor([2], mstype.float32)
c2 = Tensor([10], mstype.float32)
c3 = Tensor([1], mstype.float32)

J
jinyaohui 已提交
363

Z
zhunaipan 已提交
364 365 366 367 368 369 370 371 372 373
@ms_function
def t1_while(x, y, z):
    out = x
    i = c1
    while i < c2:
        out = out + conv(z)
        i = i + c3
    out = out + out
    return out

J
jinyaohui 已提交
374

Z
zhunaipan 已提交
375
def test_while_net():
J
jinyaohui 已提交
376 377 378
    y = Tensor(np.ones([1, 3, 3, 4]).astype(np.float32))
    x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32))
    z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
Z
zhunaipan 已提交
379
    res = t1_while(x, y, z)
J
jinyaohui 已提交
380 381
    assert res == Tensor(np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)

Z
zhunaipan 已提交
382 383 384 385 386 387 388

@ms_function
def if_while(a, b, x, z):
    c = a
    i = c1
    out = x
    if a < b:
J
jinyaohui 已提交
389
        c = a + a
Z
zhunaipan 已提交
390 391 392 393
        while i < c2:
            out = out + conv(z)
            i = i + c3
    else:
J
jinyaohui 已提交
394
        c = b + b
Z
zhunaipan 已提交
395 396 397
    out = c + c
    return out

J
jinyaohui 已提交
398

Z
zhunaipan 已提交
399
def test_if_while():
J
jinyaohui 已提交
400 401
    x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
    z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
Z
zhunaipan 已提交
402 403 404
    res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
    assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)

J
jinyaohui 已提交
405

Z
zhunaipan 已提交
406 407 408 409 410 411 412 413 414
def _while(x):
    """ _while """
    ret = x * x
    i = 2
    while i <= 3:
        ret = ret * i
        i = i + 1
    return ret

J
jinyaohui 已提交
415

Z
zhunaipan 已提交
416 417 418 419
def grad_while(x):
    """ grad_while """
    return C.grad_all(_while)(x)

J
jinyaohui 已提交
420

Z
zhunaipan 已提交
421 422 423 424
def test_grad_while():
    """ test_grad_while """
    assert grad_while(5) == (60,)

J
jinyaohui 已提交
425

Z
zhunaipan 已提交
426
@ms_function
427 428
def factorial(n):
    """ factorial """
Z
zhunaipan 已提交
429 430
    if n == 0:
        return 1
J
jinyaohui 已提交
431 432
    return n * factorial(n - 1)

433 434 435 436

def test_factorial():
    res = factorial(3)
    assert res == 6
Z
zhunaipan 已提交
437

J
jinyaohui 已提交
438

439 440 441
def test_grad_factorial():
    res = C.grad(factorial)(3)
    assert res == 11
Z
zhunaipan 已提交
442

J
jinyaohui 已提交
443

444 445 446 447
@ms_function
def factorial2(n):
    """ factorial """
    if n != 0:
J
jinyaohui 已提交
448
        return n * factorial2(n - 1)
449
    elif n == 1:
J
jinyaohui 已提交
450
        return 1 * factorial2(n - 1)
451 452
    else:
        return 1
J
jinyaohui 已提交
453 454


455 456 457 458
def test_factorial2():
    res = factorial2(3)
    assert res == 6

J
jinyaohui 已提交
459

460 461 462 463
@ms_function
def foo(n):
    if n <= 1:
        if n == 1:
J
jinyaohui 已提交
464
            return foo(n - 1)
465 466 467
        else:
            return 1
    else:
J
jinyaohui 已提交
468 469 470
        return foo(n - 1)


471 472 473 474
def test_foo():
    res = foo(5)
    assert res == 1

J
jinyaohui 已提交
475

476 477 478 479
@ms_function
def double_nested_loop(x):
    i = 0
    s = 0
480
    while i < x:
481 482
        j = 0
        i = i + 1
483
        while j < 3:
484 485 486
            j = j + 1
            s = s + j
    return s
J
jinyaohui 已提交
487 488


489 490 491 492
def test_nested_loop():
    res = double_nested_loop(3)
    assert res == 18

J
jinyaohui 已提交
493

494 495 496 497 498 499 500
@ms_function
def double_nested_loop2(x):
    s = 0
    for i in range(x):
        for j in range(3):
            s = s + j
    return s
J
jinyaohui 已提交
501 502


503 504 505 506
def test_nested_loop2():
    res = double_nested_loop(1)
    assert res == 6

J
jinyaohui 已提交
507

Z
zhunaipan 已提交
508 509 510 511 512 513 514
def _for(x):
    """ _for """
    ret = x * x
    for i in (2, 3):
        ret = ret * i
    return ret

J
jinyaohui 已提交
515

Z
zhunaipan 已提交
516 517 518 519
def grad_for(x):
    """ grad_for """
    return C.grad_all(_for)(x)

J
jinyaohui 已提交
520

Z
zhunaipan 已提交
521 522 523 524
def test_grad_for():
    """ test_grad_for """
    assert grad_for(5) == (60,)

J
jinyaohui 已提交
525

Z
zhunaipan 已提交
526 527 528 529 530
@ms_function
def try_tail(x):
    """ try_tail """
    return C.tail(x)

J
jinyaohui 已提交
531

Z
zhunaipan 已提交
532 533 534 535 536
@non_graph_engine
def test_tail():
    """ test_tail """
    try_tail((0, 1, 2, 3))

J
jinyaohui 已提交
537

Z
zhunaipan 已提交
538 539 540 541 542
@ms_function
def zero_like_tensor(x):
    """ zero_like_tensor """
    return C.zeros_like(x)

J
jinyaohui 已提交
543

Z
zhunaipan 已提交
544 545 546 547 548 549
def test_zeros():
    """ test_zeros """
    x = Tensor(np.ones([2, 3]).astype(np.int32))
    res = zero_like_tensor(x)
    assert res == Tensor(np.zeros([2, 3]).astype(np.int32))

J
jinyaohui 已提交
550

Z
zhunaipan 已提交
551 552
def test_ScalarGradChecker():
    """ test_ScalarGradChecker """
J
jinyaohui 已提交
553

Z
zhunaipan 已提交
554 555
    def scalar_f(x, y):
        return x * y
J
jinyaohui 已提交
556

Z
zhunaipan 已提交
557 558
    check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1)

J
jinyaohui 已提交
559

Z
zhunaipan 已提交
560 561 562
def test_GradCheckerPrimitive():
    """ test_GradCheckerPrimitive """
    matmul = P.MatMul()
J
jinyaohui 已提交
563

Z
zhunaipan 已提交
564 565
    def prim_f(x, y):
        return matmul(x, y)
J
jinyaohui 已提交
566

Z
zhunaipan 已提交
567 568 569 570
    check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
                   Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)),
                   grad_checker_class=OperationGradChecker, sampling_times=2)

J
jinyaohui 已提交
571

Z
zhunaipan 已提交
572 573
def test_NNGradChecker():
    """ test_NNGradChecker """
J
jinyaohui 已提交
574

Z
zhunaipan 已提交
575 576
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
577

Z
zhunaipan 已提交
578 579 580
        def __init__(self):
            super(Net, self).__init__()
            self.dense = nn.Dense(10, 10)
J
jinyaohui 已提交
581

Z
zhunaipan 已提交
582 583 584 585 586 587 588 589 590
        def construct(self, x):
            out = self.dense(x)
            return out

    check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
                   delta=1e-3,
                   max_error=1e-3,
                   grad_checker_class=NNGradChecker, sampling_times=3)

J
jinyaohui 已提交
591

Z
zhunaipan 已提交
592 593
def test_OperationGradChecker():
    """ test_OperationGradChecker """
J
jinyaohui 已提交
594

Z
zhunaipan 已提交
595 596
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
597

Z
zhunaipan 已提交
598 599 600 601
        def __init__(self):
            super(Net, self).__init__()
            self.matmul = P.MatMul()
            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
J
jinyaohui 已提交
602

Z
zhunaipan 已提交
603 604 605 606
        def construct(self, x, y):
            x = x * self.z
            out = self.matmul(x, y)
            return out
J
jinyaohui 已提交
607

Z
zhunaipan 已提交
608 609 610 611 612 613 614
    check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
                   Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker,
                   input_selector=[1], sampling_times=2)


def test_ScalarJacobianChecker():
    """ test_ScalarJacobianChecker """
J
jinyaohui 已提交
615

Z
zhunaipan 已提交
616 617
    def scalar_f(x, y):
        return x * y
J
jinyaohui 已提交
618

Z
zhunaipan 已提交
619 620 621 622 623
    check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0])


def test_OperationJacobianChecker():
    """ test_OperationJacobianChecker """
J
jinyaohui 已提交
624

Z
zhunaipan 已提交
625 626
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
627

Z
zhunaipan 已提交
628 629 630 631
        def __init__(self):
            super(Net, self).__init__()
            self.matmul = P.MatMul()
            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
J
jinyaohui 已提交
632

Z
zhunaipan 已提交
633 634 635 636
        def construct(self, x, y):
            x = x * self.z
            out = self.matmul(x, y)
            return x, out
J
jinyaohui 已提交
637

Z
zhunaipan 已提交
638 639 640 641 642 643 644 645
    check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)),
                   Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)),
                   grad_checker_class=OperationGradChecker, input_selector=[0],
                   output_selector=[0])


def test_NNJacobianChecker():
    """ test_NNJacobianChecker """
J
jinyaohui 已提交
646

Z
zhunaipan 已提交
647 648
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
649

Z
zhunaipan 已提交
650 651 652
        def __init__(self):
            super(Net, self).__init__()
            self.dense = nn.Dense(10, 10)
J
jinyaohui 已提交
653

Z
zhunaipan 已提交
654 655 656 657 658 659 660 661 662 663 664
        def construct(self, x):
            out = self.dense(x)
            return out, x

    check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
                   delta=1e-3,
                   max_error=1e-7,
                   grad_checker_class=NNGradChecker,
                   input_selector=[1],
                   output_selector=[0])

J
jinyaohui 已提交
665

Z
zhunaipan 已提交
666 667 668 669
def multi_outputs(x, y):
    z = x + y
    return 2 * z, 2 * z

J
jinyaohui 已提交
670

Z
zhunaipan 已提交
671 672 673
def test_grad_multi_outputs():
    assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4)

J
jinyaohui 已提交
674

Z
zhunaipan 已提交
675 676 677 678 679 680 681 682 683
@ms_function
def while_sp(x, y, z):
    out = x
    i = c3
    while i < c2:
        out = mul(x, out)
        i = i + c3
    return out

J
jinyaohui 已提交
684

Z
zhunaipan 已提交
685 686 687 688 689 690 691
def test_while_sp():
    y = Tensor(np.ones([1, 3]).astype(np.float32))
    z = Tensor(np.ones([1, 3]).astype(np.float32))
    x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
    res = while_sp(x, y, z)
    assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0)

J
jinyaohui 已提交
692

Z
zhunaipan 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
def grad_refactor_simple_1(x, y):
    """ add """
    return x * x + 2 * y


def test_grad_refactor_simple_1():
    assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2)


def grad_refactor_simple_2(x, y, z):
    """ add """
    return x * y + z + x * y * z + x + x * y


def test_grad_refactor_simple_2():
    assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7)


def grad_refactor_1(a, b):
    """ if_test """
J
jinyaohui 已提交
713

Z
zhunaipan 已提交
714 715
    def inner(x, y):
        return x * y
J
jinyaohui 已提交
716

Z
zhunaipan 已提交
717 718 719 720 721 722 723 724 725
    return inner(a, b)


def test_grad_refactor_1():
    assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2)


def grad_refactor_2(a, b):
    """ if_test """
J
jinyaohui 已提交
726

Z
zhunaipan 已提交
727 728
    def inner(x):
        return x * b
J
jinyaohui 已提交
729

Z
zhunaipan 已提交
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
    return inner(b) * inner(a)


def test_grad_refactor_2():
    assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54)


def grad_refactor_3(a):
    """ if_test """
    if a > 3:
        return 0
    return 3 * a


def test_grad_refactor_3():
    assert C.grad_all(grad_refactor_3)(3) == (3,)


def grad_refactor_4(a):
    """ if_test """
    if a > 3:
        return 3 * a
    return 0


def test_grad_refactor_4():
    assert C.grad_all(grad_refactor_4)(4) == (3,)


def grad_refactor_5(a):
    """ if_test """
    if a > 3:
        return 1
    return a


def test_grad_refactor_5():
    assert C.grad_all(grad_refactor_5)(1) == (1,)


def grad_refactor_6(a, b):
    """ if_test """
    if a > b:
        return 3 * a + b
    return 2 * b * a


def test_grad_refactor_6():
P
panyifeng 已提交
778
    assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
Z
zhunaipan 已提交
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810


def grad_refactor_while(x):
    """ grad_refactor_while """
    rval = x
    while rval < 4:
        rval = rval * rval
    return rval


def test_grad_refactor_9():
    assert C.grad_all(grad_refactor_while)(3) == (6,)


def grad_refactor__while_1(x):
    """ _while """
    ret = x * x
    i = 2
    while i <= 3:
        ret = ret * i
        i = i + 1
    return ret


def test_grad_refactor_10():
    """ test_grad_while """
    assert C.grad_all(grad_refactor__while_1)(5) == (60,)


def test_grad_refactor_11():
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
811

Z
zhunaipan 已提交
812 813
        def __init__(self):
            super(Net, self).__init__()
J
jinyaohui 已提交
814

Z
zhunaipan 已提交
815 816
        def construct(self, x, y):
            return x * y * y
J
jinyaohui 已提交
817

Z
zhunaipan 已提交
818 819 820 821 822 823 824
    net = Net()
    C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))


def test_grad_refactor_12():
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
825

Z
zhunaipan 已提交
826 827 828
        def __init__(self):
            super(Net, self).__init__()
            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
J
jinyaohui 已提交
829

Z
zhunaipan 已提交
830 831
        def construct(self, x, y):
            return x * self.z * y
J
jinyaohui 已提交
832

Z
zhunaipan 已提交
833 834 835 836 837 838 839
    net = Net()
    C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))


def test_grad_refactor_13():
    class Net(nn.Cell):
        """ Net definition """
J
jinyaohui 已提交
840

Z
zhunaipan 已提交
841 842 843
        def __init__(self):
            super(Net, self).__init__()
            self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z')
J
jinyaohui 已提交
844

Z
zhunaipan 已提交
845 846
        def construct(self, x, y):
            return x * self.z * y
J
jinyaohui 已提交
847

Z
zhunaipan 已提交
848 849 850 851 852 853 854
    net = Net()
    weights = ParameterTuple(net.trainable_params())
    C.grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))


def grad_refactor_14(a, b):
    """ if_test """
J
jinyaohui 已提交
855

Z
zhunaipan 已提交
856 857
    def inner1(x):
        return x * b
J
jinyaohui 已提交
858

Z
zhunaipan 已提交
859 860
    def inner2(x):
        return a * b
J
jinyaohui 已提交
861

Z
zhunaipan 已提交
862
    def inner3(x):
863
        if x > 2:
Z
zhunaipan 已提交
864 865
            return a
        return b
J
jinyaohui 已提交
866

Z
zhunaipan 已提交
867
    return inner1(b) + inner2(a) + inner3(a)
J
jinyaohui 已提交
868 869


Z
zhunaipan 已提交
870 871
def test_grad_refactor_14():
    assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
872 873


J
jinyaohui 已提交
874
# pylint: disable=using-constant-test
875 876 877 878 879 880 881 882 883 884 885 886
class IfDeferInline(nn.Cell):
    def __init__(self, mul_size):
        super().__init__()
        self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
        self.mul = P.Mul()

    def construct(self, inputs):
        x = self.mul(inputs, self.mul_weight)
        if True:
            x = x
        return x

J
jinyaohui 已提交
887

888 889 890 891 892 893 894
def test_grad_if_defer_inline():
    """ test_grad_if_defer_inline """
    network = IfDeferInline([128, 96])
    network.add_flags(defer_inline=False)
    inp = Tensor(np.ones([128, 96]).astype(np.float32))
    grads = C.grad_all(network)(inp)
    assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
P
panyifeng 已提交
895

J
jinyaohui 已提交
896

P
panyifeng 已提交
897
def test_bprop_with_wrong_output_num():
P
panyifeng 已提交
898
    context.set_context(check_bprop=True)
P
panyifeng 已提交
899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915
    class BpropWithWrongOutputNum(PrimitiveWithInfer):
        @prim_attr_register
        def __init__(self):
            super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')

        def __call__(self, x, y):
            return x

        def infer_shape(self, x_shape, yshape):
            return x_shape

        def infer_dtype(self, x_type, y_type):
            return x_type

    @bprop_getters.register(BpropWithWrongOutputNum)
    def get_bprop_with_wrong_output_num(self):
        """Generate bprop for BpropWithWrongOutputNum"""
J
jinyaohui 已提交
916

P
panyifeng 已提交
917 918
        def bprop(x, y, out, dout):
            return (dout,)
J
jinyaohui 已提交
919

P
panyifeng 已提交
920 921 922 923 924
        return bprop

    class BpropWithWrongOutputNumCell(nn.Cell):
        def __init__(self):
            super(BpropWithWrongOutputNumCell, self).__init__()
J
jinyaohui 已提交
925

P
panyifeng 已提交
926 927
        def construct(self, x, y):
            return BpropWithWrongOutputNum()(x, y)
J
jinyaohui 已提交
928

P
panyifeng 已提交
929 930 931 932
    with pytest.raises(TypeError):
        C.grad_all(BpropWithWrongOutputNumCell())(1, 2)

def test_bprop_with_wrong_output_type():
P
panyifeng 已提交
933
    context.set_context(check_bprop=True)
P
panyifeng 已提交
934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950
    class BpropWithWrongOutputType(PrimitiveWithInfer):
        @prim_attr_register
        def __init__(self):
            super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')

        def __call__(self, x):
            return x

        def infer_shape(self, x_shape):
            return x_shape

        def infer_dtype(self, x_type):
            return x_type

    @bprop_getters.register(BpropWithWrongOutputType)
    def get_bprop_with_wrong_output_type(self):
        """Generate bprop for BpropWithWrongOutputType"""
J
jinyaohui 已提交
951

P
panyifeng 已提交
952 953
        def bprop(x, out, dout):
            return (1,)
J
jinyaohui 已提交
954

P
panyifeng 已提交
955 956 957 958 959
        return bprop

    class BpropWithWrongOutputTypeCell(nn.Cell):
        def __init__(self):
            super(BpropWithWrongOutputTypeCell, self).__init__()
J
jinyaohui 已提交
960

P
panyifeng 已提交
961 962
        def construct(self, x):
            return BpropWithWrongOutputType()(x)
J
jinyaohui 已提交
963

P
panyifeng 已提交
964 965 966
    with pytest.raises(TypeError):
        C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))

J
jinyaohui 已提交
967

P
panyifeng 已提交
968
def test_bprop_with_wrong_output_shape():
P
panyifeng 已提交
969
    context.set_context(check_bprop=True)
P
panyifeng 已提交
970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986
    class BpropWithWrongOutputShape(PrimitiveWithInfer):
        @prim_attr_register
        def __init__(self):
            super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')

        def __call__(self, x):
            return x

        def infer_shape(self, x_shape):
            return x_shape

        def infer_dtype(self, x_type):
            return x_type

    @bprop_getters.register(BpropWithWrongOutputShape)
    def get_bprop_with_wrong_output_shape(self):
        """Generate bprop for BpropWithWrongOutputShape"""
987
        ones = Tensor(np.ones([2,]).astype(np.int32))
J
jinyaohui 已提交
988

P
panyifeng 已提交
989 990
        def bprop(x, out, dout):
            return (ones,)
J
jinyaohui 已提交
991

P
panyifeng 已提交
992 993 994 995 996
        return bprop

    class BpropWithWrongOutputShapeCell(nn.Cell):
        def __init__(self):
            super(BpropWithWrongOutputShapeCell, self).__init__()
J
jinyaohui 已提交
997

P
panyifeng 已提交
998 999
        def construct(self, x):
            return BpropWithWrongOutputShape()(x)
J
jinyaohui 已提交
1000

P
panyifeng 已提交
1001 1002
    with pytest.raises(TypeError):
        C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))