提交 2c8798d2 编写于 作者: J Johannes Reifferscheid 提交者: TensorFlower Gardener

NFC: Extract LaunchDimension computation from Triton codegen.

PiperOrigin-RevId: 564686943
上级 4aa2ee61
......@@ -934,14 +934,23 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
} // namespace
LaunchDimensions GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation,
const AutotuneResult::TritonGemmKey& config) {
const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>(
hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot));
const MatMulDims dims(config, *dot_instr, analysis);
const MatMulLaunchConfig launch_config(config, *dot_instr, dims);
return launch_config.launch_dims;
}
// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n].
// TODO(b/270937368): Split this up into smaller functions.
StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
absl::string_view libdevice_path,
const HloComputation* computation,
mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget) {
Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
const TritonFusionAnalysis& analysis,
const HloComputation* computation, mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget) {
const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>(
hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot));
// Use 32-bit indexing if addressing any of the inputs or the output (which
......@@ -970,9 +979,6 @@ StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
const int block_k = config.block_k();
const int block_n = config.block_n();
TF_ASSIGN_OR_RETURN(
const TritonFusionAnalysis analysis,
TritonFusionAnalysis::Execute(*dot_instr->parent(), split_k));
const MatMulDims dims(config, *dot_instr, analysis);
const MatMulLaunchConfig launch_config(config, *dot_instr, dims);
VLOG(6) << analysis.ToString();
......@@ -1331,15 +1337,29 @@ StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
b.create<mt::StoreOp>(tensor_pointer, values_out[producer], boundary_checks,
mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL);
}
return launch_config.launch_dims;
return OkStatus();
}
StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
absl::string_view libdevice_path,
const HloComputation* computation,
mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int) {
LaunchDimensions GetSoftMaxLaunchDimensions(
const TritonFusionAnalysis&, const HloComputation* computation,
const AutotuneResult::TritonGemmKey& config) {
const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode(
*computation, HloOpcode::kReduce);
CHECK_NE(reduce, nullptr);
const Shape& reduce_input_shape = reduce->operand(0)->shape();
int num_rows = 1;
for (int minor_axis = 1; minor_axis < reduce_input_shape.rank();
++minor_axis) {
num_rows *= reduce_input_shape.dimensions_minor(minor_axis);
}
return {{num_rows, 1, 1}, {config.num_warps() * WarpSize(), 1, 1}};
}
Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path,
const TritonFusionAnalysis& analysis,
const HloComputation* computation, mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config, int) {
const HloInstruction* root = computation->root_instruction();
auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name()));
ImplicitLocOpBuilder b(loc, builder);
......@@ -1377,10 +1397,6 @@ StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
block_row *= 2;
}
int num_rows = 1;
for (int minor_axis = 1; minor_axis < reduce_input_shape.rank(); ++minor_axis)
num_rows *= reduce_input_shape.dimensions_minor(minor_axis);
Value row_index = b.create<ma::ExtSIOp>(
b.getI64Type(), b.create<mt::GetProgramIdOp>(mt::ProgramIDDim::X));
Value row_stride = CreateConst(b, b.getI32Type(), row_len);
......@@ -1404,8 +1420,6 @@ StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
}
values_out[computation->parameter_instruction(0)] = EmitParameterLoad(
b, make_tensor_pointer(fn.getArgument(0)), boundary_checks);
TF_ASSIGN_OR_RETURN(const auto analysis,
TritonFusionAnalysis::Execute(*computation));
// Dimension 0 is the reduced one by construction and it's the only one
// present in the tile shapes.
std::vector<DimProperties> tiled_dims = {{0, row_index, block_row}};
......@@ -1418,11 +1432,7 @@ StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
b.create<mt::StoreOp>(make_tensor_pointer(fn.getArgument(1)), result,
std::vector<int32_t>{0}, mt::CacheModifier::NONE,
mt::EvictionPolicy::NORMAL);
const LaunchDimensions launch_dimensions{
{num_rows, 1, 1}, {config.num_warps() * WarpSize(), 1, 1}};
return launch_dimensions;
return OkStatus();
}
// Simplified copy of translateLLVMToLLVMIR which in addition takes
......@@ -1463,7 +1473,8 @@ StatusOr<LaunchDimensions> TritonWrapper(
absl::string_view fusion_kind, const se::CudaComputeCapability& cc,
const GpuDeviceInfo& device_info,
const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
LaunchDimensionsGenerator generator, mlir::MLIRContext& mlir_context) {
LaunchDimensionsGenerator launch_dims_generator, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context) {
if (fusion_kind == kTritonGemmFusionKind) {
// This is a heuristic that serves as a proxy for register usage and code
// size.
......@@ -1537,8 +1548,13 @@ StatusOr<LaunchDimensions> TritonWrapper(
.debug_options()
.xla_gpu_cuda_data_dir());
TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
generator(b, libdevice_path, hlo_computation, fn, config,
TF_ASSIGN_OR_RETURN(
auto analysis,
fusion_kind == kTritonGemmFusionKind
? TritonFusionAnalysis::Execute(*hlo_computation, config.split_k())
: TritonFusionAnalysis::Execute(*hlo_computation));
TF_RETURN_IF_ERROR(ir_emitter(b, libdevice_path, analysis, hlo_computation,
fn, config,
device_info.shared_memory_per_block_optin));
b.create<mt::ReturnOp>(loc);
......@@ -1613,7 +1629,6 @@ StatusOr<LaunchDimensions> TritonWrapper(
if (shared_mem_bytes > device_info.shared_memory_per_block_optin) {
return ResourceExhausted("Shared memory size limit exceeded.");
}
launch_dimensions.SetSharedMemBytes(shared_mem_bytes);
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::Module> ll_triton_module,
TranslateLLVMToLLVMIR(&llvm_module->getContext(),
......@@ -1630,6 +1645,9 @@ StatusOr<LaunchDimensions> TritonWrapper(
llvm::Linker::Flags::OverrideFromSrc));
LogAndVerify(llvm_module);
LaunchDimensions launch_dimensions =
launch_dims_generator(analysis, hlo_computation, config);
launch_dimensions.SetSharedMemBytes(shared_mem_bytes);
return launch_dimensions;
}
......
......@@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/gemm_rewriter_triton.h"
#include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/statusor.h"
......@@ -30,26 +31,36 @@ limitations under the License.
namespace xla {
namespace gpu {
// Compute the launch dimensions for the given Triton MatMul.
LaunchDimensions GetMatMulLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation,
const AutotuneResult::TritonGemmKey& config);
// Use tiling and execution parameters from 'config'.
StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder b,
absl::string_view libdevice_path,
const HloComputation* computation,
mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget);
Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path,
const TritonFusionAnalysis& analysis,
const HloComputation* computation, mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget);
// Compute the launch dimensions for the given Triton SoftMax.
LaunchDimensions GetSoftMaxLaunchDimensions(
const TritonFusionAnalysis& analysis, const HloComputation* computation,
const AutotuneResult::TritonGemmKey& config);
// Generate Softmax in Triton IR inside 'fn'.
// Use execution parameters from 'config'.
StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder b,
absl::string_view libdevice_path,
const HloComputation* computation,
mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget);
Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path,
const TritonFusionAnalysis& analysis,
const HloComputation* computation, mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget);
using LaunchDimensionsGenerator = std::function<StatusOr<LaunchDimensions>(
mlir::OpBuilder, absl::string_view, const HloComputation*,
mlir::triton::FuncOp, const AutotuneResult::TritonGemmKey&, int)>;
using LaunchDimensionsGenerator = std::function<LaunchDimensions(
const TritonFusionAnalysis&, const HloComputation*,
const AutotuneResult::TritonGemmKey&)>;
using TritonIrEmitter = std::function<Status(
mlir::OpBuilder, absl::string_view, const TritonFusionAnalysis& analysis,
const HloComputation*, mlir::triton::FuncOp,
const AutotuneResult::TritonGemmKey&, int)>;
// Generate Triton IR by running the provided generator, compile it into LLVM IR
// and return launch dimensions.
......@@ -59,7 +70,9 @@ StatusOr<LaunchDimensions> TritonWrapper(
absl::string_view fusion_kind, const se::CudaComputeCapability& cc,
const GpuDeviceInfo& device_info,
const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
LaunchDimensionsGenerator generator, mlir::MLIRContext& mlir_context);
LaunchDimensionsGenerator launch_dims_generator, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context);
} // namespace gpu
} // namespace xla
......
......@@ -215,7 +215,8 @@ ENTRY entry {
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &MatMul, mlir_context),
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context),
tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED,
"Shared memory size limit exceeded."));
......@@ -228,7 +229,8 @@ ENTRY entry {
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &MatMul, mlir_context));
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context));
// Use optin shared memory which is > shared_memory_per_block.
EXPECT_GT(launch_dimensions.SharedMemBytes(),
dev_info.shared_memory_per_block);
......@@ -642,7 +644,8 @@ ENTRY entry {
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &MatMul, mlir_context),
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context),
tsl::testing::StatusIs(
tsl::error::RESOURCE_EXHAUSTED,
"Tiling complexity heuristic exceeded: 147456 > 9000"));
......@@ -655,7 +658,8 @@ ENTRY entry {
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
se::CudaComputeCapability{se::CudaComputeCapability::AMPERE,
/*minor=*/0},
dev_info, config, &llvm_module, &MatMul, mlir_context)
dev_info, config, &llvm_module, &GetMatMulLaunchDimensions,
&EmitMatMul, mlir_context)
.status());
}
......@@ -1438,8 +1442,8 @@ ENTRY e {
const LaunchDimensions launch_dimensions,
TritonWrapper("test_fn", triton_dot_computation, kTritonGemmFusionKind,
GetCudaComputeCapability(), dev_info,
config.triton_gemm_config(), &llvm_module, &MatMul,
mlir_context));
config.triton_gemm_config(), &llvm_module,
&GetMatMulLaunchDimensions, &EmitMatMul, mlir_context));
// The config is chosen so that the used memory size is slightly above the
// 48 kB boundary of standard / optin shared memory so that any GPU that
// has the optin one should be able to execute the test.
......
......@@ -1745,7 +1745,8 @@ Status IrEmitterUnnested::EmitTritonFusion(
TritonWrapper(impl_fn_name, hlo_computation, kTritonSoftmaxFusionKind,
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_device_info(), config, module_,
&SoftMax, *ir_emitter_context_->mlir_context()));
&GetSoftMaxLaunchDimensions, &EmitSoftMax,
*ir_emitter_context_->mlir_context()));
} else { // Must be a MatMul
CHECK_EQ(fusion_kind, kTritonGemmFusionKind);
TF_ASSIGN_OR_RETURN(
......@@ -1753,7 +1754,8 @@ Status IrEmitterUnnested::EmitTritonFusion(
TritonWrapper(impl_fn_name, hlo_computation, kTritonGemmFusionKind,
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_device_info(), config, module_,
&MatMul, *ir_emitter_context_->mlir_context()));
&GetMatMulLaunchDimensions, &EmitMatMul,
*ir_emitter_context_->mlir_context()));
}
llvm::Function* impl_fn = module_->getFunction(impl_fn_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册