QuicTestBase.cs 13.5 KB
Newer Older
1 2 3
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

4
using System.Buffers;
5
using System.Collections.Generic;
T
Tomas Weinfurt 已提交
6
using System.Diagnostics;
7
using System.Net.Security;
8 9 10
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
11
using System.Threading.Tasks;
12
using Xunit;
T
Tomas Weinfurt 已提交
13
using Xunit.Abstractions;
14
using System.Diagnostics.Tracing;
15
using System.Net.Sockets;
16 17 18

namespace System.Net.Quic.Tests
{
19
    public abstract class QuicTestBase
20
    {
S
Stephen Toub 已提交
21 22
        private static readonly byte[] s_ping = "PING"u8.ToArray();
        private static readonly byte[] s_pong = "PONG"u8.ToArray();
23

24
        public static bool IsSupported => QuicListener.IsSupported && QuicConnection.IsSupported;
25

26 27
        public static SslApplicationProtocol ApplicationProtocol { get; } = new SslApplicationProtocol("quictest");

28
        public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
29
        public X509Certificate2 ClientCertificate = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate();
30

T
Tomas Weinfurt 已提交
31
        public ITestOutputHelper _output;
32 33 34
        public const int PassingTestTimeoutMilliseconds = 4 * 60 * 1000;
        public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds);

T
Tomas Weinfurt 已提交
35 36 37 38
        public QuicTestBase(ITestOutputHelper output)
        {
            _output = output;
        }
39 40 41 42 43 44
        public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
        {
            Assert.Equal(ServerCertificate.GetCertHash(), certificate?.GetCertHash());
            return true;
        }

45 46 47 48 49 50 51 52
        public QuicServerConnectionOptions CreateQuicServerOptions()
        {
            return new QuicServerConnectionOptions()
            {
                ServerAuthenticationOptions = GetSslServerAuthenticationOptions()
            };
        }

53 54 55 56
        public SslServerAuthenticationOptions GetSslServerAuthenticationOptions()
        {
            return new SslServerAuthenticationOptions()
            {
M
Marie Píchová 已提交
57
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
58
                ServerCertificate = ServerCertificate
59 60 61 62 63 64 65
            };
        }

        public SslClientAuthenticationOptions GetSslClientAuthenticationOptions()
        {
            return new SslClientAuthenticationOptions()
            {
66
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
67 68
                RemoteCertificateValidationCallback = RemoteCertificateValidationCallback,
                TargetHost = "localhost"
69 70 71
            };
        }

72
        public QuicClientConnectionOptions CreateQuicClientOptions(EndPoint endpoint)
73 74 75
        {
            return new QuicClientConnectionOptions()
            {
76
                RemoteEndPoint = endpoint,
77 78 79 80
                ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
            };
        }

81
        internal ValueTask<QuicConnection> CreateQuicConnection(IPEndPoint endpoint)
82
        {
83
            var options = CreateQuicClientOptions(endpoint);
84
            return CreateQuicConnection(options);
85 86
        }

87
        internal ValueTask<QuicConnection> CreateQuicConnection(QuicClientConnectionOptions clientOptions)
88
        {
89
            return QuicConnection.ConnectAsync(clientOptions);
90 91
        }

92
        internal QuicListenerOptions CreateQuicListenerOptions()
93
        {
94
            return new QuicListenerOptions()
M
Marie Píchová 已提交
95 96
            {
                ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0),
97 98
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
                ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(CreateQuicServerOptions())
M
Marie Píchová 已提交
99
            };
100 101
        }

102
        internal ValueTask<QuicListener> CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100)
103 104
        {
            var options = CreateQuicListenerOptions();
M
Marie Píchová 已提交
105
            return CreateQuicListener(options);
106 107
        }

108
        internal ValueTask<QuicListener> CreateQuicListener(IPEndPoint endpoint)
109
        {
M
Marie Píchová 已提交
110 111 112
            var options = new QuicListenerOptions()
            {
                ListenEndPoint = endpoint,
113 114
                ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
                ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(CreateQuicServerOptions())
M
Marie Píchová 已提交
115 116
            };
            return CreateQuicListener(options);
117 118
        }

119
        internal ValueTask<QuicListener> CreateQuicListener(QuicListenerOptions options) => QuicListener.ListenAsync(options);
120 121 122

        internal Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicListener listener) => CreateConnectedQuicConnection(null, listener);
        internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions, QuicListenerOptions listenerOptions)
T
Tomas Weinfurt 已提交
123
        {
124
            await using (QuicListener listener = await CreateQuicListener(listenerOptions))
125 126 127
            {
                clientOptions ??= new QuicClientConnectionOptions()
                {
128
                    RemoteEndPoint = listener.LocalEndPoint,
129 130
                    ClientAuthenticationOptions = GetSslClientAuthenticationOptions()
                };
131 132 133 134
                if (clientOptions.RemoteEndPoint is IPEndPoint iPEndPoint && !iPEndPoint.Equals(listener.LocalEndPoint))
                {
                    clientOptions.RemoteEndPoint = listener.LocalEndPoint;
                }
135 136 137 138 139 140 141 142 143 144 145 146
                return await CreateConnectedQuicConnection(clientOptions, listener);
            }
        }

        internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnection(QuicClientConnectionOptions? clientOptions = null, QuicListener? listener = null)
        {
            int retry = 3;
            int delay = 25;
            bool disposeListener = false;

            if (listener == null)
            {
147
                listener = await CreateQuicListener();
148 149
                disposeListener = true;
            }
T
Tomas Weinfurt 已提交
150

151 152
            clientOptions ??= CreateQuicClientOptions(listener.LocalEndPoint);
            if (clientOptions.RemoteEndPoint is IPEndPoint iPEndPoint && !iPEndPoint.Equals(listener.LocalEndPoint))
153
            {
154
                clientOptions.RemoteEndPoint = listener.LocalEndPoint;
155 156 157
            }

            QuicConnection clientConnection = null;
T
Tomas Weinfurt 已提交
158
            ValueTask<QuicConnection> serverTask = listener.AcceptConnectionAsync();
159 160
            while (retry > 0)
            {
161
                clientConnection = await CreateQuicConnection(clientOptions);
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
                retry--;
                try
                {
                    await clientConnection.ConnectAsync().ConfigureAwait(false);
                    break;
                }
                catch (QuicException ex) when (ex.HResult == (int)SocketError.ConnectionRefused)
                {
                    _output.WriteLine($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}");
                    await Task.Delay(delay);
                    delay *= 2;

                    if (retry == 0)
                    {
                        Debug.Fail($"ConnectAsync to {clientConnection.RemoteEndPoint} failed with {ex.Message}");
177
                        throw ex;
178 179 180 181 182 183 184
                    }
                }
            }

            QuicConnection serverConnection = await serverTask.ConfigureAwait(false);
            if (disposeListener)
            {
185
                await listener.DisposeAsync();
186 187 188 189 190
            }

            Assert.True(serverConnection.Connected);
            Assert.True(clientConnection.Connected);

T
Tomas Weinfurt 已提交
191 192 193
            return (clientConnection, serverTask.Result);
        }

194 195
        internal async Task PingPong(QuicConnection client, QuicConnection server)
        {
196
            using QuicStream clientStream = await client.OpenBidirectionalStreamAsync();
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
            ValueTask t = clientStream.WriteAsync(s_ping);
            using QuicStream serverStream = await server.AcceptStreamAsync();

            byte[] buffer = new byte[s_ping.Length];
            int remains = s_ping.Length;
            while (remains > 0)
            {
                int readLength = await serverStream.ReadAsync(buffer, buffer.Length - remains, remains);
                Assert.True(readLength > 0);
                remains -= readLength;
            }
            Assert.Equal(s_ping, buffer);
            await t;

            t = serverStream.WriteAsync(s_pong);
            remains = s_pong.Length;
            while (remains > 0)
            {
                int readLength = await clientStream.ReadAsync(buffer, buffer.Length - remains, remains);
                Assert.True(readLength > 0);
                remains -= readLength;
            }

            Assert.Equal(s_pong, buffer);
            await t;
        }

