提交 af5aa0e5 编写于 作者: T Tamás Danyluk 提交者: TensorFlower Gardener

[XLA:GPU][NFC] Cosmetic changes to gemm_rewriter_triton.cc

Unify iterator names to *_it;
Use auto* instead of auto if the type is a pointer type.

PiperOrigin-RevId: 565016627
上级 1803461c
...@@ -530,25 +530,26 @@ FusionDecision FusionContext::RequireSupportedDimOrder( ...@@ -530,25 +530,26 @@ FusionDecision FusionContext::RequireSupportedDimOrder(
const Fragments& tensor_dim_fragments = order.TensorFragmentsOrder(); const Fragments& tensor_dim_fragments = order.TensorFragmentsOrder();
for (const auto& [dim_index, dim_fragments] : order.DimFragmentsOrders()) { for (const auto& [dim_index, dim_fragments] : order.DimFragmentsOrders()) {
int split_counter = -1; int split_counter = -1;
auto fragment = dim_fragments.cbegin(); auto fragment_it = dim_fragments.cbegin();
while (true) { while (true) {
if (fragment == dim_fragments.cend()) { if (fragment_it == dim_fragments.cend()) {
break; break;
} }
int64_t grouped_size = tensor_dim_fragments[*fragment].size(); int64_t grouped_size = tensor_dim_fragments[*fragment_it].size();
// Gather contiguous fragments. // Gather contiguous fragments.
while ((fragment + 1) != dim_fragments.cend() && while ((fragment_it + 1) != dim_fragments.cend() &&
*(fragment + 1) == *fragment + 1) { *(fragment_it + 1) == *fragment_it + 1) {
++fragment; ++fragment_it;
grouped_size *= tensor_dim_fragments[*fragment].size(); grouped_size *= tensor_dim_fragments[*fragment_it].size();
} }
if (grouped_size == 1) { if (grouped_size == 1) {
++fragment; ++fragment_it;
continue; continue;
} }
if (fragment != dim_fragments.cbegin() && *fragment < *(fragment - 1)) { if (fragment_it != dim_fragments.cbegin() &&
*fragment_it < *(fragment_it - 1)) {
return "Transpose within a dimension."; return "Transpose within a dimension.";
} }
...@@ -570,7 +571,7 @@ FusionDecision FusionContext::RequireSupportedDimOrder( ...@@ -570,7 +571,7 @@ FusionDecision FusionContext::RequireSupportedDimOrder(
} }
} }
++fragment; ++fragment_it;
} }
} }
return FusionDecision{}; return FusionDecision{};
...@@ -811,7 +812,7 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( ...@@ -811,7 +812,7 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp(
// Temporary storage for fragments of new dimensions created by reductions. // Temporary storage for fragments of new dimensions created by reductions.
std::list<Fragment> new_fragments; std::list<Fragment> new_fragments;
if (hlo->opcode() == HloOpcode::kTranspose) { if (hlo->opcode() == HloOpcode::kTranspose) {
const auto transpose = Cast<HloTransposeInstruction>(hlo); const auto* transpose = Cast<HloTransposeInstruction>(hlo);
std::vector<int64_t> permutation(transpose->dimensions().cbegin(), std::vector<int64_t> permutation(transpose->dimensions().cbegin(),
transpose->dimensions().cend()); transpose->dimensions().cend());
if (direction == TransformDirection::kInputToOutput) { if (direction == TransformDirection::kInputToOutput) {
...@@ -822,13 +823,13 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( ...@@ -822,13 +823,13 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp(
dst_logical[permutation[i]] = src_logical[i]; dst_logical[permutation[i]] = src_logical[i];
} }
} else if (hlo->opcode() == HloOpcode::kBroadcast) { } else if (hlo->opcode() == HloOpcode::kBroadcast) {
const auto broadcast = Cast<HloBroadcastInstruction>(hlo); const auto* broadcast = Cast<HloBroadcastInstruction>(hlo);
dst_logical.resize(broadcast->dimensions().size()); dst_logical.resize(broadcast->dimensions().size());
for (int i = 0; i < broadcast->dimensions().size(); ++i) { for (int i = 0; i < broadcast->dimensions().size(); ++i) {
dst_logical[i] = src_logical[broadcast->dimensions()[i]]; dst_logical[i] = src_logical[broadcast->dimensions()[i]];
} }
} else if (hlo->opcode() == HloOpcode::kReduce) { } else if (hlo->opcode() == HloOpcode::kReduce) {
const auto reduce = Cast<HloReduceInstruction>(hlo); const auto* reduce = Cast<HloReduceInstruction>(hlo);
dst_logical.resize(src_logical.size() + reduce->dimensions().size()); dst_logical.resize(src_logical.size() + reduce->dimensions().size());
if (reduce->dimensions().size() != 1) { if (reduce->dimensions().size() != 1) {
return FusionDecision("Unsupported reduction."); return FusionDecision("Unsupported reduction.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册