resize.cpp 2.4 KB
Newer Older
1 2
#include "test/common/resize.h"
#include "megdnn/oprs/cv.h"
M
Megvii Engine Team 已提交
3 4
#include "test/common/checker.h"
#include "test/naive/fixture.h"
5 6 7 8 9 10 11 12 13

using namespace megdnn;
using namespace test;

TEST_F(NAIVE, RESIZE_NCHW4) {
    Checker<Resize> checker(handle());

    auto args = resize::get_nchw4_args();
    auto convert_true_format = [](const TensorLayout& layout) {
M
Megvii Engine Team 已提交
14
        return layout.reshape({layout[0], layout[1] / 4, layout[2], layout[3], 4})
15 16 17 18
                .dimshuffle({0, 1, 4, 2, 3});
    };

    for (auto&& arg : args) {
M
Megvii Engine Team 已提交
19 20
        auto extra_impl = [this, param = arg.param,
                           convert_true_format](const TensorNDArray& tensors) {
21 22 23 24 25 26 27
            auto resize = handle()->create_operator<Resize>();
            resize->param().imode = param.imode;
            resize->param().format = Resize::Param::Format::NCHW;

            TensorNDArray nchw_tensors;
            for (size_t i = 0; i < tensors.size(); ++i) {
                auto layout = tensors[i].layout;
M
Megvii Engine Team 已提交
28 29
                layout = layout.reshape(
                        {layout[0], layout[1] * 4, layout[2], layout[3]});
30
                layout.dtype = dtype::Int8();
M
Megvii Engine Team 已提交
31
                nchw_tensors.emplace_back(malloc(layout.span().dist_byte()), layout);
32 33 34 35
            }
            TensorNDArray nchw4_tensors;
            for (size_t i = 0; i < tensors.size(); ++i) {
                auto layout = convert_true_format(nchw_tensors[i].layout);
36
                nchw4_tensors.emplace_back(tensors[i].raw_ptr(), std::move(layout));
37 38 39 40 41 42 43
            }

            auto relayout = handle()->create_operator<RelayoutForward>();
            relayout->exec(nchw4_tensors[0], nchw_tensors[0]);

            auto workspace_size = resize->get_workspace_in_bytes(
                    nchw_tensors[0].layout, nchw_tensors[1].layout);
M
Megvii Engine Team 已提交
44
            dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
45 46 47 48 49 50 51
            Workspace workspace{workspace_ptr, workspace_size};

            resize->exec(nchw_tensors[0], nchw_tensors[1], workspace);

            relayout->exec(nchw_tensors[1], nchw4_tensors[1]);

            free(workspace_ptr);
M
Megvii Engine Team 已提交
52
            for (auto&& tensor : nchw_tensors) {
53
                free(tensor.raw_ptr());
54 55 56 57
            }
        };
        checker.set_extra_opr_impl(extra_impl);
        checker.set_param(arg.param)
M
Megvii Engine Team 已提交
58 59 60 61
                .set_dtype(0, dtype::QuantizedS8(0.1f))
                .set_dtype(1, dtype::QuantizedS8(0.1f))
                .set_epsilon(1 + 1e-3)
                .execs({arg.src, arg.dst});
62 63
    }
}