提交 0f2ec79e 编写于 作者: C CyrusNajmabadi

Merge pull request #10859 from CyrusNajmabadi/stackOverflowBKTree

Stack overflow when searching a bk tree

Fixes #10813
...@@ -278,7 +278,12 @@ private int BinarySearch(string name) ...@@ -278,7 +278,12 @@ private int BinarySearch(string name)
return info; return info;
} }
info = await LoadOrCreateSymbolTreeInfoAsync(solution, assembly, reference.FilePath, loadOnly, cancellationToken).ConfigureAwait(false); // We don't include internals from metadata assemblies. It's less likely that
// a project would have IVT to it and so it helps us save on memory. It also
// means we can avoid loading lots and lots of obfuscated code in the case hte
// dll was obfuscated.
info = await LoadOrCreateSymbolTreeInfoAsync(solution, assembly, reference.FilePath,
loadOnly, includeInternal: false, cancellationToken: cancellationToken).ConfigureAwait(false);
if (info == null && loadOnly) if (info == null && loadOnly)
{ {
return null; return null;
...@@ -293,12 +298,15 @@ private int BinarySearch(string name) ...@@ -293,12 +298,15 @@ private int BinarySearch(string name)
{ {
var compilation = await project.GetCompilationAsync(cancellationToken).ConfigureAwait(false); var compilation = await project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
// We want to know about internal symbols from source assemblies. Thre's a reasonable
// chance a project might have IVT access to it.
return await LoadOrCreateSymbolTreeInfoAsync( return await LoadOrCreateSymbolTreeInfoAsync(
project.Solution, compilation.Assembly, project.FilePath, loadOnly: false, cancellationToken: cancellationToken).ConfigureAwait(false); project.Solution, compilation.Assembly, project.FilePath,
loadOnly: false, includeInternal: true, cancellationToken: cancellationToken).ConfigureAwait(false);
} }
internal static SymbolTreeInfo CreateSymbolTreeInfo( internal static SymbolTreeInfo CreateSymbolTreeInfo(
Solution solution, VersionStamp version, IAssemblySymbol assembly, string filePath, CancellationToken cancellationToken) Solution solution, VersionStamp version, IAssemblySymbol assembly, string filePath, bool includeInternal, CancellationToken cancellationToken)
{ {
if (assembly == null) if (assembly == null)
{ {
...@@ -306,7 +314,8 @@ private int BinarySearch(string name) ...@@ -306,7 +314,8 @@ private int BinarySearch(string name)
} }
var list = new List<Node>(); var list = new List<Node>();
GenerateNodes(assembly.GlobalNamespace, list); var lookup = includeInternal ? s_getMembersNoPrivate : s_getMembersNoPrivateOrInternal;
GenerateNodes(assembly.GlobalNamespace, list, lookup);
var sortedNodes = SortNodes(list); var sortedNodes = SortNodes(list);
var createSpellCheckerTask = GetSpellCheckerTask(solution, version, assembly, filePath, sortedNodes); var createSpellCheckerTask = GetSpellCheckerTask(solution, version, assembly, filePath, sortedNodes);
...@@ -387,44 +396,65 @@ private static int CompareNodes(Node x, Node y, IReadOnlyList<Node> nodeList) ...@@ -387,44 +396,65 @@ private static int CompareNodes(Node x, Node y, IReadOnlyList<Node> nodeList)
} }
// generate nodes for the global namespace an all descendants // generate nodes for the global namespace an all descendants
private static void GenerateNodes(INamespaceSymbol globalNamespace, List<Node> list) private static void GenerateNodes(
INamespaceSymbol globalNamespace,
List<Node> list,
Func<ISymbol, IEnumerable<ISymbol>> lookup)
{ {
var node = new Node(globalNamespace.Name, Node.RootNodeParentIndex); var node = new Node(globalNamespace.Name, Node.RootNodeParentIndex);
list.Add(node); list.Add(node);
// Add all child members // Add all child members
var memberLookup = s_getMembers(globalNamespace).ToLookup(c => c.Name); var memberLookup = lookup(globalNamespace).ToLookup(c => c.Name);
foreach (var grouping in memberLookup) foreach (var grouping in memberLookup)
{ {
GenerateNodes(grouping.Key, 0 /*index of root node*/, grouping, list); GenerateNodes(grouping.Key, 0 /*index of root node*/, grouping, list, lookup);
} }
} }
private static readonly Func<ISymbol, bool> s_useSymbol = private static readonly Func<ISymbol, bool> s_useSymbolNoPrivate =
s => s.CanBeReferencedByName && s.DeclaredAccessibility != Accessibility.Private; s => s.CanBeReferencedByName && s.DeclaredAccessibility != Accessibility.Private;
private static readonly Func<ISymbol, bool> s_useSymbolNoPrivateOrInternal =
s => s.CanBeReferencedByName &&
s.DeclaredAccessibility != Accessibility.Private &&
s.DeclaredAccessibility != Accessibility.Internal;
// generate nodes for symbols that share the same name, and all their descendants // generate nodes for symbols that share the same name, and all their descendants
private static void GenerateNodes(string name, int parentIndex, IEnumerable<ISymbol> symbolsWithSameName, List<Node> list) private static void GenerateNodes(
string name,
int parentIndex,
IEnumerable<ISymbol> symbolsWithSameName,
List<Node> list,
Func<ISymbol, IEnumerable<ISymbol>> lookup)
{ {
var node = new Node(name, parentIndex); var node = new Node(name, parentIndex);
var nodeIndex = list.Count; var nodeIndex = list.Count;
list.Add(node); list.Add(node);
// Add all child members // Add all child members
var membersByName = symbolsWithSameName.SelectMany(s_getMembers).ToLookup(s => s.Name); var membersByName = symbolsWithSameName.SelectMany(lookup).ToLookup(s => s.Name);
foreach (var grouping in membersByName) foreach (var grouping in membersByName)
{ {
GenerateNodes(grouping.Key, nodeIndex, grouping, list); GenerateNodes(grouping.Key, nodeIndex, grouping, list, lookup);
} }
} }
private static Func<ISymbol, IEnumerable<ISymbol>> s_getMembers = symbol => private static Func<ISymbol, IEnumerable<ISymbol>> s_getMembersNoPrivate = symbol =>
{
var nt = symbol as INamespaceOrTypeSymbol;
return nt != null
? nt.GetMembers().Where(s_useSymbolNoPrivate)
: SpecializedCollections.EmptyEnumerable<ISymbol>();
};
private static Func<ISymbol, IEnumerable<ISymbol>> s_getMembersNoPrivateOrInternal = symbol =>
{ {
var nt = symbol as INamespaceOrTypeSymbol; var nt = symbol as INamespaceOrTypeSymbol;
return nt != null return nt != null
? nt.GetMembers().Where(s_useSymbol) ? nt.GetMembers().Where(s_useSymbolNoPrivateOrInternal)
: SpecializedCollections.EmptyEnumerable<ISymbol>(); : SpecializedCollections.EmptyEnumerable<ISymbol>();
}; };
......
...@@ -26,6 +26,7 @@ internal partial class SymbolTreeInfo : IObjectWritable ...@@ -26,6 +26,7 @@ internal partial class SymbolTreeInfo : IObjectWritable
IAssemblySymbol assembly, IAssemblySymbol assembly,
string filePath, string filePath,
bool loadOnly, bool loadOnly,
bool includeInternal,
CancellationToken cancellationToken) CancellationToken cancellationToken)
{ {
return LoadOrCreateAsync( return LoadOrCreateAsync(
...@@ -33,7 +34,7 @@ internal partial class SymbolTreeInfo : IObjectWritable ...@@ -33,7 +34,7 @@ internal partial class SymbolTreeInfo : IObjectWritable
assembly, assembly,
filePath, filePath,
loadOnly, loadOnly,
create: version => CreateSymbolTreeInfo(solution, version, assembly, filePath, cancellationToken), create: version => CreateSymbolTreeInfo(solution, version, assembly, filePath, includeInternal, cancellationToken),
keySuffix: "", keySuffix: "",
getVersion: info => info._version, getVersion: info => info._version,
readObject: reader => ReadSymbolTreeInfo(reader, (version, nodes) => GetSpellCheckerTask(solution, version, assembly, filePath, nodes)), readObject: reader => ReadSymbolTreeInfo(reader, (version, nodes) => GetSpellCheckerTask(solution, version, assembly, filePath, nodes)),
......
...@@ -88,7 +88,7 @@ public IList<string> Find(string value, int? threshold = null) ...@@ -88,7 +88,7 @@ public IList<string> Find(string value, int? threshold = null)
threshold = threshold ?? WordSimilarityChecker.GetThreshold(value); threshold = threshold ?? WordSimilarityChecker.GetThreshold(value);
var result = new List<string>(); var result = new List<string>();
Lookup(_nodes[0], lowerCaseCharacters, value.Length, threshold.Value, result); Lookup(_nodes[0], lowerCaseCharacters, value.Length, threshold.Value, result, recursionCount: 0);
return result; return result;
} }
finally finally
...@@ -97,8 +97,28 @@ public IList<string> Find(string value, int? threshold = null) ...@@ -97,8 +97,28 @@ public IList<string> Find(string value, int? threshold = null)
} }
} }
private void Lookup(Node currentNode, char[] queryCharacters, int queryLength, int threshold, List<string> result) private void Lookup(
Node currentNode,
char[] queryCharacters,
int queryLength,
int threshold,
List<string> result,
int recursionCount)
{ {
// Don't bother recursing too deeply in the case of pathological trees.
// This really only happens when the actual code is strange (like
// 10,000 symbols all a single letter long). In htat case, searching
// down this path will be fairly fruitless anyways.
//
// Note: this won't affect good searches against good data even if this
// pathological chain exists. That's because the good items will still
// cluster near the root node in the tree, and won't be off the end of
// this long chain.
if (recursionCount > 256)
{
return;
}
// We always want to compute the real edit distance (ignoring any thresholds). This is // We always want to compute the real edit distance (ignoring any thresholds). This is
// because we need that edit distance to appropriately determine which edges to walk // because we need that edit distance to appropriately determine which edges to walk
// in the tree. // in the tree.
...@@ -124,7 +144,8 @@ private void Lookup(Node currentNode, char[] queryCharacters, int queryLength, i ...@@ -124,7 +144,8 @@ private void Lookup(Node currentNode, char[] queryCharacters, int queryLength, i
if (min <= childEditDistance && childEditDistance <= max) if (min <= childEditDistance && childEditDistance <= max)
{ {
Lookup(_nodes[_edges[i].ChildNodeIndex], Lookup(_nodes[_edges[i].ChildNodeIndex],
queryCharacters, queryLength, threshold, result); queryCharacters, queryLength, threshold, result,
recursionCount + 1);
} }
} }
} }
......
...@@ -11,7 +11,7 @@ namespace Roslyn.Utilities ...@@ -11,7 +11,7 @@ namespace Roslyn.Utilities
{ {
internal class SpellChecker internal class SpellChecker
{ {
private const string SerializationFormat = "1"; private const string SerializationFormat = "2";
public VersionStamp Version { get; } public VersionStamp Version { get; }
private readonly BKTree _bkTree; private readonly BKTree _bkTree;
......
...@@ -577,7 +577,8 @@ public static async Task TestSymbolTreeInfoSerialization() ...@@ -577,7 +577,8 @@ public static async Task TestSymbolTreeInfoSerialization()
// create symbol tree info from assembly // create symbol tree info from assembly
var version = VersionStamp.Create(); var version = VersionStamp.Create();
var info = SymbolTreeInfo.CreateSymbolTreeInfo(solution, version, assembly, "", CancellationToken.None); var info = SymbolTreeInfo.CreateSymbolTreeInfo(
solution, version, assembly, "", includeInternal: true, cancellationToken: CancellationToken.None);
using (var writerStream = new MemoryStream()) using (var writerStream = new MemoryStream())
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册