提交 4ae35f9f 编写于 作者: G George Karpenkov 提交者: TensorFlower Gardener

[NFC] [XLA:GPU] Split up Triton emitter into multiple functions.

PiperOrigin-RevId: 565024582
上级 af5aa0e5
......@@ -946,6 +946,203 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
2 + (config.split_k() > 1 ? 1 : 0) + num_batch_dims);
}
struct Side {
TritonFusionAnalysis::Scope scope;
std::vector<DimProperties> tiled_dims;
std::optional<int64_t> batch_dim_idx;
};
class MatMulEmitterHelper {
public:
MatMulEmitterHelper(mlir::OpBuilder builder, absl::string_view libdevice_path,
const HloDotInstruction* dot_instr,
ImplicitLocOpBuilder& b, Type index_ty, MatMulDims dims,
const MatMulLaunchConfig& launch_config,
const TritonFusionAnalysis& analysis)
: b_(b),
libdevice_path_(libdevice_path),
dot_instr_(dot_instr),
index_ty_(index_ty),
analysis_(analysis),
dims_(dims),
launch_config_(launch_config) {}
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir::FloatType GetDotAccumulatorType() {
Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type());
// Data type of dot() immediate inputs.
Type dot_input_ty = [&] {
const Type lhs_ty =
TritonType(b_, dot_instr_->operand(0)->shape().element_type());
const Type rhs_ty =
TritonType(b_, dot_instr_->operand(1)->shape().element_type());
CHECK(lhs_ty == rhs_ty);
return lhs_ty;
}();
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type()
: b_.getF32Type();
}
std::vector<const HloInstruction*> EpiloguePostOrderTransitiveOperands(
const HloInstruction* root) {
// Collect all instructions of the dot's output scope.
absl::flat_hash_set<const HloInstruction*> to_order;
{
std::queue<const HloInstruction*> to_add;
if (root != dot_instr_) {
to_add.push(root);
}
while (!to_add.empty()) {
const HloInstruction* current = to_add.front();
for (const HloInstruction* operand : current->operands()) {
if (!to_order.contains(operand)) {
if (operand != dot_instr_) {
to_add.push(operand);
}
}
}
CHECK(to_order.insert(current).second);
to_add.pop();
}
}
// Order them producers before consumers.
std::vector<const HloInstruction*> to_emit;
for (const HloInstruction* hlo :
dot_instr_->parent()->MakeInstructionPostOrder()) {
if (to_order.contains(hlo)) {
to_emit.push_back(hlo);
}
}
return to_emit;
}
Value MakeInput(Side& side, int64_t operand_index,
absl::flat_hash_map<const HloInstruction*, Value>& values) {
return *EmitScope(
b_, libdevice_path_, &analysis_, side.scope, side.tiled_dims,
dot_instr_->parent()->MakeInstructionPostOrderFrom(
const_cast<HloInstruction&>(*dot_instr_->operand(operand_index))),
values);
}
Value EmitTensorPointer(const HloInstruction* hlo, const Side& side,
Value base, Value pid_k,
std::vector<int32_t>& boundary_checks) {
auto pid_batch =
b_.create<mt::GetProgramIdOp>(launch_config_.batch_program_id_dim);
std::vector<Value> bounds;
std::vector<Value> strides;
std::vector<Value> offsets;
std::vector<int32_t> block_dims;
std::vector<int32_t> dim_order;
auto add_dim = [&](const DimProperties& properties) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis_.IterSpec(side.scope, hlo, properties.index);
if (spec == nullptr) {
return;
}
const int64_t stride = spec->at(0).stride;
int64_t count = spec->at(0).count;
if (side.scope == TritonFusionAnalysis::Scope::OUTPUT &&
properties.index == dims_.out_lhs_noncontracting_dim_idx &&
spec->size() == 1 && dims_.lhs_noncontracting_split.has_value()) {
// Dimension of the output produced by the non-contracting LHS one
// is logically split, major part is addressed using pid_batch.
count /= *dims_.lhs_noncontracting_split;
}
if (count % properties.block_size != 0) {
boundary_checks.push_back(bounds.size());
}
bounds.push_back(Cst64(count));
strides.push_back(Cst64(stride));
offsets.push_back(properties.offset);
block_dims.push_back(properties.block_size);
dim_order.emplace(dim_order.begin(), dim_order.size());
};
for (const DimProperties& dim : side.tiled_dims) {
add_dim(dim);
}
int64_t stride_batch = 0;
if (side.scope != TritonFusionAnalysis::Scope::RHS &&
dims_.lhs_noncontracting_split) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis_.IterSpec(side.scope, hlo, side.tiled_dims[0].index);
if (spec != nullptr) {
if (spec->size() > 1) {
// Support one specific kind of output transpose that splits the
// dimension originating from the split LHS non-contracting one.
stride_batch = spec->at(1).stride;
} else {
// Because the major part of the split is implemented using the
// batch logic stride_batch is populated here as the stride of
// the minor part times its size.
stride_batch = spec->at(0).stride *
(spec->at(0).count / *dims_.lhs_noncontracting_split);
}
CHECK_NE(stride_batch, 0);
}
} else if (side.batch_dim_idx.has_value()) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis_.IterSpec(side.scope, hlo, *side.batch_dim_idx);
if (spec != nullptr) {
stride_batch = spec->at(0).stride;
CHECK_NE(stride_batch, 0);
}
}
if (stride_batch != 0) {
Value offset_batch =
b_.create<ma::MulIOp>(ConvertScalar(pid_batch), Cst(stride_batch));
base = AddPtr(b_, base, offset_batch);
}
if (dims_.out_split_k_dim_idx.has_value()) {
const TensorIterationSpec::DimIterationSpec* spec = analysis_.IterSpec(
TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims_.out_split_k_dim_idx);
if (spec != nullptr) {
int64_t stride_split_k = spec->at(0).stride;
Value offset_split_k =
b_.create<ma::MulIOp>(ConvertScalar(pid_k), Cst(stride_split_k));
base = AddPtr(b_, base, offset_split_k);
}
}
if (block_dims.empty()) {
return base;
}
return b_.create<mt::MakeTensorPtrOp>(base, bounds, strides, offsets,
block_dims, dim_order);
}
private:
// Extend int32 indexes to int64, if necessary.
Value ConvertScalar(Value value) {
if (index_ty_.getIntOrFloatBitWidth() == 64) {
return b_.create<ma::ExtSIOp>(index_ty_, value);
}
return value;
}
Value Cst(int64_t v) { return CreateConst(b_, index_ty_, v); }
Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); }
ImplicitLocOpBuilder& b_;
absl::string_view libdevice_path_;
const HloDotInstruction* dot_instr_;
Type index_ty_;
TritonFusionAnalysis analysis_;
MatMulDims dims_;
MatMulLaunchConfig launch_config_;
Type i32_ty_ = b_.getI32Type();
Type i64_ty_ = b_.getI64Type();
};
} // namespace
LaunchDimensions GetMatMulLaunchDimensions(
......@@ -964,7 +1161,6 @@ LaunchDimensions GetMatMulLaunchDimensions(
}
// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n].
// TODO(b/270937368): Split this up into smaller functions.
Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
const TritonFusionAnalysis& analysis,
const HloComputation* computation, mlir::triton::FuncOp fn,
......@@ -979,7 +1175,7 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX ||
ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX ||
ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k() > INT_MAX;
mlir::Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32);
Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32);
const HloInstruction* root = dot_instr->parent()->root_instruction();
CHECK(!root->shape().IsTuple());
......@@ -990,7 +1186,6 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name()));
ImplicitLocOpBuilder b(loc, builder);
Type i32_ty = b.getI32Type();
Type i64_ty = b.getI64Type();
ValidateMatMulConfig(config, *dot_instr);
const int split_k = config.split_k();
......@@ -1002,62 +1197,38 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
const MatMulLaunchConfig launch_config(config, *dot_instr, dims);
VLOG(6) << analysis.ToString();
constexpr int group_m = 8;
MatMulEmitterHelper emitter(builder, libdevice_path, dot_instr, b, index_ty,
dims, launch_config, analysis);
constexpr int group_m = 8;
const int64_t width = group_m * launch_config.grid_n;
auto pid_batch =
b.create<mt::GetProgramIdOp>(launch_config.batch_program_id_dim);
auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); };
auto pid_nc =
b.create<mt::GetProgramIdOp>(launch_config.noncontracting_program_id_dim);
auto pid_k = b.create<mt::GetProgramIdOp>(mt::ProgramIDDim::Z);
auto group_id = b.create<ma::DivSIOp>(pid_nc, CreateConst(b, i32_ty, width));
ma::ConstantOp group_m_op = CreateConst(b, i32_ty, group_m);
auto group_id = b.create<ma::DivSIOp>(pid_nc, c32(width));
ma::ConstantOp group_m_op = c32(group_m);
auto first_pid_m = b.create<ma::MulIOp>(group_id, group_m_op);
auto sub0 = b.create<ma::SubIOp>(CreateConst(b, i32_ty, launch_config.grid_m),
first_pid_m);
auto sub0 = b.create<ma::SubIOp>(c32(launch_config.grid_m), first_pid_m);
auto group_size = b.create<ma::SelectOp>(
b.create<ma::CmpIOp>(ma::CmpIPredicate::slt, sub0, group_m_op), sub0,
group_m_op);
// Extend int32 indexes to int64, if necessary.
auto convert_scalar = [&](Value value) -> Value {
if (index_ty.getIntOrFloatBitWidth() == 64) {
return b.create<ma::ExtSIOp>(index_ty, value);
}
return value;
};
auto pid_m = b.create<ma::AddIOp>(first_pid_m,
b.create<ma::RemSIOp>(pid_nc, group_size));
auto pid_m_offset =
b.create<ma::MulIOp>(pid_m, CreateConst(b, i32_ty, block_m));
auto pid_m_offset = b.create<ma::MulIOp>(pid_m, c32(block_m));
auto pid_n = b.create<ma::DivSIOp>(
b.create<ma::RemSIOp>(pid_nc, CreateConst(b, i32_ty, width)), group_size);
auto pid_n_offset =
b.create<ma::MulIOp>(pid_n, CreateConst(b, i32_ty, block_n));
auto pid_n = b.create<ma::DivSIOp>(b.create<ma::RemSIOp>(pid_nc, c32(width)),
group_size);
auto pid_n_offset = b.create<ma::MulIOp>(pid_n, c32(block_n));
auto pid_k_offset =
b.create<ma::MulIOp>(pid_k, CreateConst(b, i32_ty, block_k));
auto pid_k_offset = b.create<ma::MulIOp>(pid_k, c32(block_k));
mlir::FloatType acc_ty = emitter.GetDotAccumulatorType();
Type dot_output_ty = TritonType(b, dot_instr->shape().element_type());
// Data type of dot() immediate inputs.
Type dot_input_ty = b.getF32Type();
{
const Type lhs_ty =
TritonType(b, dot_instr->operand(0)->shape().element_type());
const Type rhs_ty =
TritonType(b, dot_instr->operand(1)->shape().element_type());
CHECK(lhs_ty == rhs_ty);
dot_input_ty = lhs_ty;
}
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir::FloatType acc_ty = (dot_output_ty.isF64() && dot_input_ty.isF64())
? b.getF64Type()
: b.getF32Type();
ma::ConstantOp accumulator_init =
CreateConst(b, acc_ty, 0, {block_m, block_n});
......@@ -1066,11 +1237,6 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
absl::flat_hash_map<int, const HloInstruction*> iter_args_to_parameters;
absl::flat_hash_map<int, std::vector<int32_t>> iter_args_to_boundary_checks;
struct Side {
TritonFusionAnalysis::Scope scope;
std::vector<DimProperties> tiled_dims;
std::optional<int64_t> batch_dim_idx;
};
Side lhs{TritonFusionAnalysis::Scope::LHS,
/*tiled_dims=*/
{{dims.lhs_noncontracting_dim_idx, pid_m_offset, block_m},
......@@ -1114,10 +1280,9 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
// Only the contracting dimensions are advanced.
if (dim.index == (is_lhs ? dims.lhs_contracting_dim_idx
: dims.rhs_contracting_dim_idx)) {
increments.push_back(
CreateConst(b, i32_ty, dim.block_size * split_k));
increments.push_back(c32(dim.block_size * split_k));
} else {
increments.push_back(CreateConst(b, i32_ty, 0));
increments.push_back(c32(0));
}
}
if (increments.empty()) {
......@@ -1129,15 +1294,8 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
}
// Emit all operations of LHS and RHS scopes.
auto make_input = [&](Side& side, int64_t operand_index, auto& values) {
return *EmitScope(
b, libdevice_path, &analysis, side.scope, side.tiled_dims,
dot_instr->parent()->MakeInstructionPostOrderFrom(
const_cast<HloInstruction&>(*dot_instr->operand(operand_index))),
values);
};
Value dot_input_lhs = make_input(lhs, 0, values_lhs);
Value dot_input_rhs = make_input(rhs, 1, values_rhs);
Value dot_input_lhs = emitter.MakeInput(lhs, 0, values_lhs);
Value dot_input_rhs = emitter.MakeInput(rhs, 1, values_rhs);
// Operation in the fusion before the dot can alter the elements of the
// tiles that were zero masked during loads. These have to be zeroed here
......@@ -1186,151 +1344,35 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size() +
analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size() + 1);
auto emit_tensor_pointer =
[&](const HloInstruction* hlo, const Side& side, Value base,
std::vector<int32_t>& boundary_checks) -> Value {
std::vector<Value> bounds;
std::vector<Value> strides;
std::vector<Value> offsets;
std::vector<int32_t> block_dims;
std::vector<int32_t> dim_order;
auto add_dim = [&](const DimProperties& properties) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis.IterSpec(side.scope, hlo, properties.index);
if (spec == nullptr) {
return;
}
const int64_t stride = spec->at(0).stride;
int64_t count = spec->at(0).count;
if (side.scope == TritonFusionAnalysis::Scope::OUTPUT &&
properties.index == dims.out_lhs_noncontracting_dim_idx &&
spec->size() == 1 && dims.lhs_noncontracting_split.has_value()) {
// Dimension of the output produced by the non-contracting LHS one
// is logically split, major part is addressed using pid_batch.
count /= *dims.lhs_noncontracting_split;
}
if (count % properties.block_size != 0) {
boundary_checks.push_back(bounds.size());
}
bounds.push_back(CreateConst(b, i64_ty, count));
strides.push_back(CreateConst(b, i64_ty, stride));
offsets.push_back(properties.offset);
block_dims.push_back(properties.block_size);
dim_order.emplace(dim_order.begin(), dim_order.size());
};
for (const DimProperties& dim : side.tiled_dims) {
add_dim(dim);
}
int64_t stride_batch = 0;
if (side.scope != TritonFusionAnalysis::Scope::RHS &&
dims.lhs_noncontracting_split) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis.IterSpec(side.scope, hlo, side.tiled_dims[0].index);
if (spec != nullptr) {
if (spec->size() > 1) {
// Support one specific kind of output transpose that splits the
// dimension originating from the split LHS non-contracting one.
stride_batch = spec->at(1).stride;
} else {
// Because the major part of the split is implemented using the
// batch logic stride_batch is populated here as the stride of
// the minor part times its size.
stride_batch = spec->at(0).stride *
(spec->at(0).count / *dims.lhs_noncontracting_split);
}
CHECK_NE(stride_batch, 0);
}
} else if (side.batch_dim_idx.has_value()) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis.IterSpec(side.scope, hlo, *side.batch_dim_idx);
if (spec != nullptr) {
stride_batch = spec->at(0).stride;
CHECK_NE(stride_batch, 0);
}
}
if (stride_batch != 0) {
Value offset_batch = b.create<ma::MulIOp>(
convert_scalar(pid_batch), CreateConst(b, index_ty, stride_batch));
base = AddPtr(b, base, offset_batch);
}
if (dims.out_split_k_dim_idx.has_value()) {
const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec(
TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims.out_split_k_dim_idx);
if (spec != nullptr) {
int64_t stride_split_k = spec->at(0).stride;
Value offset_split_k = b.create<ma::MulIOp>(
convert_scalar(pid_k), CreateConst(b, index_ty, stride_split_k));
base = AddPtr(b, base, offset_split_k);
}
}
if (block_dims.empty()) {
return base;
}
return b.create<mt::MakeTensorPtrOp>(base, bounds, strides, offsets,
block_dims, dim_order);
};
for (const auto& side : {lhs, rhs}) {
for (const HloInstruction* param : analysis.ScopeParameters(side.scope)) {
CHECK(iter_args_to_parameters.insert({iter_args.size(), param}).second);
iter_args.push_back(emit_tensor_pointer(
param, side, fn.getArgument(param->parameter_number()),
iter_args.push_back(emitter.EmitTensorPointer(
param, side, fn.getArgument(param->parameter_number()), pid_k,
iter_args_to_boundary_checks[iter_args.size()]));
}
}
iter_args.push_back(accumulator_init);
Value acc_final =
b.create<mlir::scf::ForOp>(
/*lowerBound=*/b.create<ma::ConstantIntOp>(0, /*width=*/32),
/*upperBound=*/b.create<ma::ConstantIntOp>(dims.k, /*width=*/32),
/*step=*/
b.create<ma::ConstantIntOp>(block_k * split_k,
/*width=*/32),
/*iterArgs=*/iter_args, body_builder)
.getResult(iter_args.size() - 1);
Value acc_final = b.create<mlir::scf::ForOp>(
/*lowerBound=*/c32(0),
/*upperBound=*/c32(dims.k),
/*step=*/c32(block_k * split_k),
/*iterArgs=*/iter_args, body_builder)
.getResult(iter_args.size() - 1);
absl::flat_hash_map<const HloInstruction*, Value> values_out;
values_out[dot_instr] =
Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type()));
// Collect all instructions of the dot's output scope.
absl::flat_hash_set<const HloInstruction*> to_order;
{
std::queue<const HloInstruction*> to_add;
if (root != dot_instr) {
to_add.push(root);
}
while (!to_add.empty()) {
const HloInstruction* current = to_add.front();
for (const HloInstruction* operand : current->operands()) {
if (!to_order.contains(operand)) {
if (operand != dot_instr) {
to_add.push(operand);
}
}
}
CHECK(to_order.insert(current).second);
to_add.pop();
}
}
// Order them producers before consumers.
std::vector<const HloInstruction*> to_emit;
for (const HloInstruction* hlo :
dot_instr->parent()->MakeInstructionPostOrder()) {
if (to_order.contains(hlo)) {
to_emit.push_back(hlo);
}
}
// Emit the output scope.
if (!to_emit.empty()) {
if (std::vector<const HloInstruction*> to_emit =
emitter.EpiloguePostOrderTransitiveOperands(root);
!to_emit.empty()) {
for (const HloInstruction* parameter :
analysis.ScopeParameters(TritonFusionAnalysis::Scope::OUTPUT)) {
std::vector<int32_t> boundary_checks;
Value tensor_pointer = emit_tensor_pointer(
parameter, out, fn.getArgument(parameter->parameter_number()),
Value tensor_pointer = emitter.EmitTensorPointer(
parameter, out, fn.getArgument(parameter->parameter_number()), pid_k,
boundary_checks);
CHECK(values_out
.insert({parameter,
......@@ -1349,9 +1391,9 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
const HloInstruction* producer =
root->shape().IsTuple() ? root->operand(i) : root;
std::vector<int32_t> boundary_checks;
Value tensor_pointer = emit_tensor_pointer(
Value tensor_pointer = emitter.EmitTensorPointer(
producer, out,
fn.getArgument(i + dot_instr->parent()->num_parameters()),
fn.getArgument(i + dot_instr->parent()->num_parameters()), pid_k,
boundary_checks);
b.create<mt::StoreOp>(tensor_pointer, values_out[producer], boundary_checks,
mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册