未验证 提交 d4217fc6 编写于 作者: L limingshu 提交者: GitHub

Matmul performance optimization with cuBlasLt (#46431)


* implement of matmul using cublasLt instead of cublas

* Update matmul_kernel_impl_via_blasLt.h

---------
Co-authored-by: Nzhangbopd <1299246947@qq.com>
Co-authored-by: NBo Zhang <105368690+zhangbopd@users.noreply.github.com>
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 57f6a469
......@@ -141,7 +141,43 @@ class AutoTuneBase {
}
};
// To init the auto_tuner object.
template <typename T, typename ReturnType, typename... Args>
class MatmulAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
public:
static MatmulAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(Args...)) {
static std::once_flag matmul_init_flag;
static std::unique_ptr<MatmulAutoTuner<T, ReturnType, Args...>> instance;
std::call_once(matmul_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new MatmulAutoTuner<T, ReturnType, Args...>);
instance->AddCallBack(func);
});
return instance.get();
}
template <typename Context>
void Run(const Context& ctx, const size_t key, Args... args) {
this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetMatmul();
if (cache.Find(key)) {
auto best_idx = cache.Get(key);
this->kernels_[best_idx].Run(args...);
} else {
bool use_autotune = AutoTuneStatus::Instance().UseAutoTune();
if (use_autotune) {
auto best_idx = this->PickBestKernel(ctx, args...);
cache.Set(key, best_idx);
} else {
this->kernels_[0].Run(args...);
}
}
}
};
// Define the auto_tuner inital object.
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
template <typename T, typename ReturnType, typename... Args> \
class name##AutoTuner \
......@@ -161,7 +197,7 @@ class AutoTuneBase {
} \
};
// To init auto_tuner inital function.
// Define the auto_tuner inital function.
#define DEFINE_AUTOTUNER_FN(name) \
template <typename T, typename ReturnType, typename... Args> \
static name##AutoTuner<T, ReturnType, Args...>* Make##name##Tuner( \
......@@ -170,9 +206,11 @@ class AutoTuneBase {
}
#define DEFINE_AUTOTUNER(name) \
DEFINE_AUTOTUNER_COMMON_OBJ(name) DEFINE_AUTOTUNER_FN(name)
DEFINE_AUTOTUNER_COMMON_OBJ(name) \
DEFINE_AUTOTUNER_FN(name)
DEFINE_AUTOTUNER(Transpose)
DEFINE_AUTOTUNER_FN(Matmul)
#undef DEFINE_AUTOTUNER_COMMON_OBJECT
#undef DEFINE_AUTOTUNER_FN
......
......@@ -44,28 +44,33 @@ enum class AlgorithmType {
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kTranspose = 4,
#ifdef PADDLE_WITH_CUDNN_FRONTEND
kConvForwardV8 = 5,
kConvBackwardDataV8 = 6,
kConvBackwardFilterV8 = 7,
kAlgorithmCount = 8
kMatmul = 5,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 6
#else
kAlgorithmCount = 5
kConvForwardV8 = 6,
kConvBackwardDataV8 = 7,
kConvBackwardFilterV8 = 8,
kAlgorithmCount = 9
#endif
};
// AlgorithmsConfigKey -> AlgorithmsID
// (todo. hong) use cudnnConvolutionFwdAlgo_t
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
// AlgorithmType -> AlgorithmsCache
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
// (todo. hong) use cudnnConvolutionFwdAlgo_t
using ConvAlgorithmsCacheMap = ConvAlgorithmsCache<ConvAutoTuneResult>;
using ConvAlgorithmsTypeMap =
std::unordered_map<int64_t, ConvAlgorithmsCacheMap>;
using MatmulAlgorithmsCacheMap = MatmulAlgorithmsCache<size_t, int64_t>;
#ifdef PADDLE_WITH_CUDNN_FRONTEND
using CudnnV8AlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif
class AutoTuneCache {
public:
static AutoTuneCache& Instance() {
......@@ -77,6 +82,8 @@ class AutoTuneCache {
return auto_tune_map_[static_cast<int64_t>(algo_type)];
}
MatmulAlgorithmsCacheMap& GetMatmul() { return matmul_auto_tune_map_; }
ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
......@@ -87,8 +94,6 @@ class AutoTuneCache {
}
#endif
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
void Clean() {
for (auto& v : auto_tune_map_) {
v.second.Clean();
......@@ -162,6 +167,7 @@ class AutoTuneCache {
AlgorithmsTypeMap auto_tune_map_;
ConvAlgorithmsTypeMap conv_auto_tune_map_;
MatmulAlgorithmsCacheMap matmul_auto_tune_map_;
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnV8AlgorithmsTypeMap cudnn_v8_auto_tune_map_;
#endif
......
......@@ -60,6 +60,31 @@ size_t GenKey(Args&&... args) {
return seed;
}
struct MatmulHashValueType {
uint64_t data[8];
};
struct MatmulCacheKey {
public:
MatmulCacheKey() {}
MatmulCacheKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
const bool trans_x,
const bool trans_y,
phi::DataType dtype) {
key = GenKey(x_dims,
y_dims,
static_cast<int64_t>(trans_x),
static_cast<int64_t>(trans_y),
static_cast<int64_t>(dtype));
}
size_t GetKey() const { return key; }
size_t GetSubKey(int64_t idx) const { return GenKey(key, idx); }
private:
size_t key;
};
struct ConvCacheKey {
ConvCacheKey() {}
ConvCacheKey(const std::vector<int64_t>& arg_x_dims,
......@@ -213,5 +238,34 @@ class ConvAlgorithmsCache : public AlgorithmsCache<ConvCacheKey,
}
};
template <typename KeyT, typename AlgorithmT>
class MatmulAlgorithmsCache : public AlgorithmsCache<KeyT, AlgorithmT> {
public:
MatmulAlgorithmsCache() : AlgorithmsCache<KeyT, AlgorithmT>() {}
bool FindSubKey(const KeyT& sub_key) {
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
bool ret = (sub_hash_.find(sub_key) != sub_hash_.end()) ? true : false;
return ret;
}
void SetSubKey(const KeyT& sub_key, const MatmulHashValueType* algo) {
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
sub_hash_[sub_key] = *algo;
}
MatmulHashValueType* GetSubKey(const KeyT& sub_key) {
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
PADDLE_ENFORCE_NE(
sub_hash_.find(sub_key),
sub_hash_.end(),
phi::errors::PreconditionNotMet("The key does not exist."));
return &(sub_hash_[sub_key]);
}
private:
std::unordered_map<KeyT, MatmulHashValueType> sub_hash_;
};
} // namespace autotune
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include <cuda_runtime_api.h>
#include "cuda.h" // NOLINT
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
namespace phi {
namespace funcs {
enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 };
template <typename T>
cudaDataType_t ConvertToCudaDataType() {
if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else if (std::is_same<T, double>::value) {
return CUDA_R_64F;
} else if (std::is_same<T, phi::dtype::float16>::value) {
return CUDA_R_16F;
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
}
}
template <typename T>
cublasComputeType_t GetCudaComputeType() {
if (std::is_same<T, double>::value) {
return CUBLAS_COMPUTE_64F;
} else {
return CUBLAS_COMPUTE_32F;
}
}
struct MatmulDescriptor {
public:
cublasLtMatmulDesc_t op_desc{nullptr};
cublasLtMatrixLayout_t x_desc{nullptr};
cublasLtMatrixLayout_t y_desc{nullptr};
cublasLtMatrixLayout_t out_desc{nullptr};
template <typename T>
void Create(const int M,
const int N,
const int K,
const bool trans_x,
const bool trans_y,
const int batch_size = 1,
int64_t stride_x = 0,
int64_t stride_y = 0,
int64_t stride_out = 0) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
cudaDataType_t mat_type = ConvertToCudaDataType<T>();
cudaDataType_t scale_type = ConvertToCudaDataType<MT>();
cublasComputeType_t compute_type = GetCudaComputeType<T>();
// Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
// details about defaults; just need to set the transforms for A and B
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&cublas_trans_x,
sizeof(cublas_trans_x)));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&cublas_trans_y,
sizeof(cublas_trans_y)));
// Create matrix descriptors
CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x);
CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y);
CreateMatrixLayout(&out_desc, mat_type, M, N, false);
// Config batch size and stride.
if (batch_size > 1) {
SetBatchAndStride(x_desc, batch_size, stride_x);
SetBatchAndStride(y_desc, batch_size, stride_y);
SetBatchAndStride(out_desc, batch_size, stride_out);
}
}
void CreateMatrixLayout(cublasLtMatrixLayout_t* desc,
cudaDataType type,
uint64_t rows,
uint64_t cols,
bool trans) {
if (trans) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols));
}
}
void SetBatchAndStride(cublasLtMatrixLayout_t desc,
int batch_size,
int64_t stride) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_size,
sizeof(batch_size)));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride,
sizeof(stride)));
}
void Release() {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(out_desc));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc));
op_desc = nullptr;
x_desc = nullptr;
y_desc = nullptr;
out_desc = nullptr;
}
};
template <typename T>
struct MatmulWithCublasLt {
public:
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
static void Run(const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
const int M,
const int N,
const int K,
const bool trans_x,
const bool trans_y,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
MatmulDescriptor desc;
desc.Create<T>(M, N, K, trans_x, trans_y);
RunImpl(ctx, desc, x_data, y_data, out_data, matmul_key);
desc.Release();
}
static void RunWithBatch(
const phi::GPUContext& ctx,
const T* x_data,
const T* y_data,
T* out_data,
const int M,
const int N,
const int K,
bool trans_x,
bool trans_y,
int batch_size,
int64_t stride_x,
int64_t stride_y,
int64_t stride_out,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
MatmulDescriptor desc;
desc.Create<T>(
M, N, K, trans_x, trans_y, batch_size, stride_x, stride_y, stride_out);
RunImpl(ctx, desc, x_data, y_data, out_data, matmul_key);
desc.Release();
}
static void RunWithBatch(
const phi::GPUContext& ctx,
const T** x_data,
const T** y_data,
T** out_data,
const int M,
const int N,
const int K,
bool trans_x,
bool trans_y,
int batch_size,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
for (int i = 0; i < batch_size; ++i) {
Run(ctx,
x_data[i],
y_data[i],
out_data[i],
M,
N,
K,
trans_x,
trans_y,
matmul_key);
}
}
private:
static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx,
size_t workspace_size) {
return paddle::memory::Alloc(
ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
}
static void RunImpl(const phi::GPUContext& ctx,
const MatmulDescriptor& desc,
const T* x_ptr,
const T* y_ptr,
T* out_ptr,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
MT alpha = static_cast<MT>(1);
MT beta = static_cast<MT>(0);
cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle();
cublasLtMatmulAlgo_t* best_algo = nullptr;
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size);
if (matmul_key != nullptr) {
auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
size_t sub_key = matmul_key->GetSubKey(
static_cast<int64_t>(MatmulImplType::kImplWithCublasLt));
if (cache.FindSubKey(sub_key)) {
best_algo =
reinterpret_cast<cublasLtMatmulAlgo_t*>(cache.GetSubKey(sub_key));
} else if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune()) {
cublasLtMatmulAlgo_t test_algo;
SearchBestAlgo(ctx,
cublaslt_handle,
desc.op_desc,
desc.y_desc,
desc.x_desc,
desc.out_desc,
static_cast<void*>(&alpha),
static_cast<void*>(&beta),
y_ptr,
x_ptr,
out_ptr,
workspace->ptr(),
workspace_size,
&(test_algo));
cache.SetSubKey(
sub_key,
reinterpret_cast<phi::autotune::MatmulHashValueType*>(&test_algo));
best_algo = &test_algo;
}
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(
cublaslt_handle,
desc.op_desc,
static_cast<void*>(&alpha),
y_ptr,
desc.y_desc,
x_ptr,
desc.x_desc,
static_cast<void*>(&beta),
out_ptr,
desc.out_desc,
out_ptr,
desc.out_desc,
reinterpret_cast<cublasLtMatmulAlgo_t*>(best_algo),
workspace->ptr(),
workspace_size,
ctx.stream()));
}
static void SearchBestAlgo(const phi::GPUContext& ctx,
const cublasLtHandle_t& lt_handle,
const cublasLtMatmulDesc_t& op_desc,
const cublasLtMatrixLayout_t& y_desc,
const cublasLtMatrixLayout_t& x_desc,
const cublasLtMatrixLayout_t& out_desc,
const void* alpha,
const void* beta,
const void* y_data,
const void* x_data,
void* out_data,
void* workspace_ptr,
size_t workspace_size,
cublasLtMatmulAlgo_t* best_algo) {
const auto& stream = ctx.stream();
int returned_results = 0;
constexpr int requested_algo_count = 10;
cublasLtMatmulPreference_t preference;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulPreferenceCreate(&preference));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(workspace_size)));
std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
requested_algo_count);
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle,
op_desc,
y_desc,
x_desc,
out_desc,
out_desc,
preference,
requested_algo_count,
heuristic_results.data(),
&returned_results));
PADDLE_ENFORCE_GT(returned_results,
0,
phi::errors::Unavailable("No GEMM algorithm avaliable."));
phi::GpuTimer timer;
int best_algo_idx = -1;
constexpr int repeats = 6;
float min_time_cost = std::numeric_limits<float>::max();
for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
ctx.Wait();
float cur_time = 0.f;
for (int i = 0; i < repeats; ++i) {
timer.Start(stream);
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmul(lt_handle,
op_desc,
alpha,
y_data,
y_desc,
x_data,
x_desc,
beta,
out_data,
out_desc,
out_data,
out_desc,
&(heuristic_results[algo_idx].algo),
workspace_ptr,
workspace_size,
stream));
timer.Stop(stream);
auto time = timer.ElapsedTime();
if (i > 0) {
cur_time += time;
}
}
float time_cnt = (cur_time / (repeats - 1));
VLOG(4) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]"
<< "is : " << time_cnt << " s";
if (cur_time < min_time_cost) {
best_algo_idx = algo_idx;
min_time_cost = cur_time;
}
}
VLOG(4) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx;
*best_algo = heuristic_results[best_algo_idx].algo;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulPreferenceDestroy(preference));
}
};
} // namespace funcs
} // namespace phi
#endif
......@@ -15,8 +15,13 @@ limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache_base.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#endif
namespace phi {
......@@ -84,8 +89,10 @@ static void IndexIncreaseFromDims(const int ndim,
}
}
// The general implementation with blas.
template <typename Context, typename T>
void MatMulFunction(const Context& dev_ctx,
void MatMulFunctionImplWithBlas(
const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
const std::vector<std::int64_t>& x_dims,
......@@ -93,7 +100,8 @@ void MatMulFunction(const Context& dev_ctx,
DenseTensor* Out,
bool trans_x,
bool trans_y,
bool flag = false) {
bool flag = false,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
......@@ -471,22 +479,487 @@ void MatMulFunction(const Context& dev_ctx,
}
}
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
// This is almost a copy from MatMulFunctionImplWithBlas,
// compare cublas with cublasLt kernels when Matmul autotune is on
template <typename Context, typename T>
void MatMulFunction(const Context& dev_ctx,
void MatMulFunctionImplWithCublasLt(
const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* Out,
bool trans_x,
bool trans_y,
bool flag = false,
phi::autotune::MatmulCacheKey* matmul_key = nullptr) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
const T* x_data = X.data<T>();
const T* y_data = Y.data<T>();
using blaslt = phi::funcs::MatmulWithCublasLt<T>;
if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel();
const int N = Y.numel();
PADDLE_ENFORCE_EQ(
M,
N,
phi::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements",
M,
N));
// MatMul's case 0 => vector * vector
Out->Resize({1});
dev_ctx.template Alloc<T>(Out);
VLOG(3) << "MatMul with blaslt case 1";
blaslt::Run(dev_ctx,
y_data,
x_data,
dev_ctx.template Alloc<T>(Out),
1,
1,
M,
false,
true,
matmul_key);
return;
}
if (x_ndim == 1) {
const int N = X.numel();
if (trans_y) {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 1],
N,
phi::errors::InvalidArgument("Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1,
N,
y_ndim - 1,
y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 2],
N,
phi::errors::InvalidArgument("Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2,
N,
y_ndim - 2,
y_dims[y_ndim - 2]));
}
std::vector<std::int64_t> out_dims(y_ndim - 1);
if (trans_y) {
std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin());
} else {
std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin());
out_dims.back() = y_dims.back();
}
Out->ResizeAndAllocate(phi::make_ddim(out_dims));
dev_ctx.template Alloc<T>(Out);
if (trans_y) {
const int M = Y.numel() / N;
VLOG(3) << "MatMul with blaslt 2";
blaslt::Run(dev_ctx,
y_data,
x_data,
dev_ctx.template Alloc<T>(Out),
M,
1,
N,
false,
false,
matmul_key);
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y.numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul with blaslt 3";
blaslt::Run(dev_ctx,
y_data,
x_data,
dev_ctx.template Alloc<T>(Out),
M,
1,
N,
true,
false,
matmul_key);
} else {
VLOG(3) << "MatMul with blaslt 4";
blaslt::RunWithBatch(dev_ctx,
y_data,
x_data,
dev_ctx.template Alloc<T>(Out),
M,
1,
N,
true,
false,
batch_size,
M * N,
0,
M,
matmul_key);
}
}
return;
}
if (y_ndim == 1) {
const int N = Y.numel();
if (trans_x) {
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 2],
N,
phi::errors::InvalidArgument("Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d",
x_ndim - 2,
N,
x_ndim - 2,
x_dims[x_ndim - 2]));
} else {
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 1],
N,
phi::errors::InvalidArgument("Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d",
x_ndim - 1,
N,
x_ndim - 1,
x_dims[x_ndim - 1]));
}
std::vector<std::int64_t> out_dims(x_ndim - 1);
if (trans_x) {
std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin());
out_dims.back() = x_dims.back();
} else {
std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin());
}
Out->ResizeAndAllocate(phi::make_ddim(out_dims));
dev_ctx.template Alloc<T>(Out);
if (trans_x) {
const int M = x_dims[x_ndim - 1];
const int batch_size = X.numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul with blaslt 5";
blaslt::Run(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
1,
N,
true,
false,
matmul_key);
} else {
VLOG(3) << "MatMul with blaslt 6";
blaslt::RunWithBatch(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
1,
N,
true,
false,
batch_size,
M * N,
0,
M,
matmul_key);
}
} else {
const int M = X.numel() / N;
VLOG(3) << "MatMul with blaslt 7";
blaslt::Run(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
1,
N,
false,
false,
matmul_key);
}
return;
}
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 1],
K,
phi::errors::InvalidArgument("Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1,
K,
y_ndim - 1,
y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 2],
K,
phi::errors::InvalidArgument("Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2,
K,
y_ndim - 2,
y_dims[y_ndim - 2]));
}
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
const int ndim = (std::max)(x_ndim, y_ndim);
std::vector<std::int64_t> x_broadcast_dims(ndim);
std::vector<std::int64_t> y_broadcast_dims(ndim);
std::vector<std::int64_t> out_broadcast_dims(ndim);
GetBroadcastFromDims(x_ndim - 2,
x_dims.data(),
y_ndim - 2,
y_dims.data(),
x_broadcast_dims.data(),
y_broadcast_dims.data(),
out_broadcast_dims.data());
out_broadcast_dims[ndim - 2] = M;
out_broadcast_dims[ndim - 1] = N;
Out->ResizeAndAllocate(phi::make_ddim(out_broadcast_dims));
dev_ctx.template Alloc<T>(Out);
const int batch_dim = ndim - 2;
// broadcast message
const bool is_broadcast_dims =
!std::equal(x_broadcast_dims.cbegin(),
x_broadcast_dims.cbegin() + batch_dim,
y_broadcast_dims.cbegin());
const std::int64_t x_batch_size =
std::accumulate(x_broadcast_dims.cbegin(),
x_broadcast_dims.cbegin() + batch_dim,
1LL,
std::multiplies<std::int64_t>());
const std::int64_t y_batch_size =
std::accumulate(y_broadcast_dims.cbegin(),
y_broadcast_dims.cbegin() + batch_dim,
1LL,
std::multiplies<std::int64_t>());
const std::int64_t out_batch_size =
std::accumulate(out_broadcast_dims.cbegin(),
out_broadcast_dims.cbegin() + batch_dim,
1LL,
std::multiplies<std::int64_t>());
if (out_batch_size == 0) return;
if (x_batch_size == 1 && y_batch_size == 1) {
VLOG(3) << "MatMul with blaslt 8";
blaslt::Run(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
N,
K,
trans_x,
trans_y,
matmul_key);
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul with blaslt 9";
blaslt::Run(dev_ctx,
y_data,
x_data,
dev_ctx.template Alloc<T>(Out),
y_batch_size * N,
1,
K,
false,
false,
matmul_key);
} else {
VLOG(3) << "MatMul with blaslt 10";
blaslt::RunWithBatch(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
N,
K,
trans_x,
trans_y,
out_batch_size,
0,
K * N,
M * N,
matmul_key);
}
} else if (y_batch_size == 1) {
if (!trans_x) {
VLOG(3) << "MatMul with blaslt 11";
blaslt::Run(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
x_batch_size * M,
N,
K,
false,
trans_y,
matmul_key);
} else {
VLOG(3) << "MatMul with blaslt 12";
blaslt::RunWithBatch(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
N,
K,
true,
trans_y,
out_batch_size,
M * K,
0,
M * N,
matmul_key);
}
} else if (!is_broadcast_dims) {
VLOG(3) << "MatMul with blaslt 13";
blaslt::RunWithBatch(dev_ctx,
x_data,
y_data,
dev_ctx.template Alloc<T>(Out),
M,
N,
K,
trans_x,
trans_y,
out_batch_size,
M * K,
K * N,
M * N,
matmul_key);
} else {
// in the case, can't use stridedgemm
std::vector<const T*> x_ptr(out_batch_size);
std::vector<const T*> y_ptr(out_batch_size);
std::vector<T*> out_ptr(out_batch_size);
std::vector<std::int64_t> index(batch_dim, 0);
for (std::int64_t i = 0; i < out_batch_size; ++i) {
// using the index to get offset
const std::int64_t x_index =
GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data());
const std::int64_t y_index =
GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data());
x_ptr[i] = x_data + x_index * M * K;
y_ptr[i] = y_data + y_index * K * N;
out_ptr[i] = dev_ctx.template Alloc<T>(Out) + i * M * N;
IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data());
}
VLOG(3) << "MatMul with blaslt 14";
blaslt::RunWithBatch(dev_ctx,
x_ptr.data(),
y_ptr.data(),
out_ptr.data(),
M,
N,
K,
trans_x,
trans_y,
out_batch_size,
matmul_key);
}
}
#endif
template <typename Context, typename T>
struct MatMulDispatcher {
void operator()(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool trans_x,
bool trans_y,
bool flag = false) {
const std::vector<std::int64_t> x_dims = vectorize(X.dims());
const std::vector<std::int64_t> y_dims = vectorize(Y.dims());
MatMulFunction<Context, T>(
dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
MatMulFunctionImplWithBlas<Context, T>(
ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag);
}
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
struct MatMulDispatcher<phi::GPUContext, T> {
void operator()(const phi::GPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool trans_x,
bool trans_y,
bool flag = false) {
#if CUDA_VERSION >= 11060
auto* tuner = phi::autotune::MakeMatmulTuner<T>(
MatMulFunctionImplWithBlas<phi::GPUContext, T>);
tuner->AddCallBack(MatMulFunctionImplWithCublasLt<phi::GPUContext, T>);
phi::autotune::MatmulCacheKey matmul_cache(
x_dims,
y_dims,
trans_x,
trans_y,
paddle::experimental::CppTypeToDataType<T>::Type());
tuner->Run(ctx,
matmul_cache.GetKey(),
ctx,
x,
y,
x_dims,
y_dims,
out,
trans_x,
trans_y,
flag,
&matmul_cache);
#else
MatMulFunctionImplWithBlas<phi::GPUContext, T>(
ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag);
#endif
}
};
#endif // PADDLE_WITH_CUDA
template <typename Context, typename T>
void MatMulFunction(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool trans_x,
bool trans_y,
bool flag = false) {
MatMulDispatcher<Context, T>()(
ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag);
}
template <typename T, typename Context>
void MatmulKernel(const Context& dev_ctx,
void MatmulKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
......@@ -502,7 +975,10 @@ void MatmulKernel(const Context& dev_ctx,
0,
phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<Context, T>(dev_ctx, x, y, out, transpose_x, transpose_y);
const std::vector<std::int64_t> x_dims = vectorize(x.dims());
const std::vector<std::int64_t> y_dims = vectorize(y.dims());
MatMulFunction<Context, T>(
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
}
template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册