attn_gemm.h 10.8 KB
Newer Older
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9 10 11 12 13 14 15 16
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
18 19
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
20
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
21 22
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
23
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
24
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
L
Li Min 已提交
25

26 27
namespace paddle {
namespace operators {
28

29 30 31 32 33
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
 public:
  // (m, n, k) = bsz_seq, output_size, input_size
L
Leo Chen 已提交
34
  AttnMatMul(const phi::GPUContext& dev_ctx,
35 36 37 38 39
             bool transA,
             bool transB,
             int bsz_seq,
             int output_size,
             int input_size,
40 41 42 43 44 45 46 47 48
             bool compute_bias)
      : dev_ctx_(dev_ctx),
        transA_(transA),
        transB_(transB),
        bsz_seq_(bsz_seq),
        output_size_(output_size),
        input_size_(input_size),
        compute_bias_(compute_bias) {}

49 50 51 52
  void ComputeForward(const phi::DenseTensor* weight,
                      const phi::DenseTensor* input,
                      const phi::DenseTensor* bias,
                      phi::DenseTensor* output,
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
                      phi::DenseTensor* bias_out,
                      bool fused = false) {
    VLOG(6) << "input.shape={" << input->dims() << "}, weight.shape={"
            << weight->dims() << "}, output.shape={" << output->dims()
            << "}, batch_size=" << bsz_seq_ << ", output_size=" << output_size_
            << ", input_size=" << input_size_ << ", transA=" << transA_
            << ", transB=" << transB_ << ", compute_bias=" << compute_bias_
            << ", fused=" << fused;

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
    if (compute_bias_ && fused) {
      PADDLE_ENFORCE_EQ(
          !output || output == bias_out,
          true,
          phi::errors::InvalidArgument(
              "The output (= input * weight) is expected to be nullptr or the "
              "same as bias_out when fused is true."));
70

71 72 73 74 75 76 77 78 79
      auto fused_impl =
          phi::funcs::MatmulPlanner(vectorize(input->dims()),
                                    vectorize(weight->dims()),
                                    transA_,
                                    transB_,
                                    phi::CppTypeToDataType<T>::Type(),
                                    phi::funcs::MatmulFusedType::kMatmulBias,
                                    static_cast<const void*>(bias->data<T>()),
                                    nullptr);
80 81 82 83 84 85 86 87 88 89
      phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx_,
                                             input->data<T>(),
                                             weight->data<T>(),
                                             bias_out->data<T>(),
                                             bsz_seq_,      // M
                                             output_size_,  // N
                                             input_size_,   // K
                                             transA_,
                                             transB_,
                                             &fused_impl);
90 91 92 93
      return;
    }
#endif

94 95
    // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
    // here: (transa, transb): nt, input * weight.
96 97
    CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans;
    CBLAS_TRANSPOSE transB = transB_ ? CblasTrans : CblasNoTrans;
98 99 100
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);

101
    // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
L
Leo Chen 已提交
102
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
103 104 105 106 107 108 109 110 111 112
    blas.GEMM(transA,
              transB,
              bsz_seq_,
              output_size_,
              input_size_,
              alpha,
              input->data<T>(),
              weight->data<T>(),
              beta,
              output->data<T>());
113
    if (compute_bias_) {
114
      // bias_out = output + bias
115 116
      std::vector<const phi::DenseTensor*> ins = {output, bias};
      std::vector<phi::DenseTensor*> outs = {bias_out};
117
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
118
          dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
119 120 121
    }
  }

122 123 124 125 126 127
  void ComputeBackward(const phi::DenseTensor* input,
                       const phi::DenseTensor* weight,
                       const phi::DenseTensor* d_output,
                       phi::DenseTensor* d_input,
                       phi::DenseTensor* d_weight,
                       phi::DenseTensor* d_bias,
128 129 130 131
                       bool use_addto = false,
                       bool fused = false) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
    if (compute_bias_ && fused) {
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
      phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
                                                      d_output,
                                                      input,
                                                      weight,
                                                      nullptr,
                                                      bsz_seq_,      // M
                                                      output_size_,  // N
                                                      input_size_,   // K
                                                      transA_,
                                                      transB_,
                                                      "none",
                                                      d_input,
                                                      d_weight,
                                                      d_bias,
                                                      use_addto);
147 148 149 150
      return;
    }
#endif

151
    T alpha = static_cast<T>(1.0);
152 153
    T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0);
    T beta_dB = static_cast<T>(0.0);
154

