未验证 提交 68242c7d 编写于 作者: D Dan Moseley 提交者: GitHub

fix pooled array leak (#88810)

* ntlm

* test base

* qpack

* FileSys

* JSON

* interp tests

* fix NTAuthentication leak

* More in JSON

* more tests

* ntlmserver disposable

* more tests

* more tests

* tar tests

* feedback
上级 dfe7a6c9
......@@ -24,7 +24,7 @@ namespace System.Net.Security
// and responses for unit test purposes. The validation checks the
// structure of the messages, their integrity and use of specified
// features (eg. MIC).
internal class FakeNtlmServer
internal class FakeNtlmServer : IDisposable
{
public FakeNtlmServer(NetworkCredential expectedCredential)
{
......@@ -142,6 +142,14 @@ private enum AvFlags : uint
UntrustedSPN = 4,
}
public void Dispose()
{
_clientSeal?.Dispose();
_clientSeal = null;
_serverSeal?.Dispose();
_serverSeal = null;
}
private static ReadOnlySpan<byte> GetField(ReadOnlySpan<byte> payload, int fieldOffset)
{
uint offset = BinaryPrimitives.ReadUInt32LittleEndian(payload.Slice(fieldOffset + 4));
......@@ -396,6 +404,9 @@ private void ValidateAuthentication(byte[] incomingBlob)
public void ResetKeys()
{
_clientSeal?.Dispose();
_serverSeal?.Dispose();
_clientSeal = new RC4(_clientSealingKey);
_serverSeal = new RC4(_serverSealingKey);
}
......
......@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.ComponentModel;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Text;
......@@ -201,17 +202,24 @@ private unsafe string GetTestDirectoryActualCasing()
if (!handle.IsInvalid)
{
const int InitialBufferSize = 4096;
char[]? buffer = ArrayPool<char>.Shared.Rent(InitialBufferSize);
uint result = GetFinalPathNameByHandle(handle, buffer);
char[] buffer = new char[4096];
uint result;
fixed (char* bufPtr = buffer)
{
result = Interop.Kernel32.GetFinalPathNameByHandle(handle, bufPtr, (uint)buffer.Length, Interop.Kernel32.FILE_NAME_NORMALIZED);
}
if (result == 0)
{
throw new Win32Exception();
}
Debug.Assert(result <= buffer.Length);
// Remove extended prefix
int skip = PathInternal.IsExtended(buffer) ? 4 : 0;
return new string(
buffer,
skip,
(int)result - skip);
return new string(buffer, skip, (int)result - skip);
}
}
catch { }
......@@ -219,14 +227,6 @@ private unsafe string GetTestDirectoryActualCasing()
return TestDirectory;
}
private unsafe uint GetFinalPathNameByHandle(SafeFileHandle handle, char[] buffer)
{
fixed (char* bufPtr = buffer)
{
return Interop.Kernel32.GetFinalPathNameByHandle(handle, bufPtr, (uint)buffer.Length, Interop.Kernel32.FILE_NAME_NORMALIZED);
}
}
protected string CreateTestDirectory(params string[] paths)
{
string dir = Path.Combine(paths);
......
......@@ -16,7 +16,7 @@ public sealed class MultiArrayBufferTests
[Fact]
public void BasicTest()
{
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
Assert.True(buffer.IsEmpty);
Assert.True(buffer.ActiveMemory.IsEmpty);
......@@ -98,7 +98,7 @@ public void AddByteByByteAndConsumeByteByByte_Success()
{
const int Size = 64 * 1024 + 1;
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
for (int i = 0; i < Size; i++)
{
......@@ -124,7 +124,7 @@ public void AddSeveralBytesRepeatedlyAndConsumeSeveralBytesRepeatedly_Success()
const int ByteCount = 7;
const int RepeatCount = 8 * 1024; // enough to ensure we cross several block boundaries
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
for (int i = 0; i < RepeatCount; i++)
{
......@@ -156,7 +156,7 @@ public void AddSeveralBytesRepeatedlyAndConsumeSeveralBytesRepeatedly_UsingSlice
const int ByteCount = 7;
const int RepeatCount = 8 * 1024; // enough to ensure we cross several block boundaries
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
for (int i = 0; i < RepeatCount; i++)
{
......@@ -188,7 +188,7 @@ public void AddSeveralBytesRepeatedlyAndConsumeSeveralBytesRepeatedly_UsingSlice
const int ByteCount = 7;
const int RepeatCount = 8 * 1024; // enough to ensure we cross several block boundaries
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
for (int i = 0; i < RepeatCount; i++)
{
......@@ -221,7 +221,7 @@ public void CopyFromRepeatedlyAndCopyToRepeatedly_Success()
const int RepeatCount = 8 * 1024; // enough to ensure we cross several block boundaries
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
for (int i = 0; i < RepeatCount; i++)
{
......@@ -250,7 +250,7 @@ public void CopyFromRepeatedlyAndCopyToRepeatedly_LargeCopies_Success()
const int RepeatCount = 13;
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
for (int i = 0; i < RepeatCount; i++)
{
......@@ -291,7 +291,7 @@ public void EmptyMultiMemoryTest()
[Fact]
public void EnsureAvailableSpaceTest()
{
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
Assert.Equal(0, buffer.ActiveMemory.Length);
Assert.Equal(0, buffer.AvailableMemory.Length);
......@@ -423,7 +423,7 @@ public void EnsureAvailableSpaceTest()
[Fact]
public void EnsureAvailableSpaceUpToLimitTest()
{
MultiArrayBuffer buffer = new MultiArrayBuffer(0);
using MultiArrayBuffer buffer = new MultiArrayBuffer(0);
Assert.Equal(0, buffer.ActiveMemory.Length);
Assert.Equal(0, buffer.AvailableMemory.Length);
......
......@@ -15,7 +15,7 @@
namespace System.Net.Http.Unit.Tests.QPack
{
public class QPackDecoderTests
public class QPackDecoderTests : IDisposable
{
private const int MaxHeaderFieldSize = 8192;
......@@ -64,6 +64,11 @@ public QPackDecoderTests()
_decoder = new QPackDecoder(MaxHeaderFieldSize);
}
public void Dispose()
{
_decoder.Dispose();
}
[Fact]
public void DecodesIndexedHeaderField_StaticTableWithValue()
{
......@@ -318,7 +323,7 @@ private static void TestDecodeWithoutIndexing(byte[] encoded, KeyValuePair<strin
private static void TestDecode(byte[] encoded, KeyValuePair<string, string>[] expectedValues, bool expectDynamicTableEntry, int? bytesAtATime)
{
QPackDecoder decoder = new QPackDecoder(MaxHeaderFieldSize);
using QPackDecoder decoder = new QPackDecoder(MaxHeaderFieldSize);
TestHttpHeadersHandler handler = new TestHttpHeadersHandler();
// Read past header
......
......@@ -180,6 +180,7 @@ public void AsSpan_ReturnsCorrectValue_DoesntClearBuilder()
Assert.NotEqual(0, sb.Length);
Assert.Equal(sb.Length, vsb.Length);
Assert.Equal(sb.ToString(), vsb.ToString());
}
[Fact]
......@@ -275,6 +276,7 @@ public unsafe void Indexer()
Assert.Equal('b', vsb[3]);
vsb[3] = 'c';
Assert.Equal('c', vsb[3]);
vsb.Dispose();
}
[Fact]
......@@ -297,6 +299,7 @@ public void EnsureCapacity_IfBufferTimesTwoWins()
builder.EnsureCapacity(33);
Assert.Equal(64, builder.Capacity);
builder.Dispose();
}
[Fact]
......@@ -309,6 +312,7 @@ public void EnsureCapacity_NoAllocIfNotNeeded()
builder.EnsureCapacity(16);
Assert.Equal(64, builder.Capacity);
builder.Dispose();
}
}
}
......@@ -340,7 +340,7 @@ public void UnixFileModes_RestrictiveParentDir(bool overwrite)
AssertFileModeEquals(filePath, TestPermission1);
}
[Fact]
[ConditionalFact(typeof(MountHelper), nameof(MountHelper.CanCreateSymbolicLinks))]
public void LinkBeforeTarget()
{
using TempDirectory source = new TempDirectory();
......
......@@ -362,7 +362,7 @@ public async Task UnixFileModes_RestrictiveParentDir_Async()
AssertFileModeEquals(filePath, TestPermission1);
}
[Fact]
[ConditionalFact(typeof(MountHelper), nameof(MountHelper.CanCreateSymbolicLinks))]
public async Task LinkBeforeTargetAsync()
{
using TempDirectory source = new TempDirectory();
......
......@@ -90,7 +90,7 @@ public static void MemoryPoolPin(int elementIndex)
public static void MemoryPoolPinBadOffset(int elementIndex)
{
MemoryPool<int> pool = MemoryPool<int>.Shared;
IMemoryOwner<int> block = pool.Rent(10);
using IMemoryOwner<int> block = pool.Rent(10);
Memory<int> memory = block.Memory;
Span<int> sp = memory.Span;
Assert.Equal(memory.Length, sp.Length);
......@@ -101,7 +101,7 @@ public static void MemoryPoolPinBadOffset(int elementIndex)
public static void MemoryPoolPinOffsetAtEnd()
{
MemoryPool<int> pool = MemoryPool<int>.Shared;
IMemoryOwner<int> block = pool.Rent(10);
using IMemoryOwner<int> block = pool.Rent(10);
Memory<int> memory = block.Memory;
Span<int> sp = memory.Span;
Assert.Equal(memory.Length, sp.Length);
......@@ -122,7 +122,7 @@ public static void MemoryPoolPinOffsetAtEnd()
public static void MemoryPoolPinBadOffsetTooLarge()
{
MemoryPool<int> pool = MemoryPool<int>.Shared;
IMemoryOwner<int> block = pool.Rent(10);
using IMemoryOwner<int> block = pool.Rent(10);
Memory<int> memory = block.Memory;
Span<int> sp = memory.Span;
Assert.Equal(memory.Length, sp.Length);
......
......@@ -67,7 +67,7 @@ private static async Task WriteAsyncTest(Encoding targetEncoding, string message
var model = new TestModel { Message = message };
var stream = new MemoryStream();
var transcodingStream = new TranscodingWriteStream(stream, targetEncoding);
using var transcodingStream = new TranscodingWriteStream(stream, targetEncoding);
await JsonSerializer.SerializeAsync(transcodingStream, model, model.GetType());
// The transcoding streams use Encoders and Decoders that have internal buffers. We need to flush these
// when there is no more data to be written. Stream.FlushAsync isn't suitable since it's
......
......@@ -107,6 +107,8 @@ internal static async Task HandleAuthenticationRequestWithFakeServer(LoopbackSer
}
while (!isAuthenticated);
fakeNtlmServer?.Dispose();
await connection.SendResponseAsync(HttpStatusCode.OK);
}
......
......@@ -156,7 +156,7 @@ async ValueTask SendMessageAsync(string text)
else if (parts[1].Equals("GSSAPI", StringComparison.OrdinalIgnoreCase))
{
Debug.Assert(ExpectedGssapiCredential != null);
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(ExpectedGssapiCredential) { ForceNegotiateVersion = true };
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(ExpectedGssapiCredential) { ForceNegotiateVersion = true };
FakeNegotiateServer fakeNegotiateServer = new FakeNegotiateServer(fakeNtlmServer);
try
......
......@@ -773,6 +773,10 @@ private static byte[] DeriveKey(ReadOnlySpan<byte> exportedSessionKey, ReadOnlyS
private void ResetKeys()
{
// Release buffers to pool
_clientSeal?.Dispose();
_serverSeal?.Dispose();
_clientSeal = new RC4(_clientSealingKey);
_serverSeal = new RC4(_serverSealingKey);
}
......
......@@ -41,7 +41,7 @@ public void RemoteIdentity_ThrowsOnUnauthenticated()
[ConditionalFact(nameof(IsNtlmAvailable))]
public void RemoteIdentity_ThrowsOnDisposed()
{
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
NegotiateAuthentication negotiateAuthentication = new NegotiateAuthentication(
new NegotiateAuthenticationClientOptions
{
......@@ -98,7 +98,7 @@ public void NtlmProtocolExampleTest()
{
// Mirrors the NTLMv2 example in the NTLM specification:
NetworkCredential credential = new NetworkCredential("User", "Password", "Domain");
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(credential);
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(credential);
fakeNtlmServer.SendTimestamp = false;
fakeNtlmServer.TargetIsServer = true;
fakeNtlmServer.PreferUnicode = false;
......@@ -151,7 +151,7 @@ public void NtlmProtocolExampleTest()
[ConditionalFact(nameof(IsNtlmAvailable))]
public void NtlmCorrectExchangeTest()
{
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
NegotiateAuthentication ntAuth = new NegotiateAuthentication(
new NegotiateAuthenticationClientOptions
{
......@@ -175,7 +175,7 @@ public void NtlmCorrectExchangeTest()
[ConditionalFact(nameof(IsNtlmAvailable))]
public void NtlmIncorrectExchangeTest()
{
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
NegotiateAuthentication ntAuth = new NegotiateAuthentication(
new NegotiateAuthenticationClientOptions
{
......@@ -194,7 +194,7 @@ public void NtlmIncorrectExchangeTest()
[ActiveIssue("https://github.com/dotnet/runtime/issues/65678", TestPlatforms.OSX | TestPlatforms.iOS | TestPlatforms.MacCatalyst)]
public void NtlmSignatureTest()
{
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight);
NegotiateAuthentication ntAuth = new NegotiateAuthentication(
new NegotiateAuthenticationClientOptions
{
......@@ -250,7 +250,7 @@ private void DoNtlmExchange(FakeNtlmServer fakeNtlmServer, NegotiateAuthenticati
public void NegotiateCorrectExchangeTest(bool requestMIC, bool requestConfidentiality)
{
// Older versions of gss-ntlmssp on Linux generate MIC at incorrect offset unless ForceNegotiateVersion is specified
FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight) { ForceNegotiateVersion = true };
using FakeNtlmServer fakeNtlmServer = new FakeNtlmServer(s_testCredentialRight) { ForceNegotiateVersion = true };
FakeNegotiateServer fakeNegotiateServer = new FakeNegotiateServer(fakeNtlmServer) { RequestMIC = requestMIC };
NegotiateAuthentication ntAuth = new NegotiateAuthentication(
new NegotiateAuthenticationClientOptions
......
......@@ -51,7 +51,13 @@ public static string TranslateWin32Expression(string? expression)
}
}
return modified ? sb.ToString() : expression;
if (!modified)
{
sb.Dispose();
return expression;
}
return sb.ToString();
}
/// <summary>Verifies whether the given Win32 expression matches the given name. Supports the following wildcards: '*', '?', '&lt;', '&gt;', '"'. The backslash character '\' escapes.</summary>
......
......@@ -18,17 +18,17 @@ public class DefaultInterpolatedStringHandlerTests
[InlineData(-16, 1)]
public void LengthAndHoleArguments_Valid(int literalLength, int formattedCount)
{
new DefaultInterpolatedStringHandler(literalLength, formattedCount);
new DefaultInterpolatedStringHandler(literalLength, formattedCount).ToStringAndClear();
Span<char> scratch1 = stackalloc char[1];
foreach (IFormatProvider provider in new IFormatProvider[] { null, new ConcatFormatter(), CultureInfo.InvariantCulture, CultureInfo.CurrentCulture, new CultureInfo("en-US"), new CultureInfo("fr-FR") })
{
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider);
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider).ToStringAndClear();
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, default);
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, scratch1);
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, Array.Empty<char>());
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, new char[256]);
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, default).ToStringAndClear();
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, scratch1).ToStringAndClear();
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, Array.Empty<char>()).ToStringAndClear();
new DefaultInterpolatedStringHandler(literalLength, formattedCount, provider, new char[256]).ToStringAndClear();
}
}
......@@ -244,6 +244,8 @@ public void AppendFormatted_ReferenceTypes_CreateProviderFlowed(bool useScratch)
handler.AppendFormatted(tss, 1, "X2");
Assert.Same(provider, tss.ToStringState.LastProvider);
}
handler.ToStringAndClear();
}
[Fact]
......@@ -357,6 +359,8 @@ void Test<T>(T t)
handler.AppendFormatted(t, 1, "X2");
Assert.Same(provider, ((IHasToStringState)t).ToStringState.LastProvider);
handler.ToStringAndClear();
}
Test(new FormattableInt32Wrapper(42));
......
......@@ -57,6 +57,8 @@ public static void MemoryGetsCleared(int byteLength)
{
Assert.Equal(0, testSpan[i]);
}
ArrayPool<byte>.Shared.Return(rented, clearArray: false);
}
[Fact]
......
......@@ -118,12 +118,12 @@ public virtual void RoundTripTypeNameClash()
[InlineData("{ \"key\" : \"value\" }")]
public void RoundtripJsonDocument(string json)
{
JsonDocument jsonDocument = JsonDocument.Parse(json);
using JsonDocument jsonDocument = JsonDocument.Parse(json);
string actualJson = JsonSerializer.Serialize(jsonDocument, DefaultContext.JsonDocument);
JsonTestHelper.AssertJsonEqual(json, actualJson);
JsonDocument actualJsonDocument = JsonSerializer.Deserialize(actualJson, DefaultContext.JsonDocument);
using JsonDocument actualJsonDocument = JsonSerializer.Deserialize(actualJson, DefaultContext.JsonDocument);
JsonTestHelper.AssertJsonEqual(jsonDocument.RootElement, actualJsonDocument.RootElement);
}
......@@ -135,7 +135,8 @@ public void RoundtripJsonDocument(string json)
[InlineData("{ \"key\" : \"value\" }")]
public void RoundtripJsonElement(string json)
{
JsonElement jsonElement = JsonDocument.Parse(json).RootElement;
using JsonDocument jsonDocument = JsonDocument.Parse(json);
JsonElement jsonElement = jsonDocument.RootElement;
string actualJson = JsonSerializer.Serialize(jsonElement, DefaultContext.JsonElement);
JsonTestHelper.AssertJsonEqual(json, actualJson);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册