resize.cpp 2.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/test/naive/resize.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "test/common/resize.h"
#include "megdnn/oprs/cv.h"
M
Megvii Engine Team 已提交
14 15
#include "test/common/checker.h"
#include "test/naive/fixture.h"
16 17 18 19 20 21 22 23 24

using namespace megdnn;
using namespace test;

TEST_F(NAIVE, RESIZE_NCHW4) {
    Checker<Resize> checker(handle());

    auto args = resize::get_nchw4_args();
    auto convert_true_format = [](const TensorLayout& layout) {
M
Megvii Engine Team 已提交
25
        return layout.reshape({layout[0], layout[1] / 4, layout[2], layout[3], 4})
26 27 28 29
                .dimshuffle({0, 1, 4, 2, 3});
    };

    for (auto&& arg : args) {
M
Megvii Engine Team 已提交
30 31
        auto extra_impl = [this, param = arg.param,
                           convert_true_format](const TensorNDArray& tensors) {
32 33 34 35 36 37 38
            auto resize = handle()->create_operator<Resize>();
            resize->param().imode = param.imode;
            resize->param().format = Resize::Param::Format::NCHW;

            TensorNDArray nchw_tensors;
            for (size_t i = 0; i < tensors.size(); ++i) {
                auto layout = tensors[i].layout;
M
Megvii Engine Team 已提交
39 40
                layout = layout.reshape(
                        {layout[0], layout[1] * 4, layout[2], layout[3]});
41
                layout.dtype = dtype::Int8();
M
Megvii Engine Team 已提交
42
                nchw_tensors.emplace_back(malloc(layout.span().dist_byte()), layout);
43 44 45 46
            }
            TensorNDArray nchw4_tensors;
            for (size_t i = 0; i < tensors.size(); ++i) {
                auto layout = convert_true_format(nchw_tensors[i].layout);
M
Megvii Engine Team 已提交
47
                nchw4_tensors.emplace_back(tensors[i].raw_ptr, std::move(layout));
48 49 50 51 52 53 54
            }

            auto relayout = handle()->create_operator<RelayoutForward>();
            relayout->exec(nchw4_tensors[0], nchw_tensors[0]);

            auto workspace_size = resize->get_workspace_in_bytes(
                    nchw_tensors[0].layout, nchw_tensors[1].layout);
M
Megvii Engine Team 已提交
55
            dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
56 57 58 59 60 61 62
            Workspace workspace{workspace_ptr, workspace_size};

            resize->exec(nchw_tensors[0], nchw_tensors[1], workspace);

            relayout->exec(nchw_tensors[1], nchw4_tensors[1]);

            free(workspace_ptr);
M
Megvii Engine Team 已提交
63
            for (auto&& tensor : nchw_tensors) {
64 65 66 67 68
                free(tensor.raw_ptr);
            }
        };
        checker.set_extra_opr_impl(extra_impl);
        checker.set_param(arg.param)
M
Megvii Engine Team 已提交
69 70 71 72
                .set_dtype(0, dtype::QuantizedS8(0.1f))
                .set_dtype(1, dtype::QuantizedS8(0.1f))
                .set_epsilon(1 + 1e-3)
                .execs({arg.src, arg.dst});
73 74
    }
}