L
Leo Chen 已提交
155
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
156
    if (!transA_) {
157
      // forward: gemm-nt
158
      if (transB_) {
159 160 161 162 163 164 165
        // backward: gemm-tn, dB = (dC)^T * A
        if (d_weight) {
          int dB_m = output_size_;
          int dB_n = input_size_;
          int dB_k = bsz_seq_;

          T* dB_output_ptr = d_weight->data<T>();
166 167 168 169 170 171 172 173 174
          blas.GEMM(CblasTrans,
                    CblasNoTrans,
                    dB_m,
                    dB_n,
                    dB_k,
                    alpha,
                    d_output->data<T>(),
                    input->data<T>(),
                    beta_dB,
175 176 177 178 179 180 181 182 183 184
                    dB_output_ptr);
        }

        // backward: gemm-nn, dA = dC * B
        if (d_input) {
          int dA_m = bsz_seq_;
          int dA_n = input_size_;
          int dA_k = output_size_;

          T* dA_output_ptr = d_input->data<T>();
185 186 187 188 189 190 191 192 193
          blas.GEMM(CblasNoTrans,
                    CblasNoTrans,
                    dA_m,
                    dA_n,
                    dA_k,
                    alpha,
                    d_output->data<T>(),
                    weight->data<T>(),
                    beta_dA,
194 195
                    dA_output_ptr);
        }
196
      } else {  // fw: gemm-nn
197 198 199 200 201 202 203
        // backward: gemm-tn, dB = A^T * dC
        if (d_weight) {
          int dB_m = input_size_;
          int dB_n = output_size_;
          int dB_k = bsz_seq_;

          T* dB_output_ptr = d_weight->data<T>();
204 205 206 207 208 209 210 211 212
          blas.GEMM(CblasTrans,
                    CblasNoTrans,
                    dB_m,
                    dB_n,
                    dB_k,
                    alpha,
                    input->data<T>(),
                    d_output->data<T>(),
                    beta_dB,
213 214 215 216 217 218 219 220 221 222
                    dB_output_ptr);
        }

        // backward: gemm-nt, dA = dC * B^T
        if (d_input) {
          int dA_m = bsz_seq_;
          int dA_n = input_size_;
          int dA_k = output_size_;

          T* dA_output_ptr = d_input->data<T>();
223 224 225 226 227 228 229 230 231
          blas.GEMM(CblasNoTrans,
                    CblasTrans,
                    dA_m,
                    dA_n,
                    dA_k,
                    alpha,
                    d_output->data<T>(),
                    weight->data<T>(),
                    beta_dA,
232 233
                    dA_output_ptr);
        }
234 235 236
      }
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
237
          "AttnMatMul wrapper do not support (transA=T, transB=T/N)"
238 239
          "parameters."));
    }
240 241 242
    if (compute_bias_ && d_bias) {
      // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3}
      // -> {3} or {0,1,2,3,4} -> {3,4}
L
Li Min 已提交
243 244 245 246 247 248 249 250 251 252
      const auto input_dims = d_output->dims();
      const auto output_dims = d_bias->dims();
      bool support_case_1 =
          (input_dims.size() == 5 && output_dims.size() == 3 &&
           (input_dims[2] == output_dims[0]) &&
           (input_dims[3] == output_dims[1]) &&
           (input_dims[4] == output_dims[2]));
      bool support_case_2 =
          (input_dims.size() == 3 && output_dims.size() == 1 &&
           (input_dims[2] == output_dims[0]));
253 254 255 256 257 258 259 260
      bool support_case_3 =
          (input_dims.size() == 4 && output_dims.size() == 1 &&
           input_dims[3] == output_dims[0]);
      bool support_case_4 =
          (input_dims.size() == 5 && output_dims.size() == 2 &&
           input_dims[3] == output_dims[0] && input_dims[4] == output_dims[1]);

      gpuStream_t stream = dev_ctx_.stream();
L
Li Min 已提交
261
      if (support_case_1 || support_case_2) {
262
        TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
263 264 265 266 267
            dev_ctx_,
            *d_output,
            d_bias,
            kps::IdentityFunctor<T>(),
            {0, 1},
W
Wilber 已提交
268
            stream);
269 270
      } else if (support_case_3 || support_case_4) {
        TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
271 272 273 274 275
            dev_ctx_,
            *d_output,
            d_bias,
            kps::IdentityFunctor<T>(),
            {0, 1, 2},
276
            stream);
L
Li Min 已提交
277 278 279 280 281 282
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Only support reduce when the input dims are [0,1,2,3,4] and "
            "output is [2,3,4]"
            "or input is [0,1,2] and output is [2]."));
      }
283 284 285 286
    }
  }

 private:
L
Leo Chen 已提交
287
  const phi::GPUContext& dev_ctx_;
288 289 290 291 292 293 294 295 296 297 298 299 300

  bool transA_;
  bool transB_;

  int bsz_seq_;
  int output_size_;
  int input_size_;

  int compute_bias_;
};

}  // namespace operators
}  // namespace paddle