attn_gemm.h 7.0 KB
Newer Older
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9 10 11 12 13 14 15 16 17
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/platform/float16.h"
18
#include "paddle/phi/kernels/funcs/blas/blas.h"
19
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
20

L
Li Min 已提交
21
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
22
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
23 24
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
L
Li Min 已提交
25

26 27
namespace paddle {
namespace operators {
28 29

using Tensor = framework::Tensor;
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
 public:
  // (m, n, k) = bsz_seq, output_size, input_size
  AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA,
             bool transB, int bsz_seq, int output_size, int input_size,
             bool compute_bias)
      : dev_ctx_(dev_ctx),
        transA_(transA),
        transB_(transB),
        bsz_seq_(bsz_seq),
        output_size_(output_size),
        input_size_(input_size),
        compute_bias_(compute_bias) {}

  ~AttnMatMul() {}

L
Li Min 已提交
48 49 50 51
  void ComputeForward(const framework::Tensor* weight,
                      const framework::Tensor* input,
                      const framework::Tensor* bias, framework::Tensor* output,
                      framework::Tensor* bias_out) {
52 53
    // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
    // here: (transa, transb): nt, input * weight.
54 55
    CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans;
    CBLAS_TRANSPOSE transB = transB_ ? CblasTrans : CblasNoTrans;
56 57 58
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);

59
    // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
60
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
61
    blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
L
Li Min 已提交
62
              input->data<T>(), weight->data<T>(), beta, output->data<T>());
63
    if (compute_bias_) {
64 65 66
      // bias_out = output + bias
      std::vector<const Tensor*> ins = {output, bias};
      std::vector<Tensor*> outs = {bias_out};
67
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
68
          dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
69 70 71
    }
  }

L
Li Min 已提交
72 73 74 75
  void ComputeBackward(const framework::Tensor* input,
                       const framework::Tensor* weight,
                       const framework::Tensor* d_output,
                       framework::Tensor* d_input, framework::Tensor* d_weight,
76
                       framework::Tensor* d_bias, bool use_addto = false) {
77
    T alpha = static_cast<T>(1.0);
78 79
    T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0);
    T beta_dB = static_cast<T>(0.0);
80

81
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
82
    if (!transA_) {
83
      // forward: gemm-nt
84
      if (transB_) {
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        // backward: gemm-tn, dB = (dC)^T * A
        if (d_weight) {
          int dB_m = output_size_;
          int dB_n = input_size_;
          int dB_k = bsz_seq_;

          T* dB_output_ptr = d_weight->data<T>();
          blas.GEMM(CblasTrans, CblasNoTrans, dB_m, dB_n, dB_k, alpha,
                    d_output->data<T>(), input->data<T>(), beta_dB,
                    dB_output_ptr);
        }

        // backward: gemm-nn, dA = dC * B
        if (d_input) {
          int dA_m = bsz_seq_;
          int dA_n = input_size_;
          int dA_k = output_size_;

          T* dA_output_ptr = d_input->data<T>();
          blas.GEMM(CblasNoTrans, CblasNoTrans, dA_m, dA_n, dA_k, alpha,
                    d_output->data<T>(), weight->data<T>(), beta_dA,
                    dA_output_ptr);
        }
108
      } else {  // fw: gemm-nn
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        // backward: gemm-tn, dB = A^T * dC
        if (d_weight) {
          int dB_m = input_size_;
          int dB_n = output_size_;
          int dB_k = bsz_seq_;

          T* dB_output_ptr = d_weight->data<T>();
          blas.GEMM(CblasTrans, CblasNoTrans, dB_m, dB_n, dB_k, alpha,
                    input->data<T>(), d_output->data<T>(), beta_dB,
                    dB_output_ptr);
        }

        // backward: gemm-nt, dA = dC * B^T
        if (d_input) {
          int dA_m = bsz_seq_;
          int dA_n = input_size_;
          int dA_k = output_size_;

          T* dA_output_ptr = d_input->data<T>();
          blas.GEMM(CblasNoTrans, CblasTrans, dA_m, dA_n, dA_k, alpha,
                    d_output->data<T>(), weight->data<T>(), beta_dA,
                    dA_output_ptr);
        }
132 133 134
      }
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
135
          "AttnMatMul wrapper do not support (transA=T, transB=T/N)"
136 137
          "parameters."));
    }
138 139 140
    if (compute_bias_ && d_bias) {
      // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3}
      // -> {3} or {0,1,2,3,4} -> {3,4}
L
Li Min 已提交
141 142 143 144 145 146 147 148 149 150
      const auto input_dims = d_output->dims();
      const auto output_dims = d_bias->dims();
      bool support_case_1 =
          (input_dims.size() == 5 && output_dims.size() == 3 &&
           (input_dims[2] == output_dims[0]) &&
           (input_dims[3] == output_dims[1]) &&
           (input_dims[4] == output_dims[2]));
      bool support_case_2 =
          (input_dims.size() == 3 && output_dims.size() == 1 &&
           (input_dims[2] == output_dims[0]));
151 152 153 154 155 156 157 158
      bool support_case_3 =
          (input_dims.size() == 4 && output_dims.size() == 1 &&
           input_dims[3] == output_dims[0]);
      bool support_case_4 =
          (input_dims.size() == 5 && output_dims.size() == 2 &&
           input_dims[3] == output_dims[0] && input_dims[4] == output_dims[1]);

      gpuStream_t stream = dev_ctx_.stream();
L
Li Min 已提交
159
      if (support_case_1 || support_case_2) {
160
        TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
W
Wilber 已提交
161 162
            dev_ctx_, *d_output, d_bias, kps::IdentityFunctor<T>(), {0, 1},
            stream);
163 164 165 166
      } else if (support_case_3 || support_case_4) {
        TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
            dev_ctx_, *d_output, d_bias, kps::IdentityFunctor<T>(), {0, 1, 2},
            stream);
L
Li Min 已提交
167 168 169 170 171 172
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Only support reduce when the input dims are [0,1,2,3,4] and "
            "output is [2,3,4]"
            "or input is [0,1,2] and output is [2]."));
      }
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    }
  }

 private:
  const platform::CUDADeviceContext& dev_ctx_;

  bool transA_;
  bool transB_;

  int bsz_seq_;
  int output_size_;
  int input_size_;

  int compute_bias_;
};

}  // namespace operators
}  // namespace paddle