opr_impl.h 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
/**
 * \file dnn/src/naive/warp_perspective/opr_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {
namespace naive {

class WarpPerspectiveForwardImpl: public WarpPerspectiveForward {
    protected:
        using Format = Param::Format;
        template <typename ctype, typename mtype>
        struct KernParam {
            Format format;
            BorderMode bmode;
            float border_val;
            size_t n_src, n_mat, c, ih, iw, oh, ow;
            ctype *sptr, *dptr;
            mtype *mptr;
            int *midx_ptr;  //!< can be null
            Workspace workspace;

            static KernParam from_tensors(
                    Format format, BorderMode bmode, float border_val,
                    _megdnn_tensor_in src, _megdnn_tensor_in mat,
                    _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
                    _megdnn_workspace workspace) {
                KernParam ret;
                ret.format = format;
                ret.bmode = bmode;
                ret.border_val = border_val;
                ret.n_src = src.layout.shape[0];
                if (mat_idx.raw_ptr) {
                    megdnn_assert(mat_idx.layout.ndim == 1);
                    ret.n_mat = mat_idx.layout.shape[0];
                    ret.midx_ptr = mat_idx.ptr<int>();
                } else {
                    megdnn_assert(mat_idx.layout.ndim == 0);
                    ret.n_mat = ret.n_src;
                    ret.midx_ptr = nullptr;
                }
                if (format == Format::NCHW) {
                    ret.c = src.layout.shape[1];
                    ret.ih = src.layout.shape[2];
                    ret.iw = src.layout.shape[3];
                    ret.oh = dst.layout.shape[2];
                    ret.ow = dst.layout.shape[3];
                } else if (format == Format::NHWC) {
                    ret.c = src.layout.shape[3];
                    ret.ih = src.layout.shape[1];
                    ret.iw = src.layout.shape[2];
                    ret.oh = dst.layout.shape[1];
                    ret.ow = dst.layout.shape[2];
                } else if (format == Format::NCHW4) {
                    ret.c = src.layout.shape[1] * 4;
                    ret.ih = src.layout.shape[2];
                    ret.iw = src.layout.shape[3];
                    ret.oh = dst.layout.shape[2];
                    ret.ow = dst.layout.shape[3];
                } else {
                    megdnn_assert(format == Format::NHWCD4);
                    ret.c = src.layout.shape[2] * 4;
                    ret.ih = src.layout.shape[1];
                    ret.iw = src.layout.shape[3];
                    ret.oh = dst.layout.shape[1];
                    ret.ow = dst.layout.shape[3];
                }
                if (src.layout.dtype.enumv() == DTypeEnum::Float32 ||
                    MEGDNN_FLOAT16_SELECT(
79 80
                            (src.layout.dtype.enumv() == DTypeEnum::Float16 ||
                             src.layout.dtype.enumv() == DTypeEnum::BFloat16),
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
                            false) ||
                    src.layout.dtype.enumv() == DTypeEnum::Int8 ||
                    src.layout.dtype.enumv() == DTypeEnum::Uint8 ||
                    src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
                    src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
                    ret.sptr = src.compatible_ptr<ctype>();
                    ret.mptr = mat.ptr<mtype>();
                    ret.dptr = dst.compatible_ptr<ctype>();
                } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
                    ret.sptr = src.compatible_ptr<ctype>();
                    ret.mptr = mat.ptr<mtype>();
                    ret.dptr = dst.compatible_ptr<ctype>();
                } else {
                    ret.sptr = nullptr;
                    ret.mptr = nullptr;
                    ret.dptr = nullptr;
                }
                ret.workspace = workspace;
                return ret;
            }
        };

        // ctype: C type of input data type.
        // mtype: C type of transformation matrix data type.
        template <typename ctype, typename mtype>
        void kern_naive(const KernParam<ctype, mtype>& kern_param,
                        size_t task_id);

    public:
        using WarpPerspectiveForward::WarpPerspectiveForward;
        void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
                  _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
                  _megdnn_workspace workspace) override;
        size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
                                      const TensorLayout&,
                                      const TensorLayout&) override {
            return 0;
        }

    private:
        template <typename ctype, typename mtype>
        void kern_naive_nhwcd4(const KernParam<ctype, mtype>& kern_param,
                               size_t task_id);
};

class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData {
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
protected:
    template <typename ctype, typename mtype>
    struct KernParam {
        size_t n, c, ih, iw, oh, ow;
        ctype *sptr, *hptr;
        mtype* mptr;

        static KernParam from_tensors(_megdnn_tensor_in mat,
                                      _megdnn_tensor_in diff,
                                      _megdnn_tensor_out grad) {
            KernParam ret;
            ret.n = grad.layout.shape[0], ret.c = grad.layout.shape[1],
            ret.ih = grad.layout.shape[2], ret.iw = grad.layout.shape[3];
            ret.oh = diff.layout.shape[2], ret.ow = diff.layout.shape[3];
            ret.hptr = diff.ptr<ctype>();
            ret.mptr = mat.ptr<mtype>();
            ret.sptr = grad.ptr<ctype>();
            return ret;
        }
    };

148 149 150 151 152 153 154 155
public:
    using WarpPerspectiveBackwardData::WarpPerspectiveBackwardData;
    void exec(_megdnn_tensor_in mat, _megdnn_tensor_in diff,
              _megdnn_tensor_out grad, _megdnn_workspace workspace) override;
    size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
                                  const TensorLayout&) override {
        return 0;
    }
156 157 158
private:
    template <typename ctype, typename mtype>
    void kern_naive(const KernParam<ctype, mtype>& kern_param);
159 160 161
};

class WarpPerspectiveBackwardMatImpl : public WarpPerspectiveBackwardMat {
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
protected:
    template <typename ctype, typename mtype>
    struct KernParam {
        size_t n, c, ih, iw, oh, ow;
        ctype *sptr, *hptr;
        mtype* mptr, *res;
        float border_val;
        static KernParam from_tensors(float border_val_, _megdnn_tensor_in src,
                                      _megdnn_tensor_in mat,
                                      _megdnn_tensor_in diff,
                                      _megdnn_tensor_out grad) {
            KernParam ret;
            ret.border_val = border_val_;
            ret.n = src.layout.shape[0], ret.c = src.layout.shape[1],
            ret.ih = src.layout.shape[2], ret.iw = src.layout.shape[3];
            ret.oh = diff.layout.shape[2], ret.ow = diff.layout.shape[3];
            ret.hptr = diff.ptr<ctype>();
            ret.mptr = mat.ptr<mtype>();
            ret.sptr = src.ptr<ctype>();
            ret.res = grad.ptr<mtype>();
            return ret;
        }
    };

186 187 188 189 190 191 192 193 194 195
public:
    using WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat;
    void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
              _megdnn_tensor_in diff, _megdnn_tensor_out grad,
              _megdnn_workspace workspace) override;
    size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
                                  const TensorLayout&,
                                  const TensorLayout&) override {
        return 0;
    }
196 197 198 199

private:
    template <typename ctype, typename mtype>
    void kern_naive(const KernParam<ctype, mtype>& kern_param);
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
};

#define UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(p)                         \
    auto N_SRC = p.n_src, N_MAT = p.n_mat, C = p.c, IH = p.ih, IW = p.iw, \
         OH = p.oh, OW = p.ow;                                            \
    ctype* __restrict sptr = p.sptr;                                      \
    mtype* __restrict mptr = p.mptr;                                      \
    ctype* __restrict dptr = p.dptr;                                      \
    int* __restrict midx_ptr = p.midx_ptr;                                \
    auto bmode = p.bmode;                                                 \
    float border_val = p.border_val

}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen