// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/runtime/cpu/cblas.h" #include #include "paddle/cinn/backends/extern_func_jit_register.h" #include "paddle/cinn/common/cas.h" namespace { inline CBLAS_TRANSPOSE ToCblasTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } } // namespace void cinn_cpu_mkl_gemm_fp32(float alpha, int M, int N, int K, bool ta, bool tb, int lda, int ldb, int ldc, float beta, cinn_buffer_t* A, cinn_buffer_t* B, cinn_buffer_t* C) { cblas_sgemm(CblasRowMajor, ToCblasTranspose(ta), ToCblasTranspose(tb), M, N, K, alpha, reinterpret_cast(A->memory), lda, reinterpret_cast(B->memory), ldb, beta, reinterpret_cast(C->memory), ldc); } void cinn_cpu_mkl_gemm_batch_fp32(float alpha, int batch_size, int M, int N, int K, bool ta, bool tb, int lda, int ldb, int ldc, int a_stride, int b_stride, int c_stride, float beta, cinn_buffer_t* A, cinn_buffer_t* B, cinn_buffer_t* C) { std::vector A_array(batch_size); std::vector B_array(batch_size); std::vector C_array(batch_size); for (int i = 0; i < batch_size; ++i) { A_array[i] = reinterpret_cast(A->memory) + i * a_stride; B_array[i] = reinterpret_cast(B->memory) + i * b_stride; C_array[i] = reinterpret_cast(C->memory) + i * c_stride; } CBLAS_TRANSPOSE trans_a = ToCblasTranspose(ta); CBLAS_TRANSPOSE trans_b = ToCblasTranspose(tb); cblas_sgemm_batch(CblasRowMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); } /** * This function is temporarily unavailable, see the error message in the following PR for details. * The specific reason may be that the custom call does not support host op. * See: https://github.com/PaddlePaddle/CINN/pull/1133 */ void cinn_call_cholesky_host(void* v_args, int num_args, int batch_size, int m, bool upper) { #ifdef CINN_WITH_MKL_CBLAS cinn_pod_value_t* args = static_cast(v_args); cinn_buffer_t* x = args[0].operator cinn_buffer_t*(); cinn_buffer_t* out = args[1].operator cinn_buffer_t*(); memcpy(out->memory, x->memory, x->memory_size); uint8_t bits = x->type.bits; CHECK(bits == 32 || bits == 64) << "Unsupported bits = " << bits << " float data type for cholesky"; char uplo = upper ? 'U' : 'L'; for (int i = 0; i < batch_size; i++) { if (bits == 32) { float* matrix = reinterpret_cast(out->memory) + i * m * m; LAPACKE_spotrf(LAPACK_ROW_MAJOR, uplo, m, matrix, m); } else if (bits == 64) { double* matrix = reinterpret_cast(out->memory) + i * m * m; LAPACKE_dpotrf(LAPACK_ROW_MAJOR, uplo, m, matrix, m); } } #else CINN_NOT_IMPLEMENTED #endif } CINN_REGISTER_HELPER(cinn_cpu_mkl) { using namespace cinn; // NOLINT using backends::FunctionProto; auto host_target = common::DefaultHostTarget(); FunctionProto::shape_inference_t inference_shape_gemm = [](const std::vector& args, int offset) { CHECK_EQ(offset, 0UL) << "Only one output"; CHECK_EQ(args.size(), 12UL) << "Wrong number of arguments passed in"; auto M = common::AutoSimplify(args[1]); auto N = common::AutoSimplify(args[2]); std::vector shape; shape.push_back(M); shape.push_back(N); return shape; }; FunctionProto::shape_inference_t inference_shape_gemm_batch = [](const std::vector& args, int offset) { CHECK_EQ(offset, 0UL) << "Only one output"; CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; auto& A = args[14]; auto A_tensor = A.as_tensor(); CHECK(A_tensor); auto batch_size = common::AutoSimplify(args[1]); int32_t batch_size_val = batch_size.as_int32(); auto M = common::AutoSimplify(args[2]); auto N = common::AutoSimplify(args[3]); std::vector shape; int total = 1; for (auto& v : A_tensor->shape) { auto val = common::AutoSimplify(v); CHECK(val.is_constant()); shape.push_back(val); total *= val.as_int32(); if (total >= batch_size_val) break; } shape.push_back(M); shape.push_back(N); return shape; }; REGISTER_EXTERN_FUNC_HELPER(cinn_cpu_mkl_gemm_fp32, host_target) .SetRetType() .AddInputType() // alpha .AddInputType() // M .AddInputType() // N .AddInputType() // K .AddInputType() // ta .AddInputType() // tb .AddInputType() // lda .AddInputType() // ldb .AddInputType() // ldc .AddInputType() // beta .AddInputType() // A .AddInputType() // B .AddOutputType() // C .SetShapeInference(inference_shape_gemm) .End(); REGISTER_EXTERN_FUNC_HELPER(cinn_cpu_mkl_gemm_batch_fp32, host_target) .SetRetType() .AddInputType() // alpha .AddInputType() // batch .AddInputType() // M .AddInputType() // N .AddInputType() // K .AddInputType() // ta .AddInputType() // tb .AddInputType() // lda .AddInputType() // ldb .AddInputType() // ldc .AddInputType() // a_stride .AddInputType() // b_stride .AddInputType() // c_stride .AddInputType() // beta .AddInputType() // A .AddInputType() // B .AddOutputType() // C .SetShapeInference(inference_shape_gemm_batch) .End(); REGISTER_EXTERN_FUNC_HELPER(cinn_call_cholesky_host, host_target) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args .AddInputType() // batch_size .AddInputType() // m .AddInputType() // upper .End(); return true; }