opr_impl.h 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#pragma once
#include <memory>
#include "megdnn/oprs.h"
#include "megdnn/oprs/cv.h"
#include "megdnn/oprs/general.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"

namespace megdnn {
namespace naive {

class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward {
public:
    using MultiHeadAttnForward::MultiHeadAttnForward;
    void exec(
            _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
            _megdnn_tensor_in wqkv, _megdnn_tensor_out out,
            _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override;
    size_t get_workspace_in_bytes(
            const TensorLayout& queries, const TensorLayout& keys,
            const TensorLayout& values, const TensorLayout& wqkv,
            const TensorLayout& out, const TensorLayout& reserveSpace) override;
    size_t get_reservespace_in_bytes(
            const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
            const TensorLayout& /*values*/, const TensorLayout& /*wqkv*/,
            const TensorLayout& /*out*/,
            const TensorLayout& /*reserveSpace*/) override {
        return 0;
    }
};

class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
public:
    using MultiHeadAttnBackward::MultiHeadAttnBackward;
    void exec(
            _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
            _megdnn_tensor_in values, _megdnn_tensor_in wqkv,
            _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
            _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
            _megdnn_tensor_out dweights, _megdnn_workspace workspace) override;
    size_t get_workspace_in_bytes(
            const TensorLayout& /*diff*/, const TensorLayout& /* queries*/,
            const TensorLayout& /*keyes*/, const TensorLayout& /* values*/,
            const TensorLayout& /*wqkv*/, const TensorLayout& /* reserveSpace*/,
            const TensorLayout& /*dqueries*/, const TensorLayout& /* dkeyes*/,
            const TensorLayout& /*dvalues*/,
            const TensorLayout& /* dweights*/) override {
        return 0;
    }
};

}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen