提交 1803461c 编写于 作者: J Johannes Reifferscheid 提交者: TensorFlower Gardener

NFC: Extract launch dimensions computation from TritonWrapper.

This is another prefactoring to make Triton fusions compatible with
FusionInterface and partially fused HLO. For the former, we need
the LaunchDimension computation to be a separate function. For the
latter, we change the launch dimension function signatures to no longer
take an HloComputation, because we don't yet have that during fusion (at
least not a complete one). For now, this change is a no-op, since we do
not yet have any boundary functions for non-fusion ops.

PiperOrigin-RevId: 565015567
上级 25ac48b7
......@@ -473,6 +473,7 @@ cc_library(
deps = [
":gemm_rewriter_triton",
":gpu_device_info",
":hlo_traversal",
":ir_emission_utils",
":launch_dimensions",
":matmul_utils",
......
......@@ -58,6 +58,9 @@ class HloFusionAnalysis {
const std::vector<HloInstruction*>& fusion_roots() const {
return fusion_roots_;
}
const FusionBoundaryFn& fusion_boundary() const {
return fusion_boundary_fn_;
}
// Determines the fusion type for the emitter.
EmitterFusionKind GetEmitterFusionKind() const;
......
......@@ -103,16 +103,21 @@ void FindFusionParameters(
[&](const HloInstruction&) { return TraversalResult::kVisitOperands; });
}
bool HloAnyOf(
bool HloAnyOf(absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& boundary,
const std::function<bool(const HloInstruction& node)>& visit) {
return HloFindIf(roots, boundary, visit) != nullptr;
}
const HloInstruction* HloFindIf(
absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer,
const HloInstruction& consumer)>& boundary,
const FusionBoundaryFn& boundary,
const std::function<bool(const HloInstruction& node)>& visit) {
bool result = false;
const HloInstruction* result = nullptr;
HloBfsConsumersFirstTraversal(roots, boundary,
[&](const HloInstruction& node) {
if (visit(node)) {
result = true;
result = &node;
return TraversalResult::kAbortTraversal;
}
return TraversalResult::kVisitOperands;
......
......@@ -46,16 +46,22 @@ bool DefaultFusionBoundaryFn(const HloInstruction& producer,
// traversed along edges for which `boundary` returns true.
void HloBfsConsumersFirstTraversal(
absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer,
const HloInstruction& consumer)>& boundary,
const FusionBoundaryFn& boundary,
const std::function<TraversalResult(const HloInstruction& node)>& visit);
// Visit the HLO nodes starting from `roots`, returning true if the return value
// of `visit` for any of the ones is true.
bool HloAnyOf(
// of `visit` for any of nodes is true. Uses the same order as
// `HloBfsConsumersFirstTraversal`.
bool HloAnyOf(absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& boundary,
const std::function<bool(const HloInstruction& node)>& visit);
// Visit the HLO nodes stating from `roots`, returning the first
// node for which `visit` returns true, or `nullptr` if no node matches. Uses
// the same order as `HloBfsConsumersFirstTraversal`.
const HloInstruction* HloFindIf(
absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer,
const HloInstruction& consumer)>& boundary,
const FusionBoundaryFn& boundary,
const std::function<bool(const HloInstruction& node)>& visit);
// Visit the producers of all parameters that are needed by the fusion.
......
......@@ -183,6 +183,28 @@ TEST_F(HloTraversalTest, FuseConsumer) {
"mul", "p0.1", "p1.1"));
}
TEST_F(HloTraversalTest, FindIf) {
auto module = ParseAndReturnVerifiedModule(kTestModule).value();
std::vector<std::string> visited_nodes;
auto* result = HloFindIf(
{module->GetComputationWithName("fused_computation")->root_instruction()},
DefaultFusionBoundaryFn, [&](const HloInstruction& node) {
return node.opcode() == HloOpcode::kMultiply;
});
ASSERT_NE(result, nullptr);
ASSERT_EQ(result->name(), "mul");
}
TEST_F(HloTraversalTest, NotFound) {
auto module = ParseAndReturnVerifiedModule(kTestModule).value();
std::vector<std::string> visited_nodes;
auto* result = HloFindIf(
{module->GetComputationWithName("fused_computation")->root_instruction()},
DefaultFusionBoundaryFn,
[&](const HloInstruction& node) { return false; });
ASSERT_EQ(result, nullptr);
}
} // namespace
} // namespace gpu
} // namespace xla
......@@ -91,6 +91,7 @@ limitations under the License.
#include "xla/service/dump.h"
#include "xla/service/gpu/gemm_rewriter_triton.h"
#include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
......@@ -948,12 +949,17 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
} // namespace
LaunchDimensions GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation,
const TritonFusionAnalysis& analysis,
absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config) {
const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>(
hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot));
const MatMulDims dims(config, *dot_instr, analysis);
const MatMulLaunchConfig launch_config(config, *dot_instr, dims);
const auto* dot = static_cast<const HloDotInstruction*>(
HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) {
return node.opcode() == HloOpcode::kDot;
}));
CHECK_NE(dot, nullptr);
const MatMulDims dims(config, *dot, analysis);
const MatMulLaunchConfig launch_config(config, *dot, dims);
return launch_config.launch_dims;
}
......@@ -1354,10 +1360,13 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
}
LaunchDimensions GetSoftMaxLaunchDimensions(
const TritonFusionAnalysis&, const HloComputation* computation,
absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config) {
const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode(
*computation, HloOpcode::kReduce);
const HloInstruction* reduce =
HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) {
return node.opcode() == HloOpcode::kReduce;
});
CHECK_NE(reduce, nullptr);
const Shape& reduce_input_shape = reduce->operand(0)->shape();
int num_rows = 1;
......@@ -1482,12 +1491,11 @@ StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
}
StatusOr<TritonWrapperResult> TritonWrapper(
absl::string_view fn_name, const HloComputation* hlo_computation,
absl::string_view fusion_kind, const se::CudaComputeCapability& cc,
const GpuDeviceInfo& device_info,
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, absl::string_view fusion_kind,
const se::CudaComputeCapability& cc, const GpuDeviceInfo& device_info,
const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
LaunchDimensionsGenerator launch_dims_generator, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context) {
TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) {
if (fusion_kind == kTritonGemmFusionKind) {
// This is a heuristic that serves as a proxy for register usage and code
// size.
......@@ -1561,11 +1569,6 @@ StatusOr<TritonWrapperResult> TritonWrapper(
.debug_options()
.xla_gpu_cuda_data_dir());
TF_ASSIGN_OR_RETURN(
auto analysis,
fusion_kind == kTritonGemmFusionKind
? TritonFusionAnalysis::Execute(*hlo_computation, config.split_k())
: TritonFusionAnalysis::Execute(*hlo_computation));
TF_RETURN_IF_ERROR(ir_emitter(b, libdevice_path, analysis, hlo_computation,
fn, config,
device_info.shared_memory_per_block_optin));
......@@ -1658,9 +1661,7 @@ StatusOr<TritonWrapperResult> TritonWrapper(
llvm::Linker::Flags::OverrideFromSrc));
LogAndVerify(llvm_module);
LaunchDimensions launch_dimensions =
launch_dims_generator(analysis, hlo_computation, config);
return {{launch_dimensions, shared_mem_bytes}};
return {{shared_mem_bytes}};
}
} // namespace gpu
......
......@@ -24,6 +24,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/gemm_rewriter_triton.h"
#include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/statusor.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
......@@ -32,13 +33,14 @@ namespace xla {
namespace gpu {
struct TritonWrapperResult {
LaunchDimensions launch_dimensions;
int64_t shmem_bytes;
};
// Compute the launch dimensions for the given Triton MatMul.
LaunchDimensions GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation,
const TritonFusionAnalysis& analysis,
absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config);
// Use tiling and execution parameters from 'config'.
Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path,
......@@ -49,7 +51,8 @@ Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path,
// Compute the launch dimensions for the given Triton SoftMax.
LaunchDimensions GetSoftMaxLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation,
absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config);
// Generate Softmax in Triton IR inside 'fn'.
// Use execution parameters from 'config'.
......@@ -59,24 +62,20 @@ Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget);
using LaunchDimensionsGenerator = std::function<LaunchDimensions(
const TritonFusionAnalysis&, const HloComputation*,
const AutotuneResult::TritonGemmKey&)>;
using TritonIrEmitter = std::function<Status(
mlir::OpBuilder, absl::string_view, const TritonFusionAnalysis& analysis,
const HloComputation*, mlir::triton::FuncOp,
const AutotuneResult::TritonGemmKey&, int)>;
// Generate Triton IR by running the provided generator, compile it into LLVM IR
// and return launch dimensions.
// Generate Triton IR by running the provided generator and compile it into LLVM
// IR.
// MatMul and SoftMax above are some such IR generators.
StatusOr<TritonWrapperResult> TritonWrapper(
absl::string_view fn_name, const HloComputation* hlo_computation,
absl::string_view fusion_kind, const se::CudaComputeCapability& cc,
const GpuDeviceInfo& device_info,
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, absl::string_view fusion_kind,
const se::CudaComputeCapability& cc, const GpuDeviceInfo& device_info,
const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
LaunchDimensionsGenerator launch_dims_generator, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context);
TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context);
} // namespace gpu
} // namespace xla
......
......@@ -211,11 +211,11 @@ ENTRY entry {
config.set_num_stages(4);
config.set_num_warps(8);
EXPECT_THAT(
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context),
dev_info, config, &llvm_module, &EmitMatMul, mlir_context),
tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED,
"Shared memory size limit exceeded."));
......@@ -225,11 +225,11 @@ ENTRY entry {
config.set_num_stages(1);
TF_ASSERT_OK_AND_ASSIGN(
const auto result,
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context));
dev_info, config, &llvm_module, &EmitMatMul, mlir_context));
// Use optin shared memory which is > shared_memory_per_block.
EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block);
}
......@@ -639,11 +639,11 @@ ENTRY entry {
config.set_num_stages(1);
config.set_num_warps(2);
EXPECT_THAT(
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context),
dev_info, config, &llvm_module, &EmitMatMul, mlir_context),
tsl::testing::StatusIs(
tsl::error::RESOURCE_EXHAUSTED,
"Tiling complexity heuristic exceeded: 147456 > 9000"));
......@@ -653,11 +653,11 @@ ENTRY entry {
config.set_block_n(32);
config.set_block_k(32);
TF_CHECK_OK(
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context)
dev_info, config, &llvm_module, &EmitMatMul, mlir_context)
.status());
}
......@@ -1438,10 +1438,11 @@ ENTRY e {
->backend_config<FusionBackendConfig>());
TF_ASSERT_OK_AND_ASSIGN(
const auto result,
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
GetCudaComputeCapability(), dev_info,
config.triton_gemm_config(), &llvm_module,
&GetMatMulLaunchDimensions, &EmitMatMul, mlir_context));
config.triton_gemm_config(), &llvm_module, &EmitMatMul,
mlir_context));
// The config is chosen so that the used memory size is slightly above the
// 48 kB boundary of standard / optin shared memory so that any GPU that
// has the optin one should be able to execute the test.
......
......@@ -1701,7 +1701,8 @@ static Status ProcessFusionForConversion(mlir::Region* region,
#if GOOGLE_CUDA
Status IrEmitterUnnested::EmitTritonFusion(
mlir::Operation* op, const AutotuneResult::TritonGemmKey& config,
const HloFusionAnalysis& hlo_fusion_analysis, mlir::Operation* op,
const AutotuneResult::TritonGemmKey& config,
const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
hlo_for_lmhlo) {
// Note: In this method we can't use `BuildKernelThunk` as usual,
......@@ -1740,23 +1741,35 @@ Status IrEmitterUnnested::EmitTritonFusion(
absl::string_view fusion_kind = backend_config.kind();
TritonWrapperResult triton_wrapper_result;
LaunchDimensions launch_dimensions;
if (fusion_kind == kTritonSoftmaxFusionKind) {
TF_ASSIGN_OR_RETURN(auto analysis,
TritonFusionAnalysis::Execute(*hlo_computation));
TF_ASSIGN_OR_RETURN(
triton_wrapper_result,
TritonWrapper(impl_fn_name, hlo_computation, kTritonSoftmaxFusionKind,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonSoftmaxFusionKind,
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_device_info(), config, module_,
&GetSoftMaxLaunchDimensions, &EmitSoftMax,
*ir_emitter_context_->mlir_context()));
&EmitSoftMax, *ir_emitter_context_->mlir_context()));
launch_dimensions = GetSoftMaxLaunchDimensions(
hlo_fusion_analysis.fusion_roots(),
hlo_fusion_analysis.fusion_boundary(), config);
} else { // Must be a MatMul
CHECK_EQ(fusion_kind, kTritonGemmFusionKind);
TF_ASSIGN_OR_RETURN(
auto analysis,
TritonFusionAnalysis::Execute(*hlo_computation, config.split_k()));
TF_ASSIGN_OR_RETURN(
triton_wrapper_result,
TritonWrapper(impl_fn_name, hlo_computation, kTritonGemmFusionKind,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonGemmFusionKind,
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_device_info(), config, module_,
&GetMatMulLaunchDimensions, &EmitMatMul,
*ir_emitter_context_->mlir_context()));
&EmitMatMul, *ir_emitter_context_->mlir_context()));
launch_dimensions = GetMatMulLaunchDimensions(
analysis, hlo_fusion_analysis.fusion_roots(),
hlo_fusion_analysis.fusion_boundary(), config);
}
llvm::Function* impl_fn = module_->getFunction(impl_fn_name);
......@@ -1764,7 +1777,7 @@ Status IrEmitterUnnested::EmitTritonFusion(
auto [kernel, inputs, outputs] = BuildKernelPrototype(
*ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(),
impl_fn->arg_size(), triton_wrapper_result.launch_dimensions, &b_);
impl_fn->arg_size(), launch_dimensions, &b_);
// Move function body into kernel prototype.
llvm::Function* prototype_func = b_.GetInsertBlock()->getParent();
......@@ -1775,7 +1788,7 @@ Status IrEmitterUnnested::EmitTritonFusion(
impl_fn->eraseFromParent();
LogAndVerify(module_);
return {{kernel->getName().str(), triton_wrapper_result.launch_dimensions,
return {{kernel->getName().str(), launch_dimensions,
triton_wrapper_result.shmem_bytes}};
};
......@@ -1850,7 +1863,8 @@ Status IrEmitterUnnested::EmitFusion(
triton_config.set_num_stages(1);
triton_config.set_num_warps(2);
}
return EmitTritonFusion(fusion_op, backend_config.triton_gemm_config(),
return EmitTritonFusion(fusion_analysis, fusion_op,
backend_config.triton_gemm_config(),
hlo_for_lmhlo);
}
if (backend_config.kind() == kTritonSoftmaxFusionKind) {
......@@ -1858,7 +1872,8 @@ Status IrEmitterUnnested::EmitFusion(
triton_config.set_num_stages(1);
triton_config.set_num_warps(
DeriveNumWarpsFromTritonSoftmaxComputation(fused_computation));
return EmitTritonFusion(fusion_op, backend_config.triton_gemm_config(),
return EmitTritonFusion(fusion_analysis, fusion_op,
backend_config.triton_gemm_config(),
hlo_for_lmhlo);
}
#endif
......
......@@ -136,7 +136,8 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitCublasLtMatmulThunkF8(mlir::Operation* op);
Status EmitConvolutionReorderThunk(mlir::Operation* op);
Status EmitTritonFusion(
mlir::Operation* op, const AutotuneResult::TritonGemmKey& config,
const HloFusionAnalysis& hlo_fusion_analysis, mlir::Operation* op,
const AutotuneResult::TritonGemmKey& config,
const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
hlo_for_lmhlo);
Status EmitFusedMHAThunk(mlir::Operation* op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册