diff --git a/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs b/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs index 6f810f920ef4def45fb434d0770caf8509539238..965e21ea408981594a948551028ee55fb27dd240 100644 --- a/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs +++ b/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs @@ -1118,7 +1118,12 @@ public new MetadataReference GetMetadataReference(IAssemblySymbol assemblySymbol private protected override MetadataReference CommonGetMetadataReference(IAssemblySymbol assemblySymbol) { - return GetMetadataReference(assemblySymbol.EnsureCSharpSymbolOrNull(nameof(assemblySymbol))); + if (assemblySymbol is Symbols.PublicModel.AssemblySymbol { UnderlyingAssemblySymbol: var underlyingSymbol }) + { + return GetMetadataReference(underlyingSymbol); + } + + return null; } internal MetadataReference GetMetadataReference(AssemblySymbol assemblySymbol) diff --git a/src/Compilers/CSharp/Test/Symbol/Compilation/CompilationAPITests.cs b/src/Compilers/CSharp/Test/Symbol/Compilation/CompilationAPITests.cs index 6b1118bbe20f9823b5ef8052a3db170869dbb57d..43dc2f4c2eec117ebb94ed7bfa8277d0483fdf64 100644 --- a/src/Compilers/CSharp/Test/Symbol/Compilation/CompilationAPITests.cs +++ b/src/Compilers/CSharp/Test/Symbol/Compilation/CompilationAPITests.cs @@ -2267,6 +2267,20 @@ public void GetMetadataReferenceAPITest() Assert.NotNull(reference2); } + [Fact] + [WorkItem(40466, "https://github.com/dotnet/roslyn/issues/40466")] + public void GetMetadataReference_VisualBasicSymbols() + { + var comp = CreateCompilation(""); + + var vbComp = CreateVisualBasicCompilation("", referencedAssemblies: TargetFrameworkUtil.GetReferences(TargetFramework.Standard)); + var assembly = (IAssemblySymbol)vbComp.GetBoundReferenceManager().GetReferencedAssemblies().First().Value; + + Assert.Null(comp.GetMetadataReference(assembly)); + Assert.Null(comp.GetMetadataReference(vbComp.Assembly)); + Assert.Null(comp.GetMetadataReference((IAssemblySymbol)null)); + } + [Fact] public void ConsistentParseOptions() { diff --git a/src/Compilers/VisualBasic/Portable/Compilation/VisualBasicCompilation.vb b/src/Compilers/VisualBasic/Portable/Compilation/VisualBasicCompilation.vb index 0b32090639c7addf94f6e5dc69475138f075ca77..95c92632b2e4a010aae3531e0905bcb944d8a362 100644 --- a/src/Compilers/VisualBasic/Portable/Compilation/VisualBasicCompilation.vb +++ b/src/Compilers/VisualBasic/Portable/Compilation/VisualBasicCompilation.vb @@ -1232,7 +1232,12 @@ Namespace Microsoft.CodeAnalysis.VisualBasic End Function Private Protected Overrides Function CommonGetMetadataReference(assemblySymbol As IAssemblySymbol) As MetadataReference - Return GetMetadataReference(assemblySymbol.EnsureVbSymbolOrNothing(Of AssemblySymbol)(NameOf(assemblySymbol))) + Dim symbol = TryCast(assemblySymbol, AssemblySymbol) + If symbol IsNot Nothing Then + Return GetMetadataReference(symbol) + End If + + Return Nothing End Function Public Overrides ReadOnly Property ReferencedAssemblyNames As IEnumerable(Of AssemblyIdentity) diff --git a/src/Compilers/VisualBasic/Test/Semantic/Compilation/CompilationAPITests.vb b/src/Compilers/VisualBasic/Test/Semantic/Compilation/CompilationAPITests.vb index 22bd2f263cc4bc21eef995c2ae26880d34dfccab..da74bd7e6e7fcf86aaa0716c782fb6bfa5c9699d 100644 --- a/src/Compilers/VisualBasic/Test/Semantic/Compilation/CompilationAPITests.vb +++ b/src/Compilers/VisualBasic/Test/Semantic/Compilation/CompilationAPITests.vb @@ -2257,6 +2257,19 @@ End Class Assert.NotNull(reference2) End Sub + + + Public Sub GetMetadataReference_CSharpSymbols() + Dim comp As Compilation = CreateCompilation("") + + Dim csComp = CreateCSharpCompilation("", referencedAssemblies:=TargetFrameworkUtil.GetReferences(TargetFramework.Standard)) + Dim assembly = csComp.GetBoundReferenceManager().GetReferencedAssemblies().First().Value + + Assert.Null(comp.GetMetadataReference(DirectCast(assembly.GetISymbol(), IAssemblySymbol))) + Assert.Null(comp.GetMetadataReference(csComp.Assembly)) + Assert.Null(comp.GetMetadataReference(Nothing)) + End Sub + Public Sub EqualityOfMergedNamespaces()