未验证 提交 2dfc3fa8 编写于 作者: L limingshu 提交者: GitHub

Support Linear operation in cuBlaslt and plug into attn_gemm and fusedLinear forward op (#51124)

* optimization for fused linear op

* fix code format

* optimization for linear fused forward

* merge with develop

* fix bugs for gemm_ephilog

* package of cublaslt ephilogue type with enmu

* final fix before code reviewing

* fix missed fusedType typo

* fix code according to review suggestions

* fix windows ci error

* change location of MatmulPlanner

* add some changes for compiler error fix

---------
上级 9983892e
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
......@@ -66,18 +67,26 @@ class AttnMatMul {
phi::errors::InvalidArgument(
"The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true."));
ComputeFusedGemmEpilogueForward<T>(dev_ctx_,
input,
weight,
bias,
auto fused_impl = phi::funcs::MatmulPlanner(
vectorize(input->dims()),
vectorize(weight->dims()),
transA_,
transB_,
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::funcs::MatmulFusedType::kMatmulBias,
static_cast<const void*>(bias->data<T>()),
nullptr);
phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx_,
input->data<T>(),
weight->data<T>(),
bias_out->data<T>(),
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
bias_out,
nullptr);
&fused_impl);
return;
}
#endif
......
......@@ -89,19 +89,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
K_from_y));
auto activation = ctx->Attrs().Get<std::string>("activation");
if ((activation != "relu") && (activation != "gelu") &&
(activation != "none")) {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
if (activation == "none" && ctx->HasOutput("ReserveSpace")) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
......@@ -276,18 +263,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
x_mat_dims[0]));
auto activation_grad = ctx->Attrs().Get<std::string>("activation_grad");
if ((activation_grad != "relu_grad") && (activation_grad != "gelu_grad") &&
(activation_grad != "none")) {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation_grad));
}
if (activation_grad != "none" && !ctx->HasInput("ReserveSpace")) {
PADDLE_ENFORCE_EQ(true,
false,
......
......@@ -18,12 +18,49 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
namespace paddle {
namespace operators {
#if CUDA_VERSION >= 11060
template <typename T>
phi::funcs::MatmulFusedType GetFwdFusedEpilogueType(
const phi::GPUContext& ctx,
const std::string& activation,
phi::DenseTensor* reserve_space) {
using FusedType = phi::funcs::MatmulFusedType;
FusedType fused_type = FusedType::kMatmulBias;
if (activation != "none") {
if (activation == "relu") {
if (reserve_space == nullptr) {
fused_type = FusedType::kMatmulBiasRelu;
} else {
fused_type = FusedType::kMatmulBiasReluWithReservedData;
int64_t reserve_size =
SizeOf(phi::DataType::BOOL) * phi::product(reserve_space->dims());
ctx.Alloc(reserve_space, phi::DataType::BOOL, reserve_size);
}
} else if (activation == "gelu") {
if (reserve_space == nullptr) {
fused_type = FusedType::kMatmulBiasGelu;
} else {
fused_type = FusedType::kMatmulBiasGeluWithReservedData;
int64_t reserve_size = sizeof(T) * phi::product(reserve_space->dims());
ctx.Alloc<T>(reserve_space, reserve_size);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Fued linear epilogue type should be one of {none, relu, gelu}."
"But received activation is %s, please check",
activation));
}
}
return fused_type;
}
template <typename DeviceContext, typename T>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
public:
......@@ -33,7 +70,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X");
const phi::DenseTensor* y = ctx.Input<phi::DenseTensor>("Y");
const phi::DenseTensor* bias = ctx.Input<phi::DenseTensor>("Bias");
phi::DenseTensor* out = ctx.Output<phi::DenseTensor>("Out");
phi::DenseTensor* reserve_space =
ctx.Output<phi::DenseTensor>("ReserveSpace");
......@@ -43,7 +79,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
std::string activation = ctx.Attr<std::string>("activation");
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
// (M * K) * (K * N)
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
......@@ -51,18 +86,36 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
ComputeFusedGemmEpilogueForward<T>(dev_ctx,
x,
y,
bias,
void* reserve_data = reserve_space ? reserve_space->data() : nullptr;
auto fused_type =
GetFwdFusedEpilogueType<T>(dev_ctx, activation, reserve_space);
VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims()
<< "}, out.shape={" << out->dims() << "}, M=" << M << ", N=" << N
<< ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y
<< ", activation=" << activation << ", fused_type=" << fused_type
<< ", reserve_space=" << reserve_space;
auto fused_impl = phi::funcs::MatmulPlanner(
vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type(),
fused_type,
static_cast<const void*>(bias->data<T>()),
reserve_data);
phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx,
x->data<T>(),
y->data<T>(),
out->data<T>(),
M,
N,
K,
trans_x,
trans_y,
activation,
out,
reserve_space);
&fused_impl);
}
};
......@@ -77,13 +130,11 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
const phi::DenseTensor* y = ctx.Input<phi::DenseTensor>("Y");
const phi::DenseTensor* reserve_space =
ctx.Input<phi::DenseTensor>("ReserveSpace");
phi::DenseTensor* dx = ctx.Output<phi::DenseTensor>("DX");
phi::DenseTensor* dy = ctx.Output<phi::DenseTensor>("DY");
phi::DenseTensor* dbias = ctx.Output<phi::DenseTensor>("DBias");
std::string activation_grad = ctx.Attr<std::string>("activation_grad");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
......@@ -94,6 +145,12 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims()
<< "}, dout.shape={" << dout->dims() << "}, M=" << M << ", N=" << N
<< ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y
<< ", activation=" << activation_grad
<< ", reserve_space=" << reserve_space;
ComputeFusedGemmEpilogueBackward<T>(dev_ctx,
dout,
x,
......
......@@ -330,186 +330,6 @@ class GemmEpilogueAlgoCache {
}
};
static cublasLtEpilogue_t GetEpilogueType(const std::string& activation,
bool enable_auxiliary) {
if (activation == "relu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
: CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == "gelu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
: CUBLASLT_EPILOGUE_GELU_BIAS;
} else if (activation == "none") {
return CUBLASLT_EPILOGUE_BIAS;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
}
template <typename T>
void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
const phi::DenseTensor* bias,
int64_t M,
int64_t N,
int64_t K,
bool trans_x,
bool trans_y,
const std::string& activation,
phi::DenseTensor* out,
phi::DenseTensor* reserve_space) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims()
<< "}, out.shape={" << out->dims() << "}, M=" << M << ", N=" << N
<< ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y
<< ", activation=" << activation
<< ", reserve_space=" << reserve_space;
bool enable_auxiliary = reserve_space == nullptr ? false : true;
auto* out_data = out->data<T>();
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, double>::value) {
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtMatmulDesc_t operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc, compute_type, scale_type));
cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx)));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy)));
cublasLtEpilogue_t epiloque_func =
GetEpilogueType(activation, enable_auxiliary);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func,
sizeof(epiloque_func)));
const T* bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
if (enable_auxiliary && activation != "none") {
// Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") {
paddle::experimental::DataType rs_type =
paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else {
size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
}
void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
sizeof(aux_ld)));
}
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
if (trans_x) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, M, K, M));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
}
if (trans_y) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, K, N, K));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
MT alpha = static_cast<MT>(1);
MT beta = static_cast<MT>(0);
const auto* y_data = y->data<T>();
const auto* x_data = x->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc,
y_desc,
x_desc,
out_desc,
&alpha,
&beta,
y_data,
x_data,
out_data,
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(lt_handle,
operation_desc,
&alpha,
y_data,
y_desc,
x_data,
x_desc,
&beta,
out_data,
out_desc,
out_data,
out_desc,
algo,
workspace->ptr(),
workspace_size,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
}
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY>
......@@ -573,12 +393,12 @@ static constexpr auto BoolToCuBlasEnum(bool transpose) {
static cublasLtEpilogue_t GetEpilogueGradType(
const std::string& activation_grad) {
if (activation_grad == "relu_grad") {
if (activation_grad == "none") {
return CUBLASLT_EPILOGUE_DEFAULT;
} else if (activation_grad == "relu_grad") {
return CUBLASLT_EPILOGUE_DRELU;
} else if (activation_grad == "gelu_grad") {
return CUBLASLT_EPILOGUE_DGELU;
} else if (activation_grad == "none") {
return CUBLASLT_EPILOGUE_DEFAULT;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The activation_grad attribute of fused_gemm_epilogue op should "
......
......@@ -60,27 +60,6 @@ size_t GenKey(Args&&... args) {
return seed;
}
struct MatmulCacheKey {
public:
MatmulCacheKey() {}
MatmulCacheKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
const bool trans_x,
const bool trans_y,
phi::DataType dtype) {
key = GenKey(x_dims,
y_dims,
static_cast<int64_t>(trans_x),
static_cast<int64_t>(trans_y),
static_cast<int64_t>(dtype));
}
size_t GetKey() const { return key; }
size_t GenSubKey(int64_t idx) const { return GenKey(key, idx); }
private:
size_t key;
};
struct ConvCacheKey {
ConvCacheKey() {}
ConvCacheKey(const std::vector<int64_t>& arg_x_dims,
......
......@@ -15,18 +15,64 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h> // NOLINT
#include "cuda.h" // NOLINT
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#endif
namespace phi {
namespace funcs {
enum MatmulImplType { kCublas = 1, kCublasLt = 2 };
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
// Set this enum according to
// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t
enum MatmulFusedType {
kMatmul = CUBLASLT_EPILOGUE_DEFAULT, // No special postprocessing.
kMatmulBias = CUBLASLT_EPILOGUE_BIAS,
kMatmulRelu = CUBLASLT_EPILOGUE_RELU,
kMatmulBiasRelu =
CUBLASLT_EPILOGUE_RELU_BIAS, // Apply bias and then ReLU transform.
kMatmulBiasGelu =
CUBLASLT_EPILOGUE_GELU_BIAS, // Apply Bias and then GELU transform.
kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS
};
struct MatmulPlanner {
public:
const void* bias{nullptr};
void* aux_data{nullptr};
MatmulPlanner() {}
MatmulPlanner(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
const bool trans_x,
const bool trans_y,
phi::DataType dtype,
MatmulFusedType impl_type,
const void* bias_data = nullptr,
void* reserve_data = nullptr)
: bias(bias_data), aux_data(reserve_data) {
type = impl_type;
key = phi::autotune::GenKey(x_dims,
y_dims,
static_cast<int64_t>(trans_x),
static_cast<int64_t>(trans_y),
static_cast<int64_t>(dtype));
}
MatmulFusedType ImplType() const { return type; }
size_t GetKey() const { return key; }
size_t GenSubKey(int idx) const { return phi::autotune::GenKey(key, idx); }
private:
MatmulFusedType type;
size_t key;
};
template <typename T>
cublasComputeType_t GetCudaComputeType() {
......@@ -44,6 +90,7 @@ struct MatmulDescriptor {
cublasLtMatrixLayout_t y_desc{nullptr};
cublasLtMatrixLayout_t out_desc{nullptr};
cublasLtMatmulAlgo_t* algo{nullptr};
bool is_cached{false};
MatmulDescriptor() {}
MatmulDescriptor(const MatmulDescriptor& obj) {
......@@ -52,6 +99,7 @@ struct MatmulDescriptor {
y_desc = obj.y_desc;
op_desc = obj.op_desc;
out_desc = obj.out_desc;
is_cached = obj.is_cached;
}
~MatmulDescriptor() {
......@@ -78,6 +126,7 @@ struct MatmulDescriptor {
const int K,
const bool trans_x,
const bool trans_y,
phi::funcs::MatmulPlanner* planner,
const int batch_size = 1,
int64_t stride_x = 0,
int64_t stride_y = 0,
......@@ -116,17 +165,61 @@ struct MatmulDescriptor {
SetBatchAndStride(y_desc, batch_size, stride_y);
SetBatchAndStride(out_desc, batch_size, stride_out);
}
SetFusedEpilogueOpDescriptor(planner, N);
}
cublasLtMatmulAlgo_t* SetAlgo() {
// while entering this function, the desc shall be cached.
is_cached = true;
algo = new cublasLtMatmulAlgo_t;
return algo;
}
void ValidateCache() { is_cached = true; }
template <typename T>
void SetFusedEpiloguePtr(phi::funcs::MatmulPlanner* planner) {
if (planner->bias != nullptr) {
const T* bias_data = static_cast<const T*>(planner->bias);
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
op_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
if (planner->aux_data != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&(planner->aux_data),
sizeof(planner->aux_data)));
}
}
}
std::string GetDescResultString(std::string prefix,
bool has_algo = true) const {
std::ostringstream out;
out << prefix << " \n";
#define GET_DESC_DATA_INFO(src) \
do { \
out << #src << "= ["; \
int num = sizeof((*src)) / sizeof(src->data[0]); \
for (int i = 0; i < num; ++i) { \
out << src->data[i] << ", "; \
} \
out << "]\n"; \
} while (0);
if (has_algo) {
GET_DESC_DATA_INFO(&algo);
}
GET_DESC_DATA_INFO(x_desc);
GET_DESC_DATA_INFO(y_desc);
GET_DESC_DATA_INFO(out_desc);
GET_DESC_DATA_INFO(op_desc);
return out.str();
}
private:
bool is_cached{false};
void CreateMatrixLayout(cublasLtMatrixLayout_t* desc,
cudaDataType type,
uint64_t rows,
......@@ -155,42 +248,33 @@ struct MatmulDescriptor {
&stride,
sizeof(stride)));
}
};
inline std::string GetDescResultString(std::string prefix,
const MatmulDescriptor* desc,
bool has_algo = true) {
std::ostringstream out;
out << prefix << " \n";
#define GET_DESC_DATA_INFO(src) \
do { \
out << "#data " \
<< "= ["; \
int num = sizeof((*src)) / sizeof(src->data[0]); \
for (int i = 0; i < num; ++i) { \
out << src->data[i] << ", "; \
} \
out << "]\n"; \
} while (0);
if (has_algo) {
GET_DESC_DATA_INFO(desc->algo);
void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner,
int64_t lead_dim) {
if (planner->bias) {
auto fuse_type = static_cast<cublasLtEpilogue_t>(planner->ImplType());
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&fuse_type,
sizeof(fuse_type)));
if (planner->aux_data) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&lead_dim,
sizeof(lead_dim)));
}
GET_DESC_DATA_INFO(desc->x_desc);
GET_DESC_DATA_INFO(desc->y_desc);
GET_DESC_DATA_INFO(desc->out_desc);
GET_DESC_DATA_INFO(desc->op_desc);
return out.str();
}
}
}
};
template <typename T>
struct DescriptorSetter {
MatmulDescriptor* desc{nullptr};
MatmulDescriptor desc;
size_t sub_key{std::numeric_limits<size_t>::min()};
DescriptorSetter(phi::autotune::MatmulCacheKey* matmul_key,
MatmulDescriptor* desc_ptr,
DescriptorSetter(phi::funcs::MatmulPlanner* planner,
const int M,
const int N,
const int K,
......@@ -200,27 +284,31 @@ struct DescriptorSetter {
int64_t stride_x = 0,
int64_t stride_y = 0,
int64_t stride_out = 0) {
if (matmul_key != nullptr) {
sub_key =
matmul_key->GenSubKey(static_cast<size_t>(MatmulImplType::kCublasLt));
if (planner != nullptr) {
sub_key = planner->GenSubKey(static_cast<size_t>(planner->ImplType()));
}
auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
if (mamtul_cache.FindSubKey(sub_key)) {
desc =
reinterpret_cast<MatmulDescriptor*>(mamtul_cache.GetSubKey(sub_key));
VLOG(4) << GetDescResultString("[Heap MatmulDescriptor] ", desc);
desc = *(
reinterpret_cast<MatmulDescriptor*>(mamtul_cache.GetSubKey(sub_key)));
desc.SetFusedEpiloguePtr<T>(planner);
VLOG(6) << desc.GetDescResultString("[Heap MatmulDescriptor] ");
} else {
desc_ptr->Create<T>(M,
desc.Create<T>(M,
N,
K,
trans_x,
trans_y,
planner,
batch_size,
stride_x,
stride_y,
stride_out);
desc = desc_ptr;
VLOG(4) << GetDescResultString("[Stack MatmulDescriptor] ", desc, false);
if (planner != nullptr) {
desc.SetFusedEpiloguePtr<T>(planner);
}
VLOG(6) << desc.GetDescResultString("[Stack MatmulDescriptor] ", false);
}
}
};
......@@ -239,16 +327,13 @@ struct MatmulWithCublasLt {
const int K,
const bool trans_x,
const bool trans_y,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
MatmulDescriptor desc;
auto setter =
DescriptorSetter<T>(matmul_key, &desc, M, N, K, trans_x, trans_y);
phi::funcs::MatmulPlanner* planner = nullptr) {
auto setter = DescriptorSetter<T>(planner, M, N, K, trans_x, trans_y);
RunImpl(
ctx, setter.desc, x_data, y_data, out_data, setter.sub_key, matmul_key);
ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
}
static void RunWithBatch(
const phi::GPUContext& ctx,
static void RunWithBatch(const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
......@@ -261,10 +346,8 @@ struct MatmulWithCublasLt {
int64_t stride_x,
int64_t stride_y,
int64_t stride_out,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
MatmulDescriptor desc;
auto setter = DescriptorSetter<T>(matmul_key,
&desc,
phi::funcs::MatmulPlanner* planner = nullptr) {
auto setter = DescriptorSetter<T>(planner,
M,
N,
K,
......@@ -275,11 +358,10 @@ struct MatmulWithCublasLt {
stride_y,
stride_out);
RunImpl(
ctx, setter.desc, x_data, y_data, out_data, setter.sub_key, matmul_key);
ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
}
static void RunWithBatch(
const phi::GPUContext& ctx,
static void RunWithBatch(const phi::GPUContext& ctx,
const T** x_data,
const T** y_data,
T** out_data,
......@@ -289,7 +371,7 @@ struct MatmulWithCublasLt {
bool trans_x,
bool trans_y,
int batch_size,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
phi::funcs::MatmulPlanner* planner = nullptr) {
for (int i = 0; i < batch_size; ++i) {
Run(ctx,
x_data[i],
......@@ -300,7 +382,7 @@ struct MatmulWithCublasLt {
K,
trans_x,
trans_y,
matmul_key);
planner);
}
}
......@@ -315,11 +397,11 @@ struct MatmulWithCublasLt {
static void RunImpl(const phi::GPUContext& ctx,
MatmulDescriptor* desc,
const size_t sub_key,
const T* x_ptr,
const T* y_ptr,
T* out_ptr,
const size_t sub_key,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
phi::funcs::MatmulPlanner* planner) {
MT alpha = static_cast<MT>(1);
MT beta = static_cast<MT>(0);
......@@ -327,11 +409,9 @@ struct MatmulWithCublasLt {
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);
if (matmul_key != nullptr) {
auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
if (planner != nullptr) {
if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() &&
(!cache.FindSubKey(sub_key))) {
desc->ValidateCache();
(!desc->is_cached)) {
SearchBestAlgo(ctx,
cublaslt_handle,
desc,
......@@ -343,13 +423,15 @@ struct MatmulWithCublasLt {
workspace->ptr(),
workspace_size);
MatmulDescriptor* best_desc = new MatmulDescriptor(*desc);
VLOG(4) << GetDescResultString("[Searched MatmulDescriptor] ",
best_desc);
VLOG(6) << best_desc->GetDescResultString(
"[Searched MatmulDescriptor] ");
auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
cache.SetSubKey(sub_key, reinterpret_cast<void*>(best_desc));
}
}
VLOG(4) << GetDescResultString("[Impl MatmulDescriptor] ", desc);
VLOG(6) << desc->GetDescResultString("[Impl MatmulDescriptor] ");
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmul(cublaslt_handle,
desc->op_desc,
......@@ -454,8 +536,10 @@ struct MatmulWithCublasLt {
dynload::cublasLtMatmulPreferenceDestroy(preference));
}
};
#else
// A void structure just for successfully complile.
struct MatmulPlanner {};
#endif // (PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
} // namespace funcs
} // namespace phi
#endif
......@@ -197,7 +197,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
if (x < out_effective_thread_num) {
int x_i = x / TileX;
int x_j = x % TileX;
int x_j = x - x_i * TileX;
IndexType output_ind =
output_origin_block_flat_index + x_i * output_dims[2] + x_j;
IndexType output_inc = BlockWriteRows * output_dims[2];
......
......@@ -17,10 +17,10 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache_base.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#endif
namespace phi {
......@@ -101,7 +101,7 @@ void MatMulFunctionImplWithBlas(
bool trans_x,
bool trans_y,
bool flag = false,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
phi::funcs::MatmulPlanner* matmul_planner = nullptr) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
......@@ -493,7 +493,7 @@ void MatMulFunctionImplWithCublasLt(
bool trans_x,
bool trans_y,
bool flag = false,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
phi::funcs::MatmulPlanner* matmul_planner = nullptr) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
const T* x_data = X.data<T>();
......@@ -526,7 +526,7 @@ void MatMulFunctionImplWithCublasLt(
M,
false,
true,
matmul_key);
matmul_planner);
return;
}
......@@ -576,7 +576,7 @@ void MatMulFunctionImplWithCublasLt(
N,
false,
false,
matmul_key);
matmul_planner);
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y.numel() / (M * N);
......@@ -591,7 +591,7 @@ void MatMulFunctionImplWithCublasLt(
N,
true,
false,
matmul_key);
matmul_planner);
} else {
VLOG(3) << "MatMul with blaslt 4";
blaslt::RunWithBatch(dev_ctx,
......@@ -607,7 +607,7 @@ void MatMulFunctionImplWithCublasLt(
M * N,
0,
M,
matmul_key);
matmul_planner);
}
}
return;
......@@ -662,7 +662,7 @@ void MatMulFunctionImplWithCublasLt(
N,
true,
false,
matmul_key);
matmul_planner);
} else {
VLOG(3) << "MatMul with blaslt 6";
blaslt::RunWithBatch(dev_ctx,
......@@ -678,7 +678,7 @@ void MatMulFunctionImplWithCublasLt(
M * N,
0,
M,
matmul_key);
matmul_planner);
}
} else {
const int M = X.numel() / N;
......@@ -692,7 +692,7 @@ void MatMulFunctionImplWithCublasLt(
N,
false,
false,
matmul_key);
matmul_planner);
}
return;
}
......@@ -775,7 +775,7 @@ void MatMulFunctionImplWithCublasLt(
K,
trans_x,
trans_y,
matmul_key);
matmul_planner);
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul with blaslt 9";
......@@ -788,7 +788,7 @@ void MatMulFunctionImplWithCublasLt(
K,
false,
false,
matmul_key);
matmul_planner);
} else {
VLOG(3) << "MatMul with blaslt 10";
blaslt::RunWithBatch(dev_ctx,
......@@ -804,7 +804,7 @@ void MatMulFunctionImplWithCublasLt(
0,
K * N,
M * N,
matmul_key);
matmul_planner);
}
} else if (y_batch_size == 1) {
if (!trans_x) {
......@@ -818,7 +818,7 @@ void MatMulFunctionImplWithCublasLt(
K,
false,
trans_y,
matmul_key);
matmul_planner);
} else {
VLOG(3) << "MatMul with blaslt 12";
blaslt::RunWithBatch(dev_ctx,
......@@ -834,7 +834,7 @@ void MatMulFunctionImplWithCublasLt(
M * K,
0,
M * N,
matmul_key);
matmul_planner);
}
} else if (!is_broadcast_dims) {
VLOG(3) << "MatMul with blaslt 13";
......@@ -851,7 +851,7 @@ void MatMulFunctionImplWithCublasLt(
M * K,
K * N,
M * N,
matmul_key);
matmul_planner);
} else {
// in the case, can't use stridedgemm
std::vector<const T*> x_ptr(out_batch_size);
......@@ -881,7 +881,7 @@ void MatMulFunctionImplWithCublasLt(
trans_x,
trans_y,
out_batch_size,
matmul_key);
matmul_planner);
}
}
#endif
......@@ -918,14 +918,15 @@ struct MatMulDispatcher<phi::GPUContext, T> {
auto* tuner = phi::autotune::MakeMatmulTuner<T>(
MatMulFunctionImplWithBlas<phi::GPUContext, T>);
tuner->AddCallBack(MatMulFunctionImplWithCublasLt<phi::GPUContext, T>);
phi::autotune::MatmulCacheKey matmul_cache(
phi::funcs::MatmulPlanner matmul_planner(
x_dims,
y_dims,
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type());
paddle::experimental::CppTypeToDataType<T>::Type(),
funcs::MatmulFusedType::kMatmul);
tuner->Run(ctx,
matmul_cache.GetKey(),
matmul_planner.GetKey(),
ctx,
x,
y,
......@@ -935,7 +936,7 @@ struct MatMulDispatcher<phi::GPUContext, T> {
trans_x,
trans_y,
flag,
&matmul_cache);
&matmul_planner);
#else
MatMulFunctionImplWithBlas<phi::GPUContext, T>(
ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册