opr_impl.h 9.7 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/naive/warp_perspective/opr_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18
 */
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {
namespace naive {

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
class WarpPerspectiveForwardImpl : public WarpPerspectiveForward {
protected:
    using Format = Param::Format;
    template <typename ctype, typename mtype>
    struct KernParam {
        Format format;
        BorderMode bmode;
        float border_val;
        size_t n_src, n_mat, c, ih, iw, oh, ow;
        ctype *sptr, *dptr;
        mtype* mptr;
        int* midx_ptr;  //!< can be null
        Workspace workspace;

        static KernParam from_tensors(Format format, BorderMode bmode,
                                      float border_val, _megdnn_tensor_in src,
                                      _megdnn_tensor_in mat,
                                      _megdnn_tensor_in mat_idx,
                                      _megdnn_tensor_out dst,
                                      _megdnn_workspace workspace) {
            KernParam ret;
            ret.format = format;
            ret.bmode = bmode;
            ret.border_val = border_val;
            ret.n_src = src.layout.shape[0];
            if (mat_idx.raw_ptr) {
                megdnn_assert(mat_idx.layout.ndim == 1);
                ret.n_mat = mat_idx.layout.shape[0];
                ret.midx_ptr = mat_idx.ptr<int>();
            } else {
                megdnn_assert(mat_idx.layout.ndim == 0);
                ret.n_mat = ret.n_src;
                ret.midx_ptr = nullptr;
            }
            if (format == Format::NCHW) {
                ret.c = src.layout.shape[1];
                ret.ih = src.layout.shape[2];
                ret.iw = src.layout.shape[3];
                ret.oh = dst.layout.shape[2];
                ret.ow = dst.layout.shape[3];
            } else if (format == Format::NHWC) {
                ret.c = src.layout.shape[3];
                ret.ih = src.layout.shape[1];
                ret.iw = src.layout.shape[2];
                ret.oh = dst.layout.shape[1];
                ret.ow = dst.layout.shape[2];
            } else if (format == Format::NCHW4) {
                ret.c = src.layout.shape[1] * 4;
                ret.ih = src.layout.shape[2];
                ret.iw = src.layout.shape[3];
                ret.oh = dst.layout.shape[2];
                ret.ow = dst.layout.shape[3];
            } else {
                megdnn_assert(format == Format::NHWCD4);
                ret.c = src.layout.shape[2] * 4;
                ret.ih = src.layout.shape[1];
                ret.iw = src.layout.shape[3];
                ret.oh = dst.layout.shape[1];
                ret.ow = dst.layout.shape[3];
            }
            if (src.layout.dtype.enumv() == DTypeEnum::Float32 ||
                MEGDNN_FLOAT16_SELECT(
                        (src.layout.dtype.enumv() == DTypeEnum::Float16 ||
                         src.layout.dtype.enumv() == DTypeEnum::BFloat16),
                        false) ||
                src.layout.dtype.enumv() == DTypeEnum::Int8 ||
                src.layout.dtype.enumv() == DTypeEnum::Uint8 ||
                src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
                src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
                ret.sptr = src.compatible_ptr<ctype>();
                ret.mptr = mat.ptr<mtype>();
                ret.dptr = dst.compatible_ptr<ctype>();
            } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
                ret.sptr = src.compatible_ptr<ctype>();
                ret.mptr = mat.ptr<mtype>();
                ret.dptr = dst.compatible_ptr<ctype>();
            } else {
                ret.sptr = nullptr;
                ret.mptr = nullptr;
                ret.dptr = nullptr;
99
            }
100 101
            ret.workspace = workspace;
            return ret;
102
        }
103
    };
104

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    // ctype: C type of input data type.
    // mtype: C type of transformation matrix data type.
    template <typename ctype, typename mtype>
    void kern_naive(const KernParam<ctype, mtype>& kern_param, size_t task_id);

public:
    using WarpPerspectiveForward::WarpPerspectiveForward;
    void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
              _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
              _megdnn_workspace workspace) override;
    size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
                                  const TensorLayout&,
                                  const TensorLayout&) override {
        return 0;
    }

private:
    template <typename ctype, typename mtype>
    void kern_naive_nhwcd4(const KernParam<ctype, mtype>& kern_param,
                           size_t task_id);
125 126 127
};

class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData {
128 129 130
protected:
    template <typename ctype, typename mtype>
    struct KernParam {
131
        size_t n_src, n_mat, c, ih, iw, oh, ow;
132 133
        ctype *sptr, *hptr;
        mtype* mptr;
134
        int* midx_ptr;  //!< can be null
135 136

