region_restricted_convolution.cpp 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#include "test/naive/fixture.h"

#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
#include "test/common/convolution.h"
// #include "test/common/regin_restricted_convolution.h"
#include "test/common/extra_impl_helper.h"
#include "test/common/random_state.h"

using namespace megdnn;
using namespace test;

namespace {
14 15
template <typename rtype>
void mask_tensor_kernel(
16 17 18 19 20 21 22 23 24 25
        const TensorND& in, TensorND& out, const TensorND& mask,
        const int32_t mask_val) {
    megdnn_assert(
            in.layout.ndim == out.layout.ndim && in.layout.ndim == 4 &&
            mask.layout.ndim == 3);
    megdnn_assert_eq_layout(in.layout, out.layout);
    megdnn_assert(
            mask.layout[0] == in.layout[0] && mask.layout[1] == in.layout[2] &&
            mask.layout[2] == in.layout[3]);

26
    rtype* mask_ptr = mask.compatible_ptr<rtype>();
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    float* src_ptr = in.compatible_ptr<float>();
    float* dst_ptr = out.compatible_ptr<float>();

    for (size_t n = 0; n < in.layout[0]; ++n) {
        for (size_t c = 0; c < in.layout[1]; ++c) {
            for (size_t h = 0; h < in.layout[2]; ++h) {
                for (size_t w = 0; w < in.layout[3]; ++w) {
                    size_t mask_off = n * mask.layout.stride[0] +
                                      h * mask.layout.stride[1] +
                                      w * mask.layout.stride[2];
                    size_t src_dst_off =
                            n * in.layout.stride[0] + c * in.layout.stride[1] +
                            h * in.layout.stride[2] + w * in.layout.stride[3];
                    if (mask_ptr[mask_off] == mask_val) {
                        dst_ptr[src_dst_off] = src_ptr[src_dst_off];
                    } else {
                        dst_ptr[src_dst_off] = 0.;
                    }
                }
            }
        }
    }
}
50 51 52 53 54 55 56 57 58 59

void mask_tensor(
        const TensorND& in, TensorND& out, const TensorND& mask,
        const int32_t mask_val) {
    if (mask.layout.dtype == dtype::Int32()) {
        mask_tensor_kernel<dt_int32>(in, out, mask, mask_val);
    } else if (mask.layout.dtype == dtype::Uint8()) {
        mask_tensor_kernel<dt_uint8>(in, out, mask, mask_val);
    }
}
60 61 62 63 64 65 66
}  // namespace

TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
    Checker<RegionRestrictedConvolution> checker(handle());
    RegionRestrictedConvolution::Param param;
    constexpr int N = 3;

67
    UniformIntRNG rng{0, N - 1};
68 69 70 71 72 73 74 75 76

    auto extra_impl = [&, this](const TensorNDArray& tensors) {
        auto conv = handle()->create_operator<Convolution>();
        conv->param() = param;
        auto workspace_size = conv->get_workspace_in_bytes(
                tensors[0].layout, tensors[1].layout, tensors[4].layout, nullptr);
        dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
        Workspace workspace{workspace_ptr, workspace_size};

77 78
        TensorND masked_src(
                malloc(tensors[0].layout.span().dist_byte()), tensors[0].layout);
79
        TensorNDArray dst_tensors;
80 81 82
        for (int i = 0; i < N; ++i) {
            dst_tensors.emplace_back(
                    malloc(tensors[4].layout.span().dist_byte()), tensors[4].layout);
83
        }
84
        for (int i = 0; i < N; ++i) {
85 86 87 88 89
            mask_tensor(tensors[0], masked_src, tensors[2], i);
            conv->exec(masked_src, tensors[1], dst_tensors[i], nullptr, workspace);
            mask_tensor(dst_tensors[i], dst_tensors[i], tensors[3], i);
        }
        free(workspace_ptr);
90

91 92 93 94
        using Mode = ElemwiseForward::Param::Mode;
        auto add = handle()->create_operator<ElemwiseForward>();
        add->param().mode = Mode::ADD;
        add->exec({dst_tensors[0], dst_tensors[1]}, tensors[4]);
95
        for (int i = 2; i < N; ++i) {
96 97 98 99 100 101 102 103 104 105 106 107 108 109
            add->exec({dst_tensors[i], tensors[4]}, tensors[4]);
        }
    };

    checker.set_extra_opr_impl(extra_impl)
            .set_rng(2, &rng)
            .set_rng(3, &rng)
            .set_dtype(2, dtype::Int32())
            .set_dtype(3, dtype::Int32());

    checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}})
            .execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}})
            .execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}});

110
    checker.set_dtype(2, dtype::Uint8()).set_dtype(3, dtype::Uint8());
111

112 113 114
    checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}})
            .execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}})
            .execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}});
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    param.sparse = Convolution::Param::Sparse::GROUP;
    checker.set_param(param)
            .execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}})
            .execs({{20, 25, 30, 30},
                    {25, 1, 1, 3, 3},
                    {20, 30, 30},
                    {20, 28, 28},
                    {}});

    checker.set_dtype(2, dtype::Int32()).set_dtype(3, dtype::Int32());
    checker.execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}})
            .execs({{20, 25, 30, 30},
                    {25, 1, 1, 3, 3},
                    {20, 30, 30},
                    {20, 28, 28},
                    {}});
132 133 134
}

// vim: syntax=cpp.doxygen