imgproc.h 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
/**
 * \file dnn/include/megdnn/oprs/imgproc.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

class WarpPerspectiveBase: public OperatorBase {
    DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
    DEF_OPR_PARAM(WarpPerspective);
    public:
        using InterpolationMode = Param::InterpolationMode;
        using BorderMode = Param::BorderMode;

    protected:
        void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
                const TensorLayout &dst) {
            check_layout_fwd(src, mat, {}, dst);
        }

        void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
                const TensorLayout &mat_idx, const TensorLayout &dst);
        std::string param_msg() const;
        int get_real_coord(int p, int len);
};

class WarpPerspectiveForward: public WarpPerspectiveBase {
    DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
    public:
        /**
         * \param[in] src (n, channel, in_height, in_width)
         * \param[in] mat (n, 3, 3)
         * \param[out] dst (n, channel, out_height, out_width)
         *
         * \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
         *
         * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
         * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
         *                 (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
         *
         * src and dst can have different shapes, as long as their n and c agree.
         * src, mat and dst should be contiguous.
         */
        void exec(_megdnn_tensor_in src,
                _megdnn_tensor_in mat,
                _megdnn_tensor_out dst,
                _megdnn_workspace workspace) {
            exec(src, mat, {}, dst, workspace);
        }

        /**
         * \p src should have batch size m, and \p mat and \p mat_idx should
         * both have batch size n. Each item in \p mat_idx must be in the range
         * of [0, m-1].
         *
         * \param mat_idx the indices of input image that each matrix in \p mat
         *      should act on. It can also be empty and in such case \p mat
         *      should have the same batch size as \p src.
         */
        virtual void exec(_megdnn_tensor_in src,
                _megdnn_tensor_in mat,
                _megdnn_tensor_in mat_idx,
                _megdnn_tensor_out dst,
                _megdnn_workspace workspace) = 0;

        size_t get_workspace_in_bytes(const TensorLayout &src,
                const TensorLayout &mat,
                const TensorLayout &dst) {
            return get_workspace_in_bytes(src, mat, {}, dst);
        }

        virtual size_t get_workspace_in_bytes(const TensorLayout &src,
                const TensorLayout &mat,
                const TensorLayout &mat_idx,
                const TensorLayout &dst) = 0;
    protected:
        void check_exec(const TensorLayout &src,
                const TensorLayout &mat,
                const TensorLayout &mat_idx,
                const TensorLayout &dst,
                size_t workspace_in_bytes);

        void check_exec_allow_nhwc_mat_idx(const TensorLayout &src,
                const TensorLayout &mat,
                const TensorLayout &mat_idx,
                const TensorLayout &dst,
                size_t workspace_in_bytes);
};
using WarpPerspective = WarpPerspectiveForward;

class WarpPerspectiveBackwardData: public WarpPerspectiveBase {
    DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
    public:
        /**
         * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
         * \param[in] diff the backpropagated gradient wrt. dst
         * \param[out] grad the backpropagated gradient wrt. src
         * \param[out] workspace temporary workspace to perform backward
         */
108 109 110 111 112 113 114
        void exec(_megdnn_tensor_in mat,
                _megdnn_tensor_in diff,
                _megdnn_tensor_out grad,
                _megdnn_workspace workspace) {
            exec(mat, {}, diff, grad, workspace);
        }

115
        virtual void exec(_megdnn_tensor_in mat,
116
                _megdnn_tensor_in mat_idx,
117 118 119
                _megdnn_tensor_in diff,
                _megdnn_tensor_out grad,
                _megdnn_workspace workspace) = 0;
120 121 122 123 124 125 126

        size_t get_workspace_in_bytes(const TensorLayout &mat,
                const TensorLayout &diff,
                const TensorLayout &grad) {
            return get_workspace_in_bytes(mat, {}, diff, grad);
        }

127
        virtual size_t get_workspace_in_bytes(const TensorLayout &mat,
128
                const TensorLayout &mat_idx,
129 130 131 132
                const TensorLayout &diff,
                const TensorLayout &grad) = 0;
    protected:
        void check_exec(const TensorLayout &mat,
133
                const TensorLayout &mat_idx,
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
                const TensorLayout &diff,
                const TensorLayout &grad,
                size_t workspace_in_bytes);
};

class WarpPerspectiveBackwardMat: public WarpPerspectiveBase {
    DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
    public:
        /**
         * \param[in] src the `src' parameter in WarpPerspectiveForward::exec
         * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
         * \param[in] diff the backpropagated gradient wrt. dst
         * \param[out] grad the backpropagated gradient wrt. mat
         * \param[out] workspace temporary workspace to perform backward
         */
149 150 151 152 153 154 155 156
        void exec(_megdnn_tensor_in src,
                _megdnn_tensor_in mat,
                _megdnn_tensor_in diff,
                _megdnn_tensor_out grad,
                _megdnn_workspace workspace) {
            exec(src, mat, {}, diff, grad, workspace);
        }

157 158
        virtual void exec(_megdnn_tensor_in src,
                _megdnn_tensor_in mat,
159
                _megdnn_tensor_in mat_idx,
160 161 162
                _megdnn_tensor_in diff,
                _megdnn_tensor_out grad,
                _megdnn_workspace workspace) = 0;
163 164 165 166 167 168 169 170

        size_t get_workspace_in_bytes(const TensorLayout &src,
                const TensorLayout &mat,
                const TensorLayout &diff,
                const TensorLayout &grad) {
            return get_workspace_in_bytes(src, mat, {}, diff, grad);
        }

171 172
        virtual size_t get_workspace_in_bytes(const TensorLayout &src,
                const TensorLayout &mat,
173
                const TensorLayout &mat_idx,
174 175 176 177 178
                const TensorLayout &diff,
                const TensorLayout &grad) = 0;
    protected:
        void check_exec(const TensorLayout &src,
                const TensorLayout &mat,
179
                const TensorLayout &mat_idx,
180 181 182 183 184 185 186 187 188 189
                const TensorLayout &diff,
                const TensorLayout &grad,
                size_t workspace_in_bytes);
};

} // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen