dct.cpp 43.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
/**
 * \file dnn/test/naive/dct.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
#include "test/common/dct_ref.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/naive/fixture.h"

namespace megdnn {
namespace test {

TEST_F(NAIVE, DCT) {
    Checker<DctChannelSelectForward> checker(handle(),
                                             /* check_dispatch */ false);
    DctChannelSelectForward::Param param;

    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 1, 16, 16}, dtype::Uint8(),
                             {87,  155, 59,  161, 24,  200, 58,  3,   40,  43,
                              156, 7,   176, 232, 226, 78,  73,  236, 185, 109,
                              196, 169, 62,  32,  167, 180, 96,  157, 101, 53,
                              150, 47,  26,  238, 218, 210, 204, 236, 249, 111,
                              16,  35,  169, 204, 117, 16,  3,   147, 12,  233,
                              135, 162, 58,  118, 184, 237, 90,  105, 156, 195,
                              196, 104, 138, 19,  82,  62,  126, 140, 220, 171,
                              206, 232, 105, 123, 2,   135, 137, 41,  26,  219,
                              167, 245, 104, 103, 24,  144, 141, 210, 208, 114,
                              169, 170, 22,  11,  69,  106, 236, 150, 57,  184,
                              75,  241, 28,  175, 178, 186, 190, 124, 187, 116,
                              112, 162, 214, 154, 207, 31,  43,  40,  15,  188,
                              81,  197, 20,  199, 246, 132, 159, 111, 79,  95,
                              148, 184, 171, 173, 203, 146, 150, 33,  178, 9,
                              141, 49,  237, 222, 72,  5,   23,  38,  248, 82,
                              93,  229, 70,  180, 149, 232, 245, 72,  196, 138,
                              4,   31,  160, 30,  8,   109, 153, 252, 204, 126,
                              15,  182, 145, 130, 179, 234, 21,  240, 144, 105,
                              77,  116, 155, 232, 168, 99,  159, 92,  251, 223,
                              119, 173, 166, 39,  228, 91,  34,  5,   62,  172,
                              131, 164, 143, 10,  161, 165, 221, 214, 178, 110,
                              185, 254, 152, 149, 46,  144, 173, 237, 76,  210,
                              221, 45,  200, 113, 58,  20,  47,  135, 228, 80,
                              91,  51,  238, 194, 222, 231, 174, 244, 139, 96,
                              71,  25,  25,  62,  172, 181, 71,  27,  86,  0,
                              121, 38,  199, 236, 93,  158}),
                     {},
                     {},
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue(
                             {1, 64, 2, 2}, dtype::Float32(),
                             {1.10687500e+03,  9.59500000e+02,  8.98125000e+02,
                              1.21912500e+03,  1.38846378e+01,  3.91629181e+01,
                              -1.50343018e+02, -1.02085358e+02, 2.34341068e+01,
                              -8.40960388e+01, -4.23510742e+01, 1.72630596e+01,
                              -4.66624413e+01, -4.87857285e+01, -7.06332016e+01,
                              6.31493912e+01,  -9.96249924e+01, 7.72499924e+01,
                              7.46250153e+01,  5.81250114e+01,  -9.07061768e+01,
                              -7.68266630e+00, -3.15778809e+01, -3.35406876e+01,
                              8.55864143e+00,  -7.36760712e+01, 6.20557327e+01,
                              -2.92043419e+01, -1.39985870e+02, 2.56675129e+01,
                              5.21866226e+01,  1.07624054e+02,  -6.16851950e+00,
                              -8.56008530e+01, 7.35654449e+01,  -2.56767311e+01,
                              -2.09981880e+01, -6.22950821e+01, -1.31617493e+02,
                              -6.30962448e+01, -2.21552780e+02, -4.79528542e+01,
                              1.04179153e+02,  7.45253448e+01,  3.19730816e+01,
                              1.24306192e+01,  -9.93905945e+01, -8.95680237e+01,
                              -1.44870041e+02, -9.44738235e+01, -4.09417763e+01,
                              4.50356903e+01,  -3.65339231e+00, 5.79474449e+01,
                              -2.46253452e+01, 3.29394951e+01,  -1.09065903e+02,
                              5.23808861e+01,  -1.00386992e+01, -7.92311325e+01,
                              -1.44292374e+01, 5.74285736e+01,  2.28798485e+01,
                              6.84826508e+01,  -1.49241837e+02, 9.35751495e+01,
                              -4.02763329e+01, -6.63586197e+01, 2.15622040e+02,
                              -7.83887939e+01, -8.06824951e+01, -2.51097183e+01,
                              1.58941059e+01,  -5.66967869e+00, -1.53566467e+02,
                              -4.33494377e+01, 8.12108078e+01,  1.21169144e+02,
                              2.14673615e+02,  -3.72018318e+01, 2.45811577e+01,
                              -1.27189613e+02, 4.98553581e+01,  -5.83694696e+00,
                              -4.80477619e+00, -2.24601650e+01, -5.02191353e+00,
                              5.16259460e+01,  1.07266571e+02,  -3.41748886e+01,
                              -5.44621315e+01, 6.25573196e+01,  -4.24649086e+01,
                              4.42625465e+01,  2.71147366e+01,  4.83264275e+01,
                              -6.99711227e+01, -1.00299120e+01, 1.33173111e+02,
                              2.48003254e+01,  -1.74687519e+01, 9.44530487e-01,
                              1.35930038e+02,  6.72219162e+01,  4.53297043e+01,
                              1.37072708e+02,  -7.73253784e+01, 6.12967606e+01,
                              9.78184891e+01,  3.63894577e+01,  -1.64039135e+01,
                              -6.67858887e+01, 5.27859840e+01,  -4.99117432e+01,
                              8.77927475e+01,  -5.86666260e+01, 3.86430244e+01,
                              2.17759323e+01,  8.34562683e+01,  3.06256886e+01,
                              1.61030369e+01,  8.11268158e+01,  1.36932516e+01,
                              -1.06112595e+02, -9.31621475e+01, 3.13674717e+01,
                              -4.90609503e+00, 7.96453857e+01,  -1.02625000e+02,
                              1.40000076e+01,  3.18749981e+01,  -1.08375000e+02,
                              -5.44420319e+01, -1.50944397e+02, 5.29974670e+01,
                              -1.44041641e+02, 4.86086197e+01,  -7.13610382e+01,
                              3.06417294e+01,  7.20477829e+01,  -6.95384140e+01,
                              1.25441925e+02,  -1.54897385e+01, 3.78566666e+01,
                              4.23749886e+01,  -3.37500000e+01, -9.96250000e+01,
                              -6.73750076e+01, 3.34241295e+01,  -6.24825974e+01,
                              1.76387348e+01,  -6.45708389e+01, 1.70728874e+01,
                              -5.73032570e+01, -1.71570969e+01, 1.84064590e+02,
                              4.17566071e+01,  7.08248520e+00,  -2.59306641e+01,
                              1.37766739e+02,  -2.16669798e+00, 6.03565750e+01,
                              6.84421844e+01,  6.19825096e+01,  -1.44220114e+01,
                              -3.12404213e+01, -2.50061111e+01, 6.73021851e+01,
                              2.52050266e+01,  -8.35850677e+01, -4.70746574e+01,
                              1.73889160e+01,  1.18955564e+01,  6.16792488e+00,
                              -3.29667168e+01, 4.55779572e+01,  -4.17868996e+00,
                              -9.40233841e+01, -9.77727051e+01, 1.74934635e+01,
                              5.25992851e+01,  1.23662634e+01,  5.26129305e-01,
                              4.69518929e+01,  -1.52657738e+01, 9.96897888e+01,
                              -9.51726151e+01, 9.99432602e+01,  -1.75949844e+02,
                              1.00472336e+02,  -5.89417953e+01, -1.72231483e+01,
                              1.89282093e+01,  -8.17851868e+01, 7.22908936e+01,
                              -9.06294174e+01, 2.46093607e+00,  -4.03946457e+01,
                              2.17710762e+01,  -5.62999649e+01, 4.77665749e+01,
                              -4.04248848e+01, 4.78787374e+00,  1.05557320e+02,
                              -4.60584450e+01, -7.33774490e+01, -4.25107193e+01,
                              1.71907139e+01,  -8.01314316e+01, 1.69647141e+01,
                              -8.24824219e+01, 8.29206543e+01,  3.72900200e+01,
                              3.77470016e+01,  6.70151443e+01,  1.79784470e+01,
                              -4.01441078e+01, 6.29196739e+01,  7.60664597e+01,
                              -5.59005699e+01, 8.81600475e+00,  -6.89491081e+00,
                              -8.03825378e+01, -5.33856511e-01, 7.26196136e+01,
                              -3.76809120e+01, -1.08401566e+02, 6.35455990e+00,
                              -8.66767120e+01, -1.02679443e+02, -9.54313660e+00,
                              -3.55650787e+01, -1.21355652e+02, 2.32628040e+01,
                              3.94072838e+01,  1.24754738e+02,  9.51344986e+01,
                              -5.84752541e+01, -4.65028038e+01, 6.00556993e+00,
                              4.94889374e+01,  7.64868622e+01,  -1.49546280e+01,
                              -3.70648766e+01, 5.55572205e+01,  -1.17196434e+02,
                              9.20216217e+01,  3.29843826e+01,  3.25113411e+01,
                              5.62059135e+01,  6.30202141e+01,  4.99030991e+01,
                              2.85804024e+01,  -1.44606361e+01, 7.64952774e+01,
                              -2.95697536e+01})});
}

