linalg.h 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

class BatchedMatrixMulForward
        : public OperatorBase,
          public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
    DEF_OPR_PARAM(MatrixMul);
    DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);

public:
    /**
     * \brief C = op(A) * op(B)
     * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
     * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
     * \param C (B, m, n)
     *
     * A, B, C must be 3-dimensional and C must be contiguous. A and B must
     * have stride[2] == 1, and stride[1] >= shape[2],
     * and stride[0] >= shape[1] * stride[1]
     *
     * op(A) = A if transposeA is false, otherwise op(A) = A^t.
     * op(B) = B if transposeB is false, otherwise op(B) = B^t.
     */
M
Megvii Engine Team 已提交
26 27 28
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
29
    MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
30 31
    MGE_WIN_DECLSPEC_FUC void deduce_layout(
            const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
M
Megvii Engine Team 已提交
32 33
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
34

35 36 37 38
    static Algorithm::OprType get_opr_type() {
        return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
    }

39
protected:
M
Megvii Engine Team 已提交
40 41 42
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
};
using BatchedMatrixMul = BatchedMatrixMulForward;

class MatrixMulForward : public OperatorBase,
                         public detail::MultiAlgoOpr<MatrixMulForward, 3> {
    DEF_OPR_PARAM(MatrixMul);
    DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);

public:
    /**
     * \brief C = op(A) * op(B)
     * \param A (m, k) if transposeA is false, (k, m) otherwise
     * \param B (k, n) if transposeB is false, (n, k) otherwise
     * \param C (m, n)
     *
     * A, B, C must be 2-dimensional and C must be contiguous. A and B must
     * have stride[1] == 1, and stride[0] >= shape[1]
     *
     * op(A) = A if transposeA is false, otherwise op(A) = A^t.
     * op(B) = B if transposeB is false, otherwise op(B) = B^t.
     */
M
Megvii Engine Team 已提交
64 65 66
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
67
    MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
68 69
    MGE_WIN_DECLSPEC_FUC void deduce_layout(
            const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
M
Megvii Engine Team 已提交
70 71
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
72

M
Megvii Engine Team 已提交
73
    static size_t pack_size(const Param::Format format);
74 75 76 77 78

    static Algorithm::OprType get_opr_type() {
        return Algorithm::OprType::MATRIX_MUL_FORWARD;
    }

79
protected:
M
Megvii Engine Team 已提交
80 81 82
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
};
using MatrixMul = MatrixMulForward;

/*!
 * \brief compute the inverse of a batch of matrices
 *
 * Input and output tensors have the same shape [..., n, n] where the last two
 * dimensions represent the matrices.
 *
 * Currently only float32 is supported.
 */
class MatrixInverse : public OperatorBase {
    DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);
    DEF_OPR_PARAM(Empty);

public:
M
Megvii Engine Team 已提交
99 100 101
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
102
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
103
    size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
104 105 106 107 108 109 110

protected:
    /*!
     * \brief get canonized params; throw exception on error.
     *
     * Note that \p batch and \p n can be null
     */
M
Megvii Engine Team 已提交
111
    static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
112 113 114 115 116 117 118

    /*!
     * \brief canonize and validate input params for exec() impls
     *
     * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
     * be null
     */
M
Megvii Engine Team 已提交
119 120 121
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            _megdnn_workspace workspace, size_t* batch, size_t* n);
122

M
Megvii Engine Team 已提交
123 124
    virtual size_t get_workspace_in_bytes(
            size_t batch, size_t n, size_t dtype_size) = 0;
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
};

//! inter-product of two vectors
class DotForward : public OperatorBase {
    DEF_OPR_PARAM(Empty);
    DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);

public:
    /**
     * \param[in] A
     * \param[in] B
     * \param[out] C
     *
     * Calculating the dot product of A and B and store it in C.
     * A, B, C must be contiguous. A and B must have the same 1-dimensional
     * shape and non-negative strides. C must be scalar.
     */
M
Megvii Engine Team 已提交
142 143 144
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
145 146
    MGE_WIN_DECLSPEC_FUC void deduce_layout(
            const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
M
Megvii Engine Team 已提交
147 148
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
149 150

protected:
M
Megvii Engine Team 已提交
151 152 153
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
};
using Dot = DotForward;

/*!
 * \brief Compute the singular value decomposition of a batch of matrices
 *
 * Input tensors have the shape [..., m, n], where the last two
 * dimensions represent the matrices. For the output tensor u, s, vt,
 * the following equation holds: u * diag(s) * vt == src.
 *
 * Currently only float32 is supported.
 */
class SVDForward : public OperatorBase {
    DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);
    DEF_OPR_PARAM(SVD);

public:
    /**
     * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
     * \param src (..., m, n) The input tensor, let p = min(m, n)
     * \param u (..., m, p) if full_matrices is false,
                (..., m, m) if full_matrices is true,
                empty tensor if compute_uv is false.
                The left singular vector.

     * \param s (..., p) The singular values.
     * \param vt (..., p, n) if full_matrices is false,
                 (..., n, n) if full_matrices is true,
                 empty tensor if compute_uv is false.
                 The right singular vector.
     *
     * src must be contiguous. The computation might be significantly faster
     * if compute_uv is false (default to true).
     *
     */
M
Megvii Engine Team 已提交
189 190 191 192 193 194 195 196 197
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
            _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
    void deduce_layout(
            const TensorLayout& src, TensorLayout& u, TensorLayout& s,
            TensorLayout& vt);
    size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
            const TensorLayout& vt);
198 199

protected:
M
Megvii Engine Team 已提交
200 201 202 203 204 205 206
    static void canonize_params(
            const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
    virtual size_t get_workspace_in_bytes(
            size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
    void check_exec(
            const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
            const TensorLayout& vt, size_t workspace_in_bytes);
207 208 209 210 211 212 213 214 215
};

using SVD = SVDForward;

}  // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen