未验证 提交 f6f18835 编写于 作者: L limingshu 提交者: GitHub

Support Linear operation in cuBlaslt and plug into attn_gemm and fusedLinear backward op (#52028)

* first commit

* restruct c++ interface to divide linear from matmulwithcublaslt

* finish building in cublaslt impl

* fix code bugs

* fix host cost

* add some changes
上级 2944d3c0
......@@ -68,25 +68,20 @@ class AttnMatMul {
"The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true."));
auto fused_impl =
phi::funcs::MatmulPlanner(vectorize(input->dims()),
vectorize(weight->dims()),
transA_,
transB_,
phi::CppTypeToDataType<T>::Type(),
phi::funcs::MatmulFusedType::kMatmulBias,
static_cast<const void*>(bias->data<T>()),
nullptr);
phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx_,
input->data<T>(),
weight->data<T>(),
bias_out->data<T>(),
phi::funcs::LinearWithCublasLt<T>::Run(
dev_ctx_,
input, // x
weight, // y
bias_out, // out
static_cast<const void*>(bias->data<T>()), // bias
nullptr,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
&fused_impl);
phi::funcs::MatmulFusedType::kMatmulBias);
return;
}
#endif
......
......@@ -36,7 +36,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto bias_dims = ctx->GetInputDim("Bias");
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
......@@ -88,27 +87,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
K_from_x,
K_from_y));
auto activation = ctx->Attrs().Get<std::string>("activation");
if (activation == "none" && ctx->HasOutput("ReserveSpace")) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
}
// cublasLt's restriction for auxiliary.
if (ctx->HasOutput("ReserveSpace") && activation != "none") {
int min_size_of_n = activation == "relu" ? 128 : 8;
int N_size = trans_y ? y_dims[0] : y_dims[1];
PADDLE_ENFORCE_EQ(N_size % min_size_of_n,
0,
platform::errors::InvalidArgument(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d.",
min_size_of_n,
activation,
N_size));
}
std::vector<int64_t> out_dims;
out_dims.reserve(static_cast<size_t>(x_dims.size()));
if (trans_x) {
......@@ -122,11 +100,29 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
} else {
out_dims.push_back(y_dims[1]);
}
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
auto activation = ctx->Attrs().Get<std::string>("activation");
if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims));
if (activation == "none") {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
} else {
int min_size_of_n = activation == "relu" ? 128 : 8;
int N_size = trans_y ? y_dims[0] : y_dims[1];
PADDLE_ENFORCE_EQ(
N_size % min_size_of_n,
0,
platform::errors::InvalidArgument(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d.",
min_size_of_n,
activation,
N_size));
}
}
}
......@@ -202,7 +198,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto dout_dims = ctx->GetInputDim("DOut");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
......@@ -241,7 +236,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
x_dims.size()));
auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1);
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ(
......@@ -268,25 +262,17 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
false,
platform::errors::InvalidArgument(
"The ReserveSpace should not be empty. "
"when activation_grad == {relu_grad, gelu_grad}."));
"when activation == {relu_grad, gelu_grad}."));
}
if (ctx->HasOutput("DX")) {
std::vector<int64_t> dx_dims;
dx_dims.reserve(static_cast<size_t>(x_dims.size()));
for (int i = 0; i < x_dims.size(); ++i) {
dx_dims.push_back(x_dims[i]);
ctx->SetOutputDim("DX", x_dims);
}
ctx->SetOutputDim("DX", phi::make_ddim(dx_dims));
}
std::vector<int64_t> dy_dims(y_dims.Get(), y_dims.Get() + y_dims.size());
ctx->SetOutputDim("DY", phi::make_ddim(dy_dims));
ctx->SetOutputDim("DY", y_dims);
if (ctx->HasOutput("DBias")) {
std::vector<int64_t> dbias_dims;
dbias_dims.push_back(trans_y ? y_dims[0] : y_dims[1]);
ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims));
int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1];
ctx->SetOutputDim("DBias", phi::make_ddim({dbias_dim}));
}
}
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace paddle {
......@@ -101,26 +100,19 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
<< ", activation=" << activation << ", fused_type=" << fused_type
<< ", reserve_space=" << reserve_space;
auto fused_impl =
phi::funcs::MatmulPlanner(vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
phi::CppTypeToDataType<T>::Type(),
fused_type,
phi::funcs::LinearWithCublasLt<T>::Run(
dev_ctx,
x,
y,
out,
static_cast<const void*>(bias->data<T>()),
reserve_data);
phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx,
x->data<T>(),
y->data<T>(),
out->data<T>(),
reserve_data,
M,
N,
K,
trans_x,
trans_y,
&fused_impl);
fused_type);
}
};
......
......@@ -25,7 +25,7 @@ size_t TransposeKey(const std::vector<int64_t>& x_dims,
const std::vector<int32_t>& perm,
phi::DataType dtype) {
const auto rank = perm.size();
return GenKey(x_dims, perm, rank, static_cast<int64_t>(dtype));
return GenKey(x_dims, perm, rank, static_cast<int>(dtype));
}
std::string AlgorithmTypeString(int64_t algo_type) {
......
......@@ -52,6 +52,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
......@@ -68,7 +69,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
}
for (int i = 0; i < max_dim; i++) {
for (int i = 0; i < max_dim; ++i) {
PADDLE_ENFORCE_EQ(
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
y_dims_array[i] <= 1,
......
......@@ -350,8 +350,10 @@ void DropoutFwGPUKernelDriver(
auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(x.dims());
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask->dims());
std::vector<int64_t> out_dims =
std::move(phi::vectorize<int64_t>(x.dims()));
std::vector<int64_t> in_dims =
std::move(phi::vectorize<int64_t>(mask->dims()));
std::reverse(out_dims.begin(), out_dims.end());
std::reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config(
......
......@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/scope_guard.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/utils/optional.h"
DECLARE_int64(cublaslt_exhaustive_search_times);
......@@ -488,62 +489,103 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
phi::dynload::cublasLtMatrixLayoutDestroy(out_desc));
}
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;
template <>
struct FusedGEMMGradTrait<false, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
struct BwdFusedEpilogueSetter {
public:
static phi::funcs::MatmulFusedType SetForDx(
const std::string& activation_grad) {
if (activation_grad == "none") {
return kMatmulGrad;
} else if (activation_grad == "relu_grad") {
return kMatmulReluGrad;
} else if (activation_grad == "gelu_grad") {
return kMatmulGeluGrad;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Fued linear epilogue type should be one of {none, relu, gelu}."
"But received activation is %s, please check",
activation_grad));
}
}
template <>
struct FusedGEMMGradTrait<true, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = false;
static constexpr auto kYGradBTrans = false;
template <typename DYT, bool TransY>
static phi::funcs::MatmulFusedType SetForDy(const phi::GPUContext& dev_ctx,
phi::DenseTensor* dbias) {
if (dbias != nullptr) {
dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT));
return TransY ? kMatmulBiasGradToB : kMatmulBiasGradToA;
} else {
return kMatmulGradWithoutBias;
}
}
};
template <>
struct FusedGEMMGradTrait<false, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = false;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
const phi::DenseTensor* reserve_space,
int64_t M,
int64_t N,
int64_t K,
const std::string activation_grad,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto_dx,
bool use_addto_dy) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
static_assert(std::is_same<DXT, T>::value || std::is_same<DXT, MT>::value);
static_assert(std::is_same<DYT, T>::value || std::is_same<DYT, MT>::value);
using Trait = FusedGEMMGradTrait<TransX, TransY>;
template <>
struct FusedGEMMGradTrait<true, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = true;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = true;
};
if (dx) {
constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
auto fused_type = BwdFusedEpilogueSetter::SetForDx(activation_grad);
void* reserve_data = (fused_type == kMatmulGrad)
? nullptr
: const_cast<void*>(reserve_space->data());
dev_ctx.Alloc<DXT>(dx, dx->numel() * sizeof(DXT));
phi::funcs::LinearGradWithCublasLt<T, DXT, DYT, TransX, TransY>::Run(
dev_ctx,
dout,
y,
dx,
nullptr,
reserve_data,
M,
N,
K,
fused_type,
Trait::kXGradATrans,
Trait::kXGradBTrans,
use_addto_dx,
kXGradAIsDZ);
}
if (dy) {
auto fused_type =
BwdFusedEpilogueSetter::SetForDy<DYT, TransY>(dev_ctx, dbias);
constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);
// Caution: DYT is in front of DXT in this template arguments.
dev_ctx.Alloc<DYT>(dy, dy->numel() * sizeof(DYT));
phi::funcs::LinearGradWithCublasLt<T, DXT, DYT, TransX, TransY>::Run(
dev_ctx,
dout,
x,
dy,
dbias ? static_cast<const void*>(dbias->data<DYT>()) : nullptr,
nullptr,
M,
N,
K,
fused_type,
Trait::kYGradATrans,
Trait::kYGradBTrans,
use_addto_dy,
kYGradAIsDZ,
/*is_dx=*/false);
}
}
static constexpr auto BoolToCuBlasEnum(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
......@@ -567,7 +609,8 @@ static cublasLtEpilogue_t GetEpilogueGradType(
}
template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
void ComputeFusedGemmEpilogueBackwardImplDev(
const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
......
......@@ -559,7 +559,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
// max index to read
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
#pragma unroll
// read data
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
......@@ -659,7 +659,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
// loss
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
if (i >= local_batches) break;
loss[first_batch + i] = sumloss[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册