提交 1803461c 编写于 作者: J Johannes Reifferscheid 提交者: TensorFlower Gardener

NFC: Extract launch dimensions computation from TritonWrapper.

This is another prefactoring to make Triton fusions compatible with
FusionInterface and partially fused HLO. For the former, we need
the LaunchDimension computation to be a separate function. For the
latter, we change the launch dimension function signatures to no longer
take an HloComputation, because we don't yet have that during fusion (at
least not a complete one). For now, this change is a no-op, since we do
not yet have any boundary functions for non-fusion ops.

PiperOrigin-RevId: 565015567
上级 25ac48b7
...@@ -473,6 +473,7 @@ cc_library( ...@@ -473,6 +473,7 @@ cc_library(
deps = [ deps = [
":gemm_rewriter_triton", ":gemm_rewriter_triton",
":gpu_device_info", ":gpu_device_info",
":hlo_traversal",
":ir_emission_utils", ":ir_emission_utils",
":launch_dimensions", ":launch_dimensions",
":matmul_utils", ":matmul_utils",
......
...@@ -58,6 +58,9 @@ class HloFusionAnalysis { ...@@ -58,6 +58,9 @@ class HloFusionAnalysis {
const std::vector<HloInstruction*>& fusion_roots() const { const std::vector<HloInstruction*>& fusion_roots() const {
return fusion_roots_; return fusion_roots_;
} }
const FusionBoundaryFn& fusion_boundary() const {
return fusion_boundary_fn_;
}
// Determines the fusion type for the emitter. // Determines the fusion type for the emitter.
EmitterFusionKind GetEmitterFusionKind() const; EmitterFusionKind GetEmitterFusionKind() const;
......
...@@ -103,16 +103,21 @@ void FindFusionParameters( ...@@ -103,16 +103,21 @@ void FindFusionParameters(
[&](const HloInstruction&) { return TraversalResult::kVisitOperands; }); [&](const HloInstruction&) { return TraversalResult::kVisitOperands; });
} }
bool HloAnyOf( bool HloAnyOf(absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& boundary,
const std::function<bool(const HloInstruction& node)>& visit) {
return HloFindIf(roots, boundary, visit) != nullptr;
}
const HloInstruction* HloFindIf(
absl::Span<const HloInstruction* const> roots, absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer, const FusionBoundaryFn& boundary,
const HloInstruction& consumer)>& boundary,
const std::function<bool(const HloInstruction& node)>& visit) { const std::function<bool(const HloInstruction& node)>& visit) {
bool result = false; const HloInstruction* result = nullptr;
HloBfsConsumersFirstTraversal(roots, boundary, HloBfsConsumersFirstTraversal(roots, boundary,
[&](const HloInstruction& node) { [&](const HloInstruction& node) {
if (visit(node)) { if (visit(node)) {
result = true; result = &node;
return TraversalResult::kAbortTraversal; return TraversalResult::kAbortTraversal;
} }
return TraversalResult::kVisitOperands; return TraversalResult::kVisitOperands;
......
...@@ -46,16 +46,22 @@ bool DefaultFusionBoundaryFn(const HloInstruction& producer, ...@@ -46,16 +46,22 @@ bool DefaultFusionBoundaryFn(const HloInstruction& producer,
// traversed along edges for which `boundary` returns true. // traversed along edges for which `boundary` returns true.
void HloBfsConsumersFirstTraversal( void HloBfsConsumersFirstTraversal(
absl::Span<const HloInstruction* const> roots, absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer, const FusionBoundaryFn& boundary,
const HloInstruction& consumer)>& boundary,
const std::function<TraversalResult(const HloInstruction& node)>& visit); const std::function<TraversalResult(const HloInstruction& node)>& visit);
// Visit the HLO nodes starting from `roots`, returning true if the return value // Visit the HLO nodes starting from `roots`, returning true if the return value
// of `visit` for any of the ones is true. // of `visit` for any of nodes is true. Uses the same order as
bool HloAnyOf( // `HloBfsConsumersFirstTraversal`.
bool HloAnyOf(absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& boundary,
const std::function<bool(const HloInstruction& node)>& visit);
// Visit the HLO nodes stating from `roots`, returning the first
// node for which `visit` returns true, or `nullptr` if no node matches. Uses
// the same order as `HloBfsConsumersFirstTraversal`.
const HloInstruction* HloFindIf(
absl::Span<const HloInstruction* const> roots, absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer, const FusionBoundaryFn& boundary,
const HloInstruction& consumer)>& boundary,
const std::function<bool(const HloInstruction& node)>& visit); const std::function<bool(const HloInstruction& node)>& visit);
// Visit the producers of all parameters that are needed by the fusion. // Visit the producers of all parameters that are needed by the fusion.
......
...@@ -183,6 +183,28 @@ TEST_F(HloTraversalTest, FuseConsumer) { ...@@ -183,6 +183,28 @@ TEST_F(HloTraversalTest, FuseConsumer) {
"mul", "p0.1", "p1.1")); "mul", "p0.1", "p1.1"));
} }
TEST_F(HloTraversalTest, FindIf) {
auto module = ParseAndReturnVerifiedModule(kTestModule).value();
std::vector<std::string> visited_nodes;
auto* result = HloFindIf(
{module->GetComputationWithName("fused_computation")->root_instruction()},
DefaultFusionBoundaryFn, [&](const HloInstruction& node) {
return node.opcode() == HloOpcode::kMultiply;
});
ASSERT_NE(result, nullptr);
ASSERT_EQ(result->name(), "mul");
}
TEST_F(HloTraversalTest, NotFound) {
auto module = ParseAndReturnVerifiedModule(kTestModule).value();
std::vector<std::string> visited_nodes;
auto* result = HloFindIf(
{module->GetComputationWithName("fused_computation")->root_instruction()},
DefaultFusionBoundaryFn,
[&](const HloInstruction& node) { return false; });
ASSERT_EQ(result, nullptr);
}
} // namespace } // namespace
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla
...@@ -91,6 +91,7 @@ limitations under the License. ...@@ -91,6 +91,7 @@ limitations under the License.
#include "xla/service/dump.h" #include "xla/service/dump.h"
#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/gemm_rewriter_triton.h"
#include "xla/service/gpu/gpu_device_info.h" #include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
...@@ -948,12 +949,17 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config, ...@@ -948,12 +949,17 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
} // namespace } // namespace
LaunchDimensions GetMatMulLaunchDimensions( LaunchDimensions GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation, const TritonFusionAnalysis& analysis,
absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config) { const AutotuneResult::TritonGemmKey& config) {
const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>( const auto* dot = static_cast<const HloDotInstruction*>(
hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot)); HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) {
const MatMulDims dims(config, *dot_instr, analysis); return node.opcode() == HloOpcode::kDot;
const MatMulLaunchConfig launch_config(config, *dot_instr, dims); }));
CHECK_NE(dot, nullptr);
const MatMulDims dims(config, *dot, analysis);
const MatMulLaunchConfig launch_config(config, *dot, dims);
return launch_config.launch_dims; return launch_config.launch_dims;
} }
...@@ -1354,10 +1360,13 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, ...@@ -1354,10 +1360,13 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
} }
LaunchDimensions GetSoftMaxLaunchDimensions( LaunchDimensions GetSoftMaxLaunchDimensions(
const TritonFusionAnalysis&, const HloComputation* computation, absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config) { const AutotuneResult::TritonGemmKey& config) {
const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode( const HloInstruction* reduce =
*computation, HloOpcode::kReduce); HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) {
return node.opcode() == HloOpcode::kReduce;
});
CHECK_NE(reduce, nullptr); CHECK_NE(reduce, nullptr);
const Shape& reduce_input_shape = reduce->operand(0)->shape(); const Shape& reduce_input_shape = reduce->operand(0)->shape();
int num_rows = 1; int num_rows = 1;
...@@ -1482,12 +1491,11 @@ StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR( ...@@ -1482,12 +1491,11 @@ StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
} }
StatusOr<TritonWrapperResult> TritonWrapper( StatusOr<TritonWrapperResult> TritonWrapper(
absl::string_view fn_name, const HloComputation* hlo_computation, const TritonFusionAnalysis& analysis, absl::string_view fn_name,
absl::string_view fusion_kind, const se::CudaComputeCapability& cc, const HloComputation* hlo_computation, absl::string_view fusion_kind,
const GpuDeviceInfo& device_info, const se::CudaComputeCapability& cc, const GpuDeviceInfo& device_info,
const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module, const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
LaunchDimensionsGenerator launch_dims_generator, TritonIrEmitter ir_emitter, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) {
mlir::MLIRContext& mlir_context) {
if (fusion_kind == kTritonGemmFusionKind) { if (fusion_kind == kTritonGemmFusionKind) {
// This is a heuristic that serves as a proxy for register usage and code // This is a heuristic that serves as a proxy for register usage and code
// size. // size.
...@@ -1561,11 +1569,6 @@ StatusOr<TritonWrapperResult> TritonWrapper( ...@@ -1561,11 +1569,6 @@ StatusOr<TritonWrapperResult> TritonWrapper(
.debug_options() .debug_options()
.xla_gpu_cuda_data_dir()); .xla_gpu_cuda_data_dir());
TF_ASSIGN_OR_RETURN(
auto analysis,
fusion_kind == kTritonGemmFusionKind
? TritonFusionAnalysis::Execute(*hlo_computation, config.split_k())
: TritonFusionAnalysis::Execute(*hlo_computation));
TF_RETURN_IF_ERROR(ir_emitter(b, libdevice_path, analysis, hlo_computation, TF_RETURN_IF_ERROR(ir_emitter(b, libdevice_path, analysis, hlo_computation,
fn, config, fn, config,
device_info.shared_memory_per_block_optin)); device_info.shared_memory_per_block_optin));
...@@ -1658,9 +1661,7 @@ StatusOr<TritonWrapperResult> TritonWrapper( ...@@ -1658,9 +1661,7 @@ StatusOr<TritonWrapperResult> TritonWrapper(
llvm::Linker::Flags::OverrideFromSrc)); llvm::Linker::Flags::OverrideFromSrc));
LogAndVerify(llvm_module); LogAndVerify(llvm_module);
LaunchDimensions launch_dimensions = return {{shared_mem_bytes}};
launch_dims_generator(analysis, hlo_computation, config);
return {{launch_dimensions, shared_mem_bytes}};
} }
} // namespace gpu } // namespace gpu
......
...@@ -24,6 +24,7 @@ limitations under the License. ...@@ -24,6 +24,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/gemm_rewriter_triton.h"
#include "xla/service/gpu/gpu_device_info.h" #include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/launch_dimensions.h"
#include "xla/statusor.h" #include "xla/statusor.h"
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
...@@ -32,13 +33,14 @@ namespace xla { ...@@ -32,13 +33,14 @@ namespace xla {
namespace gpu { namespace gpu {
struct TritonWrapperResult { struct TritonWrapperResult {
LaunchDimensions launch_dimensions;
int64_t shmem_bytes; int64_t shmem_bytes;
}; };
// Compute the launch dimensions for the given Triton MatMul. // Compute the launch dimensions for the given Triton MatMul.
LaunchDimensions GetMatMulLaunchDimensions( LaunchDimensions GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation, const TritonFusionAnalysis& analysis,
absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config); const AutotuneResult::TritonGemmKey& config);
// Use tiling and execution parameters from 'config'. // Use tiling and execution parameters from 'config'.
Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path,
...@@ -49,7 +51,8 @@ Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, ...@@ -49,7 +51,8 @@ Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path,
// Compute the launch dimensions for the given Triton SoftMax. // Compute the launch dimensions for the given Triton SoftMax.
LaunchDimensions GetSoftMaxLaunchDimensions( LaunchDimensions GetSoftMaxLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation, absl::Span<const HloInstruction* const> roots,
const FusionBoundaryFn& fusion_boundary,
const AutotuneResult::TritonGemmKey& config); const AutotuneResult::TritonGemmKey& config);
// Generate Softmax in Triton IR inside 'fn'. // Generate Softmax in Triton IR inside 'fn'.
// Use execution parameters from 'config'. // Use execution parameters from 'config'.
...@@ -59,24 +62,20 @@ Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path, ...@@ -59,24 +62,20 @@ Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path,
const AutotuneResult::TritonGemmKey& config, const AutotuneResult::TritonGemmKey& config,
int shmem_budget); int shmem_budget);
using LaunchDimensionsGenerator = std::function<LaunchDimensions(
const TritonFusionAnalysis&, const HloComputation*,
const AutotuneResult::TritonGemmKey&)>;
using TritonIrEmitter = std::function<Status( using TritonIrEmitter = std::function<Status(
mlir::OpBuilder, absl::string_view, const TritonFusionAnalysis& analysis, mlir::OpBuilder, absl::string_view, const TritonFusionAnalysis& analysis,
const HloComputation*, mlir::triton::FuncOp, const HloComputation*, mlir::triton::FuncOp,
const AutotuneResult::TritonGemmKey&, int)>; const AutotuneResult::TritonGemmKey&, int)>;
// Generate Triton IR by running the provided generator, compile it into LLVM IR // Generate Triton IR by running the provided generator and compile it into LLVM
// and return launch dimensions. // IR.
// MatMul and SoftMax above are some such IR generators. // MatMul and SoftMax above are some such IR generators.
StatusOr<TritonWrapperResult> TritonWrapper( StatusOr<TritonWrapperResult> TritonWrapper(
absl::string_view fn_name, const HloComputation* hlo_computation, const TritonFusionAnalysis& analysis, absl::string_view fn_name,
absl::string_view fusion_kind, const se::CudaComputeCapability& cc, const HloComputation* hlo_computation, absl::string_view fusion_kind,
const GpuDeviceInfo& device_info, const se::CudaComputeCapability& cc, const GpuDeviceInfo& device_info,
const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module, const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
LaunchDimensionsGenerator launch_dims_generator, TritonIrEmitter ir_emitter, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context);
mlir::MLIRContext& mlir_context);
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla
......
...@@ -211,11 +211,11 @@ ENTRY entry { ...@@ -211,11 +211,11 @@ ENTRY entry {
config.set_num_stages(4); config.set_num_stages(4);
config.set_num_warps(8); config.set_num_warps(8);
EXPECT_THAT( EXPECT_THAT(
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0}, /*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions, dev_info, config, &llvm_module, &EmitMatMul, mlir_context),
&EmitMatMul, mlir_context),
tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED, tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED,
"Shared memory size limit exceeded.")); "Shared memory size limit exceeded."));
...@@ -225,11 +225,11 @@ ENTRY entry { ...@@ -225,11 +225,11 @@ ENTRY entry {
config.set_num_stages(1); config.set_num_stages(1);
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
const auto result, const auto result,
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0}, /*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions, dev_info, config, &llvm_module, &EmitMatMul, mlir_context));
&EmitMatMul, mlir_context));
// Use optin shared memory which is > shared_memory_per_block. // Use optin shared memory which is > shared_memory_per_block.
EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block); EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block);
} }
...@@ -639,11 +639,11 @@ ENTRY entry { ...@@ -639,11 +639,11 @@ ENTRY entry {
config.set_num_stages(1); config.set_num_stages(1);
config.set_num_warps(2); config.set_num_warps(2);
EXPECT_THAT( EXPECT_THAT(
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0}, /*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions, dev_info, config, &llvm_module, &EmitMatMul, mlir_context),
&EmitMatMul, mlir_context),
tsl::testing::StatusIs( tsl::testing::StatusIs(
tsl::error::RESOURCE_EXHAUSTED, tsl::error::RESOURCE_EXHAUSTED,
"Tiling complexity heuristic exceeded: 147456 > 9000")); "Tiling complexity heuristic exceeded: 147456 > 9000"));
...@@ -653,11 +653,11 @@ ENTRY entry { ...@@ -653,11 +653,11 @@ ENTRY entry {
config.set_block_n(32); config.set_block_n(32);
config.set_block_k(32); config.set_block_k(32);
TF_CHECK_OK( TF_CHECK_OK(
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0}, /*minor=*/0},
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions, dev_info, config, &llvm_module, &EmitMatMul, mlir_context)
&EmitMatMul, mlir_context)
.status()); .status());
} }
...@@ -1438,10 +1438,11 @@ ENTRY e { ...@@ -1438,10 +1438,11 @@ ENTRY e {
->backend_config<FusionBackendConfig>()); ->backend_config<FusionBackendConfig>());
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
const auto result, const auto result,
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
"test_fn", triton_dot_computation, kTritonGemmFusionKind,
GetCudaComputeCapability(), dev_info, GetCudaComputeCapability(), dev_info,
config.triton_gemm_config(), &llvm_module, config.triton_gemm_config(), &llvm_module, &EmitMatMul,
&GetMatMulLaunchDimensions, &EmitMatMul, mlir_context)); mlir_context));
// The config is chosen so that the used memory size is slightly above the // The config is chosen so that the used memory size is slightly above the
// 48 kB boundary of standard / optin shared memory so that any GPU that // 48 kB boundary of standard / optin shared memory so that any GPU that
// has the optin one should be able to execute the test. // has the optin one should be able to execute the test.
......
...@@ -1701,7 +1701,8 @@ static Status ProcessFusionForConversion(mlir::Region* region, ...@@ -1701,7 +1701,8 @@ static Status ProcessFusionForConversion(mlir::Region* region,
#if GOOGLE_CUDA #if GOOGLE_CUDA
Status IrEmitterUnnested::EmitTritonFusion( Status IrEmitterUnnested::EmitTritonFusion(
mlir::Operation* op, const AutotuneResult::TritonGemmKey& config, const HloFusionAnalysis& hlo_fusion_analysis, mlir::Operation* op,
const AutotuneResult::TritonGemmKey& config,
const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>& const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
hlo_for_lmhlo) { hlo_for_lmhlo) {
// Note: In this method we can't use `BuildKernelThunk` as usual, // Note: In this method we can't use `BuildKernelThunk` as usual,
...@@ -1740,23 +1741,35 @@ Status IrEmitterUnnested::EmitTritonFusion( ...@@ -1740,23 +1741,35 @@ Status IrEmitterUnnested::EmitTritonFusion(
absl::string_view fusion_kind = backend_config.kind(); absl::string_view fusion_kind = backend_config.kind();
TritonWrapperResult triton_wrapper_result; TritonWrapperResult triton_wrapper_result;
LaunchDimensions launch_dimensions;
if (fusion_kind == kTritonSoftmaxFusionKind) { if (fusion_kind == kTritonSoftmaxFusionKind) {
TF_ASSIGN_OR_RETURN(auto analysis,
TritonFusionAnalysis::Execute(*hlo_computation));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
triton_wrapper_result, triton_wrapper_result,
TritonWrapper(impl_fn_name, hlo_computation, kTritonSoftmaxFusionKind, TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonSoftmaxFusionKind,
ir_emitter_context_->cuda_compute_capability(), ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_device_info(), config, module_, ir_emitter_context_->gpu_device_info(), config, module_,
&GetSoftMaxLaunchDimensions, &EmitSoftMax, &EmitSoftMax, *ir_emitter_context_->mlir_context()));
*ir_emitter_context_->mlir_context())); launch_dimensions = GetSoftMaxLaunchDimensions(
hlo_fusion_analysis.fusion_roots(),
hlo_fusion_analysis.fusion_boundary(), config);
} else { // Must be a MatMul } else { // Must be a MatMul
CHECK_EQ(fusion_kind, kTritonGemmFusionKind); CHECK_EQ(fusion_kind, kTritonGemmFusionKind);
TF_ASSIGN_OR_RETURN(
auto analysis,
TritonFusionAnalysis::Execute(*hlo_computation, config.split_k()));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
triton_wrapper_result, triton_wrapper_result,
TritonWrapper(impl_fn_name, hlo_computation, kTritonGemmFusionKind, TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonGemmFusionKind,
ir_emitter_context_->cuda_compute_capability(), ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_device_info(), config, module_, ir_emitter_context_->gpu_device_info(), config, module_,
&GetMatMulLaunchDimensions, &EmitMatMul, &EmitMatMul, *ir_emitter_context_->mlir_context()));
*ir_emitter_context_->mlir_context())); launch_dimensions = GetMatMulLaunchDimensions(
analysis, hlo_fusion_analysis.fusion_roots(),
hlo_fusion_analysis.fusion_boundary(), config);
} }
llvm::Function* impl_fn = module_->getFunction(impl_fn_name); llvm::Function* impl_fn = module_->getFunction(impl_fn_name);
...@@ -1764,7 +1777,7 @@ Status IrEmitterUnnested::EmitTritonFusion( ...@@ -1764,7 +1777,7 @@ Status IrEmitterUnnested::EmitTritonFusion(
auto [kernel, inputs, outputs] = BuildKernelPrototype( auto [kernel, inputs, outputs] = BuildKernelPrototype(
*ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(), *ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(),
impl_fn->arg_size(), triton_wrapper_result.launch_dimensions, &b_); impl_fn->arg_size(), launch_dimensions, &b_);
// Move function body into kernel prototype. // Move function body into kernel prototype.
llvm::Function* prototype_func = b_.GetInsertBlock()->getParent(); llvm::Function* prototype_func = b_.GetInsertBlock()->getParent();
...@@ -1775,7 +1788,7 @@ Status IrEmitterUnnested::EmitTritonFusion( ...@@ -1775,7 +1788,7 @@ Status IrEmitterUnnested::EmitTritonFusion(
impl_fn->eraseFromParent(); impl_fn->eraseFromParent();
LogAndVerify(module_); LogAndVerify(module_);
return {{kernel->getName().str(), triton_wrapper_result.launch_dimensions, return {{kernel->getName().str(), launch_dimensions,
triton_wrapper_result.shmem_bytes}}; triton_wrapper_result.shmem_bytes}};
}; };
...@@ -1850,7 +1863,8 @@ Status IrEmitterUnnested::EmitFusion( ...@@ -1850,7 +1863,8 @@ Status IrEmitterUnnested::EmitFusion(
triton_config.set_num_stages(1); triton_config.set_num_stages(1);
triton_config.set_num_warps(2); triton_config.set_num_warps(2);
} }
return EmitTritonFusion(fusion_op, backend_config.triton_gemm_config(), return EmitTritonFusion(fusion_analysis, fusion_op,
backend_config.triton_gemm_config(),
hlo_for_lmhlo); hlo_for_lmhlo);
} }
if (backend_config.kind() == kTritonSoftmaxFusionKind) { if (backend_config.kind() == kTritonSoftmaxFusionKind) {
...@@ -1858,7 +1872,8 @@ Status IrEmitterUnnested::EmitFusion( ...@@ -1858,7 +1872,8 @@ Status IrEmitterUnnested::EmitFusion(
triton_config.set_num_stages(1); triton_config.set_num_stages(1);
triton_config.set_num_warps( triton_config.set_num_warps(
DeriveNumWarpsFromTritonSoftmaxComputation(fused_computation)); DeriveNumWarpsFromTritonSoftmaxComputation(fused_computation));
return EmitTritonFusion(fusion_op, backend_config.triton_gemm_config(), return EmitTritonFusion(fusion_analysis, fusion_op,
backend_config.triton_gemm_config(),
hlo_for_lmhlo); hlo_for_lmhlo);
} }
#endif #endif
......
...@@ -136,7 +136,8 @@ class IrEmitterUnnested : public IrEmitter { ...@@ -136,7 +136,8 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitCublasLtMatmulThunkF8(mlir::Operation* op); Status EmitCublasLtMatmulThunkF8(mlir::Operation* op);
Status EmitConvolutionReorderThunk(mlir::Operation* op); Status EmitConvolutionReorderThunk(mlir::Operation* op);
Status EmitTritonFusion( Status EmitTritonFusion(
mlir::Operation* op, const AutotuneResult::TritonGemmKey& config, const HloFusionAnalysis& hlo_fusion_analysis, mlir::Operation* op,
const AutotuneResult::TritonGemmKey& config,
const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>& const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
hlo_for_lmhlo); hlo_for_lmhlo);
Status EmitFusedMHAThunk(mlir::Operation* op); Status EmitFusedMHAThunk(mlir::Operation* op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册