// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using Microsoft.CodeAnalysis.PooledObjects; namespace Microsoft.CodeAnalysis.Shared.Collections { /// /// An interval tree represents an ordered tree data structure to store intervals of the form /// [start, end). It allows you to efficiently find all intervals that intersect or overlap /// a provided interval. /// internal partial class IntervalTree : IEnumerable { public static readonly IntervalTree Empty = new IntervalTree(); protected Node root; private delegate bool TestInterval(T value, int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector; private static readonly ObjectPool> s_stackPool = SharedPools.Default>(); public IntervalTree() { } public static IntervalTree Create(in TIntrospector introspector, IEnumerable values) where TIntrospector : struct, IIntervalIntrospector { var result = new IntervalTree(); foreach (var value in values) { result.root = Insert(result.root, new Node(value), in introspector); } return result; } protected static bool Contains(T value, int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { var otherStart = start; var otherEnd = start + length; var thisEnd = GetEnd(value, in introspector); var thisStart = introspector.GetStart(value); // make sure "Contains" test to be same as what TextSpan does if (length == 0) { return thisStart <= otherStart && otherEnd < thisEnd; } return thisStart <= otherStart && otherEnd <= thisEnd; } private static bool IntersectsWith(T value, int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { var otherStart = start; var otherEnd = start + length; var thisEnd = GetEnd(value, in introspector); var thisStart = introspector.GetStart(value); return otherStart <= thisEnd && otherEnd >= thisStart; } private static bool OverlapsWith(T value, int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { var otherStart = start; var otherEnd = start + length; var thisEnd = GetEnd(value, in introspector); var thisStart = introspector.GetStart(value); if (length == 0) { return thisStart < otherStart && otherStart < thisEnd; } var overlapStart = Math.Max(thisStart, otherStart); var overlapEnd = Math.Min(thisEnd, otherEnd); return overlapStart < overlapEnd; } public ImmutableArray GetIntervalsThatOverlapWith(int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => this.GetIntervalsThatMatch(start, length, Tests.OverlapsWithTest, in introspector); public ImmutableArray GetIntervalsThatIntersectWith(int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => this.GetIntervalsThatMatch(start, length, Tests.IntersectsWithTest, in introspector); public ImmutableArray GetIntervalsThatContain(int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => this.GetIntervalsThatMatch(start, length, Tests.ContainsTest, in introspector); public void FillWithIntervalsThatOverlapWith(int start, int length, ArrayBuilder builder, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => this.FillWithIntervalsThatMatch(start, length, Tests.OverlapsWithTest, builder, in introspector, stopAfterFirst: false); public void FillWithIntervalsThatIntersectWith(int start, int length, ArrayBuilder builder, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => this.FillWithIntervalsThatMatch(start, length, Tests.IntersectsWithTest, builder, in introspector, stopAfterFirst: false); public void FillWithIntervalsThatContain(int start, int length, ArrayBuilder builder, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => this.FillWithIntervalsThatMatch(start, length, Tests.ContainsTest, builder, in introspector, stopAfterFirst: false); public bool HasIntervalThatIntersectsWith(int position, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => HasIntervalThatIntersectsWith(position, 0, in introspector); public bool HasIntervalThatIntersectsWith(int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => Any(start, length, Tests.IntersectsWithTest, in introspector); public bool HasIntervalThatOverlapsWith(int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => Any(start, length, Tests.OverlapsWithTest, in introspector); public bool HasIntervalThatContains(int start, int length, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => Any(start, length, Tests.ContainsTest, in introspector); private bool Any(int start, int length, TestInterval testInterval, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { using var _ = ArrayBuilder.GetInstance(out var builder); FillWithIntervalsThatMatch(start, length, testInterval, builder, in introspector, stopAfterFirst: true); return builder.Count > 0; } private ImmutableArray GetIntervalsThatMatch( int start, int length, TestInterval testInterval, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { var result = ArrayBuilder.GetInstance(); FillWithIntervalsThatMatch(start, length, testInterval, result, in introspector, stopAfterFirst: false); return result.ToImmutableAndFree(); } private void FillWithIntervalsThatMatch( int start, int length, TestInterval testInterval, ArrayBuilder builder, in TIntrospector introspector, bool stopAfterFirst) where TIntrospector : struct, IIntervalIntrospector { if (root == null) { return; } var candidates = s_stackPool.Allocate(); FillWithIntervalsThatMatch( start, length, testInterval, builder, in introspector, stopAfterFirst, candidates); s_stackPool.ClearAndFree(candidates); } private void FillWithIntervalsThatMatch( int start, int length, TestInterval testInterval, ArrayBuilder builder, in TIntrospector introspector, bool stopAfterFirst, Stack<(Node node, bool firstTime)> candidates) where TIntrospector : struct, IIntervalIntrospector { var end = start + length; candidates.Push((root, firstTime: true)); while (candidates.Count > 0) { var currentTuple = candidates.Pop(); var currentNode = currentTuple.node; Debug.Assert(currentNode != null); var firstTime = currentTuple.firstTime; if (!firstTime) { // We're seeing this node for the second time (as we walk back up the left // side of it). Now see if it matches our test, and if so return it out. if (testInterval(currentNode.Value, start, length, in introspector)) { builder.Add(currentNode.Value); if (stopAfterFirst) { return; } } } else { // First time we're seeing this node. In order to see the node 'in-order', // we push the right side, then the node again, then the left side. This // time we mark the current node with 'false' to indicate that it's the // second time we're seeing it the next time it comes around. // right children's starts will never be to the left of the parent's start // so we should consider right subtree only if root's start overlaps with // interval's End, if (introspector.GetStart(currentNode.Value) <= end) { var right = currentNode.Right; if (right != null && GetEnd(right.MaxEndNode.Value, in introspector) >= start) { candidates.Push((right, firstTime: true)); } } candidates.Push((currentNode, firstTime: false)); // only if left's maxVal overlaps with interval's start, we should consider // left subtree var left = currentNode.Left; if (left != null && GetEnd(left.MaxEndNode.Value, in introspector) >= start) { candidates.Push((left, firstTime: true)); } } } } public bool IsEmpty() => this.root == null; protected static Node Insert(Node root, Node newNode, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { var newNodeStart = introspector.GetStart(newNode.Value); return Insert(root, newNode, newNodeStart, in introspector); } private static Node Insert(Node root, Node newNode, int newNodeStart, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { if (root == null) { return newNode; } Node newLeft, newRight; if (newNodeStart < introspector.GetStart(root.Value)) { newLeft = Insert(root.Left, newNode, newNodeStart, in introspector); newRight = root.Right; } else { newLeft = root.Left; newRight = Insert(root.Right, newNode, newNodeStart, in introspector); } root.SetLeftRight(newLeft, newRight, in introspector); var newRoot = root; return Balance(newRoot, in introspector); } private static Node Balance(Node node, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector { var balanceFactor = BalanceFactor(node); if (balanceFactor == -2) { var rightBalance = BalanceFactor(node.Right); if (rightBalance == -1) { return node.LeftRotation(in introspector); } else { Debug.Assert(rightBalance == 1); return node.InnerRightOuterLeftRotation(in introspector); } } else if (balanceFactor == 2) { var leftBalance = BalanceFactor(node.Left); if (leftBalance == 1) { return node.RightRotation(in introspector); } else { Debug.Assert(leftBalance == -1); return node.InnerLeftOuterRightRotation(in introspector); } } return node; } public IEnumerator GetEnumerator() { if (root == null) { yield break; } var candidates = new Stack<(Node node, bool firstTime)>(); candidates.Push((root, firstTime: true)); while (candidates.Count != 0) { var (currentNode, firstTime) = candidates.Pop(); if (currentNode != null) { if (firstTime) { // First time seeing this node. Mark that we've been seen and recurse // down the left side. The next time we see this node we'll yield it // out. candidates.Push((currentNode.Right, firstTime: true)); candidates.Push((currentNode, firstTime: false)); candidates.Push((currentNode.Left, firstTime: true)); } else { yield return currentNode.Value; } } } } IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator(); protected static int GetEnd(T value, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => introspector.GetStart(value) + introspector.GetLength(value); protected static int MaxEndValue(Node node, in TIntrospector introspector) where TIntrospector : struct, IIntervalIntrospector => node == null ? 0 : GetEnd(node.MaxEndNode.Value, in introspector); private static int Height(Node node) => node == null ? 0 : node.Height; private static int BalanceFactor(Node node) => node == null ? 0 : Height(node.Left) - Height(node.Right); private static class Tests where TIntrospector : struct, IIntervalIntrospector { public static readonly TestInterval IntersectsWithTest = IntersectsWith; public static readonly TestInterval ContainsTest = Contains; public static readonly TestInterval OverlapsWithTest = OverlapsWith; } } }