diff --git a/src/libraries/Common/tests/System/Diagnostics/Tracing/TestEventListener.cs b/src/libraries/Common/tests/System/Diagnostics/Tracing/TestEventListener.cs
index b12f5022c62f18188dd616267e5be85e6a5ab3ff..74c0d9fa9395a6d423c1d953fa8ad077e1598f68 100644
--- a/src/libraries/Common/tests/System/Diagnostics/Tracing/TestEventListener.cs
+++ b/src/libraries/Common/tests/System/Diagnostics/Tracing/TestEventListener.cs
@@ -9,90 +9,91 @@ namespace System.Diagnostics.Tracing
/// Simple event listener than invokes a callback for each event received.
internal sealed class TestEventListener : EventListener
{
- private readonly string _targetSourceName;
- private readonly Guid _targetSourceGuid;
- private readonly EventLevel _level;
+ private class Settings
+ {
+ public EventLevel Level;
+ public EventKeywords Keywords;
+ }
+
+ private readonly Dictionary _names = new Dictionary();
+ private readonly Dictionary _guids = new Dictionary();
+
private readonly double? _eventCounterInterval;
private Action _eventWritten;
- private List _tmpEventSourceList = new List();
+ private readonly List _eventSourceList = new List();
public TestEventListener(string targetSourceName, EventLevel level, double? eventCounterInterval = null)
{
- // Store the arguments
- _targetSourceName = targetSourceName;
- _level = level;
_eventCounterInterval = eventCounterInterval;
-
- LoadSourceList();
+ AddSource(targetSourceName, level);
}
public TestEventListener(Guid targetSourceGuid, EventLevel level, double? eventCounterInterval = null)
{
- // Store the arguments
- _targetSourceGuid = targetSourceGuid;
- _level = level;
_eventCounterInterval = eventCounterInterval;
-
- LoadSourceList();
+ AddSource(targetSourceGuid, level);
}
- private void LoadSourceList()
+ public void AddSource(string name, EventLevel level, EventKeywords keywords = EventKeywords.All) =>
+ AddSource(name, null, level, keywords);
+
+ public void AddSource(Guid guid, EventLevel level, EventKeywords keywords = EventKeywords.All) =>
+ AddSource(null, guid, level, keywords);
+
+ private void AddSource(string name, Guid? guid, EventLevel level, EventKeywords keywords)
{
- // The base constructor, which is called before this constructor,
- // will invoke the virtual OnEventSourceCreated method for each
- // existing EventSource, which means OnEventSourceCreated will be
- // called before _targetSourceGuid and _level have been set. As such,
- // we store a temporary list that just exists from the moment this instance
- // is created (instance field initializers run before the base constructor)
- // and until we finish construction... in that window, OnEventSourceCreated
- // will store the sources into the list rather than try to enable them directly,
- // and then here we can enumerate that list, then clear it out.
- List sources;
- lock (_tmpEventSourceList)
+ lock (_eventSourceList)
{
- sources = _tmpEventSourceList;
- _tmpEventSourceList = null;
- }
- foreach (EventSource source in sources)
- {
- EnableSourceIfMatch(source);
+ var settings = new Settings()
+ {
+ Level = level,
+ Keywords = keywords
+ };
+
+ if (name is not null)
+ _names.Add(name, settings);
+
+ if (guid.HasValue)
+ _guids.Add(guid.Value, settings);
+
+ foreach (EventSource source in _eventSourceList)
+ {
+ if (name == source.Name || guid == source.Guid)
+ {
+ EnableEventSource(source, level, keywords);
+ }
+ }
}
}
+ public void AddActivityTracking() =>
+ AddSource("System.Threading.Tasks.TplEventSource", EventLevel.Informational, (EventKeywords)0x80 /* TasksFlowActivityIds */);
+
protected override void OnEventSourceCreated(EventSource eventSource)
{
- List tmp = _tmpEventSourceList;
- if (tmp != null)
+ lock (_eventSourceList)
{
- lock (tmp)
+ _eventSourceList.Add(eventSource);
+
+ if (_names.TryGetValue(eventSource.Name, out Settings settings) ||
+ _guids.TryGetValue(eventSource.Guid, out settings))
{
- if (_tmpEventSourceList != null)
- {
- _tmpEventSourceList.Add(eventSource);
- return;
- }
+ EnableEventSource(eventSource, settings.Level, settings.Keywords);
}
}
-
- EnableSourceIfMatch(eventSource);
}
- private void EnableSourceIfMatch(EventSource source)
+ private void EnableEventSource(EventSource source, EventLevel level, EventKeywords keywords)
{
- if (source.Name.Equals(_targetSourceName) ||
- source.Guid.Equals(_targetSourceGuid))
+ var args = new Dictionary();
+
+ if (_eventCounterInterval != null)
{
- if (_eventCounterInterval != null)
- {
- var args = new Dictionary { { "EventCounterIntervalSec", _eventCounterInterval?.ToString() } };
- EnableEvents(source, _level, EventKeywords.All, args);
- }
- else
- {
- EnableEvents(source, _level);
- }
+ args.Add("EventCounterIntervalSec", _eventCounterInterval.ToString());
}
+
+ EnableEvents(source, level, keywords, args);
}
public void RunWithCallback(Action handler, Action body)
diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs
index 98cb7ccbc0b559d108e13e1ffac76080a1075e25..60cc964ec08ceab470ca01788582b8585dd2f3b6 100644
--- a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs
+++ b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs
@@ -466,34 +466,9 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR
if (NameResolutionTelemetry.Log.IsEnabled())
{
- ValueStopwatch stopwatch = NameResolutionTelemetry.Log.BeforeResolution(hostName);
-
- Task coreTask;
- try
- {
- coreTask = NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses);
- }
- catch when (LogFailure(stopwatch))
- {
- Debug.Fail("LogFailure should return false");
- throw;
- }
-
- coreTask.ContinueWith(
- (task, state) =>
- {
- NameResolutionTelemetry.Log.AfterResolution(
- stopwatch: (ValueStopwatch)state!,
- successful: task.IsCompletedSuccessfully);
- },
- state: stopwatch,
- cancellationToken: default,
- TaskContinuationOptions.ExecuteSynchronously,
- TaskScheduler.Default);
-
- // coreTask is not actually a base Task, but Task / Task
- // We have to return it and not the continuation
- return coreTask;
+ return justAddresses
+ ? (Task)GetAddrInfoWithTelemetryAsync(hostName, justAddresses)
+ : (Task)GetAddrInfoWithTelemetryAsync(hostName, justAddresses);
}
else
{
@@ -506,6 +481,23 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR
RunAsync(s => GetHostEntryCore((string)s), hostName);
}
+ private static async Task GetAddrInfoWithTelemetryAsync(string hostName, bool justAddresses)
+ where T : class
+ {
+ ValueStopwatch stopwatch = NameResolutionTelemetry.Log.BeforeResolution(hostName);
+
+ T? result = null;
+ try
+ {
+ result = await ((Task)NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses)).ConfigureAwait(false);
+ return result;
+ }
+ finally
+ {
+ NameResolutionTelemetry.Log.AfterResolution(stopwatch, successful: result is not null);
+ }
+ }
+
private static Task RunAsync(Func