From 4ae35f9f35f38fc12b65b077a7f25eab1137e391 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 13 Sep 2023 05:51:12 -0700 Subject: [PATCH] [NFC] [XLA:GPU] Split up Triton emitter into multiple functions. PiperOrigin-RevId: 565024582 --- .../xla/xla/service/gpu/ir_emitter_triton.cc | 422 ++++++++++-------- 1 file changed, 232 insertions(+), 190 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 31aa89978e0..7ed29e64274 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -946,6 +946,203 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config, 2 + (config.split_k() > 1 ? 1 : 0) + num_batch_dims); } +struct Side { + TritonFusionAnalysis::Scope scope; + std::vector tiled_dims; + std::optional batch_dim_idx; +}; + +class MatMulEmitterHelper { + public: + MatMulEmitterHelper(mlir::OpBuilder builder, absl::string_view libdevice_path, + const HloDotInstruction* dot_instr, + ImplicitLocOpBuilder& b, Type index_ty, MatMulDims dims, + const MatMulLaunchConfig& launch_config, + const TritonFusionAnalysis& analysis) + : b_(b), + libdevice_path_(libdevice_path), + dot_instr_(dot_instr), + index_ty_(index_ty), + analysis_(analysis), + dims_(dims), + launch_config_(launch_config) {} + + // TODO(b/266862493): Accumulator can be integer too. + // Otherwise only f64 x f64 -> f64 uses f64 accumulator. + mlir::FloatType GetDotAccumulatorType() { + Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type()); + // Data type of dot() immediate inputs. + Type dot_input_ty = [&] { + const Type lhs_ty = + TritonType(b_, dot_instr_->operand(0)->shape().element_type()); + const Type rhs_ty = + TritonType(b_, dot_instr_->operand(1)->shape().element_type()); + CHECK(lhs_ty == rhs_ty); + return lhs_ty; + }(); + // TODO(b/266862493): Accumulator can be integer too. + // Otherwise only f64 x f64 -> f64 uses f64 accumulator. + return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() + : b_.getF32Type(); + } + + std::vector EpiloguePostOrderTransitiveOperands( + const HloInstruction* root) { + // Collect all instructions of the dot's output scope. + absl::flat_hash_set to_order; + { + std::queue to_add; + if (root != dot_instr_) { + to_add.push(root); + } + while (!to_add.empty()) { + const HloInstruction* current = to_add.front(); + for (const HloInstruction* operand : current->operands()) { + if (!to_order.contains(operand)) { + if (operand != dot_instr_) { + to_add.push(operand); + } + } + } + CHECK(to_order.insert(current).second); + to_add.pop(); + } + } + // Order them producers before consumers. + std::vector to_emit; + for (const HloInstruction* hlo : + dot_instr_->parent()->MakeInstructionPostOrder()) { + if (to_order.contains(hlo)) { + to_emit.push_back(hlo); + } + } + return to_emit; + } + + Value MakeInput(Side& side, int64_t operand_index, + absl::flat_hash_map& values) { + return *EmitScope( + b_, libdevice_path_, &analysis_, side.scope, side.tiled_dims, + dot_instr_->parent()->MakeInstructionPostOrderFrom( + const_cast(*dot_instr_->operand(operand_index))), + values); + } + + Value EmitTensorPointer(const HloInstruction* hlo, const Side& side, + Value base, Value pid_k, + std::vector& boundary_checks) { + auto pid_batch = + b_.create(launch_config_.batch_program_id_dim); + + std::vector bounds; + std::vector strides; + std::vector offsets; + std::vector block_dims; + std::vector dim_order; + + auto add_dim = [&](const DimProperties& properties) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo, properties.index); + if (spec == nullptr) { + return; + } + const int64_t stride = spec->at(0).stride; + int64_t count = spec->at(0).count; + if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && + properties.index == dims_.out_lhs_noncontracting_dim_idx && + spec->size() == 1 && dims_.lhs_noncontracting_split.has_value()) { + // Dimension of the output produced by the non-contracting LHS one + // is logically split, major part is addressed using pid_batch. + count /= *dims_.lhs_noncontracting_split; + } + if (count % properties.block_size != 0) { + boundary_checks.push_back(bounds.size()); + } + bounds.push_back(Cst64(count)); + strides.push_back(Cst64(stride)); + offsets.push_back(properties.offset); + block_dims.push_back(properties.block_size); + dim_order.emplace(dim_order.begin(), dim_order.size()); + }; + for (const DimProperties& dim : side.tiled_dims) { + add_dim(dim); + } + + int64_t stride_batch = 0; + if (side.scope != TritonFusionAnalysis::Scope::RHS && + dims_.lhs_noncontracting_split) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo, side.tiled_dims[0].index); + if (spec != nullptr) { + if (spec->size() > 1) { + // Support one specific kind of output transpose that splits the + // dimension originating from the split LHS non-contracting one. + stride_batch = spec->at(1).stride; + } else { + // Because the major part of the split is implemented using the + // batch logic stride_batch is populated here as the stride of + // the minor part times its size. + stride_batch = spec->at(0).stride * + (spec->at(0).count / *dims_.lhs_noncontracting_split); + } + CHECK_NE(stride_batch, 0); + } + } else if (side.batch_dim_idx.has_value()) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo, *side.batch_dim_idx); + if (spec != nullptr) { + stride_batch = spec->at(0).stride; + CHECK_NE(stride_batch, 0); + } + } + if (stride_batch != 0) { + Value offset_batch = + b_.create(ConvertScalar(pid_batch), Cst(stride_batch)); + base = AddPtr(b_, base, offset_batch); + } + + if (dims_.out_split_k_dim_idx.has_value()) { + const TensorIterationSpec::DimIterationSpec* spec = analysis_.IterSpec( + TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims_.out_split_k_dim_idx); + if (spec != nullptr) { + int64_t stride_split_k = spec->at(0).stride; + Value offset_split_k = + b_.create(ConvertScalar(pid_k), Cst(stride_split_k)); + base = AddPtr(b_, base, offset_split_k); + } + } + + if (block_dims.empty()) { + return base; + } + return b_.create(base, bounds, strides, offsets, + block_dims, dim_order); + } + + private: + // Extend int32 indexes to int64, if necessary. + Value ConvertScalar(Value value) { + if (index_ty_.getIntOrFloatBitWidth() == 64) { + return b_.create(index_ty_, value); + } + return value; + } + + Value Cst(int64_t v) { return CreateConst(b_, index_ty_, v); } + + Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); } + + ImplicitLocOpBuilder& b_; + absl::string_view libdevice_path_; + const HloDotInstruction* dot_instr_; + Type index_ty_; + TritonFusionAnalysis analysis_; + MatMulDims dims_; + MatMulLaunchConfig launch_config_; + Type i32_ty_ = b_.getI32Type(); + Type i64_ty_ = b_.getI64Type(); +}; + } // namespace LaunchDimensions GetMatMulLaunchDimensions( @@ -964,7 +1161,6 @@ LaunchDimensions GetMatMulLaunchDimensions( } // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -// TODO(b/270937368): Split this up into smaller functions. Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, const TritonFusionAnalysis& analysis, const HloComputation* computation, mlir::triton::FuncOp fn, @@ -979,7 +1175,7 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX || ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX || ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k() > INT_MAX; - mlir::Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32); + Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32); const HloInstruction* root = dot_instr->parent()->root_instruction(); CHECK(!root->shape().IsTuple()); @@ -990,7 +1186,6 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name())); ImplicitLocOpBuilder b(loc, builder); Type i32_ty = b.getI32Type(); - Type i64_ty = b.getI64Type(); ValidateMatMulConfig(config, *dot_instr); const int split_k = config.split_k(); @@ -1002,62 +1197,38 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, const MatMulLaunchConfig launch_config(config, *dot_instr, dims); VLOG(6) << analysis.ToString(); - constexpr int group_m = 8; + MatMulEmitterHelper emitter(builder, libdevice_path, dot_instr, b, index_ty, + dims, launch_config, analysis); + constexpr int group_m = 8; const int64_t width = group_m * launch_config.grid_n; - auto pid_batch = - b.create(launch_config.batch_program_id_dim); + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; + auto pid_nc = b.create(launch_config.noncontracting_program_id_dim); auto pid_k = b.create(mt::ProgramIDDim::Z); - auto group_id = b.create(pid_nc, CreateConst(b, i32_ty, width)); - ma::ConstantOp group_m_op = CreateConst(b, i32_ty, group_m); + auto group_id = b.create(pid_nc, c32(width)); + ma::ConstantOp group_m_op = c32(group_m); auto first_pid_m = b.create(group_id, group_m_op); - auto sub0 = b.create(CreateConst(b, i32_ty, launch_config.grid_m), - first_pid_m); + auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); auto group_size = b.create( b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, group_m_op); - // Extend int32 indexes to int64, if necessary. - auto convert_scalar = [&](Value value) -> Value { - if (index_ty.getIntOrFloatBitWidth() == 64) { - return b.create(index_ty, value); - } - return value; - }; - auto pid_m = b.create(first_pid_m, b.create(pid_nc, group_size)); - auto pid_m_offset = - b.create(pid_m, CreateConst(b, i32_ty, block_m)); + auto pid_m_offset = b.create(pid_m, c32(block_m)); - auto pid_n = b.create( - b.create(pid_nc, CreateConst(b, i32_ty, width)), group_size); - auto pid_n_offset = - b.create(pid_n, CreateConst(b, i32_ty, block_n)); + auto pid_n = b.create(b.create(pid_nc, c32(width)), + group_size); + auto pid_n_offset = b.create(pid_n, c32(block_n)); - auto pid_k_offset = - b.create(pid_k, CreateConst(b, i32_ty, block_k)); + auto pid_k_offset = b.create(pid_k, c32(block_k)); + + mlir::FloatType acc_ty = emitter.GetDotAccumulatorType(); - Type dot_output_ty = TritonType(b, dot_instr->shape().element_type()); - // Data type of dot() immediate inputs. - Type dot_input_ty = b.getF32Type(); - { - const Type lhs_ty = - TritonType(b, dot_instr->operand(0)->shape().element_type()); - const Type rhs_ty = - TritonType(b, dot_instr->operand(1)->shape().element_type()); - CHECK(lhs_ty == rhs_ty); - dot_input_ty = lhs_ty; - } - // TODO(b/266862493): Accumulator can be integer too. - // Otherwise only f64 x f64 -> f64 uses f64 accumulator. - mlir::FloatType acc_ty = (dot_output_ty.isF64() && dot_input_ty.isF64()) - ? b.getF64Type() - : b.getF32Type(); ma::ConstantOp accumulator_init = CreateConst(b, acc_ty, 0, {block_m, block_n}); @@ -1066,11 +1237,6 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, absl::flat_hash_map iter_args_to_parameters; absl::flat_hash_map> iter_args_to_boundary_checks; - struct Side { - TritonFusionAnalysis::Scope scope; - std::vector tiled_dims; - std::optional batch_dim_idx; - }; Side lhs{TritonFusionAnalysis::Scope::LHS, /*tiled_dims=*/ {{dims.lhs_noncontracting_dim_idx, pid_m_offset, block_m}, @@ -1114,10 +1280,9 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, // Only the contracting dimensions are advanced. if (dim.index == (is_lhs ? dims.lhs_contracting_dim_idx : dims.rhs_contracting_dim_idx)) { - increments.push_back( - CreateConst(b, i32_ty, dim.block_size * split_k)); + increments.push_back(c32(dim.block_size * split_k)); } else { - increments.push_back(CreateConst(b, i32_ty, 0)); + increments.push_back(c32(0)); } } if (increments.empty()) { @@ -1129,15 +1294,8 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, } // Emit all operations of LHS and RHS scopes. - auto make_input = [&](Side& side, int64_t operand_index, auto& values) { - return *EmitScope( - b, libdevice_path, &analysis, side.scope, side.tiled_dims, - dot_instr->parent()->MakeInstructionPostOrderFrom( - const_cast(*dot_instr->operand(operand_index))), - values); - }; - Value dot_input_lhs = make_input(lhs, 0, values_lhs); - Value dot_input_rhs = make_input(rhs, 1, values_rhs); + Value dot_input_lhs = emitter.MakeInput(lhs, 0, values_lhs); + Value dot_input_rhs = emitter.MakeInput(rhs, 1, values_rhs); // Operation in the fusion before the dot can alter the elements of the // tiles that were zero masked during loads. These have to be zeroed here @@ -1186,151 +1344,35 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size() + analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size() + 1); - auto emit_tensor_pointer = - [&](const HloInstruction* hlo, const Side& side, Value base, - std::vector& boundary_checks) -> Value { - std::vector bounds; - std::vector strides; - std::vector offsets; - std::vector block_dims; - std::vector dim_order; - - auto add_dim = [&](const DimProperties& properties) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis.IterSpec(side.scope, hlo, properties.index); - if (spec == nullptr) { - return; - } - const int64_t stride = spec->at(0).stride; - int64_t count = spec->at(0).count; - if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && - properties.index == dims.out_lhs_noncontracting_dim_idx && - spec->size() == 1 && dims.lhs_noncontracting_split.has_value()) { - // Dimension of the output produced by the non-contracting LHS one - // is logically split, major part is addressed using pid_batch. - count /= *dims.lhs_noncontracting_split; - } - if (count % properties.block_size != 0) { - boundary_checks.push_back(bounds.size()); - } - bounds.push_back(CreateConst(b, i64_ty, count)); - strides.push_back(CreateConst(b, i64_ty, stride)); - offsets.push_back(properties.offset); - block_dims.push_back(properties.block_size); - dim_order.emplace(dim_order.begin(), dim_order.size()); - }; - for (const DimProperties& dim : side.tiled_dims) { - add_dim(dim); - } - - int64_t stride_batch = 0; - if (side.scope != TritonFusionAnalysis::Scope::RHS && - dims.lhs_noncontracting_split) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis.IterSpec(side.scope, hlo, side.tiled_dims[0].index); - if (spec != nullptr) { - if (spec->size() > 1) { - // Support one specific kind of output transpose that splits the - // dimension originating from the split LHS non-contracting one. - stride_batch = spec->at(1).stride; - } else { - // Because the major part of the split is implemented using the - // batch logic stride_batch is populated here as the stride of - // the minor part times its size. - stride_batch = spec->at(0).stride * - (spec->at(0).count / *dims.lhs_noncontracting_split); - } - CHECK_NE(stride_batch, 0); - } - } else if (side.batch_dim_idx.has_value()) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis.IterSpec(side.scope, hlo, *side.batch_dim_idx); - if (spec != nullptr) { - stride_batch = spec->at(0).stride; - CHECK_NE(stride_batch, 0); - } - } - if (stride_batch != 0) { - Value offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, index_ty, stride_batch)); - base = AddPtr(b, base, offset_batch); - } - - if (dims.out_split_k_dim_idx.has_value()) { - const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( - TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims.out_split_k_dim_idx); - if (spec != nullptr) { - int64_t stride_split_k = spec->at(0).stride; - Value offset_split_k = b.create( - convert_scalar(pid_k), CreateConst(b, index_ty, stride_split_k)); - base = AddPtr(b, base, offset_split_k); - } - } - - if (block_dims.empty()) { - return base; - } - return b.create(base, bounds, strides, offsets, - block_dims, dim_order); - }; for (const auto& side : {lhs, rhs}) { for (const HloInstruction* param : analysis.ScopeParameters(side.scope)) { CHECK(iter_args_to_parameters.insert({iter_args.size(), param}).second); - iter_args.push_back(emit_tensor_pointer( - param, side, fn.getArgument(param->parameter_number()), + iter_args.push_back(emitter.EmitTensorPointer( + param, side, fn.getArgument(param->parameter_number()), pid_k, iter_args_to_boundary_checks[iter_args.size()])); } } iter_args.push_back(accumulator_init); - Value acc_final = - b.create( - /*lowerBound=*/b.create(0, /*width=*/32), - /*upperBound=*/b.create(dims.k, /*width=*/32), - /*step=*/ - b.create(block_k * split_k, - /*width=*/32), - /*iterArgs=*/iter_args, body_builder) - .getResult(iter_args.size() - 1); + Value acc_final = b.create( + /*lowerBound=*/c32(0), + /*upperBound=*/c32(dims.k), + /*step=*/c32(block_k * split_k), + /*iterArgs=*/iter_args, body_builder) + .getResult(iter_args.size() - 1); absl::flat_hash_map values_out; values_out[dot_instr] = Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type())); - // Collect all instructions of the dot's output scope. - absl::flat_hash_set to_order; - { - std::queue to_add; - if (root != dot_instr) { - to_add.push(root); - } - while (!to_add.empty()) { - const HloInstruction* current = to_add.front(); - for (const HloInstruction* operand : current->operands()) { - if (!to_order.contains(operand)) { - if (operand != dot_instr) { - to_add.push(operand); - } - } - } - CHECK(to_order.insert(current).second); - to_add.pop(); - } - } - // Order them producers before consumers. - std::vector to_emit; - for (const HloInstruction* hlo : - dot_instr->parent()->MakeInstructionPostOrder()) { - if (to_order.contains(hlo)) { - to_emit.push_back(hlo); - } - } // Emit the output scope. - if (!to_emit.empty()) { + if (std::vector to_emit = + emitter.EpiloguePostOrderTransitiveOperands(root); + !to_emit.empty()) { for (const HloInstruction* parameter : analysis.ScopeParameters(TritonFusionAnalysis::Scope::OUTPUT)) { std::vector boundary_checks; - Value tensor_pointer = emit_tensor_pointer( - parameter, out, fn.getArgument(parameter->parameter_number()), + Value tensor_pointer = emitter.EmitTensorPointer( + parameter, out, fn.getArgument(parameter->parameter_number()), pid_k, boundary_checks); CHECK(values_out .insert({parameter, @@ -1349,9 +1391,9 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, const HloInstruction* producer = root->shape().IsTuple() ? root->operand(i) : root; std::vector boundary_checks; - Value tensor_pointer = emit_tensor_pointer( + Value tensor_pointer = emitter.EmitTensorPointer( producer, out, - fn.getArgument(i + dot_instr->parent()->num_parameters()), + fn.getArgument(i + dot_instr->parent()->num_parameters()), pid_k, boundary_checks); b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); -- GitLab