From 907fb5bddf7eafea7166748109a537fdf7195d8c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 11 Sep 2023 10:47:35 -0700 Subject: [PATCH] NFC: Split launch grid and validation out of MatMul emitter. Also - Unify naming to ${scope}_${thing}_dim_idx. - Make logic for split_lhs_nc a bit easier to follow (at least in my opinion), by decoupling it from the batch_size computation. PiperOrigin-RevId: 564435916 --- .../xla/xla/service/gpu/ir_emitter_triton.cc | 419 +++++++++++------- 1 file changed, 255 insertions(+), 164 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 177aef14db0..3f503e565f0 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -729,6 +729,212 @@ struct GeneralizeKernelSignaturePass } }; +const TensorIterationSpec::DimIterationSpec* GetLhsNoncontractingSplitSpec( + const TritonFusionAnalysis& analysis, int64_t lhs_noncontracting_dim_idx) { + const TensorIterationSpec::DimIterationSpec* result = nullptr; + for (const HloInstruction* lhs_param : + analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS)) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, lhs_param, + lhs_noncontracting_dim_idx); + if (spec != nullptr && spec->size() > 1) { + CHECK_EQ(spec->size(), 2); + if (result != nullptr) { + CHECK_EQ(result->at(0).count, spec->at(0).count); + CHECK_EQ(result->at(1).count, spec->at(1).count); + } + result = spec; + } + } + return result; +} + +// Structure for parameters relating to the MatMul shape and dimension indices. +// +// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. +// +// The logical output dimensions are always ordered as: +// split-K, batch, non-contracting LHS, non-contracting RHS, +// where split-K and batch are optional. +struct MatMulDims { + MatMulDims(const AutotuneResult::TritonGemmKey& config, + const HloDotInstruction& dot, + const TritonFusionAnalysis& analysis); + + std::optional out_split_k_dim_idx = std::nullopt; + + // TODO(b/270937368): Create a struct for the batch dimensions. + std::optional lhs_batch_dim_idx = std::nullopt; + std::optional rhs_batch_dim_idx = std::nullopt; + std::optional out_batch_dim_idx = std::nullopt; + + // The LHS non-contracting can be split into two. + std::optional lhs_noncontracting_split = std::nullopt; + + int lhs_contracting_dim_idx; + int lhs_noncontracting_dim_idx; + int rhs_contracting_dim_idx; + int rhs_noncontracting_dim_idx; + // The index of the LHS noncontracting dim in the output. + int out_lhs_noncontracting_dim_idx; + // The index of the RHS noncontracting dim in the output. + int out_rhs_noncontracting_dim_idx; + + int64_t m; + int64_t n; + int64_t k; +}; + +// Structure for parameters relating to the MatMul launch grid. +struct MatMulLaunchConfig { + explicit MatMulLaunchConfig(const AutotuneResult::TritonGemmKey& config, + const HloDotInstruction& dot, + const MatMulDims& dims); + + int64_t grid_m; + int64_t grid_n; + LaunchDimensions launch_dims; + mt::ProgramIDDim batch_program_id_dim; + mt::ProgramIDDim noncontracting_program_id_dim; +}; + +MatMulDims::MatMulDims(const AutotuneResult::TritonGemmKey& config, + const HloDotInstruction& dot, + const TritonFusionAnalysis& analysis) { + if (config.split_k() > 1) { + // split-k is always the first logical dimension. + out_split_k_dim_idx = 0; + } + + int64_t num_split_k_dims = config.split_k() > 1 ? 1 : 0; + const auto& dims = dot.dot_dimension_numbers(); + lhs_contracting_dim_idx = dims.lhs_contracting_dimensions(0); + lhs_noncontracting_dim_idx = + GetNonContractingDims(dot.operand(0)->shape(), + dims.lhs_batch_dimensions(), + dims.lhs_contracting_dimensions()) + .value()[0]; + rhs_contracting_dim_idx = dims.rhs_contracting_dimensions(0); + rhs_noncontracting_dim_idx = + GetNonContractingDims(dot.operand(1)->shape(), + dims.rhs_batch_dimensions(), + dims.rhs_contracting_dimensions()) + .value()[0]; + + if (dims.lhs_batch_dimensions_size() > num_split_k_dims) { + lhs_batch_dim_idx = *dims.lhs_batch_dimensions().rbegin(); + rhs_batch_dim_idx = *dims.rhs_batch_dimensions().rbegin(); + // The batch dimension (if present) comes after the split-k dimension (if + // present, otherwise it's the first dimension). + out_batch_dim_idx = num_split_k_dims; + } + + // Logical output dimensions are always ordered as: + // split-K, batch, non-contracting LHS, non-contracting RHS, + // where split-K and batch are optional. + out_rhs_noncontracting_dim_idx = dot.shape().rank() - 1; + out_lhs_noncontracting_dim_idx = dot.shape().rank() - 2; + + auto* root = dot.parent()->root_instruction(); + n = analysis + .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + out_rhs_noncontracting_dim_idx) + ->at(0) + .count; + // Contracting dimension length. + k = dot.operand(0)->shape().dimensions(dims.lhs_contracting_dimensions(0)) * + config.split_k(); + + auto* lhs_noncontracting_split_spec = + GetLhsNoncontractingSplitSpec(analysis, lhs_noncontracting_dim_idx); + if (lhs_noncontracting_split_spec != nullptr) { + // Just the fastest-varying part of it if the dimension is split. + m = lhs_noncontracting_split_spec->at(0).count; + lhs_noncontracting_split = lhs_noncontracting_split_spec->at(1).count; + } else { + m = analysis + .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + out_lhs_noncontracting_dim_idx) + ->at(0) + .count; + } + + // For now split non-contracting and batch are not supported + // simultaneously because they are implemented via same mechanism. + CHECK( + !(out_batch_dim_idx.has_value() && lhs_noncontracting_split.has_value())); + + CHECK_GE(m, 1); + CHECK_GE(n, 1); +} + +MatMulLaunchConfig::MatMulLaunchConfig( + const AutotuneResult::TritonGemmKey& config, const HloDotInstruction& dot, + const MatMulDims& dims) + : grid_m((dims.m + config.block_m() - 1) / config.block_m()), + grid_n((dims.n + config.block_n() - 1) / config.block_n()) { + int64_t batch_size = dims.lhs_noncontracting_split.value_or( + dims.out_batch_dim_idx.has_value() + ? dot.shape().dimensions(*dims.out_batch_dim_idx) + : 1); + // X block size is 32-bit, Y and Z are 16-bit. Use X for large dimensions. + constexpr int64_t kBlockCountYZLimit = 65536; + + // In the imaginary situation where both batch size and grid_m * grid_n + // are over 65535 we have to give up. Given the minimal m, n block sizes of 16 + // this requires at least 256 GB of output. + CHECK_LT(batch_size * grid_m * grid_n, + kBlockCountYZLimit * kBlockCountYZLimit); + + const bool large_batch = batch_size >= kBlockCountYZLimit; + if (large_batch) { + batch_program_id_dim = mt::ProgramIDDim::X; + noncontracting_program_id_dim = mt::ProgramIDDim::Y; + launch_dims = {{batch_size, grid_m * grid_n, config.split_k()}, + {config.num_warps() * WarpSize(), 1, 1}}; + } else { + batch_program_id_dim = mt::ProgramIDDim::Y; + noncontracting_program_id_dim = mt::ProgramIDDim::X; + launch_dims = + LaunchDimensions{{grid_m * grid_n, batch_size, config.split_k()}, + {config.num_warps() * WarpSize(), 1, 1}}; + } +} + +void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config, + const HloDotInstruction& dot) { + CHECK_GE(config.split_k(), 1); + CHECK_GE(config.block_m(), 16); + CHECK_GE(config.block_k(), 16); + CHECK_GE(config.block_n(), 16); + + const auto& dims = dot.dot_dimension_numbers(); + int num_batch_dims = + dims.lhs_batch_dimensions_size() - (config.split_k() > 1 ? 1 : 0); + CHECK_LE(num_batch_dims, 1); + if (config.split_k() > 1) { + // Split-K dimension has to be the first batch one and have an index + // just before the contracting one. + const int lhs_split_k_dim_idx = dims.lhs_contracting_dimensions(0) - 1; + const int rhs_split_k_dim_idx = dims.rhs_contracting_dimensions(0) - 1; + // Size of this dimension has to match the split_k value. + CHECK_EQ(dims.lhs_batch_dimensions(0), lhs_split_k_dim_idx); + CHECK_EQ(dims.rhs_batch_dimensions(0), rhs_split_k_dim_idx); + CHECK_EQ(config.split_k(), + dot.operand(0)->shape().dimensions(lhs_split_k_dim_idx)); + CHECK_EQ(config.split_k(), + dot.operand(1)->shape().dimensions(rhs_split_k_dim_idx)); + } + + // Rely on dot decomposer: there is just one contracting and one + // non-contracting dimension on each side + batch ones optionally. + CHECK_EQ(dims.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dims.rhs_contracting_dimensions_size(), 1); + + CHECK_EQ(dot.operand(0)->shape().rank(), + 2 + (config.split_k() > 1 ? 1 : 0) + num_batch_dims); +} + } // namespace // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. @@ -761,153 +967,34 @@ StatusOr MatMul(mlir::OpBuilder builder, Type i32_ty = b.getI32Type(); Type i64_ty = b.getI64Type(); + ValidateMatMulConfig(config, *dot_instr); const int split_k = config.split_k(); const int block_m = config.block_m(); const int block_k = config.block_k(); const int block_n = config.block_n(); - CHECK_GE(split_k, 1); - CHECK_GE(block_m, 16); - CHECK_GE(block_k, 16); - CHECK_GE(block_n, 16); - const DotDimensionNumbers& dims = dot_instr->dot_dimension_numbers(); TF_ASSIGN_OR_RETURN( const TritonFusionAnalysis analysis, TritonFusionAnalysis::Execute(*dot_instr->parent(), split_k)); + const MatMulDims dims(config, *dot_instr, analysis); + const MatMulLaunchConfig launch_config(config, *dot_instr, dims); VLOG(6) << analysis.ToString(); - // Rely on dot decomposer: there is just one contracting and one - // non-contracting dimension on each side + batch ones optionally. - CHECK_EQ(dims.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dims.rhs_contracting_dimensions_size(), 1); - - std::optional split_k_out_idx = std::nullopt; - if (split_k > 1) { - // split-k is always the first logical output dimension. - split_k_out_idx = 0; - // Split-K dimension has to be the first batch one and have an index - // just before the contracting one. - const int lhs_split_k_dim_idx = dims.lhs_contracting_dimensions(0) - 1; - const int rhs_split_k_dim_idx = dims.rhs_contracting_dimensions(0) - 1; - // Size of this dimension has to match the split_k value. - CHECK_EQ(dims.lhs_batch_dimensions(0), lhs_split_k_dim_idx); - CHECK_EQ(dims.rhs_batch_dimensions(0), rhs_split_k_dim_idx); - CHECK_EQ(split_k, - dot_instr->operand(0)->shape().dimensions(lhs_split_k_dim_idx)); - CHECK_EQ(split_k, - dot_instr->operand(1)->shape().dimensions(rhs_split_k_dim_idx)); - } - - int num_split_k_dims = split_k_out_idx.has_value() ? 1 : 0; - CHECK_LE(dims.lhs_batch_dimensions_size(), 1 + num_split_k_dims); - std::optional lhs_batch_dim_idx = std::nullopt; - std::optional rhs_batch_dim_idx = std::nullopt; - std::optional batch_out_idx = std::nullopt; - if (dims.lhs_batch_dimensions_size() > num_split_k_dims) { - lhs_batch_dim_idx = *dims.lhs_batch_dimensions().rbegin(); - rhs_batch_dim_idx = *dims.rhs_batch_dimensions().rbegin(); - // The batch dimension (if present) comes after the split-k dimension (if - // present, otherwise it's the first dimension). - batch_out_idx = num_split_k_dims; - } - CHECK_EQ(dot_instr->operand(0)->shape().rank(), - 2 + (split_k_out_idx.has_value() ? 1 : 0) + - (lhs_batch_dim_idx.has_value() ? 1 : 0)); - const int lhs_noncontracting_dim_idx = - GetNonContractingDims(dot_instr->operand(0)->shape(), - dims.lhs_batch_dimensions(), - dims.lhs_contracting_dimensions()) - .value()[0]; - const int rhs_noncontracting_dim_idx = - GetNonContractingDims(dot_instr->operand(1)->shape(), - dims.rhs_batch_dimensions(), - dims.rhs_contracting_dimensions()) - .value()[0]; - - // Logical output dimensions are always ordered as: - // split-K, batch, non-contracting LHS, non-contracting RHS, - // where split-K and batch are optional. - const int rhs_nc_out_idx = dot_instr->shape().rank() - 1; - const int lhs_nc_out_idx = dot_instr->shape().rank() - 2; - - // LHS non-contracting dimension length. - // LHS non-contracting can be split, this holds only its minor part. - int m = - analysis - .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, lhs_nc_out_idx) - ->at(0) - .count; - - // Contracting dimension length. - const int k = dot_instr->operand(0)->shape().dimensions( - dims.lhs_contracting_dimensions(0)) * - split_k; - - // LHS non-contracting can be split into two. - bool lhs_nc_split = false; - // Either batch GEMM size or major part of the split - // non-contracting LHS dimension. - int batch_size = 1; - for (const HloInstruction* lhs_param : - analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS)) { - const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = - analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, lhs_param, - lhs_noncontracting_dim_idx); - if (lhs_nc_iter_spec != nullptr && lhs_nc_iter_spec->size() > 1) { - // For now split non-contracting and batch are not supported - // simultaneously because they are implemented via same mechanism. - CHECK(!lhs_batch_dim_idx.has_value()); - CHECK_EQ(lhs_nc_iter_spec->size(), 2); - lhs_nc_split = true; - // If split dimension is used all parameters have to have either have - // the same split ratio or none. - if (batch_size == 1) { - batch_size = lhs_nc_iter_spec->at(1).count; - CHECK_GE(batch_size, 1); - } else { - CHECK_EQ(batch_size, lhs_nc_iter_spec->at(1).count); - } - // Just the fastest-varying part of it if the dimension is split. - m = lhs_nc_iter_spec->at(0).count; - } - } - if (batch_out_idx.has_value() && !lhs_nc_split) { - batch_size = dot_instr->shape().dimensions(*batch_out_idx); - } - CHECK_GE(m, 1); - constexpr int group_m = 8; - const int n = - analysis - .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, rhs_nc_out_idx) - ->at(0) - .count; - CHECK_GE(n, 1); - - const int grid_m = ceil(1.0 * m / block_m); - const int grid_n = ceil(1.0 * n / block_n); - const int width = group_m * grid_n; + const int64_t width = group_m * launch_config.grid_n; - // X block size is 32-bit, Y and Z are 16-bit. Use X for large dimensions. - constexpr int64_t kBlockCountYZLimit = 65536; - const bool large_batch = batch_size >= kBlockCountYZLimit; - auto pid_batch = b.create( - large_batch ? mt::ProgramIDDim::X : mt::ProgramIDDim::Y); - auto pid_nc = b.create(large_batch ? mt::ProgramIDDim::Y - : mt::ProgramIDDim::X); + auto pid_batch = + b.create(launch_config.batch_program_id_dim); + auto pid_nc = + b.create(launch_config.noncontracting_program_id_dim); auto pid_k = b.create(mt::ProgramIDDim::Z); - // In the imaginary situation where both batch size and grid_m * grid_n - // are over 65535 we have to give up. Given the minimal m, n block sizes of 16 - // this requires at least 256 GB of output. - CHECK_LT(batch_size * grid_m * grid_n, - kBlockCountYZLimit * kBlockCountYZLimit); - auto group_id = b.create(pid_nc, CreateConst(b, i32_ty, width)); ma::ConstantOp group_m_op = CreateConst(b, i32_ty, group_m); auto first_pid_m = b.create(group_id, group_m_op); - auto sub0 = b.create(CreateConst(b, i32_ty, grid_m), first_pid_m); + auto sub0 = b.create(CreateConst(b, i32_ty, launch_config.grid_m), + first_pid_m); auto group_size = b.create( b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, group_m_op); @@ -934,12 +1021,12 @@ StatusOr MatMul(mlir::OpBuilder builder, b.create(pid_k, CreateConst(b, i32_ty, block_k)); std::vector lhs_tiled_dims = { - {lhs_noncontracting_dim_idx, pid_m_offset, block_m}, - {dims.lhs_contracting_dimensions(0), pid_k_offset, block_k}}; + {dims.lhs_noncontracting_dim_idx, pid_m_offset, block_m}, + {dims.lhs_contracting_dim_idx, pid_k_offset, block_k}}; std::vector rhs_tiled_dims = { - {dims.rhs_contracting_dimensions(0), pid_k_offset, block_k}, - {rhs_noncontracting_dim_idx, pid_n_offset, block_n}}; + {dims.rhs_contracting_dim_idx, pid_k_offset, block_k}, + {dims.rhs_noncontracting_dim_idx, pid_n_offset, block_n}}; Type dot_output_ty = TritonType(b, dot_instr->shape().element_type()); // Data type of dot() immediate inputs. @@ -987,6 +1074,8 @@ StatusOr MatMul(mlir::OpBuilder builder, .second); std::vector& tiled_dims = is_lhs ? lhs_tiled_dims : rhs_tiled_dims; + int64_t contracting_dim_idx = + is_lhs ? dims.lhs_contracting_dim_idx : dims.rhs_contracting_dim_idx; SmallVector increments; for (const DimProperties& dim : tiled_dims) { const TensorIterationSpec::DimIterationSpec* spec = @@ -995,8 +1084,7 @@ StatusOr MatMul(mlir::OpBuilder builder, continue; } // Only the contracting dimensions are advanced. - if ((is_lhs && dim.index == dims.lhs_contracting_dimensions(0)) || - (!is_lhs && dim.index == dims.rhs_contracting_dimensions(0))) { + if (dim.index == contracting_dim_idx) { increments.push_back( CreateConst(b, i32_ty, dim.block_size * split_k)); } else { @@ -1032,10 +1120,10 @@ StatusOr MatMul(mlir::OpBuilder builder, // again just before the dot so that they do not affect the output. // Only the K dimension needs masking here because unnecessary elements in // the other two get discarded by the masked store at the end. - const bool need_masking = k % (block_k * split_k) > 0; + const bool need_masking = dims.k % (block_k * split_k) > 0; if (need_masking) { auto elements_in_tile = - b.create(CreateConst(b, i32_ty, k), ki); + b.create(CreateConst(b, i32_ty, dims.k), ki); auto range_k = b.create( Splat(b, b.create(pid_k, CreateConst(b, i32_ty, block_k)), block_k), @@ -1100,11 +1188,11 @@ StatusOr MatMul(mlir::OpBuilder builder, const int64_t stride = spec->at(0).stride; int64_t count = spec->at(0).count; if (scope == TritonFusionAnalysis::Scope::OUTPUT && - properties.index == lhs_nc_out_idx && spec->size() == 1 && - lhs_nc_split) { + properties.index == dims.out_lhs_noncontracting_dim_idx && + spec->size() == 1 && dims.lhs_noncontracting_split.has_value()) { // Dimension of the output produced by the non-contracting LHS one // is logically split, major part is addressed using pid_batch. - count /= batch_size; + count /= *dims.lhs_noncontracting_split; } if (count % properties.block_size != 0) { boundary_checks.push_back(bounds.size()); @@ -1120,7 +1208,8 @@ StatusOr MatMul(mlir::OpBuilder builder, } int64_t stride_batch = 0; - if (scope != TritonFusionAnalysis::Scope::RHS && lhs_nc_split) { + if (scope != TritonFusionAnalysis::Scope::RHS && + dims.lhs_noncontracting_split) { const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec(scope, hlo, tiled_dimensions[0].index); if (spec != nullptr) { @@ -1133,7 +1222,8 @@ StatusOr MatMul(mlir::OpBuilder builder, // batch logic stride_batch is populated here as the stride of // the minor part times its size. stride_batch = - spec->at(0).stride * (spec->at(0).count / batch_size); + spec->at(0).stride * + (spec->at(0).count / *dims.lhs_noncontracting_split); } CHECK_NE(stride_batch, 0); } @@ -1152,9 +1242,10 @@ StatusOr MatMul(mlir::OpBuilder builder, base = AddPtr(b, base, offset_batch); } - if (split_k_out_idx.has_value()) { - const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( - TritonFusionAnalysis::Scope::OUTPUT, hlo, *split_k_out_idx); + if (dims.out_split_k_dim_idx.has_value()) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, hlo, + *dims.out_split_k_dim_idx); if (spec != nullptr) { int64_t stride_split_k = spec->at(0).stride; Value offset_split_k = @@ -1176,25 +1267,27 @@ StatusOr MatMul(mlir::OpBuilder builder, for (const HloInstruction* parameter : analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS)) { CHECK(iter_args_to_parameters.insert({iter_args.size(), parameter}).second); - iter_args.push_back(emit_tensor_pointer( - parameter, TritonFusionAnalysis::Scope::LHS, - fn.getArgument(parameter->parameter_number()), lhs_tiled_dims, - lhs_batch_dim_idx, iter_args_to_boundary_checks[iter_args.size()])); + iter_args.push_back( + emit_tensor_pointer(parameter, TritonFusionAnalysis::Scope::LHS, + fn.getArgument(parameter->parameter_number()), + lhs_tiled_dims, dims.lhs_batch_dim_idx, + iter_args_to_boundary_checks[iter_args.size()])); } for (const HloInstruction* parameter : analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS)) { CHECK(iter_args_to_parameters.insert({iter_args.size(), parameter}).second); - iter_args.push_back(emit_tensor_pointer( - parameter, TritonFusionAnalysis::Scope::RHS, - fn.getArgument(parameter->parameter_number()), rhs_tiled_dims, - rhs_batch_dim_idx, iter_args_to_boundary_checks[iter_args.size()])); + iter_args.push_back( + emit_tensor_pointer(parameter, TritonFusionAnalysis::Scope::RHS, + fn.getArgument(parameter->parameter_number()), + rhs_tiled_dims, dims.rhs_batch_dim_idx, + iter_args_to_boundary_checks[iter_args.size()])); } iter_args.push_back(accumulator_init); Value acc_final = b.create( /*lowerBound=*/b.create(0, /*width=*/32), - /*upperBound=*/b.create(k, /*width=*/32), + /*upperBound=*/b.create(dims.k, /*width=*/32), /*step=*/ b.create(block_k * split_k, /*width=*/32), @@ -1233,17 +1326,17 @@ StatusOr MatMul(mlir::OpBuilder builder, } } std::vector out_tiled_dims = { - {lhs_nc_out_idx, pid_m_offset, block_m}, - {rhs_nc_out_idx, pid_n_offset, block_n}}; + {dims.out_lhs_noncontracting_dim_idx, pid_m_offset, block_m}, + {dims.out_rhs_noncontracting_dim_idx, pid_n_offset, block_n}}; // Emit the output scope. if (!to_emit.empty()) { for (const HloInstruction* parameter : analysis.ScopeParameters(TritonFusionAnalysis::Scope::OUTPUT)) { std::vector boundary_checks; - Value tensor_pointer = - emit_tensor_pointer(parameter, TritonFusionAnalysis::Scope::OUTPUT, - fn.getArgument(parameter->parameter_number()), - out_tiled_dims, batch_out_idx, boundary_checks); + Value tensor_pointer = emit_tensor_pointer( + parameter, TritonFusionAnalysis::Scope::OUTPUT, + fn.getArgument(parameter->parameter_number()), out_tiled_dims, + dims.out_batch_dim_idx, boundary_checks); CHECK(values_out .insert({parameter, EmitParameterLoad(b, tensor_pointer, boundary_checks)}) @@ -1264,13 +1357,11 @@ StatusOr MatMul(mlir::OpBuilder builder, Value tensor_pointer = emit_tensor_pointer( producer, TritonFusionAnalysis::Scope::OUTPUT, fn.getArgument(i + dot_instr->parent()->num_parameters()), - out_tiled_dims, batch_out_idx, boundary_checks); + out_tiled_dims, dims.out_batch_dim_idx, boundary_checks); b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); } - return LaunchDimensions{{large_batch ? batch_size : grid_m * grid_n, - large_batch ? grid_m * grid_n : batch_size, split_k}, - {config.num_warps() * WarpSize(), 1, 1}}; + return launch_config.launch_dims; } StatusOr SoftMax(mlir::OpBuilder builder, -- GitLab