未验证 提交 296f4de3 编写于 作者: S Stephen Toub 提交者: GitHub

Allow SocketsHttpHandler.ConnectCallback for sync requests (#45300)

HTTP/1.1 support with sync requests currently doesn't work when a ConnectCallback is specified, even though a developer who wanted to make synchronous requests could provide a synchronously-completing callback.

This also consolidates the connect logic across sync/async, avoids an extra delegate, etc.
上级 ebc7d5e8
...@@ -32,52 +32,6 @@ public CertificateCallbackMapper(Func<HttpRequestMessage, X509Certificate2?, X50 ...@@ -32,52 +32,6 @@ public CertificateCallbackMapper(Func<HttpRequestMessage, X509Certificate2?, X50
} }
} }
public static async ValueTask<Stream> ConnectAsync(Func<SocketsHttpConnectionContext, CancellationToken, ValueTask<Stream>> callback, DnsEndPoint endPoint, HttpRequestMessage requestMessage, CancellationToken cancellationToken)
{
Stream stream;
try
{
stream = await callback(new SocketsHttpConnectionContext(endPoint, requestMessage), cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException ex) when (ex.CancellationToken == cancellationToken)
{
throw CancellationHelper.CreateOperationCanceledException(innerException: null, cancellationToken);
}
catch (Exception ex)
{
throw CreateWrappedException(ex, endPoint.Host, endPoint.Port, cancellationToken);
}
if (stream == null)
{
throw new HttpRequestException(SR.net_http_null_from_connect_callback);
}
return stream;
}
public static Stream Connect(string host, int port, CancellationToken cancellationToken)
{
// For synchronous connections, we can just create a socket and make the connection.
cancellationToken.ThrowIfCancellationRequested();
var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try
{
socket.NoDelay = true;
using (cancellationToken.UnsafeRegister(static s => ((Socket)s!).Dispose(), socket))
{
socket.Connect(new DnsEndPoint(host, port));
}
return new NetworkStream(socket, ownsSocket: true);
}
catch (Exception e)
{
socket.Dispose();
throw CreateWrappedException(e, host, port, cancellationToken);
}
}
public static ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken) public static ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken)
{ {
// If there's a cert validation callback, and if it came from HttpClientHandler, // If there's a cert validation callback, and if it came from HttpClientHandler,
...@@ -161,7 +115,7 @@ public static async ValueTask<QuicConnection> ConnectQuicAsync(QuicImplementatio ...@@ -161,7 +115,7 @@ public static async ValueTask<QuicConnection> ConnectQuicAsync(QuicImplementatio
} }
} }
private static Exception CreateWrappedException(Exception error, string host, int port, CancellationToken cancellationToken) internal static Exception CreateWrappedException(Exception error, string host, int port, CancellationToken cancellationToken)
{ {
return CancellationHelper.ShouldWrapInOperationCanceledException(error, cancellationToken) ? return CancellationHelper.ShouldWrapInOperationCanceledException(error, cancellationToken) ?
CancellationHelper.CreateOperationCanceledException(error, cancellationToken) : CancellationHelper.CreateOperationCanceledException(error, cancellationToken) :
......
...@@ -1282,49 +1282,62 @@ private async ValueTask<(Stream?, TransportContext?, HttpResponseMessage?)> Conn ...@@ -1282,49 +1282,62 @@ private async ValueTask<(Stream?, TransportContext?, HttpResponseMessage?)> Conn
} }
} }
private static async ValueTask<Stream> DefaultConnectAsync(SocketsHttpConnectionContext context, CancellationToken cancellationToken) private async ValueTask<Stream> ConnectToTcpHostAsync(string host, int port, HttpRequestMessage initialRequest, bool async, CancellationToken cancellationToken)
{ {
Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); cancellationToken.ThrowIfCancellationRequested();
socket.NoDelay = true;
var endPoint = new DnsEndPoint(host, port);
Socket? socket = null;
try try
{ {
await socket.ConnectAsync(context.DnsEndPoint, cancellationToken).ConfigureAwait(false); // If a ConnectCallback was supplied, use that to establish the connection.
return new NetworkStream(socket, ownsSocket: true); if (Settings._connectCallback != null)
} {
catch ValueTask<Stream> streamTask = Settings._connectCallback(new SocketsHttpConnectionContext(endPoint, initialRequest), cancellationToken);
{
socket.Dispose();
throw;
}
}
private static readonly Func<SocketsHttpConnectionContext, CancellationToken, ValueTask<Stream>> s_defaultConnectCallback = DefaultConnectAsync;
private ValueTask<Stream> ConnectToTcpHostAsync(string host, int port, HttpRequestMessage initialRequest, bool async, CancellationToken cancellationToken)
{
if (async)
{
Func<SocketsHttpConnectionContext, CancellationToken, ValueTask<Stream>> connectCallback = Settings._connectCallback ?? s_defaultConnectCallback;
var endPoint = new DnsEndPoint(host, port); Stream stream;
return ConnectHelper.ConnectAsync(connectCallback, endPoint, initialRequest, cancellationToken); if (async || streamTask.IsCompleted)
} {
stream = await streamTask.ConfigureAwait(false);
}
else
{
// User-provided ConnectCallback is completing asynchronously but the user is making a synchronous request; if the user cares, they should
// set it up so that synchronous requests are made on a handler with a synchronously-completing ConnectCallback supplied. If in the future,
// we could add a Boolean to SocketsHttpConnectionContext (https://github.com/dotnet/runtime/issues/44876) to let the callback know whether
// this request is sync or async. For now, log it and block.
Trace($"{nameof(SocketsHttpHandler.ConnectCallback)} completing asynchronously for a synchronous request.");
stream = streamTask.AsTask().GetAwaiter().GetResult();
}
// Synchronous path. return stream ?? throw new HttpRequestException(SR.net_http_null_from_connect_callback);
}
else
{
// Otherwise, create and connect a socket using default settings.
socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
if (Settings._connectCallback is not null) if (async)
{ {
throw new NotSupportedException(SR.net_http_sync_operations_not_allowed_with_connect_callback); await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);
} }
else
{
using (cancellationToken.UnsafeRegister(static s => ((Socket)s!).Dispose(), socket))
{
socket.Connect(endPoint);
}
}
try return new NetworkStream(socket, ownsSocket: true);
{ }
return new ValueTask<Stream>(ConnectHelper.Connect(host, port, cancellationToken));
} }
catch (Exception ex) catch (Exception ex)
{ {
return ValueTask.FromException<Stream>(ex); socket?.Dispose();
throw ex is OperationCanceledException oce && oce.CancellationToken == cancellationToken ?
CancellationHelper.CreateOperationCanceledException(innerException: null, cancellationToken) :
ConnectHelper.CreateWrappedException(ex, endPoint.Host, endPoint.Port, cancellationToken);
} }
} }
......
...@@ -2335,22 +2335,19 @@ public abstract class SocketsHttpHandlerTest_ConnectCallback : HttpClientHandler ...@@ -2335,22 +2335,19 @@ public abstract class SocketsHttpHandlerTest_ConnectCallback : HttpClientHandler
{ {
public SocketsHttpHandlerTest_ConnectCallback(ITestOutputHelper output) : base(output) { } public SocketsHttpHandlerTest_ConnectCallback(ITestOutputHelper output) : base(output) { }
[Fact] [Theory]
public void ConnectCallback_SyncRequest_Fails() [InlineData(false, false)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public async Task ConnectCallback_ContextHasCorrectProperties_Success(bool syncRequest, bool syncCallback)
{ {
using SocketsHttpHandler handler = new SocketsHttpHandler if (syncRequest && UseVersion > HttpVersion.Version11)
{ {
ConnectCallback = (context, token) => default, // Sync requests are only supported on 1.x
}; return;
}
using HttpClient client = CreateHttpClient(handler);
Assert.ThrowsAny<NotSupportedException>(() => client.Send(new HttpRequestMessage(HttpMethod.Get, $"http://bing.com")));
}
[Fact]
public async Task ConnectCallback_ContextHasCorrectProperties_Success()
{
await LoopbackServerFactory.CreateClientAndServerAsync( await LoopbackServerFactory.CreateClientAndServerAsync(
async uri => async uri =>
{ {
...@@ -2367,14 +2364,23 @@ public async Task ConnectCallback_ContextHasCorrectProperties_Success() ...@@ -2367,14 +2364,23 @@ public async Task ConnectCallback_ContextHasCorrectProperties_Success()
Assert.Equal(uri.Port, context.DnsEndPoint.Port); Assert.Equal(uri.Port, context.DnsEndPoint.Port);
Assert.Equal(requestMessage, context.InitialRequestMessage); Assert.Equal(requestMessage, context.InitialRequestMessage);
Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await s.ConnectAsync(context.DnsEndPoint, token); if (syncCallback)
{
s.Connect(context.DnsEndPoint);
}
else
{
await s.ConnectAsync(context.DnsEndPoint, token);
}
return new NetworkStream(s, ownsSocket: true); return new NetworkStream(s, ownsSocket: true);
}; };
using HttpClient client = CreateHttpClient(handler); using HttpClient client = CreateHttpClient(handler);
HttpResponseMessage response = await client.SendAsync(requestMessage); HttpResponseMessage response = await (syncRequest ?
Task.Run(() => client.Send(requestMessage)) :
client.SendAsync(requestMessage));
Assert.Equal("foo", await response.Content.ReadAsStringAsync()); Assert.Equal("foo", await response.Content.ReadAsStringAsync());
}, },
async server => async server =>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册