T
Tomas Weinfurt 已提交
224
        internal async Task RunClientServer(Func<QuicConnection, Task> clientFunction, Func<QuicConnection, Task> serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds, QuicListenerOptions listenerOptions = null)
225
        {
226 227 228
            const long ClientCloseErrorCode = 11111;
            const long ServerCloseErrorCode = 22222;

229
            await using QuicListener listener = await CreateQuicListener(listenerOptions ?? CreateQuicListenerOptions());
230

231 232
            using var serverFinished = new SemaphoreSlim(0);
            using var clientFinished = new SemaphoreSlim(0);
233 234 235

            for (int i = 0; i < iterations; ++i)
            {
236 237 238
                (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
                using (clientConnection)
                using (serverConnection)
239
                {
240
                    await new[]
241
                    {
242
                        Task.Run(async () =>
T
Tomas Weinfurt 已提交
243
                        {
244 245 246 247 248
                            await serverFunction(serverConnection);
                            serverFinished.Release();
                            await clientFinished.WaitAsync();
                        }),
                        Task.Run(async () =>
T
Tomas Weinfurt 已提交
249
                        {
250 251 252 253 254 255 256 257
                            await clientFunction(clientConnection);
                            clientFinished.Release();
                            await serverFinished.WaitAsync();
                        })
                    }.WhenAllOrAnyFailed(millisecondsTimeout);
                    await serverConnection.CloseAsync(ServerCloseErrorCode);
                    await clientConnection.CloseAsync(ClientCloseErrorCode);
                }
258 259 260
            }
        }

261 262 263 264 265 266 267
        internal async Task RunStreamClientServer(Func<QuicStream, Task> clientFunction, Func<QuicStream, Task> serverFunction, bool bidi, int iterations, int millisecondsTimeout)
        {
            byte[] buffer = new byte[1] { 42 };

            await RunClientServer(
                clientFunction: async connection =>
                {
268
                    await using QuicStream stream = bidi ? await connection.OpenBidirectionalStreamAsync() : await connection.OpenUnidirectionalStreamAsync();
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
                    // Open(Bi|Uni)directionalStream only allocates ID. We will force stream opening
                    // by Writing there and receiving data on the other side.
                    await stream.WriteAsync(buffer);

                    await clientFunction(stream);

                    stream.Shutdown();
                    await stream.ShutdownCompleted();
                },
                serverFunction: async connection =>
                {
                    await using QuicStream stream = await connection.AcceptStreamAsync();
                    Assert.Equal(1, await stream.ReadAsync(buffer));

                    await serverFunction(stream);

                    stream.Shutdown();
                    await stream.ShutdownCompleted();
                },
                iterations,
                millisecondsTimeout
            );
        }

T
Tomas Weinfurt 已提交
293
        internal Task RunBidirectionalClientServer(Func<QuicStream, Task> clientFunction, Func<QuicStream, Task> serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds)
294 295
            => RunStreamClientServer(clientFunction, serverFunction, bidi: true, iterations, millisecondsTimeout);

T
Tomas Weinfurt 已提交
296
        internal Task RunUnirectionalClientServer(Func<QuicStream, Task> clientFunction, Func<QuicStream, Task> serverFunction, int iterations = 1, int millisecondsTimeout = PassingTestTimeoutMilliseconds)
297 298
            => RunStreamClientServer(clientFunction, serverFunction, bidi: false, iterations, millisecondsTimeout);

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        internal static async Task<int> ReadAll(QuicStream stream, byte[] buffer)
        {
            Memory<byte> memory = buffer;
            int bytesRead = 0;
            while (true)
            {
                int res = await stream.ReadAsync(memory);
                if (res == 0)
                {
                    break;
                }
                bytesRead += res;
                memory = memory[res..];
            }

            return bytesRead;
        }

317
        internal static async Task<int> WriteForever(QuicStream stream, int size = 1)
318
        {
319 320 321 322 323 324 325 326 327
            byte[] buffer = ArrayPool<byte>.Shared.Rent(size);
            try
            {
                while (true)
                {
                    await stream.WriteAsync(buffer);
                }
            }
            finally
328
            {
329
                ArrayPool<byte>.Shared.Return(buffer);
330 331
            }
        }
332 333
    }
}