proxy_fw.h 1.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#pragma once
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/general.h"
#include "src/common/multi_head_attn/helper.h"
#include "src/common/multi_head_attn/proxy_forward_base.h"
#include "src/common/reduce_helper.h"
#include "src/common/utils.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/opr_impl.h"
#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {

using Param = megdnn::MultiHeadAttn::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;
using multi_head_attn::matmul_deduce_layout;
using multi_head_attn::matmul_exec;

class MHAForwardProxyOpr final : public multi_head_attn::MHAForwardProxyBase {
public:
    MHAForwardProxyOpr() : MHAForwardProxyBase() {}

#define cb(DType)               \
    void move_scaler_to_device( \
            Handle*, DTypeTrait<DType>::ctype*, DTypeTrait<DType>::ctype*) override;
    MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
};
}  // namespace cuda
}  // namespace megdnn
   // vim: syntax=cpp.doxygen