未验证 提交 f6f18835 编写于 作者: L limingshu 提交者: GitHub

Support Linear operation in cuBlaslt and plug into attn_gemm and fusedLinear backward op (#52028)

* first commit

* restruct c++ interface to divide linear from matmulwithcublaslt

* finish building in cublaslt impl

* fix code bugs

* fix host cost

* add some changes
上级 2944d3c0
...@@ -68,25 +68,20 @@ class AttnMatMul { ...@@ -68,25 +68,20 @@ class AttnMatMul {
"The output (= input * weight) is expected to be nullptr or the " "The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true.")); "same as bias_out when fused is true."));
auto fused_impl = phi::funcs::LinearWithCublasLt<T>::Run(
phi::funcs::MatmulPlanner(vectorize(input->dims()), dev_ctx_,
vectorize(weight->dims()), input, // x
transA_, weight, // y
transB_, bias_out, // out
phi::CppTypeToDataType<T>::Type(), static_cast<const void*>(bias->data<T>()), // bias
phi::funcs::MatmulFusedType::kMatmulBias, nullptr,
static_cast<const void*>(bias->data<T>()),
nullptr);
phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx_,
input->data<T>(),
weight->data<T>(),
bias_out->data<T>(),
bsz_seq_, // M bsz_seq_, // M
output_size_, // N output_size_, // N
input_size_, // K input_size_, // K
transA_, transA_,
transB_, transB_,
&fused_impl); phi::funcs::MatmulFusedType::kMatmulBias);
return; return;
} }
#endif #endif
......
...@@ -36,7 +36,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { ...@@ -36,7 +36,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
auto trans_x = ctx->Attrs().Get<bool>("trans_x"); auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y"); auto trans_y = ctx->Attrs().Get<bool>("trans_y");
...@@ -88,27 +87,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { ...@@ -88,27 +87,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
K_from_x, K_from_x,
K_from_y)); K_from_y));
auto activation = ctx->Attrs().Get<std::string>("activation");
if (activation == "none" && ctx->HasOutput("ReserveSpace")) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
}
// cublasLt's restriction for auxiliary.
if (ctx->HasOutput("ReserveSpace") && activation != "none") {
int min_size_of_n = activation == "relu" ? 128 : 8;
int N_size = trans_y ? y_dims[0] : y_dims[1];
PADDLE_ENFORCE_EQ(N_size % min_size_of_n,
0,
platform::errors::InvalidArgument(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d.",
min_size_of_n,
activation,
N_size));
}
std::vector<int64_t> out_dims; std::vector<int64_t> out_dims;
out_dims.reserve(static_cast<size_t>(x_dims.size())); out_dims.reserve(static_cast<size_t>(x_dims.size()));
if (trans_x) { if (trans_x) {
...@@ -122,11 +100,29 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { ...@@ -122,11 +100,29 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
} else { } else {
out_dims.push_back(y_dims[1]); out_dims.push_back(y_dims[1]);
} }
ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
auto activation = ctx->Attrs().Get<std::string>("activation");
if (ctx->HasOutput("ReserveSpace")) { if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims)); ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims));
if (activation == "none") {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
} else {
int min_size_of_n = activation == "relu" ? 128 : 8;
int N_size = trans_y ? y_dims[0] : y_dims[1];
PADDLE_ENFORCE_EQ(
N_size % min_size_of_n,
0,
platform::errors::InvalidArgument(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d.",
min_size_of_n,
activation,
N_size));
}
} }
} }
...@@ -202,7 +198,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -202,7 +198,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto dout_dims = ctx->GetInputDim("DOut"); auto dout_dims = ctx->GetInputDim("DOut");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto trans_x = ctx->Attrs().Get<bool>("trans_x"); auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y"); auto trans_y = ctx->Attrs().Get<bool>("trans_y");
...@@ -241,7 +236,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -241,7 +236,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
x_dims.size())); x_dims.size()));
auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1); auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1);
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -268,25 +262,17 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -268,25 +262,17 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
false, false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The ReserveSpace should not be empty. " "The ReserveSpace should not be empty. "
"when activation_grad == {relu_grad, gelu_grad}.")); "when activation == {relu_grad, gelu_grad}."));
} }
if (ctx->HasOutput("DX")) { if (ctx->HasOutput("DX")) {
std::vector<int64_t> dx_dims; ctx->SetOutputDim("DX", x_dims);
dx_dims.reserve(static_cast<size_t>(x_dims.size()));
for (int i = 0; i < x_dims.size(); ++i) {
dx_dims.push_back(x_dims[i]);
} }
ctx->SetOutputDim("DX", phi::make_ddim(dx_dims)); ctx->SetOutputDim("DY", y_dims);
}
std::vector<int64_t> dy_dims(y_dims.Get(), y_dims.Get() + y_dims.size());
ctx->SetOutputDim("DY", phi::make_ddim(dy_dims));
if (ctx->HasOutput("DBias")) { if (ctx->HasOutput("DBias")) {
std::vector<int64_t> dbias_dims; int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1];
dbias_dims.push_back(trans_y ? y_dims[0] : y_dims[1]); ctx->SetOutputDim("DBias", phi::make_ddim({dbias_dim}));
ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims));
} }
} }
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace paddle { namespace paddle {
...@@ -101,26 +100,19 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -101,26 +100,19 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
<< ", activation=" << activation << ", fused_type=" << fused_type << ", activation=" << activation << ", fused_type=" << fused_type
<< ", reserve_space=" << reserve_space; << ", reserve_space=" << reserve_space;
auto fused_impl = phi::funcs::LinearWithCublasLt<T>::Run(
phi::funcs::MatmulPlanner(vectorize(x->dims()), dev_ctx,
vectorize(y->dims()), x,
trans_x, y,
trans_y, out,
phi::CppTypeToDataType<T>::Type(),
fused_type,
static_cast<const void*>(bias->data<T>()), static_cast<const void*>(bias->data<T>()),
reserve_data); reserve_data,
phi::funcs::MatmulWithCublasLt<T>::Run(dev_ctx,
x->data<T>(),
y->data<T>(),
out->data<T>(),
M, M,
N, N,
K, K,
trans_x, trans_x,
trans_y, trans_y,
&fused_impl); fused_type);
} }
}; };
......
...@@ -25,7 +25,7 @@ size_t TransposeKey(const std::vector<int64_t>& x_dims, ...@@ -25,7 +25,7 @@ size_t TransposeKey(const std::vector<int64_t>& x_dims,
const std::vector<int32_t>& perm, const std::vector<int32_t>& perm,
phi::DataType dtype) { phi::DataType dtype) {
const auto rank = perm.size(); const auto rank = perm.size();
return GenKey(x_dims, perm, rank, static_cast<int64_t>(dtype)); return GenKey(x_dims, perm, rank, static_cast<int>(dtype));
} }
std::string AlgorithmTypeString(int64_t algo_type) { std::string AlgorithmTypeString(int64_t algo_type) {
......
...@@ -33,20 +33,87 @@ namespace phi { ...@@ -33,20 +33,87 @@ namespace phi {
namespace funcs { namespace funcs {
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
// Set this enum according to // Set this enum according to
// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t // https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t
// While kMatmul, kMatmulGrad, kMatmulGradWithoutBias share the same
// enum value, but if all elements for MatmulPlanner->GetKey() is same,
// no matter forward or backward, they could share the same descriptor
// cache, in that the descritpor is for decription of matmul operation.
enum MatmulFusedType { enum MatmulFusedType {
kMatmul = CUBLASLT_EPILOGUE_DEFAULT, // No special postprocessing. kMatmul = CUBLASLT_EPILOGUE_DEFAULT,
kMatmulGrad = CUBLASLT_EPILOGUE_DEFAULT,
kMatmulGradWithoutBias = CUBLASLT_EPILOGUE_DEFAULT,
kMatmulBias = CUBLASLT_EPILOGUE_BIAS, kMatmulBias = CUBLASLT_EPILOGUE_BIAS,
kMatmulRelu = CUBLASLT_EPILOGUE_RELU, kMatmulRelu = CUBLASLT_EPILOGUE_RELU,
kMatmulBiasRelu = kMatmulBiasRelu = CUBLASLT_EPILOGUE_RELU_BIAS,
CUBLASLT_EPILOGUE_RELU_BIAS, // Apply bias and then ReLU transform. kMatmulBiasGelu = CUBLASLT_EPILOGUE_GELU_BIAS,
kMatmulBiasGelu =
CUBLASLT_EPILOGUE_GELU_BIAS, // Apply Bias and then GELU transform.
kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS, kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
kMatmulReluGrad = CUBLASLT_EPILOGUE_DRELU,
kMatmulGeluGrad = CUBLASLT_EPILOGUE_DGELU,
kMatmulBiasGradToA = CUBLASLT_EPILOGUE_BGRADA,
kMatmulBiasGradToB = CUBLASLT_EPILOGUE_BGRADB
};
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;
template <>
struct FusedGEMMGradTrait<false, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = false;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<false, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = false;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = true;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = true;
}; };
// To tell any matmul or fused matmul operation from each other.
struct MatmulPlanner { struct MatmulPlanner {
public: public:
const void* bias{nullptr}; const void* bias{nullptr};
...@@ -60,23 +127,31 @@ struct MatmulPlanner { ...@@ -60,23 +127,31 @@ struct MatmulPlanner {
phi::DataType dtype, phi::DataType dtype,
MatmulFusedType impl_type, MatmulFusedType impl_type,
const void* bias_data = nullptr, const void* bias_data = nullptr,
void* reserve_data = nullptr) void* reserve_data = nullptr, // Commonly for ReLu bit-mask.
: bias(bias_data), aux_data(reserve_data) { bool use_addto = false,
type = impl_type; bool no_exchange = true)
key = phi::autotune::GenKey(x_dims, : bias(bias_data), aux_data(reserve_data), impl_type_(impl_type) {
use_addto_ = use_addto;
key_ = phi::autotune::GenKey(x_dims,
y_dims, y_dims,
static_cast<int64_t>(trans_x), static_cast<int>(trans_x),
static_cast<int64_t>(trans_y), static_cast<int>(trans_y),
static_cast<int64_t>(dtype)); static_cast<int>(dtype),
static_cast<int>(no_exchange));
} }
MatmulFusedType ImplType() const { return type; } bool UseAddTo() const { return use_addto_; }
size_t GetKey() const { return key; } size_t GetKey() const { return key_; }
size_t GenSubKey(int idx) const { return phi::autotune::GenKey(key, idx); } MatmulFusedType ImplType() const { return impl_type_; }
size_t GenSubKey(int idx) const {
return phi::autotune::GenKey(key_, static_cast<int>(use_addto_), idx);
}
private: private:
MatmulFusedType type; MatmulFusedType impl_type_;
size_t key; bool use_addto_;
size_t key_;
}; };
template <typename T> template <typename T>
...@@ -124,19 +199,19 @@ struct MatmulDescriptor { ...@@ -124,19 +199,19 @@ struct MatmulDescriptor {
} }
// x_desc, y_desc, op_desc are allocated in heap memory. // x_desc, y_desc, op_desc are allocated in heap memory.
template <typename T> template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
void Create(const int M, void Create(const int64_t M,
const int N, const int64_t N,
const int K, const int64_t K,
const bool trans_x, const bool trans_x,
const bool trans_y, const bool trans_y,
phi::funcs::MatmulPlanner* planner, phi::funcs::MatmulPlanner* planner,
const int batch_size = 1, const int batch_size = 1,
int64_t stride_x = 0, const int64_t stride_x = 0,
int64_t stride_y = 0, const int64_t stride_y = 0,
int64_t stride_out = 0) { const int64_t stride_out = 0,
bool grad_for_dx = true) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>(); cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>(); cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = GetCudaComputeType<T>(); cublasComputeType_t compute_type = GetCudaComputeType<T>();
...@@ -145,18 +220,7 @@ struct MatmulDescriptor { ...@@ -145,18 +220,7 @@ struct MatmulDescriptor {
// details about defaults; just need to set the transforms for A and B // details about defaults; just need to set the transforms for A and B
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; SetFusedEpilogueOpDescriptor(planner, trans_x, trans_y, N);
cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&cublas_trans_x,
sizeof(cublas_trans_x)));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&cublas_trans_y,
sizeof(cublas_trans_y)));
// Create matrix descriptors // Create matrix descriptors
CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x);
...@@ -169,7 +233,6 @@ struct MatmulDescriptor { ...@@ -169,7 +233,6 @@ struct MatmulDescriptor {
SetBatchAndStride(y_desc, batch_size, stride_y); SetBatchAndStride(y_desc, batch_size, stride_y);
SetBatchAndStride(out_desc, batch_size, stride_out); SetBatchAndStride(out_desc, batch_size, stride_out);
} }
SetFusedEpilogueOpDescriptor(planner, N);
} }
cublasLtMatmulAlgo_t* SetAlgo() { cublasLtMatmulAlgo_t* SetAlgo() {
...@@ -188,7 +251,7 @@ struct MatmulDescriptor { ...@@ -188,7 +251,7 @@ struct MatmulDescriptor {
CUBLASLT_MATMUL_DESC_BIAS_POINTER, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data, &bias_data,
sizeof(bias_data))); sizeof(bias_data)));
}
if (planner->aux_data != nullptr) { if (planner->aux_data != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
op_desc, op_desc,
...@@ -197,7 +260,6 @@ struct MatmulDescriptor { ...@@ -197,7 +260,6 @@ struct MatmulDescriptor {
sizeof(planner->aux_data))); sizeof(planner->aux_data)));
} }
} }
}
std::string GetDescResultString(std::string prefix, std::string GetDescResultString(std::string prefix,
bool has_algo = true) const { bool has_algo = true) const {
...@@ -223,7 +285,42 @@ struct MatmulDescriptor { ...@@ -223,7 +285,42 @@ struct MatmulDescriptor {
return out.str(); return out.str();
} }
private: void ExchangeXYDesc(bool no_exchange) {}
protected:
void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner,
const bool trans_x,
const bool trans_y,
int64_t lead_dim) {
cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&cublas_trans_x,
sizeof(cublas_trans_x)));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&cublas_trans_y,
sizeof(cublas_trans_y)));
if (planner->ImplType() != kMatmul) {
auto fused_type = static_cast<cublasLtEpilogue_t>(planner->ImplType());
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&fused_type,
sizeof(fused_type)));
}
if (planner->aux_data) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&lead_dim,
sizeof(lead_dim)));
}
}
void CreateMatrixLayout(cublasLtMatrixLayout_t* desc, void CreateMatrixLayout(cublasLtMatrixLayout_t* desc,
cudaDataType type, cudaDataType type,
uint64_t rows, uint64_t rows,
...@@ -252,145 +349,62 @@ struct MatmulDescriptor { ...@@ -252,145 +349,62 @@ struct MatmulDescriptor {
&stride, &stride,
sizeof(stride))); sizeof(stride)));
} }
void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner,
int64_t lead_dim) {
if (planner->bias) {
auto fuse_type = static_cast<cublasLtEpilogue_t>(planner->ImplType());
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&fuse_type,
sizeof(fuse_type)));
if (planner->aux_data) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute(
op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&lead_dim,
sizeof(lead_dim)));
}
}
}
}; };
template <typename T> struct MatmulGradDescriptor : MatmulDescriptor {
struct DescriptorSetter { public:
MatmulDescriptor desc; MatmulGradDescriptor() {}
size_t sub_key{std::numeric_limits<size_t>::min()};
DescriptorSetter(phi::funcs::MatmulPlanner* planner, template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
const int M, void Create(const int64_t M,
const int N, const int64_t N,
const int K, const int64_t K,
const bool trans_x, const bool trans_x,
const bool trans_y, const bool trans_y,
phi::funcs::MatmulPlanner* planner,
const int batch_size = 1, const int batch_size = 1,
int64_t stride_x = 0, int64_t stride_x = 0,
int64_t stride_y = 0, int64_t stride_y = 0,
int64_t stride_out = 0) { int64_t stride_out = 0,
if (planner != nullptr) { bool grad_for_dx = true) {
sub_key = planner->GenSubKey(static_cast<size_t>(planner->ImplType())); using MT = typename phi::dtype::MPTypeTrait<T>::Type;
} cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = GetCudaComputeType<T>();
auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); PADDLE_ENFORCE_GPU_SUCCESS(
if (mamtul_cache.FindSubKey(sub_key)) { dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
desc = *( this->SetFusedEpilogueOpDescriptor(
reinterpret_cast<MatmulDescriptor*>(mamtul_cache.GetSubKey(sub_key))); planner, trans_x, trans_y, TransX ? M : K);
desc.SetFusedEpiloguePtr<T>(planner);
VLOG(6) << desc.GetDescResultString("[Heap MatmulDescriptor] "); // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
// details about defaults; just need to set the transforms for A and B
this->CreateMatrixLayout(&x_desc, mat_type, N, M, true);
if (grad_for_dx) {
this->CreateMatrixLayout(&y_desc, mat_type, K, N, TransY);
this->CreateMatrixLayout(
&out_desc, phi::backends::gpu::ToCudaDataType<DXT>(), M, K, TransX);
} else { } else {
desc.Create<T>(M, this->CreateMatrixLayout(&y_desc, mat_type, M, K, TransX);
N, this->CreateMatrixLayout(
K, &out_desc, phi::backends::gpu::ToCudaDataType<DYT>(), K, N, TransY);
trans_x,
trans_y,
planner,
batch_size,
stride_x,
stride_y,
stride_out);
if (planner != nullptr) {
desc.SetFusedEpiloguePtr<T>(planner);
} }
VLOG(6) << desc.GetDescResultString("[Stack MatmulDescriptor] ", false);
} }
void ExchangeXYDesc(bool no_exchange) {
if (no_exchange) {
return;
}
auto* temp = y_desc;
y_desc = x_desc;
x_desc = temp;
} }
}; };
template <typename T> template <typename T, typename OutT = T, class MatmulDescT = MatmulDescriptor>
struct MatmulWithCublasLt { struct CublasLtBase {
public: public:
using MT = typename phi::dtype::MPTypeTrait<T>::Type; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
static void Run(const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
const int M,
const int N,
const int K,
const bool trans_x,
const bool trans_y,
phi::funcs::MatmulPlanner* planner = nullptr) {
auto setter = DescriptorSetter<T>(planner, M, N, K, trans_x, trans_y);
RunImpl(
ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
}
static void RunWithBatch(const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
const int M,
const int N,
const int K,
bool trans_x,
bool trans_y,
int batch_size,
int64_t stride_x,
int64_t stride_y,
int64_t stride_out,
phi::funcs::MatmulPlanner* planner = nullptr) {
auto setter = DescriptorSetter<T>(planner,
M,
N,
K,
trans_x,
trans_y,
batch_size,
stride_x,
stride_y,
stride_out);
RunImpl(
ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
}
static void RunWithBatch(const phi::GPUContext& ctx,
const T** x_data,
const T** y_data,
T** out_data,
const int M,
const int N,
const int K,
bool trans_x,
bool trans_y,
int batch_size,
phi::funcs::MatmulPlanner* planner = nullptr) {
for (int i = 0; i < batch_size; ++i) {
Run(ctx,
x_data[i],
y_data[i],
out_data[i],
M,
N,
K,
trans_x,
trans_y,
planner);
}
}
private:
static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx,
size_t workspace_size) { size_t workspace_size) {
return phi::memory_utils::Alloc( return phi::memory_utils::Alloc(
...@@ -400,16 +414,19 @@ struct MatmulWithCublasLt { ...@@ -400,16 +414,19 @@ struct MatmulWithCublasLt {
} }
static void RunImpl(const phi::GPUContext& ctx, static void RunImpl(const phi::GPUContext& ctx,
MatmulDescriptor* desc, MatmulDescT* desc,
const size_t sub_key, const size_t sub_key,
const T* x_ptr, const T* x_ptr,
const T* y_ptr, const T* y_ptr,
T* out_ptr, OutT* out_ptr,
phi::funcs::MatmulPlanner* planner) { phi::funcs::MatmulPlanner* planner) {
MT alpha = static_cast<MT>(1); MT alpha = static_cast<MT>(1);
MT beta = static_cast<MT>(0); MT beta = planner->UseAddTo() ? static_cast<MT>(1) : static_cast<MT>(0);
cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle();
// NOTE(limingshu): As workspace_size varies from different DL framework,
// I wonder is there any smarter idea for workspace setting, currently I
// just followed the settings from the NVIDIA colleague`s setting.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);
...@@ -426,16 +443,16 @@ struct MatmulWithCublasLt { ...@@ -426,16 +443,16 @@ struct MatmulWithCublasLt {
out_ptr, out_ptr,
workspace->ptr(), workspace->ptr(),
workspace_size); workspace_size);
MatmulDescriptor* best_desc = new MatmulDescriptor(*desc); MatmulDescT* best_desc = new MatmulDescT(*desc);
VLOG(6) << best_desc->GetDescResultString( VLOG(6) << best_desc->GetDescResultString(
"[Searched MatmulDescriptor] "); "[Searched CublasltDescriptor] ");
auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
cache.SetSubKey(sub_key, reinterpret_cast<void*>(best_desc)); cache.SetSubKey(sub_key, reinterpret_cast<void*>(best_desc));
} }
} }
VLOG(6) << desc->GetDescResultString("[Impl MatmulDescriptor] "); VLOG(6) << desc->GetDescResultString("[Impl CublasltDescriptor] ");
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmul(cublaslt_handle, dynload::cublasLtMatmul(cublaslt_handle,
desc->op_desc, desc->op_desc,
...@@ -457,7 +474,7 @@ struct MatmulWithCublasLt { ...@@ -457,7 +474,7 @@ struct MatmulWithCublasLt {
static void SearchBestAlgo(const phi::GPUContext& ctx, static void SearchBestAlgo(const phi::GPUContext& ctx,
const cublasLtHandle_t& lt_handle, const cublasLtHandle_t& lt_handle,
MatmulDescriptor* desc, MatmulDescT* desc,
const void* alpha, const void* alpha,
const void* beta, const void* beta,
const void* y_data, const void* y_data,
...@@ -526,7 +543,7 @@ struct MatmulWithCublasLt { ...@@ -526,7 +543,7 @@ struct MatmulWithCublasLt {
} }
} }
float time_cnt = (cur_time / (repeats - 1)); float time_cnt = (cur_time / (repeats - 1));
VLOG(4) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]" VLOG(6) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]"
<< "is : " << time_cnt << " s"; << "is : " << time_cnt << " s";
if (cur_time < min_time_cost) { if (cur_time < min_time_cost) {
...@@ -534,12 +551,241 @@ struct MatmulWithCublasLt { ...@@ -534,12 +551,241 @@ struct MatmulWithCublasLt {
min_time_cost = cur_time; min_time_cost = cur_time;
} }
} }
VLOG(4) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx; VLOG(6) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx;
*best_algo = heuristic_results[best_algo_idx].algo; *best_algo = heuristic_results[best_algo_idx].algo;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulPreferenceDestroy(preference)); dynload::cublasLtMatmulPreferenceDestroy(preference));
} }
}; };
// To judge if desc is cached or not.
template <class DescT,
typename T,
typename DXT = T,
typename DYT = T,
bool TransX = false,
bool TransY = false>
struct DescriptorSetter {
public:
DescT desc;
size_t sub_key{std::numeric_limits<size_t>::min()};
DescriptorSetter(phi::funcs::MatmulPlanner* planner,
const int64_t M,
const int64_t N,
const int64_t K,
const bool trans_x,
const bool trans_y,
const int batch_size = 1,
int64_t stride_x = 0,
int64_t stride_y = 0,
int64_t stride_out = 0,
const bool no_exchange = true,
bool grad_for_dx = true) {
if (planner != nullptr) {
sub_key = planner->GenSubKey(static_cast<size_t>(planner->ImplType()));
}
auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
if (mamtul_cache.FindSubKey(sub_key)) {
desc = *(reinterpret_cast<DescT*>(mamtul_cache.GetSubKey(sub_key)));
desc.template SetFusedEpiloguePtr<DYT>(planner);
VLOG(6) << desc.GetDescResultString("[Heap CublasltDescriptor] ");
} else {
desc.template Create<T, DXT, DYT, TransX, TransY>(M,
N,
K,
trans_x,
trans_y,
planner,
batch_size,
stride_x,
stride_y,
stride_out,
grad_for_dx);
desc.ExchangeXYDesc(no_exchange);
if (planner != nullptr) {
desc.template SetFusedEpiloguePtr<DYT>(planner);
}
VLOG(6) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false);
}
}
};
// For matmul with kernels autotune
template <typename T>
struct MatmulWithCublasLt : public CublasLtBase<T> {
public:
static void Run(const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
const int64_t M,
const int64_t N,
const int64_t K,
const bool trans_x,
const bool trans_y,
phi::funcs::MatmulPlanner* planner = nullptr) {
auto setter = DescriptorSetter<MatmulDescriptor, T>(
planner, M, N, K, trans_x, trans_y);
CublasLtBase<T>::RunImpl(
ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
}
static void RunWithBatch(const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
const int64_t M,
const int64_t N,
const int64_t K,
bool trans_x,
bool trans_y,
int batch_size,
int64_t stride_x,
int64_t stride_y,
int64_t stride_out,
phi::funcs::MatmulPlanner* planner = nullptr) {
auto setter = DescriptorSetter<MatmulDescriptor, T>(planner,
M,
N,
K,
trans_x,
trans_y,
batch_size,
stride_x,
stride_y,
stride_out);
CublasLtBase<T>::RunImpl(
ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner);
}
static void RunWithBatch(const phi::GPUContext& ctx,
const T** x_data,
const T** y_data,
T** out_data,
const int64_t M,
const int64_t N,
const int64_t K,
bool trans_x,
bool trans_y,
int batch_size,
phi::funcs::MatmulPlanner* planner = nullptr) {
for (int i = 0; i < batch_size; ++i) {
Run(ctx,
x_data[i],
y_data[i],
out_data[i],
M,
N,
K,
trans_x,
trans_y,
planner);
}
}
};
// As for just Linear fused ephilogue below: out = matmul(x, y) + bias.
template <typename T>
struct LinearWithCublasLt : public CublasLtBase<T> {
static void Run(const phi::GPUContext& ctx,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
phi::DenseTensor* out,
const void* bias_data,
void* reserve_data,
const int64_t M,
const int64_t N,
const int64_t K,
const bool trans_x,
const bool trans_y,
const MatmulFusedType fused_type) {
auto planner = phi::funcs::MatmulPlanner(vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
phi::CppTypeToDataType<T>::Type(),
fused_type,
bias_data,
reserve_data);
auto setter = DescriptorSetter<MatmulDescriptor, T>(
&planner, M, N, K, trans_x, trans_y);
CublasLtBase<T>::RunImpl(ctx,
&setter.desc,
setter.sub_key,
x->data<T>(),
y->data<T>(),
out->data<T>(),
&planner);
}
};
template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
struct LinearGradWithCublasLt : public CublasLtBase<T> {
static void Run(
const phi::GPUContext& ctx,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
phi::DenseTensor* out,
const void* bias_data,
void* reserve_data,
const int64_t M,
const int64_t N,
const int64_t K,
const MatmulFusedType fused_type,
const bool trans_x,
const bool trans_y,
const bool use_addto,
const bool no_exchange, // exchange x_desc and y_desc for grad.
bool grad_for_dx = true) {
auto planner = phi::funcs::MatmulPlanner(vectorize(x->dims()),
vectorize(y->dims()),
trans_x,
trans_y,
phi::CppTypeToDataType<T>::Type(),
fused_type,
bias_data,
reserve_data,
use_addto,
no_exchange);
auto setter =
DescriptorSetter<MatmulGradDescriptor, T, DXT, DYT, TransX, TransY>(
&planner,
M,
N,
K,
trans_x,
trans_y,
/*batch_size=*/1,
/*stride_x=*/0,
/*stride_y=*/0,
/*stride_out=*/0,
/*exchange_x_y_desc=*/no_exchange,
/*grad_for_dx=*/grad_for_dx);
// To setting data type for different kinda out_data.
if (grad_for_dx) {
CublasLtBase<T, DXT, MatmulGradDescriptor>::RunImpl(
ctx,
&setter.desc,
setter.sub_key,
no_exchange ? x->data<T>() : y->data<T>(),
no_exchange ? y->data<T>() : x->data<T>(),
out->data<DXT>(),
&planner);
} else {
CublasLtBase<T, DYT, MatmulGradDescriptor>::RunImpl(
ctx,
&setter.desc,
setter.sub_key,
no_exchange ? x->data<T>() : y->data<T>(),
no_exchange ? y->data<T>() : x->data<T>(),
out->data<DYT>(),
&planner);
}
}
};
#else #else
// A void structure just for successfully complile. // A void structure just for successfully complile.
struct MatmulPlanner {}; struct MatmulPlanner {};
......
...@@ -52,6 +52,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, ...@@ -52,6 +52,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
"Axis should be less than or equal to %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
if (x_dims.size() > y_dims.size()) { if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1); std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) { if (axis + y_dims.size() < max_dim) {
...@@ -68,7 +69,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, ...@@ -68,7 +69,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array); std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
} }
for (int i = 0; i < max_dim; i++) { for (int i = 0; i < max_dim; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
y_dims_array[i] <= 1, y_dims_array[i] <= 1,
......
...@@ -350,8 +350,10 @@ void DropoutFwGPUKernelDriver( ...@@ -350,8 +350,10 @@ void DropoutFwGPUKernelDriver(
auto dst_functor = auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel); DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(x.dims()); std::vector<int64_t> out_dims =
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask->dims()); std::move(phi::vectorize<int64_t>(x.dims()));
std::vector<int64_t> in_dims =
std::move(phi::vectorize<int64_t>(mask->dims()));
std::reverse(out_dims.begin(), out_dims.end()); std::reverse(out_dims.begin(), out_dims.end());
std::reverse(in_dims.begin(), in_dims.end()); std::reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config( kps::details::BroadcastConfig broadcast_config(
......
...@@ -37,6 +37,7 @@ limitations under the License. */ ...@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/scope_guard.h" #include "paddle/phi/core/scope_guard.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
DECLARE_int64(cublaslt_exhaustive_search_times); DECLARE_int64(cublaslt_exhaustive_search_times);
...@@ -488,62 +489,103 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, ...@@ -488,62 +489,103 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
phi::dynload::cublasLtMatrixLayoutDestroy(out_desc)); phi::dynload::cublasLtMatrixLayoutDestroy(out_desc));
} }
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; struct BwdFusedEpilogueSetter {
public:
template <bool TransX, bool TransY> static phi::funcs::MatmulFusedType SetForDx(
struct FusedGEMMGradTrait; const std::string& activation_grad) {
if (activation_grad == "none") {
template <> return kMatmulGrad;
struct FusedGEMMGradTrait<false, false> { } else if (activation_grad == "relu_grad") {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; return kMatmulReluGrad;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY; } else if (activation_grad == "gelu_grad") {
static constexpr auto kXGradATrans = false; return kMatmulGeluGrad;
static constexpr auto kXGradBTrans = true; } else {
PADDLE_THROW(phi::errors::InvalidArgument(
static constexpr auto kYGradA = FusedGEMMGradInType::kDX; "Fued linear epilogue type should be one of {none, relu, gelu}."
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; "But received activation is %s, please check",
static constexpr auto kYGradATrans = true; activation_grad));
static constexpr auto kYGradBTrans = false; }
}; }
template <> template <typename DYT, bool TransY>
struct FusedGEMMGradTrait<true, false> { static phi::funcs::MatmulFusedType SetForDy(const phi::GPUContext& dev_ctx,
static constexpr auto kXGradA = FusedGEMMGradInType::kDY; phi::DenseTensor* dbias) {
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; if (dbias != nullptr) {
static constexpr auto kXGradATrans = false; dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT));
static constexpr auto kXGradBTrans = true; return TransY ? kMatmulBiasGradToB : kMatmulBiasGradToA;
} else {
static constexpr auto kYGradA = FusedGEMMGradInType::kDX; return kMatmulGradWithoutBias;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; }
static constexpr auto kYGradATrans = false; }
static constexpr auto kYGradBTrans = false;
}; };
template <> template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
struct FusedGEMMGradTrait<false, true> { void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; const phi::DenseTensor* dout,
static constexpr auto kXGradB = FusedGEMMGradInType::kDY; const phi::DenseTensor* x,
static constexpr auto kXGradATrans = false; const phi::DenseTensor* y,
static constexpr auto kXGradBTrans = false; const phi::DenseTensor* reserve_space,
int64_t M,
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; int64_t N,
static constexpr auto kYGradB = FusedGEMMGradInType::kDX; int64_t K,
static constexpr auto kYGradATrans = true; const std::string activation_grad,
static constexpr auto kYGradBTrans = false; phi::DenseTensor* dx,
}; phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto_dx,
bool use_addto_dy) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
static_assert(std::is_same<DXT, T>::value || std::is_same<DXT, MT>::value);
static_assert(std::is_same<DYT, T>::value || std::is_same<DYT, MT>::value);
using Trait = FusedGEMMGradTrait<TransX, TransY>;
template <> if (dx) {
struct FusedGEMMGradTrait<true, true> { constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
static constexpr auto kXGradA = FusedGEMMGradInType::kDY; auto fused_type = BwdFusedEpilogueSetter::SetForDx(activation_grad);
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; void* reserve_data = (fused_type == kMatmulGrad)
static constexpr auto kXGradATrans = true; ? nullptr
static constexpr auto kXGradBTrans = true; : const_cast<void*>(reserve_space->data());
dev_ctx.Alloc<DXT>(dx, dx->numel() * sizeof(DXT));
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; phi::funcs::LinearGradWithCublasLt<T, DXT, DYT, TransX, TransY>::Run(
static constexpr auto kYGradB = FusedGEMMGradInType::kDX; dev_ctx,
static constexpr auto kYGradATrans = true; dout,
static constexpr auto kYGradBTrans = true; y,
}; dx,
nullptr,
reserve_data,
M,
N,
K,
fused_type,
Trait::kXGradATrans,
Trait::kXGradBTrans,
use_addto_dx,
kXGradAIsDZ);
}
if (dy) {
auto fused_type =
BwdFusedEpilogueSetter::SetForDy<DYT, TransY>(dev_ctx, dbias);
constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);
// Caution: DYT is in front of DXT in this template arguments.
dev_ctx.Alloc<DYT>(dy, dy->numel() * sizeof(DYT));
phi::funcs::LinearGradWithCublasLt<T, DXT, DYT, TransX, TransY>::Run(
dev_ctx,
dout,
x,
dy,
dbias ? static_cast<const void*>(dbias->data<DYT>()) : nullptr,
nullptr,
M,
N,
K,
fused_type,
Trait::kYGradATrans,
Trait::kYGradBTrans,
use_addto_dy,
kYGradAIsDZ,
/*is_dx=*/false);
}
}
static constexpr auto BoolToCuBlasEnum(bool transpose) { static constexpr auto BoolToCuBlasEnum(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
...@@ -567,7 +609,8 @@ static cublasLtEpilogue_t GetEpilogueGradType( ...@@ -567,7 +609,8 @@ static cublasLtEpilogue_t GetEpilogueGradType(
} }
template <typename T, typename DXT, typename DYT, bool TransX, bool TransY> template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, void ComputeFusedGemmEpilogueBackwardImplDev(
const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout, const phi::DenseTensor* dout,
const phi::DenseTensor* x, const phi::DenseTensor* x,
const phi::DenseTensor* y, const phi::DenseTensor* y,
......
...@@ -559,7 +559,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, ...@@ -559,7 +559,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
// max index to read // max index to read
int idx_max = (i < local_batches) ? element_count : 0; int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize; int idx_max_v = idx_max / kVSize;
#pragma unroll
// read data // read data
for (int it = 0; it < kIterationsV; ++it) { for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize; int src_idx = threadIdx.x + it * kWarpSize;
...@@ -659,7 +659,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, ...@@ -659,7 +659,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
// loss // loss
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss); phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
#pragma unroll
for (int i = 0; i < kBatchSize; i++) { for (int i = 0; i < kBatchSize; i++) {
if (i >= local_batches) break; if (i >= local_batches) break;
loss[first_batch + i] = sumloss[i]; loss[first_batch + i] = sumloss[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册