QuicTestBase.cs 13.8 KB
Newer Older
1 2 3
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

4
using System.Buffers;
5
using System.Collections.Generic;
T
Tomas Weinfurt 已提交
6
using System.Diagnostics;
7
using System.Net.Security;
8 9 10
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
11
using System.Threading.Tasks;
12
using Xunit;
T
Tomas Weinfurt 已提交
13
using Xunit.Abstractions;
14
using System.Diagnostics.Tracing;
15
using System.Net.Sockets;
16 17 18

namespace System.Net.Quic.Tests
{
19
    public abstract class QuicTestBase
20
    {
S
Stephen Toub 已提交
21 22
        private static readonly byte[] s_ping = "PING"u8.ToArray();
        private static readonly byte[] s_pong = "PONG"u8.ToArray();
23

24
        public static bool IsSupported => QuicListener.IsSupported && QuicConnection.IsSupported;
25

26 27
        public static SslApplicationProtocol ApplicationProtocol { get; } = new SslApplicationProtocol("quictest");

28
        public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
29
        public X509Certificate2 ClientCertificate = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate();
30

T
Tomas Weinfurt 已提交
31
        public ITestOutputHelper _output;
32 33 34
        public const int PassingTestTimeoutMilliseconds = 4 * 60 * 1000;
        public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds);

T
Tomas Weinfurt 已提交
35 36 37 38
        public QuicTestBase(ITestOutputHelper output)
        {
            _output = output;
        }
39 40 41 42 43 44
        public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
        {
            Assert.Equal(ServerCertificate.GetCertHash(), certificate?.GetCertHash());
            return true;
        }

45 46 47 48 49 50 51
        public async Task<QuicException> AssertThrowsQuicExceptionAsync(QuicError expectedError, Func<Task> testCode)
        {
            QuicException ex = await Assert.ThrowsAsync<QuicException>(testCode);
            Assert.Equal(expectedError, ex.QuicError);
            return ex;
        }

52 53 54 55 56 57 58 59
        public QuicServerConnectionOptions CreateQuicServerOptions()
        {
            return new QuicServerConnectionOptions()
            {
                ServerAuthenticationOptions = GetSslServerAuthenticationOptions()
            };
        }

60 61 62 63
        public SslServerAuthenticationOptions GetSslServerAuthenticationOptions()
        {
            return new SslServerAuthenticationOptions()
            {
M
Marie Píchová 已提交
64
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
65
                ServerCertificate = ServerCertificate
66 67 68 69 70 71 72
            };
        }

        public SslClientAuthenticationOptions GetSslClientAuthenticationOptions()
        {
            return new SslClientAuthenticationOptions()
            {
73
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
74 75
                RemoteCertificateValidationCallback = RemoteCertificateValidationCallback,
                TargetHost = "localhost"
76 77 78
            };
        }

79
        public QuicClientConnectionOptions CreateQuicClientOptions(EndPoint endpoint)
80 81 82
        {
            return new QuicClientConnectionOptions()
            {
83
                RemoteEndPoint = endpoint,
84 85 86 87
                ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
            };
        }

88
        internal ValueTask<QuicConnection> CreateQuicConnection(IPEndPoint endpoint)
89
        {
90
            var options = CreateQuicClientOptions(endpoint);
91
            return CreateQuicConnection(options);
92 93
        }

94
        internal ValueTask<QuicConnection> CreateQuicConnection(QuicClientConnectionOptions clientOptions)
95
        {
96
            return QuicConnection.ConnectAsync(clientOptions);
97 98
        }

99
        internal QuicListenerOptions CreateQuicListenerOptions()
100
        {
101
            return new QuicListenerOptions()
M
Marie Píchová 已提交
102 103
            {
                ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0),
104 105
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
                ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(CreateQuicServerOptions())
M
Marie Píchová 已提交
106
            };
107 108
        }

109
        internal ValueTask<QuicListener> CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100)
110 111
        {
            var options = CreateQuicListenerOptions();
M
Marie Píchová 已提交
112
            return CreateQuicListener(options);
113 114
        }

115
        internal ValueTask<QuicListener> CreateQuicListener(IPEndPoint endpoint)
