提交 d58a8dfc 编写于 作者: V Victor Stone 提交者: TensorFlower Gardener

[XLA:GPU] Check that the types used in the matmul are supported by cublasLt....

[XLA:GPU] Check that the types used in the matmul are supported by cublasLt. If they are not supported, we fall back to legacy cublas.

This fixes an issue found by some JAX dot_general tests which depended on device (p100/v100). There are certain combinations of types which work on some devices but not on others. However, according to the official cublasLt documentation, these combinations of types are unsupported by cublasLt on all devices.

PiperOrigin-RevId: 481251576
上级 0549fd10
......@@ -830,6 +830,9 @@ cc_library(
":cublas_cudnn",
":ir_emission_utils",
":matmul_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"//tensorflow/compiler/xla/stream_executor:blas",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
......@@ -840,9 +843,9 @@ cc_library(
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/stream_executor/lib",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
]),
)
cc_library(
......
......@@ -15,8 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
#include <array>
#include <memory>
#include <numeric>
#include <tuple>
#include <utility>
#include <vector>
......@@ -36,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/stream_executor/blas.h"
#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
......@@ -416,6 +419,88 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return absl::string_view(kGemmCallTarget);
}
StatusOr<bool> TypesAreSupportedByCublasLt(
const HloInstruction *instr) const {
// cublasLt has a defined set of combinations of types that it supports.
// Figure out the computeType and scaleType.
TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
AsBlasDataType(instr->shape().element_type()));
TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
GetBlasComputationType(instr->shape().element_type()));
const se::blas::DataType scale_type =
se::cuda::BlasLt::GetScaleType(output_dtype, compute_type);
// Figure out the Atype/Btype.
const PrimitiveType a_dtype = instr->operand(0)->shape().element_type();
const PrimitiveType b_dtype = instr->operand(1)->shape().element_type();
if (a_dtype != b_dtype) {
// AType must match BType.
return false;
}
using se::blas::ComputationType;
using se::blas::DataType;
// This matrix of supported types is taken directly from cublasLt
// documentation.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
const std::array<
std::tuple<ComputationType, DataType /*scale_type*/,
PrimitiveType /*a_dtype*/, DataType /*output_dtype*/>,
18>
supported_type_combinations = {{
{ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
DataType::kHalf},
{ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
DataType::kInt32},
{ComputationType::kI32, DataType::kFloat, PrimitiveType::S8,
DataType::kInt8},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
DataType::kBF16},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
DataType::kFloat},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
DataType::kFloat},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
DataType::kFloat},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
DataType::kFloat},
// There would be an entry here for A/BType complex int8, but we do
// not support that type.
{ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
DataType::kComplexFloat},
{ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
DataType::kFloat},
{ComputationType::kF16AsF32, DataType::kComplexFloat,
PrimitiveType::C64, DataType::kComplexFloat},
{ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
DataType::kFloat},
{ComputationType::kBF16AsF32, DataType::kComplexFloat,
PrimitiveType::C64, DataType::kComplexFloat},
{ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
DataType::kFloat},
{ComputationType::kTF32AsF32, DataType::kComplexFloat,
PrimitiveType::C64, DataType::kComplexFloat},
{ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
DataType::kDouble},
{ComputationType::kF64, DataType::kComplexDouble,
PrimitiveType::C128, DataType::kComplexDouble},
}};
return absl::c_linear_search(
supported_type_combinations,
std::make_tuple(compute_type, scale_type, a_dtype, output_dtype));
}
StatusOr<bool> GemmIsSupportedByCublasLt(
const HloInstruction *instr,
const GemmBackendConfig &gemm_backend_config) const {
......@@ -423,6 +508,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const HloInstruction *rhs = instr->operand(1);
const Shape &output_shape = instr->shape();
TF_ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt,
TypesAreSupportedByCublasLt(instr));
if (!types_are_supported_by_cublas_lt) {
return false;
}
// The cublasLt API has two currently known limitations:
// 1. Batch count must be <2^16.
constexpr int64_t kMaxBatchCount = 65535;
......
......@@ -366,22 +366,6 @@ StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
op.getBeta().convertToDouble(), algorithm, compute_precision);
}
namespace {
// BLAS GeMM's output is column-major. If we require row-major, use identity:
// C^T = (A @ B)^T = B^T @ A^T.
bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs,
MatrixLayout& output) {
bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor;
if (swap_operands) {
std::swap(lhs, rhs);
lhs.Transpose();
rhs.Transpose();
output.Transpose();
}
return swap_operands;
}
StatusOr<se::blas::ComputationType> GetBlasComputationType(
PrimitiveType dtype) {
switch (dtype) {
......@@ -402,6 +386,22 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
}
}
namespace {
// BLAS GeMM's output is column-major. If we require row-major, use identity:
// C^T = (A @ B)^T = B^T @ A^T.
bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs,
MatrixLayout& output) {
bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor;
if (swap_operands) {
std::swap(lhs, rhs);
lhs.Transpose();
rhs.Transpose();
output.Transpose();
}
return swap_operands;
}
se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) {
// BLAS is column-major by default.
return (order == MatrixLayout::Order::kColumnMajor)
......@@ -560,8 +560,6 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
#if GOOGLE_CUDA
namespace {
StatusOr<se::blas::DataType> AsBlasDataType(PrimitiveType dtype) {
switch (dtype) {
case F16:
......@@ -581,6 +579,8 @@ StatusOr<se::blas::DataType> AsBlasDataType(PrimitiveType dtype) {
}
}
namespace {
StatusOr<se::cuda::BlasLt::MatrixLayout> AsBlasLtMatrixLayout(
const MatrixLayout& layout) {
TF_ASSIGN_OR_RETURN(se::blas::DataType dtype, AsBlasDataType(layout.dtype));
......
......@@ -107,6 +107,8 @@ struct GemmConfig {
int64_t compute_precision;
};
StatusOr<se::blas::ComputationType> GetBlasComputationType(PrimitiveType dtype);
// Run the given GEMM instruction `gemm` subject to the configuration
// in `gemm_config` and the passed buffers.
//
......@@ -119,6 +121,8 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
#if GOOGLE_CUDA
StatusOr<se::blas::DataType> AsBlasDataType(PrimitiveType dtype);
namespace cublas_lt {
StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
......
......@@ -990,6 +990,121 @@ ENTRY int8gemm {
}
}
TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
Arg_0.1 = bf16[4,3]{1,0} parameter(0)
Arg_1.2 = bf16[3,6]{1,0} parameter(1)
ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);
// This is a type combination which is not supported by cublasLt, expect
// GemmRewriter to choose legacy cublas.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall("__cublas$gemm")));
}
TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
Arg_0.1 = c64[4,3]{1,0} parameter(0)
Arg_1.2 = c64[3,6]{1,0} parameter(1)
ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);
// This is a type combination which is not supported by cublasLt, expect
// GemmRewriter to choose legacy cublas.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall("__cublas$gemm")));
}
TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
Arg_0.1 = f16[4,3]{1,0} parameter(0)
Arg_1.2 = f16[3,6]{1,0} parameter(1)
ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);
// This is a type combination which is not supported by cublasLt, expect
// GemmRewriter to choose legacy cublas.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall("__cublas$gemm")));
}
TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
Arg_0.1 = f16[4,3]{1,0} parameter(0)
Arg_1.2 = f16[3,6]{1,0} parameter(1)
ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);
// This is a type combination which is not supported by cublasLt, expect
// GemmRewriter to choose legacy cublas.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall("__cublas$gemm")));
}
TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
Arg_0.1 = f32[4,3]{1,0} parameter(0)
Arg_1.2 = f32[3,6]{1,0} parameter(1)
ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(GetCudaComputeCapability());
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);
// This is a type combination which is not supported by cublasLt, expect
// GemmRewriter to choose legacy cublas.
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall("__cublas$gemm")));
}
INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
ParameterizedGemmRewriteTest, ::testing::Bool());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册