imgproc.h 8.8 KB
Newer Older
1 2 3 4 5
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

M
Megvii Engine Team 已提交
6
class WarpPerspectiveBase : public OperatorBase {
7 8
    DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
    DEF_OPR_PARAM(WarpPerspective);
M
Megvii Engine Team 已提交
9 10 11 12 13 14 15 16 17 18

public:
    using InterpolationMode = Param::InterpolationMode;
    using BorderMode = Param::BorderMode;

protected:
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
        check_layout_fwd(src, mat, {}, dst);
    }
19 20 21 22 23
    void check_layout_fwd(
            const TensorLayoutArray& srcs, const TensorLayout& mat,
            const TensorLayout& dst) {
        check_layout_fwd(srcs, mat, {}, dst);
    }
M
Megvii Engine Team 已提交
24 25 26 27

    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst);
28 29 30
    void check_layout_fwd(
            const TensorLayoutArray& srcs, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst);
M
Megvii Engine Team 已提交
31 32
    std::string param_msg() const;
    int get_real_coord(int p, int len);
33 34
};

M
Megvii Engine Team 已提交
35
class WarpPerspectiveForward : public WarpPerspectiveBase {
36
    DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
M
Megvii Engine Team 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

public:
    /**
     * \param[in] src (n, channel, in_height, in_width)
     * \param[in] mat (n, 3, 3)
     * \param[out] dst (n, channel, out_height, out_width)
     *
     * \see
     * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
     *
     * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
     * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
     *                 (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
     *
     * src and dst can have different shapes, as long as their n and c agree.
     * src, mat and dst should be contiguous.
     */
    void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) {
        exec(src, mat, {}, dst, workspace);
    }

60 61 62 63 64 65
    void exec(
            _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
            _megdnn_tensor_out dst, _megdnn_workspace workspace) {
        exec(srcs, mat, {}, dst, workspace);
    }

M
Megvii Engine Team 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78
    /**
     * \p src should have batch size m, and \p mat and \p mat_idx should
     * both have batch size n. Each item in \p mat_idx must be in the range
     * of [0, m-1].
     *
     * \param mat_idx the indices of input image that each matrix in \p mat
     *      should act on. It can also be empty and in such case \p mat
     *      should have the same batch size as \p src.
     */
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
            _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;

79 80 81 82 83
    virtual void exec(
            _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
            _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;

M
Megvii Engine Team 已提交
84 85 86 87 88
    size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
        return get_workspace_in_bytes(src, mat, {}, dst);
    }

89 90 91 92 93 94
    size_t get_workspace_in_bytes(
            const TensorLayoutArray& srcs, const TensorLayout& mat,
            const TensorLayout& dst) {
        return get_workspace_in_bytes(srcs, mat, {}, dst);
    }

M
Megvii Engine Team 已提交
95 96 97 98
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst) = 0;

99 100 101 102
    virtual size_t get_workspace_in_bytes(
            const TensorLayoutArray& srcs, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst) = 0;

M
Megvii Engine Team 已提交
103 104 105 106 107 108 109 110 111 112
protected:
    void check_exec(
            const TensorLayout& src, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst,
            size_t workspace_in_bytes);

    void check_exec_allow_nhwc_mat_idx(
            const TensorLayout& src, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst,
            size_t workspace_in_bytes);
113 114 115 116
    void check_exec_allow_nhwc_mat_idx(
            const TensorLayoutArray& srcs, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& dst,
            size_t workspace_in_bytes);
117 118 119
};
using WarpPerspective = WarpPerspectiveForward;

M
Megvii Engine Team 已提交
120
class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
121
    DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
M
Megvii Engine Team 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154

public:
    /**
     * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
     * \param[in] diff the backpropagated gradient wrt. dst
     * \param[out] grad the backpropagated gradient wrt. src
     * \param[out] workspace temporary workspace to perform backward
     */
    void exec(
            _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
            _megdnn_workspace workspace) {
        exec(mat, {}, diff, grad, workspace);
    }

    virtual void exec(
            _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
            _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;

    size_t get_workspace_in_bytes(
            const TensorLayout& mat, const TensorLayout& diff,
            const TensorLayout& grad) {
        return get_workspace_in_bytes(mat, {}, diff, grad);
    }

    virtual size_t get_workspace_in_bytes(
            const TensorLayout& mat, const TensorLayout& mat_idx,
            const TensorLayout& diff, const TensorLayout& grad) = 0;

protected:
    void check_exec(
            const TensorLayout& mat, const TensorLayout& mat_idx,
            const TensorLayout& diff, const TensorLayout& grad,
            size_t workspace_in_bytes);
155 156
};

M
Megvii Engine Team 已提交
157
class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
158
    DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
M
Megvii Engine Team 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

public:
    /**
     * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
     * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
     * \param[in] diff the backpropagated gradient wrt. dst
     * \param[out] grad the backpropagated gradient wrt. mat
     * \param[out] workspace temporary workspace to perform backward
     */
    void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
            _megdnn_tensor_out grad, _megdnn_workspace workspace) {
        exec(src, mat, {}, diff, grad, workspace);
    }

    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
            _megdnn_tensor_in diff, _megdnn_tensor_out grad,
            _megdnn_workspace workspace) = 0;

    size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
            const TensorLayout& grad) {
        return get_workspace_in_bytes(src, mat, {}, diff, grad);
    }

    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& diff,
            const TensorLayout& grad) = 0;

protected:
    void check_exec(
            const TensorLayout& src, const TensorLayout& mat,
            const TensorLayout& mat_idx, const TensorLayout& diff,
            const TensorLayout& grad, size_t workspace_in_bytes);
195 196
};

197 198 199 200 201 202 203 204 205 206 207 208
class DctChannelSelectForward : public OperatorBase {
    DEF_OPR_PARAM(DctChannelSelect);
    DEF_OPR_IMPL(DctChannelSelectForward, OperatorBase, 3, 1);

public:
    /**
     * \param[in] DctChannelSelectForward input, must be uint8 nchw tensor
     * \param[in] mask_offset input, must be int32 nchw tensor
     * \param[in] mask_val input, must be int32 nchw tensor
     * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
     * \param[out] workspace temporary workspace to perform forward
     */
M
Megvii Engine Team 已提交
209 210 211 212 213 214 215 216 217 218 219 220
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
            _megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;

    void deduce_layout(
            const TensorLayout& src, const TensorLayout& mask_offset,
            const TensorLayout& mask_val, TensorLayout& dst);

    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& mask_offset,
            const TensorLayout& mask_val, const TensorLayout& dst) = 0;
221 222

protected:
M
Megvii Engine Team 已提交
223 224 225 226 227 228 229
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& mask_offset,
            const TensorLayout& mask_val, const TensorLayout& dst);

    void deduce_layout_fwd(
            const TensorLayout& src, const TensorLayout& mask_offset,
            const TensorLayout& mask_val, TensorLayout& dst);
230 231 232 233

    std::string param_msg() const;
};

M
Megvii Engine Team 已提交
234
}  // namespace megdnn
235 236 237 238

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen