attn_gemm.h 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
 public:
  // (m, n, k) = bsz_seq, output_size, input_size
  AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA,
             bool transB, int bsz_seq, int output_size, int input_size,
             bool compute_bias)
      : dev_ctx_(dev_ctx),
        transA_(transA),
        transB_(transB),
        bsz_seq_(bsz_seq),
        output_size_(output_size),
        input_size_(input_size),
        compute_bias_(compute_bias) {}

  ~AttnMatMul() {}

  void ComputeForward(const T* weight_data, const T* input_data,
                      const T* bias_data, T* output_data, T* bias_out_data) {
    // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
    // here: (transa, transb): nt, input * weight.
    CBLAS_TRANSPOSE transA = CblasNoTrans;
    CBLAS_TRANSPOSE transB = CblasNoTrans;
    if (transA_) {
      transA = CblasTrans;
    }
    if (transB_) {
      transB = CblasTrans;
    }
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);

    // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
    blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
              input_data, weight_data, beta, output_data);
    if (compute_bias_) {
      // compute output + bias
      LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
                            bias_data, bias_out_data);
    }
  }

  void ComputeBackward(const T* input, const T* weight, const T* d_output,
                       T* d_input, T* d_weight, T* d_bias) {
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);

    CBLAS_TRANSPOSE dB_transA = CblasNoTrans;
    CBLAS_TRANSPOSE dB_transB = CblasNoTrans;
    CBLAS_TRANSPOSE dA_transA = CblasNoTrans;
    CBLAS_TRANSPOSE dA_transB = CblasNoTrans;
    int dB_m = 1;
    int dB_n = 1;
    int dB_k = 1;
    int dA_m = 1;
    int dA_n = 1;
    int dA_k = 1;

    T* dB_input_1_ptr = nullptr;
    T* dB_input_2_ptr = nullptr;
    T* dB_output_ptr = d_weight;

    T* dA_input_1_ptr = nullptr;
    T* dA_input_2_ptr = nullptr;
    T* dA_output_ptr = d_input;

    if (!transA_) {
      // fw: gemm-nt
      if (transB_) {
        // bw: gemm-tn, dB = (dC)^t * A
        dB_transA = CblasTrans;
        dB_transB = CblasNoTrans;
        dB_m = output_size_;
        dB_n = input_size_;
        dB_k = bsz_seq_;

        // bw: gemm-nn, dA = dC * B
        dA_transA = CblasNoTrans;
        dA_transB = CblasNoTrans;
        dA_m = bsz_seq_;
        dA_n = input_size_;
        dA_k = output_size_;

        blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
                  input, beta, dB_output_ptr);
        blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
                  weight, beta, dA_output_ptr);
      } else {  // fw: gemm-nn
        // bw: gemm-tn, dB = A^t * dC
        dB_transA = CblasTrans;
        dB_transB = CblasNoTrans;
        dB_m = input_size_;
        dB_n = output_size_;
        dB_k = bsz_seq_;

        // bw: gemm-nt, dA = dC * B^t
        dA_transA = CblasNoTrans;
        dA_transB = CblasTrans;
        dA_m = bsz_seq_;
        dA_n = input_size_;
        dA_k = output_size_;

        blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
                  d_output, beta, dB_output_ptr);
        blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
                  weight, beta, dA_output_ptr);
      }
    } else if (transB_) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "AttnMatMul wrapper do not support (transA=T, transB=T)"
          "parameters."));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "AttnMatMul wrapper do not support (transA=T, transB=N)"
          "parameters."));
    }
    if (compute_bias_) {
      LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
    }
  }

 private:
  const platform::CUDADeviceContext& dev_ctx_;

  bool transA_;
  bool transB_;

  int bsz_seq_;
  int output_size_;
  int input_size_;

  int compute_bias_;
};

}  // namespace operators
}  // namespace paddle