提交 3cba63ad 编写于 作者: B Benjamin Kramer 提交者: TensorFlower Gardener

[XLA:GPU] Give HloFusionAnalysis an entry point for custom boundary functions

This can be used to supply it with a graph cut instead of a fusion computation

PiperOrigin-RevId: 563739866
上级 c8cee324
......@@ -79,12 +79,13 @@ bool AllSliceInputsAreCompatible(
});
}
bool MayPreventVectorization(const std::vector<HloInstruction*>& fusion_roots) {
bool MayPreventVectorization(const std::vector<HloInstruction*>& fusion_roots,
const FusionBoundaryFn& fusion_boundary_fn) {
// An empirically chosen constant: unrolling concat with a large amount of
// arguments causes excessive register spilling.
static constexpr int kMaxConcatArgumentsForUnrolling = 10;
return HloAnyOf(
fusion_roots, DefaultFusionBoundaryFn, [&](const HloInstruction& node) {
fusion_roots, fusion_boundary_fn, [&](const HloInstruction& node) {
switch (node.opcode()) {
case HloOpcode::kReduceWindow:
case HloOpcode::kSort:
......@@ -275,12 +276,8 @@ std::optional<TransposeDescription> FindConsistentTransposeHero(
// static
StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
const HloFusionInstruction* fusion, const GpuDeviceInfo* device_info) {
CHECK(device_info != nullptr);
TF_ASSIGN_OR_RETURN(auto backend_config,
fusion->backend_config<FusionBackendConfig>());
auto hlo_roots = GetFusionRoots(*fusion->fused_instructions_computation());
FusionBackendConfig backend_config, std::vector<HloInstruction*> hlo_roots,
FusionBoundaryFn boundary_fn, const GpuDeviceInfo* device_info) {
std::vector<const HloInstruction*> heroes;
heroes.reserve(hlo_roots.size());
for (auto* root : hlo_roots) {
......@@ -288,7 +285,7 @@ StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
}
std::vector<const HloInstruction*> fusion_parameter_inputs;
FindFusionParameters(hlo_roots, DefaultFusionBoundaryFn,
FindFusionParameters(hlo_roots, boundary_fn,
[&](const HloInstruction& parameter) {
fusion_parameter_inputs.push_back(&parameter);
});
......@@ -296,10 +293,22 @@ StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
std::optional<TransposeDescription> tiled_transpose_hero =
FindConsistentTransposeHero(hlo_roots, heroes);
return HloFusionAnalysis(std::move(backend_config), std::move(hlo_roots),
std::move(fusion_parameter_inputs),
std::move(heroes), device_info,
tiled_transpose_hero);
return HloFusionAnalysis(
std::move(backend_config), std::move(hlo_roots), std::move(boundary_fn),
std::move(fusion_parameter_inputs), std::move(heroes), device_info,
tiled_transpose_hero);
}
// static
StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
const HloFusionInstruction* fusion, const GpuDeviceInfo* device_info) {
CHECK(device_info != nullptr);
TF_ASSIGN_OR_RETURN(auto backend_config,
fusion->backend_config<FusionBackendConfig>());
auto hlo_roots = GetFusionRoots(*fusion->fused_instructions_computation());
return Create(std::move(backend_config), std::move(hlo_roots),
DefaultFusionBoundaryFn, device_info);
}
// Returns true if the fusion has consistent transpose heros.
......@@ -466,7 +475,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
int64_t n_threads_max =
device_info_->threads_per_core_limit * device_info_->core_count;
if (num_elements >= n_threads_max &&
!MayPreventVectorization(fusion_roots_)) {
!MayPreventVectorization(fusion_roots_, fusion_boundary_fn_)) {
unroll_factor = ComputeMaxUnrollFactor(num_elements);
}
VLOG(2) << "Unroll factor: " << unroll_factor;
......@@ -482,7 +491,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
std::tie(row_vectorized, num_big_inputs) =
RowVectorizationEnabled(fusion_roots(), GetElementShape().rank());
bool few_waves = !HloAnyOf(
fusion_roots_, DefaultFusionBoundaryFn, [&](const HloInstruction& instr) {
fusion_roots_, fusion_boundary_fn_, [&](const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kParameter ||
instr.opcode() == HloOpcode::kConstant ||
HloInstruction::IsOpElementwise(instr.opcode())) {
......@@ -603,7 +612,7 @@ HloFusionAnalysis::GroupDisjointReductions() const {
for (auto* instruction : reachable_outputs[&consumer]) {
producer_reachable.insert(instruction);
}
return DefaultFusionBoundaryFn(producer, consumer);
return fusion_boundary_fn_(producer, consumer);
},
[&](const HloInstruction& node) {
instructions.push_back(&node);
......@@ -686,15 +695,22 @@ bool HloFusionAnalysis::IsUnrollingColumnReductionBeneficial(
HloBfsConsumersFirstTraversal(
fusion_roots_,
[&](const HloInstruction& producer, const HloInstruction& consumer) {
// We check if the consumer is elementwise, unless this edge is a
// virtual edge that only exists in partially fused HLO. There are two
// types of such edges:
// 1. Edges from producers outside a fusion to a parameter instruction
// within a fusion. Here, the producer is a parameter of the fusion
// instruction.
// 2. Edges from fusion roots to fusion nodes.
if (reachable_through_non_elementwise.contains(&consumer) ||
(consumer.opcode() != HloOpcode::kParameter &&
consumer.opcode() != HloOpcode::kFusion &&
!use_chain_endings.contains(&consumer) &&
!consumer.IsElementwise())) {
(!(consumer.opcode() == HloOpcode::kParameter ||
consumer.opcode() == HloOpcode::kFusion ||
consumer.IsElementwise()) &&
!use_chain_endings.contains(&consumer))) {
reachable_through_non_elementwise.insert(&producer);
}
return DefaultFusionBoundaryFn(producer, consumer);
return fusion_boundary_fn_(producer, consumer);
},
[&](const HloInstruction& node) {
return TraversalResult::kVisitOperands;
......@@ -735,7 +751,7 @@ bool HloFusionAnalysis::CanVectorizeReduction(
}
if (reduction_dimensions.dimensions[kDimX] % 2 != 0 ||
MayPreventVectorization(fusion_roots_)) {
MayPreventVectorization(fusion_roots_, fusion_boundary_fn_)) {
return false;
}
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/kernel_mapping_scheme.h"
#include "xla/service/gpu/launch_dimensions.h"
......@@ -47,6 +48,10 @@ class HloFusionAnalysis {
kScatter,
};
static StatusOr<HloFusionAnalysis> Create(
FusionBackendConfig backend_config,
std::vector<HloInstruction*> hlo_roots, FusionBoundaryFn boundary_fn,
const GpuDeviceInfo* device_info);
static StatusOr<HloFusionAnalysis> Create(const HloFusionInstruction* fusion,
const GpuDeviceInfo* device_info);
......@@ -79,12 +84,14 @@ class HloFusionAnalysis {
private:
HloFusionAnalysis(FusionBackendConfig fusion_backend_config,
std::vector<HloInstruction*> fusion_roots,
FusionBoundaryFn fusion_boundary_fn,
std::vector<const HloInstruction*> fusion_parameters,
std::vector<const HloInstruction*> fusion_heroes,
const GpuDeviceInfo* device_info,
std::optional<TransposeDescription> tiled_transpose)
: fusion_backend_config_(std::move(fusion_backend_config)),
fusion_roots_(std::move(fusion_roots)),
fusion_boundary_fn_(std::move(fusion_boundary_fn)),
fusion_parameter_inputs_(std::move(fusion_parameters)),
fusion_heroes_(std::move(fusion_heroes)),
device_info_(device_info),
......@@ -109,6 +116,7 @@ class HloFusionAnalysis {
FusionBackendConfig fusion_backend_config_;
std::vector<HloInstruction*> fusion_roots_;
FusionBoundaryFn fusion_boundary_fn_;
// The HLO instructions that are inputs into the fusion. These instructions
// are /outside/ the fusion.
std::vector<const HloInstruction*> fusion_parameter_inputs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册