116
        {
M
Marie Píchová 已提交
117 118 119
            var options = new QuicListenerOptions()
            {
                ListenEndPoint = endpoint,
120 121
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
                ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(CreateQuicServerOptions())
M
Marie Píchová 已提交
122 123
            };
            return CreateQuicListener(options);
124 125
        }

126
        internal ValueTask<QuicListener> CreateQuicListener(QuicListenerOptions options) => QuicListener.ListenAsync(options);
127 128 129

        internal Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicListener listener) => CreateConnectedQuicConnection(null, listener);
        internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions, QuicListenerOptions listenerOptions)
T
Tomas Weinfurt 已提交
130
        {
131
            await using (QuicListener listener = await CreateQuicListener(listenerOptions))
132 133 134
            {
                clientOptions ??= new QuicClientConnectionOptions()
                {
135
                    RemoteEndPoint = listener.LocalEndPoint,
136 137
                    ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
                };
138 139 140 141
                if (clientOptions.RemoteEndPoint is IPEndPoint iPEndPoint && !iPEndPoint.Equals(listener.LocalEndPoint))
                {
                    clientOptions.RemoteEndPoint = listener.LocalEndPoint;
                }
142 143 144 145 146 147 148 149 150 151 152 153
                return await CreateConnectedQuicConnection(clientOptions, listener);
            }
        }

        internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions = null, QuicListener? listener = null)
        {
            int retry = 3;
            int delay = 25;
            bool disposeListener = false;

            if (listener == null)
            {
154
                listener = await CreateQuicListener();
155 156
                disposeListener = true;
            }
T
Tomas Weinfurt 已提交
157

158 159
            clientOptions ??= CreateQuicClientOptions(listener.LocalEndPoint);
            if (clientOptions.RemoteEndPoint is IPEndPoint iPEndPoint && !iPEndPoint.Equals(listener.LocalEndPoint))
160
            {
161
                clientOptions.RemoteEndPoint = listener.LocalEndPoint;
162 163 164
            }

            QuicConnection clientConnection = null;
T
Tomas Weinfurt 已提交
165
            ValueTask<QuicConnection> serverTask = listener.AcceptConnectionAsync();
166 167
            while (retry > 0)
            {
168
                clientConnection = await CreateQuicConnection(clientOptions);
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
                retry--;
                try
                {
                    await clientConnection.ConnectAsync().ConfigureAwait(false);
                    break;
                }
                catch (QuicException ex) when (ex.HResult == (int)SocketError.ConnectionRefused)
                {
                    _output.WriteLine($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}");
                    await Task.Delay(delay);
                    delay *= 2;

                    if (retry == 0)
                    {
                        Debug.Fail($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}");
184
                        throw ex;
185 186 187 188 189 190 191
                    }
                }
            }

            QuicConnection serverConnection = await serverTask.ConfigureAwait(false);
            if (disposeListener)
            {
192
                await listener.DisposeAsync();
193 194 195 196 197
            }

            Assert.True(serverConnection.Connected);
            Assert.True(clientConnection.Connected);

T
Tomas Weinfurt 已提交
198 199 200
            return (clientConnection, serverTask.Result);
        }

201 202
        internal async Task PingPong(QuicConnection client, QuicConnection server)
        {
203
            using QuicStream clientStream = await client.OpenBidirectionalStreamAsync();
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
            ValueTask t = clientStream.WriteAsync(s_ping);
            using QuicStream serverStream = await server.AcceptStreamAsync();

            byte[] buffer = new byte[s_ping.Length];
            int remains = s_ping.Length;
            while (remains > 0)
            {
                int readLength = await serverStream.ReadAsync(buffer, buffer.Length - remains, remains);
                Assert.True(readLength > 0);
                remains -= readLength;
            }
            Assert.Equal(s_ping, buffer);
            await t;

            t = serverStream.WriteAsync(s_pong);
            remains = s_pong.Length;
            while (remains > 0)
            {
                int readLength = await clientStream.ReadAsync(buffer, buffer.Length - remains, remains);
                Assert.True(readLength > 0);
                remains -= readLength;
            }

            Assert.Equal(s_pong, buffer);
            await t;
        }

