attn_gemm.h 7.8 KB
Newer Older
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9 10 11 12 13 14 15 16
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

L
Li Min 已提交
17
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
18
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
19 20
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
21 22
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
L
Li Min 已提交
23

24 25
namespace paddle {
namespace operators {
26 27

using Tensor = framework::Tensor;
28 29 30 31 32
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
 public:
  // (m, n, k) = bsz_seq, output_size, input_size
33 34 35 36 37 38
  AttnMatMul(const platform::CUDADeviceContext& dev_ctx,
             bool transA,
             bool transB,
             int bsz_seq,
             int output_size,
             int input_size,
39 40 41 42 43 44 45 46 47 48 49
             bool compute_bias)
      : dev_ctx_(dev_ctx),
        transA_(transA),
        transB_(transB),
        bsz_seq_(bsz_seq),
        output_size_(output_size),
        input_size_(input_size),
        compute_bias_(compute_bias) {}

  ~AttnMatMul() {}

L
Li Min 已提交
50 51
  void ComputeForward(const framework::Tensor* weight,
                      const framework::Tensor* input,
52 53
                      const framework::Tensor* bias,
                      framework::Tensor* output,
L
Li Min 已提交
54
                      framework::Tensor* bias_out) {
55 56
    // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
    // here: (transa, transb): nt, input * weight.
57 58
    CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans;
    CBLAS_TRANSPOSE transB = transB_ ? CblasTrans : CblasNoTrans;
59 60 61
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);

62
    // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
63
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
64 65 66 67 68 69 70 71 72 73
    blas.GEMM(transA,
              transB,
              bsz_seq_,
              output_size_,
              input_size_,
              alpha,
              input->data<T>(),
              weight->data<T>(),
              beta,
              output->data<T>());
74
    if (compute_bias_) {
75 76 77
      // bias_out = output + bias
      std::vector<const Tensor*> ins = {output, bias};
      std::vector<Tensor*> outs = {bias_out};
78
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
79
          dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
80 81 82
    }
  }

L
Li Min 已提交
83 84 85
  void ComputeBackward(const framework::Tensor* input,
                       const framework::Tensor* weight,
                       const framework::Tensor* d_output,
86 87 88 89
                       framework::Tensor* d_input,
                       framework::Tensor* d_weight,
                       framework::Tensor* d_bias,
                       bool use_addto = false) {
90
    T alpha = static_cast<T>(1.0);
91 92
    T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0);
    T beta_dB = static_cast<T>(0.0);
93

94
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
95
    if (!transA_) {
96
      // forward: gemm-nt
97
      if (transB_) {
98 99 100 101 102 103 104
        // backward: gemm-tn, dB = (dC)^T * A
        if (d_weight) {
          int dB_m = output_size_;
          int dB_n = input_size_;
          int dB_k = bsz_seq_;

          T* dB_output_ptr = d_weight->data<T>();
105 106 107 108 109 110 111 112 113
          blas.GEMM(CblasTrans,
                    CblasNoTrans,
                    dB_m,
                    dB_n,
                    dB_k,
                    alpha,
                    d_output->data<T>(),
                    input->data<T>(),
                    beta_dB,
114 115 116 117 118 119 120 121 122 123
                    dB_output_ptr);
        }

        // backward: gemm-nn, dA = dC * B
        if (d_input) {
          int dA_m = bsz_seq_;
          int dA_n = input_size_;
          int dA_k = output_size_;

          T* dA_output_ptr = d_input->data<T>();
124 125 126 127 128 129 130 131 132
          blas.GEMM(CblasNoTrans,
                    CblasNoTrans,
                    dA_m,
                    dA_n,
                    dA_k,
                    alpha,
                    d_output->data<T>(),
                    weight->data<T>(),
                    beta_dA,
133 134
                    dA_output_ptr);
        }
135
      } else {  // fw: gemm-nn
136 137 138 139 140 141 142
        // backward: gemm-tn, dB = A^T * dC
        if (d_weight) {
          int dB_m = input_size_;
          int dB_n = output_size_;
          int dB_k = bsz_seq_;

          T* dB_output_ptr = d_weight->data<T>();
143 144 145 146 147 148 149 150 151
          blas.GEMM(CblasTrans,
                    CblasNoTrans,
                    dB_m,
                    dB_n,
                    dB_k,
                    alpha,
                    input->data<T>(),
                    d_output->data<T>(),
                    beta_dB,
152 153 154 155 156 157 158 159 160 161
                    dB_output_ptr);
        }

        // backward: gemm-nt, dA = dC * B^T
        if (d_input) {
          int dA_m = bsz_seq_;
          int dA_n = input_size_;
          int dA_k = output_size_;

          T* dA_output_ptr = d_input->data<T>();
162 163 164 165 166 167 168 169 170
          blas.GEMM(CblasNoTrans,
                    CblasTrans,
                    dA_m,
                    dA_n,
                    dA_k,
                    alpha,
                    d_output->data<T>(),
                    weight->data<T>(),
                    beta_dA,
171 172
                    dA_output_ptr);
        }
173 174 175
      }
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
176
          "AttnMatMul wrapper do not support (transA=T, transB=T/N)"
177 178
          "parameters."));
    }
179 180 181
    if (compute_bias_ && d_bias) {
      // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3}
      // -> {3} or {0,1,2,3,4} -> {3,4}
L
Li Min 已提交
182 183 184 185 186 187 188 189 190 191
      const auto input_dims = d_output->dims();
      const auto output_dims = d_bias->dims();
      bool support_case_1 =
          (input_dims.size() == 5 && output_dims.size() == 3 &&
           (input_dims[2] == output_dims[0]) &&
           (input_dims[3] == output_dims[1]) &&
           (input_dims[4] == output_dims[2]));
      bool support_case_2 =
          (input_dims.size() == 3 && output_dims.size() == 1 &&
           (input_dims[2] == output_dims[0]));
192 193 194 195 196 197 198 199
      bool support_case_3 =
          (input_dims.size() == 4 && output_dims.size() == 1 &&
           input_dims[3] == output_dims[0]);
      bool support_case_4 =
          (input_dims.size() == 5 && output_dims.size() == 2 &&
           input_dims[3] == output_dims[0] && input_dims[4] == output_dims[1]);

      gpuStream_t stream = dev_ctx_.stream();
L
Li Min 已提交
200
      if (support_case_1 || support_case_2) {
201
        TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
202 203 204 205 206
            dev_ctx_,
            *d_output,
            d_bias,
            kps::IdentityFunctor<T>(),
            {0, 1},
W
Wilber 已提交
207
            stream);
208 209
      } else if (support_case_3 || support_case_4) {
        TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
210 211 212 213 214
            dev_ctx_,
            *d_output,
            d_bias,
            kps::IdentityFunctor<T>(),
            {0, 1, 2},
215
            stream);
L
Li Min 已提交
216 217 218 219 220 221
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Only support reduce when the input dims are [0,1,2,3,4] and "
            "output is [2,3,4]"
            "or input is [0,1,2] and output is [2]."));
      }
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    }
  }

 private:
  const platform::CUDADeviceContext& dev_ctx_;

  bool transA_;
  bool transB_;

  int bsz_seq_;
  int output_size_;
  int input_size_;

  int compute_bias_;
};

}  // namespace operators
}  // namespace paddle