TEST_F(NAIVE, DCT_INT8) {
    Checker<DctChannelSelectForward> checker(handle(),
                                             /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    param.format = DctChannelSelectForward::Param::Format::NCHW4;
    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 1, 16, 16}, dtype::Uint8(),
                             {113, 223, 229, 159, 249, 252, 89,  84,  45,  16,
                              41,  72,  184, 236, 70,  184, 86,  172, 218, 211,
                              47,  177, 18,  85,  174, 226, 37,  109, 38,  135,
                              228, 195, 133, 238, 47,  246, 244, 118, 175, 143,
                              34,  10,  28,  4,   82,  103, 89,  55,  235, 78,
                              151, 178, 249, 62,  183, 84,  105, 0,   121, 98,
                              249, 90,  161, 114, 121, 241, 21,  199, 196, 119,
                              231, 209, 250, 180, 192, 213, 116, 105, 114, 169,
                              1,   142, 3,   30,  140, 245, 201, 109, 19,  26,
                              224, 68,  123, 228, 64,  150, 184, 212, 136, 172,
                              241, 152, 222, 233, 15,  72,  130, 144, 107, 130,
                              242, 79,  195, 46,  226, 57,  183, 36,  88,  161,
                              121, 170, 2,   215, 109, 212, 35,  18,  76,  197,
                              117, 81,  208, 8,   237, 75,  15,  20,  16,  192,
                              61,  113, 96,  126, 211, 57,  49,  62,  185, 211,
                              155, 87,  233, 163, 164, 84,  61,  28,  1,   11,
                              190, 253, 145, 30,  38,  98,  153, 56,  231, 152,
                              12,  204, 96,  8,   47,  87,  25,  237, 21,  150,
                              173, 19,  41,  175, 164, 231, 39,  145, 39,  187,
                              210, 123, 165, 98,  87,  242, 38,  136, 182, 145,
                              41,  47,  147, 171, 172, 35,  170, 148, 26,  89,
                              107, 151, 130, 232, 65,  217, 27,  206, 68,  219,
                              60,  106, 3,   209, 175, 189, 191, 32,  119, 141,
                              56,  48,  105, 58,  94,  163, 185, 60,  83,  249,
                              112, 245, 137, 60,  178, 51,  177, 106, 199, 209,
                              4,   247, 3,   127, 88,  46}),
                     {},
                     {},
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue(
                             {1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f),
                             {122, -1,  -8,  4,   92,  -13, -5,  7,   99,  4,
                              5,   3,   89,  7,   2,   -6,  3,   -8,  -10, 2,
                              -1,  0,   4,   -3,  -5,  -8,  -11, 1,   14,  4,
                              -10, -18, 3,   12,  -14, -2,  -4,  -9,  12,  4,
                              -2,  -2,  2,   6,   -9,  6,   1,   5,   -5,  -1,
                              2,   -12, 4,   -5,  -0,  4,   1,   5,   -8,  5,
                              -3,  4,   2,   6,   -0,  9,   -4,  -7,  -4,  -5,
                              -2,  8,   2,   4,   0,   7,   -8,  4,   -2,  3,
                              -6,  -5,  19,  5,   -4,  -4,  -5,  -16, -8,  -3,
                              -5,  19,  4,   3,   4,   -6,  1,   -12, -1,  7,
                              11,  -5,  -1,  -8,  2,   -12, -9,  -2,  -4,  -20,
                              -11, -15, -15, -9,  -2,  -9,  -2,  -3,  13,  2,
                              5,   6,   7,   -4,  1,   -7,  6,   4,   2,   6,
                              0,   -0,  8,   8,   -6,  5,   1,   -2,  -2,  -12,
                              2,   -12, -2,  6,   7,   3,   4,   14,  14,  -3,
                              1,   -3,  6,   0,   -20, 2,   -10, 10,  -5,  -5,
                              13,  0,   -3,  7,   -12, -17, -13, 1,   -6,  10,
                              -1,  -9,  4,   -16, 3,   2,   5,   1,   -4,  9,
                              -0,  1,   3,   15,  -4,  -13, -6,  4,   3,   -2,
                              -1,  -4,  -7,  -7,  -2,  8,   -16, -4,  -10, 5,
                              1,   -3,  2,   -9,  -4,  1,   -1,  -1,  -4,  -6,
                              -4,  1,   0,   -9,  15,  -1,  -7,  -3,  -5,  -0,
                              3,   -0,  -6,  -17, 16,  -3,  3,   -2,  -3,  5,
                              3,   -2,  3,   13,  8,   1,   -3,  -8,  -7,  -4,
                              6,   -6,  -15, -7,  0,   4,   -3,  -3,  -10, 14,
                              1,   3,   14,  4,   -1,  14})});
}

TEST_F(NAIVE, DCT_INT8_MASK) {
    Checker<DctChannelSelectForward> checker(handle(),
                                             /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    param.format = DctChannelSelectForward::Param::Format::NCHW4;
    auto src_tensor = TensorValue(
            {1, 3, 8, 16}, dtype::Uint8(),
            {195, 165, 82,  30,  154, 60,  175, 195, 179, 165, 132, 37,  250,
             107, 36,  80,  5,   54,  247, 218, 191, 211, 239, 76,  140, 33,
             253, 85,  132, 101, 105, 177, 46,  183, 102, 99,  19,  175, 108,
             252, 42,  238, 48,  251, 108, 90,  176, 2,   35,  46,  161, 252,
             38,  225, 195, 174, 58,  165, 198, 249, 162, 118, 198, 41,  154,
             10,  87,  24,  201, 12,  188, 1,   93,  179, 246, 134, 18,  178,
             173, 36,  122, 89,  115, 46,  43,  205, 232, 55,  149, 30,  206,
             97,  186, 125, 35,  209, 51,  48,  222, 222, 130, 173, 63,  0,
             223, 19,  5,   162, 154, 143, 134, 63,  123, 102, 102, 212, 145,
             80,  87,  212, 42,  26,  219, 225, 120, 94,  213, 238,

             25,  172, 141, 45,  182, 203, 50,  94,  44,  88,  74,  76,  151,
             105, 138, 87,  125, 55,  60,  211, 15,  158, 198, 37,  54,  203,
             239, 79,  56,  6,   53,  201, 97,  233, 178, 74,  193, 46,  249,
             65,  5,   208, 130, 67,  191, 168, 152, 129, 253, 195, 231, 3,
             109, 229, 254, 193, 229, 202, 108, 22,  89,  251, 13,  53,  47,
             192, 12,  81,  19,  53,  93,  104, 41,  217, 215, 184, 136, 249,
             14,  244, 4,   220, 33,  53,  142, 219, 43,  28,  68,  198, 202,
             88,  235, 7,   233, 47,  84,  127, 28,  17,  189, 135, 183, 192,
             239, 116, 31,  118, 186, 49,  251, 233, 220, 27,  97,  30,  43,
             193, 217, 48,  24,  225, 15,  3,   26,  71,  82,  104,

             175, 125, 79,  195, 50,  236, 114, 179, 180, 177, 230, 173, 43,
             195, 123, 111, 106, 5,   91,  254, 34,  76,  52,  82,  193, 179,
             185, 71,  57,  215, 18,  5,   151, 13,  59,  206, 154, 95,  149,
             40,  229, 16,  116, 144, 249, 67,  97,  223, 208, 144, 92,  174,
             246, 77,  196, 211, 20,  123, 239, 250, 235, 65,  184, 54,  239,
             168, 135, 17,  79,  117, 171, 173, 109, 39,  57,  13,  129, 79,
             236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226, 44,
             113, 164, 217, 46,  249, 182, 22,  98,  228, 49,  78,  101, 236,
             181, 5,   245, 72,  62,  182, 151, 210, 254, 190, 35,  73,  190,
             247, 50,  81,  49,  217, 86,  229, 139, 203, 57,  194});
    checker.set_param(param).exect(
            Testcase{src_tensor,
                     TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
                     TensorValue({32}, dtype::Int32(),
                                 {0,  1,  8,  16, 9, 2,  3, 10, 17, 24, 32,
                                  25, 18, 11, 4,  5, 0,  1, 8,  16, 9,  2,
                                  3,  10, 0,  1,  8, 16, 9, 2,  3,  10}),
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue(
                             {1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f),
                             {100, -12, 7,   7,  104, 2,  -2,  -2, -7,  -7, -3,
                              8,   12,  -12, -5, -1,  5,  -7,  -1, 7,   -7, -3,
                              6,   7,   -0,  -2, -7,  11, 6,   3,  -1,  7,  94,
                              -5,  6,   -5,  98, 0,   -3, -16, 5,  7,   13, -8,
                              1,   5,   -5,  -8, 108, -3, -8,  -7, 110, 1,  -2,
                              5,   -0,  7,   8,  -9,  14, -0,  1,  -4})});

    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 3, 8, 16}, dtype::Uint8(),
                             {195, 165, 82,  30,  154, 60,  175, 195, 179, 165,
                              132, 37,  250, 107, 36,  80,  5,   54,  247, 218,
                              191, 211, 239, 76,  140, 33,  253, 85,  132, 101,
                              105, 177, 46,  183, 102, 99,  19,  175, 108, 252,
                              42,  238, 48,  251, 108, 90,  176, 2,   35,  46,
                              161, 252, 38,  225, 195, 174, 58,  165, 198, 249,
                              162, 118, 198, 41,  154, 10,  87,  24,  201, 12,
                              188, 1,   93,  179, 246, 134, 18,  178, 173, 36,
                              122, 89,  115, 46,  43,  205, 232, 55,  149, 30,
                              206, 97,  186, 125, 35,  209, 51,  48,  222, 222,
                              130, 173, 63,  0,   223, 19,  5,   162, 154, 143,
                              134, 63,  123, 102, 102, 212, 145, 80,  87,  212,
                              42,  26,  219, 225, 120, 94,  213, 238,

                              25,  172, 141, 45,  182, 203, 50,  94,  44,  88,
                              74,  76,  151, 105, 138, 87,  125, 55,  60,  211,
                              15,  158, 198, 37,  54,  203, 239, 79,  56,  6,
                              53,  201, 97,  233, 178, 74,  193, 46,  249, 65,
                              5,   208, 130, 67,  191, 168, 152, 129, 253, 195,
                              231, 3,   109, 229, 254, 193, 229, 202, 108, 22,
                              89,  251, 13,  53,  47,  192, 12,  81,  19,  53,
                              93,  104, 41,  217, 215, 184, 136, 249, 14,  244,
                              4,   220, 33,  53,  142, 219, 43,  28,  68,  198,
                              202, 88,  235, 7,   233, 47,  84,  127, 28,  17,
                              189, 135, 183, 192, 239, 116, 31,  118, 186, 49,
                              251, 233, 220, 27,  97,  30,  43,  193, 217, 48,
                              24,  225, 15,  3,   26,  71,  82,  104,

                              175, 125, 79,  195, 50,  236, 114, 179, 180, 177,
                              230, 173, 43,  195, 123, 111, 106, 5,   91,  254,
                              34,  76,  52,  82,  193, 179, 185, 71,  57,  215,
                              18,  5,   151, 13,  59,  206, 154, 95,  149, 40,
                              229, 16,  116, 144, 249, 67,  97,  223, 208, 144,
                              92,  174, 246, 77,  196, 211, 20,  123, 239, 250,
                              235, 65,  184, 54,  239, 168, 135, 17,  79,  117,
                              171, 173, 109, 39,  57,  13,  129, 79,  236, 117,
                              134, 123, 149, 113, 198, 160, 249, 242, 220, 226,
                              44,  113, 164, 217, 46,  249, 182, 22,  98,  228,
                              49,  78,  101, 236, 181, 5,   245, 72,  62,  182,
                              151, 210, 254, 190, 35,  73,  190, 247, 50,  81,
                              49,  217, 86,  229, 139, 203, 57,  194}),
                     TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}),
                     TensorValue({28}, dtype::Int32(),
                                 {0,  1,  8, 16, 9, 2,  3, 10, 17, 24,
                                  32, 25, 0, 1,  8, 16, 9, 2,  3,  10,
                                  0,  1,  8, 16, 9, 2,  3, 10}),
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue(
                             {1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f),
                             {100, -12, 7,   7,  104, 2,   -2, -2,  -7, -7,  -3,
                              8,   12,  -12, -5, -1,  5,   -7, -1,  7,  -7,  -3,
                              6,   7,

                              94,  -5,  6,   -5, 98,  0,   -3, -16, 5,  7,   13,
                              -8,  1,   5,   -5, -8,  108, -3, -8,  -7, 110, 1,
                              -2,  5,   -0,  7,  8,   -9,  14, -0,  1,  -4})});
}

TEST_F(NAIVE, DCT_4x4) {
    Checker<DctChannelSelectForward> checker(handle(),
                                             /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    param.dct_block_size = 4;
    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 1, 8, 8}, dtype::Uint8(),
                             {186, 120, 112, 220, 69,  80,  201, 127, 246, 254,
                              175, 50,  240, 251, 76,  37,  34,  166, 250, 195,
                              231, 139, 128, 233, 75,  80,  3,   2,   19,  140,
                              193, 203, 115, 107, 250, 209, 14,  243, 199, 60,
                              234, 107, 174, 156, 81,  87,  13,  116, 96,  140,
                              197, 253, 113, 223, 229, 159, 249, 252, 89,  84,
                              45,  16,  41,  72}),
                     {},
                     {},
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue(
                             {1, 16, 2, 2}, dtype::Float32(),
                             {5.42000000e+02,  5.91750000e+02,  6.78000000e+02,
                              4.27750000e+02,  3.49953423e+01,  -1.17686939e+01,
                              -1.66842098e+01, -3.85316620e+01, -3.80000000e+01,
                              -1.22500000e+01, 2.00000000e+01,  -9.77500000e+01,
                              -1.61191311e+01, -9.46695328e+00, 3.28882408e+01,
                              -4.92537880e+01, 1.66958221e+02,  -4.26609573e+01,
                              2.56999969e-01,  5.39384537e+01,  1.71819706e+01,
                              9.00009003e+01,  -1.23818558e+02, 1.18912420e+01,
                              6.61014938e+01,  -2.49261990e+01, 4.95798302e+00,
                              -1.02324417e+02, 7.85859919e+00,  3.73140755e+01,
                              1.03783745e+02,  -4.61430321e+01, -1.43000000e+02,
                              -7.57500000e+01, -5.00000000e-01, -8.27500000e+01,
                              1.34834738e+01,  -1.93409515e+02, 6.84791718e+01,
                              -4.01652241e+00, 1.22000000e+02,  -8.57500000e+01,
                              -4.05000000e+01, -5.62500000e+01, -2.88564739e+01,
                              5.76532059e+01,  -2.67414131e+01, 1.70877876e+01,
                              3.85416756e+01,  3.09300461e+01,  5.84670639e+00,
                              1.85747864e+02,  -2.05141403e+02, -9.91859360e+01,
                              -1.66716263e+02, -1.71430378e+01, 6.71520996e+00,
                              8.41980438e+01,  -3.50666313e+01, -1.48387482e+02,
                              1.08180256e+01,  5.49991112e+01,  -1.06814528e+01,
                              1.86087704e+01})});

    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 1, 8, 8}, dtype::Uint8(),
                             {186, 120, 112, 220, 69,  80,  201, 127, 246, 254,
                              175, 50,  240, 251, 76,  37,  34,  166, 250, 195,
                              231, 139, 128, 233, 75,  80,  3,   2,   19,  140,
                              193, 203, 115, 107, 250, 209, 14,  243, 199, 60,
                              234, 107, 174, 156, 81,  87,  13,  116, 96,  140,
                              197, 253, 113, 223, 229, 159, 249, 252, 89,  84,
                              45,  16,  41,  72}),
                     TensorValue({2}, dtype::Int32(), {0, 6}),
                     TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}),
                     {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 6, 2, 2}, dtype::Float32(),
                            {5.4200000e+02,  5.9175000e+02,  6.7800000e+02,
                             4.2775000e+02,  3.4995342e+01,  -1.1768694e+01,
                             -1.6684210e+01, -3.8531662e+01, -1.4300000e+02,
                             -7.5750000e+01, -5.0000000e-01, -8.2750000e+01,
                             1.6695822e+02,  -4.2660957e+01, 2.5699997e-01,
                             5.3938454e+01,  -3.8000000e+01, -1.2250000e+01,
                             2.0000000e+01,  -9.7750000e+01, -1.6119131e+01,
                             -9.4669533e+00, 3.2888241e+01,  -4.9253788e+01})});
}

TEST_F(NAIVE, DCT_WITH_MASK) {
    Checker<DctChannelSelectForward> checker(handle(),
                                             /* check_dispatch */ false);
    DctChannelSelectForward::Param param;
    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 3, 8, 16}, dtype::Uint8(),
                             {109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204,

                              109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204,

                              109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204}),
                     TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
                     TensorValue({32}, dtype::Int32(),
                                 {0,  1,  8,  16, 9, 2,  3, 10, 17, 24, 32,
                                  25, 18, 11, 4,  5, 0,  1, 8,  16, 9,  2,
                                  3,  10, 0,  1,  8, 16, 9, 2,  3,  10}),
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue({1, 32, 1, 2}, dtype::Float32(),
                                 {890.12494,   941.25,     -7.0498576,
                                  99.47632,    -22.850792, -97.862236,
                                  -101.043236, -4.727012,  28.275675,
                                  -157.96654,  42.1377,    45.06531,
                                  -149.77373,  24.487143,  -8.054966,
                                  -13.990831,  -6.9395194, -3.9211385,
                                  64.79172,    -12.363858, -47.875,
                                  59.,         56.271786,  -62.725567,
                                  120.522675,  16.559765,  85.74334,
                                  112.904495,  99.375,     29.499973,
                                  2.0220923,   -19.681704, 890.12494,
                                  941.25,      -7.0498576, 99.47632,
                                  -22.850792,  -97.862236, -101.043236,
                                  -4.727012,   28.275675,  -157.96654,
                                  42.1377,     45.06531,   -149.77373,
                                  24.487143,   -8.054966,  -13.990831,
                                  890.12494,   941.25,     -7.0498576,
                                  99.47632,    -22.850792, -97.862236,
                                  -101.043236, -4.727012,  28.275675,
                                  -157.96654,  42.1377,    45.06531,
                                  -149.77373,  24.487143,  -8.054966,
                                  -13.990831})});
    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 3, 8, 16}, dtype::Uint8(),
                             {109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204,

                              109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204,

                              109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204}),
                     TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}),
                     TensorValue({24}, dtype::Int32(),
                                 {17, 24, 32, 25, 18, 11, 4, 5,  0, 1, 8, 16,
                                  9,  2,  3,  10, 0,  1,  8, 16, 9, 2, 3, 10}),
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue({1, 24, 1, 2}, dtype::Float32(),
                                 {-6.9395194, -3.9211385,  64.79172,
                                  -12.363858, -47.875,     59.,
                                  56.271786,  -62.725567,  120.522675,
                                  16.559765,  85.74334,    112.904495,
                                  99.375,     29.499973,   2.0220923,
                                  -19.681704, 890.12494,   941.25,
                                  -7.0498576, 99.47632,    -22.850792,
                                  -97.862236, -101.043236, -4.727012,
                                  28.275675,  -157.96654,  42.1377,
                                  45.06531,   -149.77373,  24.487143,
                                  -8.054966,  -13.990831,  890.12494,
                                  941.25,     -7.0498576,  99.47632,
                                  -22.850792, -97.862236,  -101.043236,
                                  -4.727012,  28.275675,   -157.96654,
                                  42.1377,    45.06531,    -149.77373,
                                  24.487143,  -8.054966,   -13.990831})});
}

TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) {
    Checker<DctChannelSelectForward> checker(handle(),
                                             /* check_dispatch */ false);
    using Param = DctChannelSelectForward::Param;
    Param param;
    param.fastImpl = Param::FastImpl::FIX_32_MASK;
    checker.set_param(param).exect(
            Testcase{TensorValue(
                             {1, 3, 8, 16}, dtype::Uint8(),
                             {109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204,

                              109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204,

                              109, 39,  30,  115, 71,  15,  206, 139, 221, 5,
                              18,  16,  93,  185, 99,  102, 205, 172, 191, 29,
                              185, 6,   47,  84,  0,   47,  105, 203, 251, 73,
                              196, 83,  3,   211, 32,  181, 49,  111, 114, 83,
                              148, 232, 77,  17,  35,  2,   154, 100, 41,  135,
                              141, 206, 56,  91,  137, 199, 104, 192, 75,  122,
                              78,  65,  184, 69,  91,  82,  2,   172, 194, 240,
                              49,  145, 87,  210, 97,  190, 179, 93,  125, 105,
                              181, 207, 148, 178, 133, 53,  25,  198, 238, 151,
                              14,  120, 213, 195, 145, 20,  122, 107, 217, 185,
                              65,  5,   115, 110, 82,  206, 163, 86,  2,   2,
                              44,  125, 50,  38,  41,  106, 30,  5,   151, 243,
                              238, 181, 232, 191, 161, 57,  23,  204}),
                     TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
                     TensorValue({32}, dtype::Int32(),
                                 {0,  1,  8,  16, 9, 2,  3, 10, 17, 24, 32,
                                  25, 18, 11, 4,  5, 0,  1, 8,  16, 9,  2,
                                  3,  10, 0,  1,  8, 16, 9, 2,  3,  10}),
                     {}},
            Testcase{{},
                     {},
                     {},
                     TensorValue({1, 32, 1, 2}, dtype::Float32(),
                                 {890.12494,   941.25,     -7.0498576,
                                  99.47632,    -22.850792, -97.862236,
                                  -101.043236, -4.727012,  28.275675,
                                  -157.96654,  42.1377,    45.06531,
                                  -149.77373,  24.487143,  -8.054966,
                                  -13.990831,  -6.9395194, -3.9211385,
                                  64.79172,    -12.363858, -47.875,
                                  59.,         56.271786,  -62.725567,
                                  120.522675,  16.559765,  85.74334,
                                  112.904495,  99.375,     29.499973,
                                  2.0220923,   -19.681704, 890.12494,
                                  941.25,      -7.0498576, 99.47632,
                                  -22.850792,  -97.862236, -101.043236,
                                  -4.727012,   28.275675,  -157.96654,
                                  42.1377,     45.06531,   -149.77373,
                                  24.487143,   -8.054966,  -13.990831,
                                  890.12494,   941.25,     -7.0498576,
                                  99.47632,    -22.850792, -97.862236,
                                  -101.043236, -4.727012,  28.275675,
                                  -157.96654,  42.1377,    45.06531,
                                  -149.77373,  24.487143,  -8.054966,
                                  -13.990831})});
}

TEST_F(NAIVE, DCT_WITH_MASK2) {
    Checker<DctChannelSelectForward> checker(handle(), false);
    DctChannelSelectForward::Param param;
    UniformIntRNG rng_oc(0, 3 * 64);
    for (size_t n : {1, 3}) {
        for (size_t ic : {1, 3}) {
            for (size_t ih : {8, 16, 32, 512, 1024}) {
                for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
                    int random_oc = static_cast<int>(rng_oc.gen_single_val());
                    int max_oc = ic * 64;
                    int mask_oc = (random_oc % max_oc) + 1;
                    auto test_case =
                            gen_dct_case(n, ic, ih, iw, mask_oc, param);
                    checker.set_param(param).exect(test_case->testcase_in,
                                                   test_case->testcase_out);
                }
            }
        }
    }
}

}  // namespace test
}  // namespace megdnn

// vim: syntax=cpp.doxygen