T
Tomas Weinfurt 已提交
231
        internal async Task RunClientServer(Func<QuicConnection, Task> clientFunction, Func<QuicConnection, Task> serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds, QuicListenerOptions listenerOptions = null)
232
        {
233 234 235
            const long ClientCloseErrorCode = 11111;
            const long ServerCloseErrorCode = 22222;

236
            await using QuicListener listener = await CreateQuicListener(listenerOptions ?? CreateQuicListenerOptions());
237

238 239
            using var serverFinished = new SemaphoreSlim(0);
            using var clientFinished = new SemaphoreSlim(0);
240 241 242

            for (int i = 0; i < iterations; ++i)
            {
243 244 245
                (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
                using (clientConnection)
                using (serverConnection)
246
                {
247
                    await new[]
248
                    {
249
                        Task.Run(async () =>
T
Tomas Weinfurt 已提交
250
                        {
251 252 253 254 255
                            await serverFunction(serverConnection);
                            serverFinished.Release();
                            await clientFinished.WaitAsync();
                        }),
                        Task.Run(async () =>
T
Tomas Weinfurt 已提交
256
                        {
257 258 259 260 261 262 263 264
                            await clientFunction(clientConnection);
                            clientFinished.Release();
                            await serverFinished.WaitAsync();
                        })
                    }.WhenAllOrAnyFailed(millisecondsTimeout);
                    await serverConnection.CloseAsync(ServerCloseErrorCode);
                    await clientConnection.CloseAsync(ClientCloseErrorCode);
                }
265 266 267
            }
        }

268 269 270 271 272 273 274
        internal async Task RunStreamClientServer(Func<QuicStream, Task> clientFunction, Func<QuicStream, Task> serverFunction, bool bidi, int iterations, int millisecondsTimeout)
        {
            byte[] buffer = new byte[1] { 42 };

            await RunClientServer(
                clientFunction: async connection =>
                {
275
                    await using QuicStream stream = bidi ? await connection.OpenBidirectionalStreamAsync() : await connection.OpenUnidirectionalStreamAsync();
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
                    // Open(Bi|Uni)directionalStream only allocates ID. We will force stream opening
                    // by Writing there and receiving data on the other side.
                    await stream.WriteAsync(buffer);

                    await clientFunction(stream);

                    stream.Shutdown();
                    await stream.ShutdownCompleted();
                },
                serverFunction: async connection =>
                {
                    await using QuicStream stream = await connection.AcceptStreamAsync();
                    Assert.Equal(1, await stream.ReadAsync(buffer));

                    await serverFunction(stream);

                    stream.Shutdown();
                    await stream.ShutdownCompleted();
                },
                iterations,
                millisecondsTimeout
            );
        }

T
Tomas Weinfurt 已提交
300
        internal Task RunBidirectionalClientServer(Func<QuicStream, Task> clientFunction, Func<QuicStream, Task> serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds)
301 302
            => RunStreamClientServer(clientFunction, serverFunction, bidi: true, iterations, millisecondsTimeout);

T
Tomas Weinfurt 已提交
303
        internal Task RunUnirectionalClientServer(Func<QuicStream, Task> clientFunction, Func<QuicStream, Task> serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds)
304 305
            => RunStreamClientServer(clientFunction, serverFunction, bidi: false, iterations, millisecondsTimeout);

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
        internal static async Task<int> ReadAll(QuicStream stream, byte[] buffer)
        {
            Memory<byte> memory = buffer;
            int bytesRead = 0;
            while (true)
            {
                int res = await stream.ReadAsync(memory);
                if (res == 0)
                {
                    break;
                }
                bytesRead += res;
                memory = memory[res..];
            }

            return bytesRead;
        }

324
        internal static async Task<int> WriteForever(QuicStream stream, int size = 1)
325
        {
326 327 328 329 330 331 332 333 334
            byte[] buffer = ArrayPool<byte>.Shared.Rent(size);
            try
            {
                while (true)
                {
                    await stream.WriteAsync(buffer);
                }
            }
            finally
335
            {
336
                ArrayPool<byte>.Shared.Return(buffer);
337 338
            }
        }
339 340
    }
}