提交 8cf6ece8 编写于 作者: J Johannes Reifferscheid 提交者: TensorFlower Gardener

Fusion analysis: Remove remaining dependencies on fusion_.

PiperOrigin-RevId: 563691979
上级 453a8537
......@@ -323,7 +323,7 @@ Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder,
StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion,
const HloFusionAnalysis& fusion_analysis,
const HloComputation* fused_computation,
ElementalIrEmitter& elemental_emitter, KernelReuseCache& kernel_cache,
int output_index, llvm::IRBuilder<>* builder) {
auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
......@@ -349,9 +349,6 @@ StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
auto builder_fn = [&](std::vector<llvm_ir::IrArray> inputs,
std::vector<llvm_ir::IrArray> outputs) -> Status {
const HloComputation* fused_computation =
fusion_analysis.fused_computation();
FusedIrEmitter fused_emitter(elemental_emitter);
for (int i = 0; i < fused_computation->num_parameters(); i++) {
fused_emitter.BindGenerator(
......@@ -375,11 +372,11 @@ StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
return OkStatus();
};
return BuildKernelThunkForFusion(
ir_emitter_context, kernel_cache, fusion,
fusion_analysis.fused_computation(), launch_dimensions,
/*discriminator=*/
absl::StrCat("init_", output_index), builder_fn, builder);
return BuildKernelThunkForFusion(ir_emitter_context, kernel_cache, fusion,
fused_computation, launch_dimensions,
/*discriminator=*/
absl::StrCat("init_", output_index),
builder_fn, builder);
}
// Gets the output offset as calculated from thread_id.x (to be applied to the
......@@ -972,14 +969,17 @@ StatusOr<FusionEmissionResult> ReductionFusion::Emit(
VLOG(3) << "Launch dimensions of "
<< mlir::mhlo::GetDebugNameFromLocation(fusion_op.getLoc()) << ": "
<< launch_dimensions.ToString();
const HloComputation* fused_computation =
fusion.fused_instructions_computation();
if (!reduction_codegen_info->IsRaceFree()) {
absl::Span<HloInstruction* const> fusion_roots = analysis_.fusion_roots();
for (int i = 0; i < fusion_roots.size(); ++i) {
if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) {
TF_ASSIGN_OR_RETURN(result.thunks.emplace_back(),
BuildFusedInitializerThunk(
ir_emitter_context, fusion_op, analysis_,
elemental_emitter, kernel_cache, i, builder));
TF_ASSIGN_OR_RETURN(
result.thunks.emplace_back(),
BuildFusedInitializerThunk(ir_emitter_context, fusion_op,
fused_computation, elemental_emitter,
kernel_cache, i, builder));
}
}
}
......@@ -987,7 +987,6 @@ StatusOr<FusionEmissionResult> ReductionFusion::Emit(
auto builder_fn = [&, this](std::vector<llvm_ir::IrArray> inputs,
std::vector<llvm_ir::IrArray> outputs) -> Status {
FusedIrEmitter fused_emitter(elemental_emitter);
const HloComputation* fused_computation = analysis_.fused_computation();
for (int i = 0; i < fused_computation->num_parameters(); i++) {
HloInstruction* fused_operand =
fused_computation->parameter_instruction(i);
......@@ -1042,8 +1041,8 @@ StatusOr<FusionEmissionResult> ReductionFusion::Emit(
TF_ASSIGN_OR_RETURN(
result.thunks.emplace_back(),
BuildKernelThunkForFusion(ir_emitter_context, kernel_cache, fusion_op,
analysis_.fused_computation(),
launch_dimensions, "", builder_fn, builder));
fused_computation, launch_dimensions, "",
builder_fn, builder));
return result;
}
......
......@@ -79,37 +79,30 @@ bool AllSliceInputsAreCompatible(
});
}
bool MayPreventVectorization(const HloComputation* fusion) {
bool MayPreventVectorization(const std::vector<HloInstruction*>& fusion_roots) {
// An empirically chosen constant: unrolling concat with a large amount of
// arguments causes excessive register spilling.
static constexpr int kMaxConcatArgumentsForUnrolling = 10;
for (const HloInstruction* instr : fusion->instructions()) {
switch (instr->opcode()) {
case HloOpcode::kReduceWindow:
case HloOpcode::kSort:
case HloOpcode::kDot:
case HloOpcode::kSin:
case HloOpcode::kCos:
case HloOpcode::kTan:
case HloOpcode::kPower:
case HloOpcode::kAtan2:
return true;
case HloOpcode::kConcatenate:
if (instr->operand_count() > kMaxConcatArgumentsForUnrolling) {
return true;
}
break;
case HloOpcode::kReduce:
if (instr->shape().tuple_shapes_size() > 1) {
return true;
return HloAnyOf(
fusion_roots, DefaultFusionBoundaryFn, [&](const HloInstruction& node) {
switch (node.opcode()) {
case HloOpcode::kReduceWindow:
case HloOpcode::kSort:
case HloOpcode::kDot:
case HloOpcode::kSin:
case HloOpcode::kCos:
case HloOpcode::kTan:
case HloOpcode::kPower:
case HloOpcode::kAtan2:
return true;
case HloOpcode::kConcatenate:
return node.operand_count() > kMaxConcatArgumentsForUnrolling;
case HloOpcode::kReduce:
return node.shape().tuple_shapes_size() > 1;
default:
return false;
}
break;
default:
break;
}
}
return false;
});
}
// Determines if we enable the row optimized codegen. When we have a fusion with
......@@ -303,10 +296,10 @@ StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
std::optional<TransposeDescription> tiled_transpose_hero =
FindConsistentTransposeHero(hlo_roots, heroes);
return HloFusionAnalysis(
fusion, std::move(backend_config), std::move(hlo_roots),
std::move(fusion_parameter_inputs), std::move(heroes), device_info,
tiled_transpose_hero);
return HloFusionAnalysis(std::move(backend_config), std::move(hlo_roots),
std::move(fusion_parameter_inputs),
std::move(heroes), device_info,
tiled_transpose_hero);
}
// Returns true if the fusion has consistent transpose heros.
......@@ -473,7 +466,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
int64_t n_threads_max =
device_info_->threads_per_core_limit * device_info_->core_count;
if (num_elements >= n_threads_max &&
!MayPreventVectorization(fused_computation_)) {
!MayPreventVectorization(fusion_roots_)) {
unroll_factor = ComputeMaxUnrollFactor(num_elements);
}
VLOG(2) << "Unroll factor: " << unroll_factor;
......@@ -488,25 +481,23 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
int num_big_inputs;
std::tie(row_vectorized, num_big_inputs) =
RowVectorizationEnabled(fusion_roots(), GetElementShape().rank());
bool few_waves = [this, row_vectorized, num_big_inputs]() {
for (const HloInstruction* instr : fused_computation_->instructions()) {
if (instr->opcode() == HloOpcode::kParameter ||
instr->opcode() == HloOpcode::kConstant ||
HloInstruction::IsOpElementwise(instr->opcode())) {
continue;
}
if (auto broadcast = DynCast<HloBroadcastInstruction>(instr)) {
if (broadcast->dimensions().empty() ||
// More than 3 big inputs cause a speed regression.
(row_vectorized && num_big_inputs <= 3)) {
continue;
bool few_waves = !HloAnyOf(
fusion_roots_, DefaultFusionBoundaryFn, [&](const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kParameter ||
instr.opcode() == HloOpcode::kConstant ||
HloInstruction::IsOpElementwise(instr.opcode())) {
return false;
}
}
VLOG(2) << "few_waves not enabled due to: " << instr->ToString();
return false;
}
return true;
}();
if (auto broadcast = DynCast<HloBroadcastInstruction>(&instr)) {
if (broadcast->dimensions().empty() ||
// More than 3 big inputs cause a speed regression.
(row_vectorized && num_big_inputs <= 3)) {
return false;
}
}
VLOG(2) << "few_waves not enabled due to: " << instr.ToString();
return true;
});
LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
row_vectorized};
......@@ -523,7 +514,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
}
const Shape& HloFusionAnalysis::GetElementShape() const {
const Shape* shape = &fusion_->shape();
const Shape* shape = &fusion_roots_.front()->shape();
while (shape->IsTuple()) {
shape = &shape->tuple_shapes(0);
}
......@@ -581,16 +572,20 @@ HloFusionAnalysis::GroupDisjointReductions() const {
return {{fusion_roots()[0]}};
}
HloInstructionMap<tensorflow::UnionFind<HloInstruction*>> disjoint_sets;
ConstHloInstructionMap<tensorflow::UnionFind<const HloInstruction*>>
disjoint_sets;
// TODO(b/249976438): we currently do not treat properly
// aliasing between inputs and outputs of the fusion, so for now put all
// non-reduction roots into one group to avoid read-after-write conflicts.
HloInstruction* first_non_reduction_root = nullptr;
ConstHloInstructionMap<absl::flat_hash_set<const HloInstruction*>>
reachable_outputs;
absl::flat_hash_set<HloInstruction*> roots_with_reduction;
for (auto [root, hero] : llvm::zip(fusion_roots(), fusion_heroes_)) {
disjoint_sets[root].Get() = root;
reachable_outputs[root].insert(root);
if (IsRealReductionHero(*root, *hero)) {
roots_with_reduction.insert(root);
} else if (first_non_reduction_root) {
......@@ -600,9 +595,23 @@ HloFusionAnalysis::GroupDisjointReductions() const {
}
}
std::unique_ptr<HloReachabilityMap> reachability_map =
HloReachabilityMap::Build(fused_computation_);
for (HloInstruction* instr : fused_computation_->instructions()) {
std::vector<const HloInstruction*> instructions;
HloBfsConsumersFirstTraversal(
fusion_roots_,
[&](const HloInstruction& producer, const HloInstruction& consumer) {
auto& producer_reachable = reachable_outputs[&producer];
for (auto* instruction : reachable_outputs[&consumer]) {
producer_reachable.insert(instruction);
}
return DefaultFusionBoundaryFn(producer, consumer);
},
[&](const HloInstruction& node) {
instructions.push_back(&node);
return TraversalResult::kVisitOperands;
});
for (const HloInstruction* instr : instructions) {
const auto& reachable = reachable_outputs[instr];
std::vector<HloInstruction*> reached_output_ids;
bool added_to_reduce = false;
for (HloInstruction* output : fusion_roots()) {
......@@ -618,7 +627,7 @@ HloFusionAnalysis::GroupDisjointReductions() const {
}
}
// Now group output instructions if they have common predecessors.
if (reachability_map->IsReachable(instr, output)) {
if (reachable.contains(output)) {
VLOG(3) << "Reaching " << output->ToString() << " from "
<< instr->ToString();
reached_output_ids.push_back(output);
......@@ -634,12 +643,13 @@ HloFusionAnalysis::GroupDisjointReductions() const {
}
// Place output instructions in the same set into the same group.
HloInstructionMap<std::vector<HloInstruction*>> groups;
ConstHloInstructionMap<std::vector<HloInstruction*>> groups;
for (HloInstruction* root : fusion_roots()) {
groups[disjoint_sets[root].Get()].push_back(root);
}
std::vector<std::vector<HloInstruction*>> ret;
ret.reserve(groups.size());
absl::c_for_each(
groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
return ret;
......@@ -725,7 +735,7 @@ bool HloFusionAnalysis::CanVectorizeReduction(
}
if (reduction_dimensions.dimensions[kDimX] % 2 != 0 ||
MayPreventVectorization(fusion_->fused_instructions_computation())) {
MayPreventVectorization(fusion_roots_)) {
return false;
}
......
......@@ -50,7 +50,6 @@ class HloFusionAnalysis {
static StatusOr<HloFusionAnalysis> Create(const HloFusionInstruction* fusion,
const GpuDeviceInfo* device_info);
const HloComputation* fused_computation() const { return fused_computation_; }
const std::vector<HloInstruction*>& fusion_roots() const {
return fusion_roots_;
}
......@@ -78,16 +77,13 @@ class HloFusionAnalysis {
const HloInstruction* FindHeroReduction() const;
private:
HloFusionAnalysis(const HloFusionInstruction* fusion,
FusionBackendConfig fusion_backend_config,
HloFusionAnalysis(FusionBackendConfig fusion_backend_config,
std::vector<HloInstruction*> fusion_roots,
std::vector<const HloInstruction*> fusion_parameters,
std::vector<const HloInstruction*> fusion_heroes,
const GpuDeviceInfo* device_info,
std::optional<TransposeDescription> tiled_transpose)
: fusion_(fusion),
fusion_backend_config_(std::move(fusion_backend_config)),
fused_computation_(fusion->fused_instructions_computation()),
: fusion_backend_config_(std::move(fusion_backend_config)),
fusion_roots_(std::move(fusion_roots)),
fusion_parameter_inputs_(std::move(fusion_parameters)),
fusion_heroes_(std::move(fusion_heroes)),
......@@ -111,9 +107,7 @@ class HloFusionAnalysis {
const HloInstruction* hero_reduction) const;
bool HasConsistentTransposeHeros() const;
const HloFusionInstruction* fusion_;
FusionBackendConfig fusion_backend_config_;
const HloComputation* fused_computation_;
std::vector<HloInstruction*> fusion_roots_;
// The HLO instructions that are inputs into the fusion. These instructions
// are /outside/ the fusion.
......
......@@ -103,5 +103,22 @@ void FindFusionParameters(
[&](const HloInstruction&) { return TraversalResult::kVisitOperands; });
}
bool HloAnyOf(
absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer,
const HloInstruction& consumer)>& boundary,
const std::function<bool(const HloInstruction& node)>& visit) {
bool result = false;
HloBfsConsumersFirstTraversal(roots, boundary,
[&](const HloInstruction& node) {
if (visit(node)) {
result = true;
return TraversalResult::kAbortTraversal;
}
return TraversalResult::kVisitOperands;
});
return result;
}
} // namespace gpu
} // namespace xla
......@@ -41,7 +41,7 @@ using FusionBoundaryFn = std::function<bool(const HloInstruction& producer,
bool DefaultFusionBoundaryFn(const HloInstruction& producer,
const HloInstruction& consumer);
// Visit the HLO nodes starting from `root` in BFS order (consumers before
// Visit the HLO nodes starting from `roots` in BFS order (consumers before
// producers). Each node will be visited exactly once. The graph is not
// traversed along edges for which `boundary` returns true.
void HloBfsConsumersFirstTraversal(
......@@ -50,6 +50,14 @@ void HloBfsConsumersFirstTraversal(
const HloInstruction& consumer)>& boundary,
const std::function<TraversalResult(const HloInstruction& node)>& visit);
// Visit the HLO nodes starting from `roots`, returning true if the return value
// of `visit` for any of the ones is true.
bool HloAnyOf(
absl::Span<const HloInstruction* const> roots,
const std::function<bool(const HloInstruction& producer,
const HloInstruction& consumer)>& boundary,
const std::function<bool(const HloInstruction& node)>& visit);
// Visit the producers of all parameters that are needed by the fusion.
void FindFusionParameters(
absl::Span<const HloInstruction* const> roots,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册