        static KernParam from_tensors(_megdnn_tensor_in mat,
137
                                      _megdnn_tensor_in mat_idx,
138 139 140
                                      _megdnn_tensor_in diff,
                                      _megdnn_tensor_out grad) {
            KernParam ret;
141
            ret.n_src = grad.layout.shape[0], ret.c = grad.layout.shape[1];
142 143 144 145 146
            ret.ih = grad.layout.shape[2], ret.iw = grad.layout.shape[3];
            ret.oh = diff.layout.shape[2], ret.ow = diff.layout.shape[3];
            ret.hptr = diff.ptr<ctype>();
            ret.mptr = mat.ptr<mtype>();
            ret.sptr = grad.ptr<ctype>();
147 148 149 150 151 152 153 154 155
            if (mat_idx.raw_ptr) {
                megdnn_assert(mat_idx.layout.ndim == 1);
                ret.n_mat = mat_idx.layout.shape[0];
                ret.midx_ptr = mat_idx.ptr<int>();
            } else {
                megdnn_assert(mat_idx.layout.ndim == 0);
                ret.n_mat = ret.n_src;
                ret.midx_ptr = nullptr;
            }
156 157 158 159
            return ret;
        }
    };

160 161
public:
    using WarpPerspectiveBackwardData::WarpPerspectiveBackwardData;
162 163 164
    void exec(_megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
              _megdnn_tensor_in diff, _megdnn_tensor_out grad,
              _megdnn_workspace workspace) override;
165
    size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
166
                                  const TensorLayout&,
167 168 169
                                  const TensorLayout&) override {
        return 0;
    }
170

171 172 173
private:
    template <typename ctype, typename mtype>
    void kern_naive(const KernParam<ctype, mtype>& kern_param);
174 175 176
};

class WarpPerspectiveBackwardMatImpl : public WarpPerspectiveBackwardMat {
177 178 179
protected:
    template <typename ctype, typename mtype>
    struct KernParam {
180
        size_t n_src, n_mat, c, ih, iw, oh, ow;
181
        ctype *sptr, *hptr;
182 183
        mtype *mptr, *res;
        int* midx_ptr;  //!< can be null
184
        float border_val;
185

186 187
        static KernParam from_tensors(float border_val_, _megdnn_tensor_in src,
                                      _megdnn_tensor_in mat,
188
                                      _megdnn_tensor_in mat_idx,
189 190 191 192
                                      _megdnn_tensor_in diff,
                                      _megdnn_tensor_out grad) {
            KernParam ret;
            ret.border_val = border_val_;
193
            ret.n_src = src.layout.shape[0], ret.c = src.layout.shape[1];
194 195 196 197 198 199
            ret.ih = src.layout.shape[2], ret.iw = src.layout.shape[3];
            ret.oh = diff.layout.shape[2], ret.ow = diff.layout.shape[3];
            ret.hptr = diff.ptr<ctype>();
            ret.mptr = mat.ptr<mtype>();
            ret.sptr = src.ptr<ctype>();
            ret.res = grad.ptr<mtype>();
200 201 202 203 204 205 206 207 208
            if (mat_idx.raw_ptr) {
                megdnn_assert(mat_idx.layout.ndim == 1);
                ret.n_mat = mat_idx.layout.shape[0];
                ret.midx_ptr = mat_idx.ptr<int>();
            } else {
                megdnn_assert(mat_idx.layout.ndim == 0);
                ret.n_mat = ret.n_src;
                ret.midx_ptr = nullptr;
            }
209 210 211 212
            return ret;
        }
    };

213 214 215
public:
    using WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat;
    void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
216 217
              _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
              _megdnn_tensor_out grad, _megdnn_workspace workspace) override;
218
    size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
219
                                  const TensorLayout&, const TensorLayout&,
220 221 222
                                  const TensorLayout&) override {
        return 0;
    }
223 224 225 226

private:
    template <typename ctype, typename mtype>
    void kern_naive(const KernParam<ctype, mtype>& kern_param);
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
};

#define UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(p)                         \
    auto N_SRC = p.n_src, N_MAT = p.n_mat, C = p.c, IH = p.ih, IW = p.iw, \
         OH = p.oh, OW = p.ow;                                            \
    ctype* __restrict sptr = p.sptr;                                      \
    mtype* __restrict mptr = p.mptr;                                      \
    ctype* __restrict dptr = p.dptr;                                      \
    int* __restrict midx_ptr = p.midx_ptr;                                \
    auto bmode = p.bmode;                                                 \
    float border_val = p.border_val

}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen