resize.cpp 3.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/test/arm_common/resize.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12
 */
#include "test/common/resize.h"
13
#include "test/arm_common/fixture.h"
14 15 16 17 18
#include "test/common/checker.h"

namespace megdnn {
namespace test {

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
using namespace resize;

static void set_nchw_args(IMode imode, std::vector<TestArg>& args) {
    param::Resize param;
    param.format = param::Resize::Format::NCHW;
    param.imode = imode;
    rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
            args.emplace_back(
                    param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul},
                    TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul});
    args.emplace_back(param, TensorShape{1, 1, 10, 10},
                      TensorShape{1, 1, 20, 20});
    args.emplace_back(param, TensorShape{1, 1, 10, 10},
                      TensorShape{1, 1, 7, 9});
    args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
    args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
}

37
TEST_F(ARM_COMMON, RESIZE_CV) {
38 39 40
    std::vector<TestArg> args = get_cv_args();
    Checker<Resize> checker(handle());

41
    for (auto&& arg : args) {
42
        checker.set_param(arg.param)
43 44 45 46
                .set_epsilon(1 + 1e-3)
                .set_dtype(0, dtype::Uint8())
                .set_dtype(1, dtype::Uint8())
                .execs({arg.src, arg.dst});
47 48
    }

49
    for (auto&& arg : args) {
50
        checker.set_param(arg.param)
51 52 53
                .set_dtype(0, dtype::Float32())
                .set_dtype(1, dtype::Float32())
                .execs({arg.src, arg.dst});
54
    }
55 56
}

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, RESIZE_NCHW_FP16) {
    std::vector<TestArg> args;
    set_nchw_args(IMode::INTER_LINEAR, args);
    set_nchw_args(IMode::INTER_NEAREST, args);
    Checker<Resize> checker(handle());

    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_epsilon(0.01)
                .set_dtype(0, dtype::Float16())
                .set_dtype(1, dtype::Float16())
                .execs({arg.src, arg.dst});
    }
}
#endif

TEST_F(ARM_COMMON, RESIZE_NCHW_FP32) {
    std::vector<TestArg> args;
    set_nchw_args(IMode::INTER_LINEAR, args);
    set_nchw_args(IMode::INTER_NEAREST, args);
    Checker<Resize> checker(handle());

    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_dtype(0, dtype::Float32())
                .set_dtype(1, dtype::Float32())
                .execs({arg.src, arg.dst});
    }
}

TEST_F(ARM_COMMON, RESIZE_NCHW44_FP32) {
89 90
    std::vector<TestArg> args = get_nchw44_args();
    Checker<Resize> checker(handle());
91

92 93 94 95 96 97 98 99
    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_dtype(0, dtype::Float32())
                .set_dtype(1, dtype::Float32())
                .execs({arg.src, arg.dst});
    }
}

100 101
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, RESIZE_NCHW88_FP16) {
102 103 104 105 106 107 108 109 110 111
    std::vector<TestArg> args = get_nchw88_args();
    Checker<Resize> checker(handle());

    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_epsilon(0.01)
                .set_dtype(0, dtype::Float16())
                .set_dtype(1, dtype::Float16())
                .execs({arg.src, arg.dst});
    }
112
}
113
#endif
114

115 116
}  // namespace test
}  // namespace megdnn
117 118 119

// vim: syntax=cpp.doxygen