From 81086cff59d1bebaf318dd04d785591c164666b9 Mon Sep 17 00:00:00 2001 From: Juncheng Date: Wed, 27 Oct 2021 20:56:57 +0800 Subject: [PATCH] Matmul kernels use primitive (#6589) * Matmul kernels use primitive * refine * fix Co-authored-by: guo ran <360112263@qq.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../core/kernel/util/cuda_blas_interface.cu | 136 +----- .../core/kernel/util/cuda_blas_interface.h | 13 - .../core/kernel/util/host_blas_interface.cpp | 35 +- .../core/kernel/util/host_blas_interface.h | 9 - .../core/primitive/common/batch_matmul.cpp | 8 +- oneflow/core/primitive/include/batch_matmul.h | 2 +- oneflow/user/kernels/matmul_kernels.cpp | 430 ++++++++---------- 7 files changed, 222 insertions(+), 411 deletions(-) diff --git a/oneflow/core/kernel/util/cuda_blas_interface.cu b/oneflow/core/kernel/util/cuda_blas_interface.cu index cc62f288f7..1702a5cfa7 100644 --- a/oneflow/core/kernel/util/cuda_blas_interface.cu +++ b/oneflow/core/kernel/util/cuda_blas_interface.cu @@ -21,7 +21,7 @@ namespace oneflow { namespace { cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) { - cublasOperation_t cublas_trans; + cublasOperation_t cublas_trans{}; if (trans == CBLAS_TRANSPOSE::CblasNoTrans) { cublas_trans = cublasOperation_t::CUBLAS_OP_N; } else if (trans == CBLAS_TRANSPOSE::CblasTrans) { @@ -29,6 +29,7 @@ cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) { } else if (trans == CBLAS_TRANSPOSE::CblasConjTrans) { cublas_trans = cublasOperation_t::CUBLAS_OP_C; } else { + UNIMPLEMENTED(); // do nothing } return cublas_trans; @@ -46,11 +47,14 @@ std::tuple PrepareToCallCub } template -void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a, +void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER /*order*/, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha, const T* a, const T* b, const double beta, T* c) { - int lda, ldb, ldc; - cublasOperation_t cublas_trans_a, cublas_trans_b; + int lda = 0; + int ldb = 0; + int ldc = 0; + cublasOperation_t cublas_trans_a{}; + cublasOperation_t cublas_trans_b{}; std::tie(lda, ldb, ldc, cublas_trans_a, cublas_trans_b) = PrepareToCallCublasGemm(trans_a, trans_b, m, n, k); @@ -61,13 +65,16 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra } template<> -void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a, +void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER /*order*/, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha, const half* a, const half* b, const double beta, half* c) { const float alpha_f = static_cast(alpha); const float beta_f = static_cast(beta); - int lda, ldb, ldc; - cublasOperation_t cublas_trans_a, cublas_trans_b; + int lda = 0; + int ldb = 0; + int ldc = 0; + cublasOperation_t cublas_trans_a{}; + cublasOperation_t cublas_trans_b; std::tie(lda, ldb, ldc, cublas_trans_a, cublas_trans_b) = PrepareToCallCublasGemm(trans_a, trans_b, m, n, k); #if CUDA_VERSION < 11000 @@ -86,21 +93,6 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra } } -std::tuple -PrepareToCallBatchedGemm(const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, - int batch_size, int m, int n, int k) { - const int a_stride = m * k; - const int b_stride = k * n; - const int c_stride = m * n; - const int lda = (trans_a == CblasNoTrans) ? k : m; - const int ldb = (trans_b == CblasNoTrans) ? n : k; - const int ldc = n; - cublasOperation_t cublas_trans_a = CblasTrans2CublasTrans(trans_a); - cublasOperation_t cublas_trans_b = CblasTrans2CublasTrans(trans_b); - return std::make_tuple(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, - cublas_trans_b); -} - #define CUDA_DATA_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(float, CUDA_R_32F) \ OF_PP_MAKE_TUPLE_SEQ(double, CUDA_R_64F) \ @@ -115,80 +107,6 @@ struct CudaDataType; OF_PP_FOR_EACH_TUPLE(SPECIALIZE_CUDA_DATA_TYPE, CUDA_DATA_TYPE_SEQ); #undef SPECIALIZE_CUDA_DATA_TYPE -template -cudaDataType_t GetCudaDataType4BatchedGemm() { - return CudaDataType::value; -} - -template -void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, - int batch_size, int m, int n, int k, const double alpha, const T* a, - const T* b, const double beta, T* c) { - int a_stride, b_stride, c_stride; - int lda, ldb, ldc; - const T alpha_val = static_cast(alpha); - const T beta_val = static_cast(beta); - cublasOperation_t cublas_trans_a, cublas_trans_b; - std::tie(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, cublas_trans_b) = - PrepareToCallBatchedGemm(trans_a, trans_b, batch_size, m, n, k); - - if (CUDA_VERSION >= 9010 && GetCudaSmVersion() >= 500) { -#if CUDA_VERSION >= 9010 - cudaDataType_t data_type = GetCudaDataType4BatchedGemm(); - OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx( - ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k, - reinterpret_cast(&alpha_val), reinterpret_cast(b), data_type, ldb, - b_stride, reinterpret_cast(a), data_type, lda, a_stride, - reinterpret_cast(&beta_val), reinterpret_cast(c), data_type, ldc, - c_stride, batch_size, data_type, CUBLAS_GEMM_DEFAULT)); -#endif - } else { - cublas_gemmStridedBatched(ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k, - &alpha_val, b, ldb, b_stride, a, lda, a_stride, &beta_val, c, ldc, - c_stride, batch_size); - } -} - -#if CUDA_VERSION >= 9010 -template<> -void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, - int batch_size, int m, int n, int k, const double alpha, const half* a, - const half* b, const double beta, half* c) { - int a_stride, b_stride, c_stride; - int lda, ldb, ldc; - cublasOperation_t cublas_trans_a, cublas_trans_b; - std::tie(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, cublas_trans_b) = - PrepareToCallBatchedGemm(trans_a, trans_b, batch_size, m, n, k); -#if CUDA_VERSION < 11000 - CublasMathModeGuard guard(ctx->cublas_handle(), CUBLAS_TENSOR_OP_MATH); -#else - CublasMathModeGuard guard(ctx->cublas_handle(), CUBLAS_DEFAULT_MATH); -#endif // CUDA_VERSION < 11000 - if (GetCudaSmVersion() >= 500) { - const float alpha_f = static_cast(alpha); - const float beta_f = static_cast(beta); -#if CUDA_VERSION >= 11000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; -#else - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; -#endif - OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx( - ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k, &alpha_f, - reinterpret_cast(b), CUDA_R_16F, ldb, b_stride, - reinterpret_cast(a), CUDA_R_16F, lda, a_stride, &beta_f, - reinterpret_cast(c), CUDA_R_16F, ldc, c_stride, batch_size, CUDA_R_32F, algo)); - } else { - const half alpha_h = static_cast(alpha); - const half beta_h = static_cast(beta); - cublas_gemmStridedBatched(ctx->cublas_handle(), cublas_trans_b, cublas_trans_a, n, m, k, - &alpha_h, b, ldb, b_stride, a, lda, a_stride, &beta_h, c, ldc, - c_stride, batch_size); - } -} -#endif - } // namespace void BlasIf::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, @@ -211,30 +129,4 @@ void BlasIf::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans reinterpret_cast(b), beta, reinterpret_cast(c)); } -void BlasIf::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, - const int m, const int n, const int k, - const double alpha, const float* a, const float* b, - const double beta, float* c) { - BatchedGemmImpl(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b, - beta, c); -} -void BlasIf::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, - const int m, const int n, const int k, - const double alpha, const double* a, const double* b, - const double beta, double* c) { - BatchedGemmImpl(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b, - beta, c); -} -void BlasIf::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, - const int m, const int n, const int k, - const double alpha, const float16* a, const float16* b, - const double beta, float16* c) { - BatchedGemmImpl(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, - reinterpret_cast(a), reinterpret_cast(b), beta, - reinterpret_cast(c)); -} - } // namespace oneflow diff --git a/oneflow/core/kernel/util/cuda_blas_interface.h b/oneflow/core/kernel/util/cuda_blas_interface.h index a2a2b2a94f..ca846e322d 100644 --- a/oneflow/core/kernel/util/cuda_blas_interface.h +++ b/oneflow/core/kernel/util/cuda_blas_interface.h @@ -33,19 +33,6 @@ struct BlasIf { static void OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha, const float16* a, const float16* b, const double beta, float16* c); - - static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m, - const int n, const int k, const double alpha, const float* a, - const float* b, const double beta, float* c); - static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m, - const int n, const int k, const double alpha, const double* a, - const double* b, const double beta, double* c); - static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m, - const int n, const int k, const double alpha, const float16* a, - const float16* b, const double beta, float16* c); }; } // namespace oneflow diff --git a/oneflow/core/kernel/util/host_blas_interface.cpp b/oneflow/core/kernel/util/host_blas_interface.cpp index 194d2042f0..184b9d067a 100644 --- a/oneflow/core/kernel/util/host_blas_interface.cpp +++ b/oneflow/core/kernel/util/host_blas_interface.cpp @@ -14,14 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/util/host_blas_interface.h" -#include "oneflow/core/register/blob.h" namespace oneflow { namespace { template -static void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a, +static void Gemm(DeviceCtx* /*ctx*/, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha, const T* a, const T* b, const double beta, T* c) { const int lda = (trans_a == CblasNoTrans) ? k : m; @@ -32,20 +31,6 @@ static void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSP static_cast(beta), c, ldc); } -template -void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, - int batch_size, int m, int n, int k, const double alpha, const T* a, - const T* b, const double beta, T* c) { - const int a_stride = m * k; - const int b_stride = k * n; - const int c_stride = m * n; - FOR_RANGE(int32_t, i, 0, batch_size) { - BlasIf::OFGemm(ctx, trans_a, trans_b, m, n, k, alpha, a + i * a_stride, - b + i * b_stride, beta, c + i * c_stride); - } -} - } // namespace void BlasIf::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, @@ -62,22 +47,4 @@ void BlasIf::OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans Gemm(ctx, CblasRowMajor, trans_a, trans_b, m, n, k, alpha, a, b, beta, c); } -void BlasIf::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, - const int m, const int n, const int k, - const double alpha, const float* a, const float* b, - const double beta, float* c) { - BatchedGemmImpl(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b, - beta, c); -} - -void BlasIf::OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, - const int m, const int n, const int k, - const double alpha, const double* a, const double* b, - const double beta, double* c) { - BatchedGemmImpl(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b, - beta, c); -} - } // namespace oneflow diff --git a/oneflow/core/kernel/util/host_blas_interface.h b/oneflow/core/kernel/util/host_blas_interface.h index 73b2f8628c..44eb46394c 100644 --- a/oneflow/core/kernel/util/host_blas_interface.h +++ b/oneflow/core/kernel/util/host_blas_interface.h @@ -32,15 +32,6 @@ struct BlasIf { static void OFGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const double alpha, const double* a, const double* b, const double beta, double* c); - - static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m, - const int n, const int k, const double alpha, const float* a, - const float* b, const double beta, float* c); - static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, - enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m, - const int n, const int k, const double alpha, const double* a, - const double* b, const double beta, double* c); }; } // namespace oneflow diff --git a/oneflow/core/primitive/common/batch_matmul.cpp b/oneflow/core/primitive/common/batch_matmul.cpp index d89751e1e2..1cbca6e002 100644 --- a/oneflow/core/primitive/common/batch_matmul.cpp +++ b/oneflow/core/primitive/common/batch_matmul.cpp @@ -32,14 +32,14 @@ class BatchMatmulImpl : public BatchMatmul { broadcast_matmul_(std::move(broadcast_matmul)) {} ~BatchMatmulImpl() override = default; - void Launch(StreamContext* stream_ctx, size_t num_batches, size_t m, size_t n, size_t k, + void Launch(StreamContext* stream_ctx, size_t batch_size, size_t m, size_t n, size_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) override { int64_t a_dims[3]; int64_t b_dims[3]; int64_t c_dims[3]; - a_dims[0] = num_batches; - b_dims[0] = num_batches; - c_dims[0] = num_batches; + a_dims[0] = batch_size; + b_dims[0] = batch_size; + c_dims[0] = batch_size; if (transpose_a_ == BlasTransposeType::N) { a_dims[1] = m; a_dims[2] = k; diff --git a/oneflow/core/primitive/include/batch_matmul.h b/oneflow/core/primitive/include/batch_matmul.h index 8ba9a7e9a6..a54616b162 100644 --- a/oneflow/core/primitive/include/batch_matmul.h +++ b/oneflow/core/primitive/include/batch_matmul.h @@ -30,7 +30,7 @@ class BatchMatmul : public Primitive { BatchMatmul() = default; ~BatchMatmul() override = default; - virtual void Launch(StreamContext* stream_ctx, size_t num_batches, size_t m, size_t n, size_t k, + virtual void Launch(StreamContext* stream_ctx, size_t batch_size, size_t m, size_t n, size_t k, Scalar alpha, const void* a, const void* b, Scalar beta, void* c) = 0; }; diff --git a/oneflow/user/kernels/matmul_kernels.cpp b/oneflow/user/kernels/matmul_kernels.cpp index 20798c0388..d9e1ebe902 100644 --- a/oneflow/user/kernels/matmul_kernels.cpp +++ b/oneflow/user/kernels/matmul_kernels.cpp @@ -17,244 +17,220 @@ limitations under the License. #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/core/primitive/include/memcpy.h" +#include "oneflow/core/primitive/include/matmul.h" +#include "oneflow/core/primitive/include/batch_matmul.h" namespace oneflow { namespace { -std::tuple CalcMNK(const ShapeView& a_shape, const ShapeView& out_shape, - CBLAS_TRANSPOSE transpose_a) { - int32_t num_axes = a_shape.NumAxes(); - int m = out_shape.At(num_axes - 2); - int n = out_shape.At(num_axes - 1); - int k = transpose_a == CblasTrans ? a_shape.At(num_axes - 2) : a_shape.At(num_axes - 1); - return std::make_tuple(m, n, k); +primitive::BlasTransposeType GetBlasTransposeType(bool transpose) { + return transpose ? primitive::BlasTransposeType::T : primitive::BlasTransposeType::N; } -} // namespace +template +primitive::BlasTransposeType GetBlasTransposeType(Context* ctx, const std::string& attr) { + return GetBlasTransposeType(ctx->template Attr(attr)); +} -template -class MatmulFloatingKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { - public: - MatmulFloatingKernel() = default; - ~MatmulFloatingKernel() = default; +void InferMatmulMNK(const ShapeView& a_shape, const ShapeView& b_shape, const ShapeView& c_shape, + primitive::BlasTransposeType transpose_a, + primitive::BlasTransposeType transpose_b, size_t* m, size_t* n, size_t* k) { + const int64_t num_a_axes = a_shape.NumAxes(); + CHECK_GE(num_a_axes, 2); + const int64_t num_b_axes = b_shape.NumAxes(); + CHECK_GE(num_b_axes, 2); + const int64_t num_c_axes = c_shape.NumAxes(); + CHECK_GE(num_c_axes, 2); + if (transpose_a == primitive::BlasTransposeType::N) { + *m = a_shape.At(num_a_axes - 2); + *k = a_shape.At(num_a_axes - 1); + } else if (transpose_a == primitive::BlasTransposeType::T) { + *m = a_shape.At(num_a_axes - 1); + *k = a_shape.At(num_a_axes - 2); + } else { + UNIMPLEMENTED(); + } + if (transpose_b == primitive::BlasTransposeType::N) { + CHECK_EQ(b_shape.At(num_b_axes - 2), *k); + *n = b_shape.At(num_b_axes - 1); + } else if (transpose_b == primitive::BlasTransposeType::T) { + CHECK_EQ(b_shape.At(num_b_axes - 1), *k); + *n = b_shape.At(num_b_axes - 2); + } else { + UNIMPLEMENTED(); + } + CHECK_EQ(c_shape.At(num_c_axes - 2), *m); + CHECK_EQ(c_shape.At(num_c_axes - 1), *n); +} - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +template +std::unique_ptr NewMemcpyPrimitive(Context* ctx) { + return primitive::NewPrimitive(ctx->device_type(), + primitive::MemcpyKind::kDtoD); +} - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - CBLAS_TRANSPOSE trans_a = ctx->Attr("transpose_a") ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE trans_b = ctx->Attr("transpose_b") ? CblasTrans : CblasNoTrans; - const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); - const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); - user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - CHECK_EQ(2, a->shape().NumAxes()); +std::unique_ptr NewMatmulPrimitive(DeviceType device_type, DataType data_type, + bool transpose_a, bool transpose_b) { + const auto trans_a = GetBlasTransposeType(transpose_a); + const auto trans_b = GetBlasTransposeType(transpose_b); + return primitive::NewPrimitive(device_type, data_type, trans_a, + trans_b); +} - int32_t m = 0, n = 0, k = 0; - std::tie(m, n, k) = CalcMNK(a->shape(), out->shape(), trans_a); - const double alpha = ctx->Attr("alpha"); - double beta; - if (ctx->has_input("_add_to_output", 0)) { - const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); - CHECK_EQ(add_to_output->data_type(), out->data_type()); - CHECK_EQ(add_to_output->shape(), out->shape()); - Memcpy( - ctx->device_ctx(), out->mut_dptr(), add_to_output->dptr(), - add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); - beta = 1.0; - } else { - beta = 0.0; - } - NewKernelUtil::OFGemm(ctx->device_ctx(), trans_a, trans_b, m, n, k, alpha, - a->dptr(), b->dptr(), beta, out->mut_dptr()); - } -}; +template +std::unique_ptr NewMatmulPrimitive(Context* ctx) { + const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); + return NewMatmulPrimitive(ctx->device_type(), data_type, ctx->template Attr("transpose_a"), + ctx->template Attr("transpose_b")); +} -#define REGISTER_MATMUL_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("matmul") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("a", 0) == GetDataType::value)) \ - .SetInplaceProposalFn([](const user_op::InferContext& ctx, \ - user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ - if (ctx.has_input("_add_to_output", 0)) { \ - OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); \ - } \ - return Maybe::Ok(); \ - }); +template +std::unique_ptr NewBatchMatmulPrimitive(Context* ctx) { + const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); + const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); + const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); + return primitive::NewPrimitive(ctx->device_type(), data_type, + trans_a, trans_b); +} + +hob::HobContextGetter MemcpyPrimitiveExists() { + return user_op::HobCtxGetter("MemcpyPrimitiveExists", + [](const user_op::KernelRegContext& ctx) { + return NewMemcpyPrimitive(&ctx).operator bool(); + }); +} -REGISTER_MATMUL_KERNEL(DeviceType::kCPU, float); -REGISTER_MATMUL_KERNEL(DeviceType::kCPU, double); -#ifdef WITH_CUDA -REGISTER_MATMUL_KERNEL(DeviceType::kGPU, float); -REGISTER_MATMUL_KERNEL(DeviceType::kGPU, double); -#endif +hob::HobContextGetter MatmulPrimitiveExists() { + return user_op::HobCtxGetter("MatmulPrimitiveExists", + [](const user_op::KernelRegContext& ctx) { + return NewMatmulPrimitive(&ctx).operator bool(); + }); +} + +hob::HobContextGetter BatchMatmulPrimitiveExists() { + return user_op::HobCtxGetter("BatchMatmulPrimitiveExists", + [](const user_op::KernelRegContext& ctx) { + return NewBatchMatmulPrimitive(&ctx).operator bool(); + }); +} -#ifdef WITH_CUDA -class MatmulGpuHalfKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { +class MatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: - MatmulGpuHalfKernel() = default; - ~MatmulGpuHalfKernel() = default; + MatmulKernel() = default; + ~MatmulKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { - CBLAS_TRANSPOSE trans_a = ctx->Attr("transpose_a") ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE trans_b = ctx->Attr("transpose_b") ? CblasTrans : CblasNoTrans; + const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); + const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); + CHECK_EQ(a->shape().NumAxes(), 2); + const DataType data_type = a->data_type(); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); + CHECK_EQ(b->shape().NumAxes(), 2); + CHECK_EQ(b->data_type(), data_type); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - CHECK_EQ(2, a->shape().NumAxes()); - - int32_t m = 0, n = 0, k = 0; - std::tie(m, n, k) = CalcMNK(a->shape(), out->shape(), trans_a); - bool has_add_to_output = ctx->has_input("_add_to_output", 0); - if (has_add_to_output) { + CHECK_EQ(out->shape().NumAxes(), 2); + CHECK_EQ(out->data_type(), data_type); + size_t m = 0, n = 0, k = 0; + InferMatmulMNK(a->shape(), b->shape(), out->shape(), trans_a, trans_b, &m, &n, &k); + const double alpha = ctx->Attr("alpha"); + double beta = 0.0; + if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); - CHECK_EQ(add_to_output->data_type(), out->data_type()); + CHECK_EQ(add_to_output->data_type(), data_type); CHECK_EQ(add_to_output->shape(), out->shape()); - Memcpy( - ctx->device_ctx(), out->mut_dptr(), add_to_output->dptr(), - add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); + auto memcpy = NewMemcpyPrimitive(ctx); + CHECK(memcpy); + memcpy->Launch(ctx->stream_ctx(), out->mut_dptr(), add_to_output->dptr(), + add_to_output->shape().elem_cnt() * GetSizeOfDataType(data_type)); + beta = 1.0; } - const double alpha = ctx->Attr("alpha"); - const double beta = has_add_to_output ? 1.0 : 0.0; - NewKernelUtil::OFGemm(ctx->device_ctx(), trans_a, trans_b, m, n, k, alpha, - a->dptr(), b->dptr(), beta, - out->mut_dptr()); + auto matmul = NewMatmulPrimitive(ctx); + CHECK(matmul); + matmul->Launch(ctx->stream_ctx(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; -#endif - -#ifdef WITH_CUDA REGISTER_USER_KERNEL("matmul") - .SetCreateFn() - .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") - & (user_op::HobDataType("a", 0) == DataType::kFloat16)) + .SetCreateFn() + .SetIsMatchedHob((MemcpyPrimitiveExists() == true) & (MatmulPrimitiveExists() == true)) .SetInplaceProposalFn([](const user_op::InferContext& ctx, - user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { + const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); -#endif -template -class BatchMatmulFloatingKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { +class BatchMatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: - BatchMatmulFloatingKernel() = default; - ~BatchMatmulFloatingKernel() = default; + BatchMatmulKernel() = default; + ~BatchMatmulKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: void Compute(user_op::KernelComputeContext* ctx) const override { - CBLAS_TRANSPOSE trans_a = ctx->Attr("transpose_a") ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE trans_b = ctx->Attr("transpose_b") ? CblasTrans : CblasNoTrans; + const auto trans_a = GetBlasTransposeType(ctx, "transpose_a"); + const auto trans_b = GetBlasTransposeType(ctx, "transpose_b"); const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); + const DataType data_type = a->data_type(); + const int64_t num_axes = a->shape().NumAxes(); + CHECK_GT(num_axes, 2); const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); + CHECK_EQ(b->data_type(), data_type); + CHECK_EQ(b->shape().NumAxes(), num_axes); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - int32_t num_axes = a->shape().NumAxes(); - CHECK_GT(num_axes, 2); - - int32_t m = 0, n = 0, k = 0; - std::tie(m, n, k) = CalcMNK(a->shape(), out->shape(), trans_a); + CHECK_EQ(out->data_type(), data_type); + CHECK_EQ(out->shape().NumAxes(), num_axes); + size_t m = 0; + size_t n = 0; + size_t k = 0; + InferMatmulMNK(a->shape(), b->shape(), out->shape(), trans_a, trans_b, &m, &n, &k); + size_t batch_size = 1; + for (size_t i = 0; i < num_axes - 2; ++i) { + const int64_t dim_size = a->shape().At(i); + CHECK_GT(dim_size, 0); + CHECK_EQ(b->shape().At(i), dim_size); + CHECK_EQ(out->shape().At(i), dim_size); + batch_size *= dim_size; + } const double alpha = ctx->Attr("alpha"); - double beta; + double beta = 0.0; if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); - CHECK_EQ(add_to_output->data_type(), out->data_type()); + CHECK_EQ(add_to_output->data_type(), data_type); CHECK_EQ(add_to_output->shape(), out->shape()); - Memcpy( - ctx->device_ctx(), out->mut_dptr(), add_to_output->dptr(), - add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); + auto memcpy = NewMemcpyPrimitive(ctx); + CHECK(memcpy); + memcpy->Launch(ctx->stream_ctx(), out->mut_dptr(), add_to_output->dptr(), + add_to_output->shape().elem_cnt() * GetSizeOfDataType(data_type)); beta = 1.0; - } else { - beta = 0.0; - } - size_t batch_size = a->shape().Count(0, num_axes - 2); - NewKernelUtil::OFBatchedGemm(ctx->device_ctx(), trans_a, trans_b, batch_size, m, n, - k, alpha, a->dptr(), b->dptr(), beta, - out->mut_dptr()); - } -}; - -#define REGISTER_BATCH_MATMUL_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("batch_matmul") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("a", 0) == GetDataType::value)) \ - .SetInplaceProposalFn([](const user_op::InferContext& ctx, \ - user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ - if (ctx.has_input("_add_to_output", 0)) { \ - OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); \ - } \ - return Maybe::Ok(); \ - }); - -REGISTER_BATCH_MATMUL_KERNEL(DeviceType::kCPU, float); -REGISTER_BATCH_MATMUL_KERNEL(DeviceType::kCPU, double); -#ifdef WITH_CUDA -REGISTER_BATCH_MATMUL_KERNEL(DeviceType::kGPU, float); -REGISTER_BATCH_MATMUL_KERNEL(DeviceType::kGPU, double); -#endif - -#ifdef WITH_CUDA -class BatchMatmulGpuHalfKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { - public: - BatchMatmulGpuHalfKernel() = default; - ~BatchMatmulGpuHalfKernel() = default; - - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - CBLAS_TRANSPOSE trans_a = ctx->Attr("transpose_a") ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE trans_b = ctx->Attr("transpose_b") ? CblasTrans : CblasNoTrans; - const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("a", 0); - const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex("b", 0); - user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - int32_t num_axes = a->shape().NumAxes(); - CHECK_GT(num_axes, 2); - - int32_t m = 0, n = 0, k = 0; - std::tie(m, n, k) = CalcMNK(a->shape(), out->shape(), trans_a); - bool has_add_to_output = ctx->has_input("_add_to_output", 0); - if (has_add_to_output) { - const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); - CHECK_EQ(add_to_output->data_type(), out->data_type()); - CHECK_EQ(add_to_output->shape(), out->shape()); - Memcpy( - ctx->device_ctx(), out->mut_dptr(), add_to_output->dptr(), - add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); } - size_t batch_size = a->shape().Count(0, num_axes - 2); - const double alpha = ctx->Attr("alpha"); - const double beta = has_add_to_output ? 1.0 : 0.0; - NewKernelUtil::OFBatchedGemm( - ctx->device_ctx(), trans_a, trans_b, batch_size, m, n, k, alpha, a->dptr(), - b->dptr(), beta, out->mut_dptr()); + auto batch_matmul = NewBatchMatmulPrimitive(ctx); + CHECK(batch_matmul); + batch_matmul->Launch(ctx->stream_ctx(), batch_size, m, n, k, alpha, a->dptr(), b->dptr(), beta, + out->mut_dptr()); } }; REGISTER_USER_KERNEL("batch_matmul") - .SetCreateFn() - .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") - & (user_op::HobDataType("a", 0) == DataType::kFloat16)) + .SetCreateFn() + .SetIsMatchedHob((MemcpyPrimitiveExists() == true) & (BatchMatmulPrimitiveExists() == true)) .SetInplaceProposalFn([](const user_op::InferContext& ctx, - user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { + const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { if (ctx.has_input("_add_to_output", 0)) { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); } return Maybe::Ok(); }); -#endif - -template +// TODO(liujuncheng): fully support class BroadcastMatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: BroadcastMatmulKernel() = default; @@ -277,8 +253,10 @@ class BroadcastMatmulKernel final : public user_op::OpKernel, public user_op::Cu if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->shape(), out->shape()); - Memcpy( - ctx->device_ctx(), out->mut_dptr(), add_to_output->dptr(), + auto memcpy = NewMemcpyPrimitive(ctx); + CHECK(memcpy); + memcpy->Launch( + ctx->stream_ctx(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = 1.0; } @@ -295,15 +273,29 @@ class BroadcastMatmulKernel final : public user_op::OpKernel, public user_op::Cu n = b->shape().At(0); CHECK_EQ(k, b->shape().At(1)); } - - CBLAS_TRANSPOSE trans_a = transpose_a ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE trans_b = transpose_b ? CblasTrans : CblasNoTrans; - NewKernelUtil::OFGemm(ctx->device_ctx(), trans_a, trans_b, m, n, k, alpha, - a->dptr(), b->dptr(), beta, out->mut_dptr()); + auto matmul = NewMatmulPrimitive(ctx); + CHECK(matmul); + matmul->Launch(ctx->stream_ctx(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; -template +REGISTER_USER_KERNEL("broadcast_matmul") + .SetCreateFn() + .SetIsMatchedHob((MemcpyPrimitiveExists() == true) & (MatmulPrimitiveExists() == true)) + .SetInplaceProposalFn([](const user_op::InferContext& ctx, + const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { + if (ctx.has_input("_add_to_output", 0)) { + OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); + } + return Maybe::Ok(); + }); + +template +std::unique_ptr NewMatmulPrimitiveForBroadcastMatmulGradB(Context* ctx) { + const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("out", 0)->data_type(); + return NewMatmulPrimitive(ctx->device_type(), data_type, true, false); +} + class BroadcastMatmulGradBKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: @@ -323,8 +315,10 @@ class BroadcastMatmulGradBKernel final : public user_op::OpKernel, if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); CHECK_EQ(add_to_output->shape(), out->shape()); - Memcpy( - ctx->device_ctx(), out->mut_dptr(), add_to_output->dptr(), + auto memcpy = NewMemcpyPrimitive(ctx); + CHECK(memcpy); + memcpy->Launch( + ctx->stream_ctx(), out->mut_dptr(), add_to_output->dptr(), add_to_output->shape().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); beta = 1.0; } @@ -335,51 +329,31 @@ class BroadcastMatmulGradBKernel final : public user_op::OpKernel, int64_t m = a->shape().At(a->shape().NumAxes() - 1); int64_t n = b->shape().At(b->shape().NumAxes() - 1); - NewKernelUtil::OFGemm(ctx->device_ctx(), CblasTrans, CblasNoTrans, m, n, k, alpha, - a->dptr(), b->dptr(), beta, out->mut_dptr()); + auto matmul = NewMatmulPrimitiveForBroadcastMatmulGradB(ctx); + CHECK(matmul); + matmul->Launch(ctx->stream_ctx(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr()); } }; -#define REGISTER_BROADCAST_MATMUL_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("broadcast_matmul") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("a", 0) == GetDataType::value)) \ - .SetInplaceProposalFn([](const user_op::InferContext& ctx, \ - user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ - if (ctx.has_input("_add_to_output", 0)) { \ - OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); \ - } \ - return Maybe::Ok(); \ - }) - -REGISTER_BROADCAST_MATMUL_KERNEL(DeviceType::kCPU, float); -REGISTER_BROADCAST_MATMUL_KERNEL(DeviceType::kCPU, double); -#ifdef WITH_CUDA -REGISTER_BROADCAST_MATMUL_KERNEL(DeviceType::kGPU, float); -REGISTER_BROADCAST_MATMUL_KERNEL(DeviceType::kGPU, double); -REGISTER_BROADCAST_MATMUL_KERNEL(DeviceType::kGPU, float16); -#endif - -#define REGISTER_BROADCAST_MATMUL_GRAD_B_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("broadcast_matmul_grad_b") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("a", 0) == GetDataType::value)) \ - .SetInplaceProposalFn([](const user_op::InferContext& ctx, \ - user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ - if (ctx.has_input("_add_to_output", 0)) { \ - OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); \ - } \ - return Maybe::Ok(); \ - }) - -REGISTER_BROADCAST_MATMUL_GRAD_B_KERNEL(DeviceType::kCPU, float); -REGISTER_BROADCAST_MATMUL_GRAD_B_KERNEL(DeviceType::kCPU, double); -#ifdef WITH_CUDA -REGISTER_BROADCAST_MATMUL_GRAD_B_KERNEL(DeviceType::kGPU, float); -REGISTER_BROADCAST_MATMUL_GRAD_B_KERNEL(DeviceType::kGPU, double); -REGISTER_BROADCAST_MATMUL_GRAD_B_KERNEL(DeviceType::kGPU, float16); -#endif +hob::HobContextGetter PrimitiveExistsForBroadcastMatmulGradB() { + return user_op::HobCtxGetter( + "MatmulPrimitiveExists", [](const user_op::KernelRegContext& ctx) { + return NewMatmulPrimitiveForBroadcastMatmulGradB(&ctx).operator bool(); + }); +} + +REGISTER_USER_KERNEL("broadcast_matmul_grad_b") + .SetCreateFn() + .SetIsMatchedHob((MemcpyPrimitiveExists() == true) + & (PrimitiveExistsForBroadcastMatmulGradB() == true)) + .SetInplaceProposalFn([](const user_op::InferContext& ctx, + const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { + if (ctx.has_input("_add_to_output", 0)) { + OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "_add_to_output", 0, true)); + } + return Maybe::Ok(); + }); + +} // namespace } // namespace oneflow -- GitLab