linalg.h 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

class BatchedMatrixMulForward
        : public OperatorBase,
          public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
    DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);

     * \brief C = op(A) * op(B)
     * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
     * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
     * \param C (B, m, n)
     * A, B, C must be 3-dimensional and C must be contiguous. A and B must
     * have stride[2] == 1, and stride[1] >= shape[2],
     * and stride[0] >= shape[1] * stride[1]
     * op(A) = A if transposeA is false, otherwise op(A) = A^t.
     * op(B) = B if transposeB is false, otherwise op(B) = B^t.
Megvii Engine Team 已提交
26 27 28
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
    MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
Megvii Engine Team 已提交
30 31 32
    void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;

34 35 36 37
    static Algorithm::OprType get_opr_type() {
        return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;

Megvii Engine Team 已提交
39 40 41
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
using BatchedMatrixMul = BatchedMatrixMulForward;

class MatrixMulForward : public OperatorBase,
                         public detail::MultiAlgoOpr<MatrixMulForward, 3> {
    DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);

     * \brief C = op(A) * op(B)
     * \param A (m, k) if transposeA is false, (k, m) otherwise
     * \param B (k, n) if transposeB is false, (n, k) otherwise
     * \param C (m, n)
     * A, B, C must be 2-dimensional and C must be contiguous. A and B must
     * have stride[1] == 1, and stride[0] >= shape[1]
     * op(A) = A if transposeA is false, otherwise op(A) = A^t.
     * op(B) = B if transposeB is false, otherwise op(B) = B^t.
Megvii Engine Team 已提交
63 64 65
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
    MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C);
Megvii Engine Team 已提交
67 68 69
    void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;

Megvii Engine Team 已提交
    static size_t pack_size(const Param::Format format);
72 73 74 75 76

    static Algorithm::OprType get_opr_type() {
        return Algorithm::OprType::MATRIX_MUL_FORWARD;

Megvii Engine Team 已提交
78 79 80
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
using MatrixMul = MatrixMulForward;

 * \brief compute the inverse of a batch of matrices
 * Input and output tensors have the same shape [..., n, n] where the last two
 * dimensions represent the matrices.
 * Currently only float32 is supported.
class MatrixInverse : public OperatorBase {
    DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);

Megvii Engine Team 已提交
97 98 99
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
Megvii Engine Team 已提交
    size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
102 103 104 105 106 107 108

     * \brief get canonized params; throw exception on error.
     * Note that \p batch and \p n can be null
Megvii Engine Team 已提交
    static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
110 111 112 113 114 115 116

     * \brief canonize and validate input params for exec() impls
     * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
     * be null
Megvii Engine Team 已提交
117 118 119
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            _megdnn_workspace workspace, size_t* batch, size_t* n);

Megvii Engine Team 已提交
121 122
    virtual size_t get_workspace_in_bytes(
            size_t batch, size_t n, size_t dtype_size) = 0;
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

//! inter-product of two vectors
class DotForward : public OperatorBase {
    DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);

     * \param[in] A
     * \param[in] B
     * \param[out] C
     * Calculating the dot product of A and B and store it in C.
     * A, B, C must be contiguous. A and B must have the same 1-dimensional
     * shape and non-negative strides. C must be scalar.
Megvii Engine Team 已提交
140 141 142
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
143 144
    MGE_WIN_DECLSPEC_FUC void deduce_layout(
            const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
Megvii Engine Team 已提交
145 146
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
147 148

Megvii Engine Team 已提交
149 150 151
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
using Dot = DotForward;

 * \brief Compute the singular value decomposition of a batch of matrices
 * Input tensors have the shape [..., m, n], where the last two
 * dimensions represent the matrices. For the output tensor u, s, vt,
 * the following equation holds: u * diag(s) * vt == src.
 * Currently only float32 is supported.
class SVDForward : public OperatorBase {
    DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);

     * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
     * \param src (..., m, n) The input tensor, let p = min(m, n)
     * \param u (..., m, p) if full_matrices is false,
                (..., m, m) if full_matrices is true,
                empty tensor if compute_uv is false.
                The left singular vector.

     * \param s (..., p) The singular values.
     * \param vt (..., p, n) if full_matrices is false,
                 (..., n, n) if full_matrices is true,
                 empty tensor if compute_uv is false.
                 The right singular vector.
     * src must be contiguous. The computation might be significantly faster
     * if compute_uv is false (default to true).
Megvii Engine Team 已提交
187 188 189 190 191 192 193 194 195
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
            _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
    void deduce_layout(
            const TensorLayout& src, TensorLayout& u, TensorLayout& s,
            TensorLayout& vt);
    size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
            const TensorLayout& vt);
196 197

Megvii Engine Team 已提交
198 199 200 201 202 203 204
    static void canonize_params(
            const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
    virtual size_t get_workspace_in_bytes(
            size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
    void check_exec(
            const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
            const TensorLayout& vt, size_t workspace_in_bytes);
205 206 207 208 209 210 211 212 213

using SVD = SVDForward;

}  // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen