未验证 提交 81086cff 编写于 作者: J Juncheng 提交者: GitHub

Matmul kernels use primitive (#6589)

* Matmul kernels use primitive

* refine

* fix
Co-authored-by: Nguo ran <360112263@qq.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 a633be98
......@@ -21,7 +21,7 @@ namespace oneflow {
namespace {
cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) {
cublasOperation_t cublas_trans;
cublasOperation_t cublas_trans{};
if (trans == CBLAS_TRANSPOSE::CblasNoTrans) {
cublas_trans = cublasOperation_t::CUBLAS_OP_N;
} else if (trans == CBLAS_TRANSPOSE::CblasTrans) {
......@@ -29,6 +29,7 @@ cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) {
} else if (trans == CBLAS_TRANSPOSE::CblasConjTrans) {
cublas_trans = cublasOperation_t::CUBLAS_OP_C;
} else {
UNIMPLEMENTED();
// do nothing
}
return cublas_trans;
......@@ -46,11 +47,14 @@ std::tuple<int, int, int, cublasOperation_t, cublasOperation_t> PrepareToCallCub
}
template<typename T>
void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER /*order*/, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha,
const T* a, const T* b, const double beta, T* c) {
int lda, ldb, ldc;
cublasOperation_t cublas_trans_a, cublas_trans_b;
int lda = 0;
int ldb = 0;
int ldc = 0;
cublasOperation_t cublas_trans_a{};
cublasOperation_t cublas_trans_b{};
std::tie(lda, ldb, ldc, cublas_trans_a, cublas_trans_b) =
PrepareToCallCublasGemm(trans_a, trans_b, m, n, k);
......@@ -61,13 +65,16 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra
}
template<>
void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER /*order*/, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha,
const half* a, const half* b, const double beta, half* c) {
const float alpha_f = static_cast<float>(alpha);
const float beta_f = static_cast<float>(beta);
int lda, ldb, ldc;
cublasOperation_t cublas_trans_a, cublas_trans_b;
int lda = 0;
int ldb = 0;
int ldc = 0;
cublasOperation_t cublas_trans_a{};
cublasOperation_t cublas_trans_b;
std::tie(lda, ldb, ldc, cublas_trans_a, cublas_trans_b) =
PrepareToCallCublasGemm(trans_a, trans_b, m, n, k);
#if CUDA_VERSION < 11000
......@@ -86,21 +93,6 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra
}
}
std::tuple<int, int, int, int, int, int, cublasOperation_t, cublasOperation_t>
PrepareToCallBatchedGemm(const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
int batch_size, int m, int n, int k) {
const int a_stride = m * k;
const int b_stride = k * n;
const int c_stride = m * n;
const int lda = (trans_a == CblasNoTrans) ? k : m;
const int ldb = (trans_b == CblasNoTrans) ? n : k;
const int ldc = n;
cublasOperation_t cublas_trans_a = CblasTrans2CublasTrans(trans_a);
cublasOperation_t cublas_trans_b = CblasTrans2CublasTrans(trans_b);
return std::make_tuple(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a,
cublas_trans_b);
}
#define CUDA_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(float, CUDA_R_32F) \
OF_PP_MAKE_TUPLE_SEQ(double, CUDA_R_64F) \
......@@ -115,80 +107,6 @@ struct CudaDataType;
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_CUDA_DATA_TYPE, CUDA_DATA_TYPE_SEQ);
#undef SPECIALIZE_CUDA_DATA_TYPE
template<typename T>
cudaDataType_t GetCudaDataType4BatchedGemm() {
return CudaDataType<T>::value;
}
template<typename T>
void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
int batch_size, int m, int n, int k, const double alpha, const T* a,
const T* b, const double beta, T* c) {
int a_stride, b_stride, c_stride;
int lda, ldb, ldc;
const T alpha_val = static_cast<T>(alpha);
const T beta_val = static_cast<T>(beta);
cublasOperation_t cublas_trans_a, cublas_trans_b;
std::tie(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, cublas_trans_b) =
PrepareToCallBatchedGemm(trans_a, trans_b, batch_size, m, n, k);
if (CUDA_VERSION >= 9010 && GetCudaSmVersion() >= 500) {
#if CUDA_VERSION >= 9010
cudaDataType_t data_type = GetCudaDataType4BatchedGemm<T>();
OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx(
ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k,
reinterpret_cast<const void*>(&alpha_val), reinterpret_cast<const void*>(b), data_type, ldb,
b_stride, reinterpret_cast<const void*>(a), data_type, lda, a_stride,
reinterpret_cast<const void*>(&beta_val), reinterpret_cast<void*>(c), data_type, ldc,
c_stride, batch_size, data_type, CUBLAS_GEMM_DEFAULT));
#endif
} else {
cublas_gemmStridedBatched<T>(ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k,
&alpha_val, b, ldb, b_stride, a, lda, a_stride, &beta_val, c, ldc,
c_stride, batch_size);
}
}
#if CUDA_VERSION >= 9010
template<>
void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
int batch_size, int m, int n, int k, const double alpha, const half* a,
const half* b, const double beta, half* c) {
int a_stride, b_stride, c_stride;
int lda, ldb, ldc;
cublasOperation_t cublas_trans_a, cublas_trans_b;
std::tie(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, cublas_trans_b) =
PrepareToCallBatchedGemm(trans_a, trans_b, batch_size, m, n, k);
#if CUDA_VERSION < 11000
CublasMathModeGuard guard(ctx->cublas_handle(), CUBLAS_TENSOR_OP_MATH);
#else
CublasMathModeGuard guard(ctx->cublas_handle(), CUBLAS_DEFAULT_MATH);
#endif // CUDA_VERSION < 11000
if (GetCudaSmVersion() >= 500) {
const float alpha_f = static_cast<float>(alpha);
const float beta_f = static_cast<float>(beta);
#if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
#endif
OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx(
ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k, &alpha_f,
reinterpret_cast<const void*>(b), CUDA_R_16F, ldb, b_stride,
reinterpret_cast<const void*>(a), CUDA_R_16F, lda, a_stride, &beta_f,
reinterpret_cast<void*>(c), CUDA_R_16F, ldc, c_stride, batch_size, CUDA_R_32F, algo));
} else {
const half alpha_h = static_cast<half>(alpha);
const half beta_h = static_cast<half>(beta);
cublas_gemmStridedBatched<half>(ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k,
&alpha_h, b, ldb, b_stride, a, lda, a_stride, &beta_h, c, ldc,
c_stride, batch_size);
}
}
#endif
} // namespace
void BlasIf<DeviceType::kGPU>::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
......@@ -211,30 +129,4 @@ void BlasIf<DeviceType::kGPU>::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans
reinterpret_cast<const half*>(b), beta, reinterpret_cast<half*>(c));
}
void BlasIf<DeviceType::kGPU>::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size,
const int m, const int n, const int k,
const double alpha, const float* a, const float* b,
const double beta, float* c) {
BatchedGemmImpl<float>(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b,
beta, c);
}
void BlasIf<DeviceType::kGPU>::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size,
const int m, const int n, const int k,
const double alpha, const double* a, const double* b,
const double beta, double* c) {
BatchedGemmImpl<double>(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b,
beta, c);
}
void BlasIf<DeviceType::kGPU>::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size,
const int m, const int n, const int k,
const double alpha, const float16* a, const float16* b,
const double beta, float16* c) {
BatchedGemmImpl<half>(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha,
reinterpret_cast<const half*>(a), reinterpret_cast<const half*>(b), beta,
reinterpret_cast<half*>(c));
}
} // namespace oneflow
......@@ -33,19 +33,6 @@ struct BlasIf<DeviceType::kGPU> {
static void OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b,
const int m, const int n, const int k, const double alpha, const float16* a,
const float16* b, const double beta, float16* c);
static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m,
const int n, const int k, const double alpha, const float* a,
const float* b, const double beta, float* c);
static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m,
const int n, const int k, const double alpha, const double* a,
const double* b, const double beta, double* c);
static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m,
const int n, const int k, const double alpha, const float16* a,
const float16* b, const double beta, float16* c);
};
} // namespace oneflow
......
......@@ -14,14 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/util/host_blas_interface.h"
#include "oneflow/core/register/blob.h"
namespace oneflow {
namespace {
template<typename T>
static void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
static void Gemm(DeviceCtx* /*ctx*/, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
const double alpha, const T* a, const T* b, const double beta, T* c) {
const int lda = (trans_a == CblasNoTrans) ? k : m;
......@@ -32,20 +31,6 @@ static void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSP
static_cast<T>(beta), c, ldc);
}
template<typename T>
void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
int batch_size, int m, int n, int k, const double alpha, const T* a,
const T* b, const double beta, T* c) {
const int a_stride = m * k;
const int b_stride = k * n;
const int c_stride = m * n;
FOR_RANGE(int32_t, i, 0, batch_size) {
BlasIf<DeviceType::kCPU>::OFGemm(ctx, trans_a, trans_b, m, n, k, alpha, a + i * a_stride,
b + i * b_stride, beta, c + i * c_stride);
}
}
} // namespace
void BlasIf<DeviceType::kCPU>::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
......@@ -62,22 +47,4 @@ void BlasIf<DeviceType::kCPU>::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans
Gemm<double>(ctx, CblasRowMajor, trans_a, trans_b, m, n, k, alpha, a, b, beta, c);
}
void BlasIf<DeviceType::kCPU>::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size,
const int m, const int n, const int k,
const double alpha, const float* a, const float* b,
const double beta, float* c) {
BatchedGemmImpl<float>(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b,
beta, c);
}
void BlasIf<DeviceType::kCPU>::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size,
const int m, const int n, const int k,
const double alpha, const double* a, const double* b,
const double beta, double* c) {
BatchedGemmImpl<double>(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b,
beta, c);
}
} // namespace oneflow
......@@ -32,15 +32,6 @@ struct BlasIf<DeviceType::kCPU> {
static void OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b,
const int m, const int n, const int k, const double alpha, const double* a,
const double* b, const double beta, double* c);
static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m,
const int n, const int k, const double alpha, const float* a,
const float* b, const double beta, float* c);
static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m,
const int n, const int k, const double alpha, const double* a,
const double* b, const double beta, double* c);
};
} // namespace oneflow
......
......@@ -32,14 +32,14 @@ class BatchMatmulImpl : public BatchMatmul {
broadcast_matmul_(std::move(broadcast_matmul)) {}
~BatchMatmulImpl() override = default;
void Launch(StreamContext* stream_ctx, size_t num_batches, size_t m, size_t n, size_t k,
void Launch(StreamContext* stream_ctx, size_t batch_size, size_t m, size_t n, size_t k,
Scalar alpha, const void* a, const void* b, Scalar beta, void* c) override {
int64_t a_dims[3];
int64_t b_dims[3];
int64_t c_dims[3];
a_dims[0] = num_batches;
b_dims[0] = num_batches;
c_dims[0] = num_batches;
a_dims[0] = batch_size;
b_dims[0] = batch_size;
c_dims[0] = batch_size;
if (transpose_a_ == BlasTransposeType::N) {
a_dims[1] = m;
a_dims[2] = k;
......
......@@ -30,7 +30,7 @@ class BatchMatmul : public Primitive {
BatchMatmul() = default;
~BatchMatmul() override = default;
virtual void Launch(StreamContext* stream_ctx, size_t num_batches, size_t m, size_t n, size_t k,
virtual void Launch(StreamContext* stream_ctx, size_t batch_size, size_t m, size_t n, size_t k,
Scalar alpha, const void* a, const void* b, Scalar beta, void* c) = 0;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册