提交 8c63dc6a 编写于 作者: S Sam Harwell

Prevent SqlConnection from being held past an async boundary

Fixes #22650
上级 9577061e
......@@ -22,7 +22,7 @@ internal enum OpenFlags
// SQLITE_OPEN_TEMP_JOURNAL = 0x00001000, /* VFS only */
// SQLITE_OPEN_SUBJOURNAL = 0x00002000, /* VFS only */
// SQLITE_OPEN_MASTER_JOURNAL = 0x00004000, /* VFS only */
// SQLITE_OPEN_NOMUTEX = 0x00008000, /* Ok for sqlite3_open_v2() */
SQLITE_OPEN_NOMUTEX = 0x00008000, /* Ok for sqlite3_open_v2() */
// SQLITE_OPEN_FULLMUTEX = 0x00010000, /* Ok for sqlite3_open_v2() */
SQLITE_OPEN_SHAREDCACHE = 0x00020000, /* Ok for sqlite3_open_v2() */
// SQLITE_OPEN_PRIVATECACHE = 0x00040000, /* Ok for sqlite3_open_v2() */
......
......@@ -53,7 +53,7 @@ public static SqlConnection Create(IPersistentStorageFaultInjector faultInjector
// Enable shared cache so that multiple connections inside of same process share cache
// see https://sqlite.org/threadsafe.html for more detail
var flags = OpenFlags.SQLITE_OPEN_CREATE | OpenFlags.SQLITE_OPEN_READWRITE | OpenFlags.SQLITE_OPEN_SHAREDCACHE;
var flags = OpenFlags.SQLITE_OPEN_CREATE | OpenFlags.SQLITE_OPEN_READWRITE | OpenFlags.SQLITE_OPEN_NOMUTEX | OpenFlags.SQLITE_OPEN_SHAREDCACHE;
var result = (Result)raw.sqlite3_open_v2(databasePath, out var handle, (int)flags, vfs: null);
if (result != Result.OK)
......
......@@ -58,25 +58,31 @@ public async Task<Stream> ReadStreamAsync(TKey key, CancellationToken cancellati
if (!Storage._shutdownTokenSource.IsCancellationRequested)
{
bool haveDataId;
TDatabaseId dataId;
using (var pooledConnection = Storage.GetPooledConnection())
{
var connection = pooledConnection.Connection;
if (TryGetDatabaseId(connection, key, out var dataId))
{
// Ensure all pending document writes to this name are flushed to the DB so that
// we can find them below.
await FlushPendingWritesAsync(connection, key, cancellationToken).ConfigureAwait(false);
haveDataId = TryGetDatabaseId(pooledConnection.Connection, key, out dataId);
}
try
if (haveDataId)
{
// Ensure all pending document writes to this name are flushed to the DB so that
// we can find them below.
await FlushPendingWritesAsync(key, cancellationToken).ConfigureAwait(false);
try
{
using (var pooledConnection = Storage.GetPooledConnection())
{
// Lookup the row from the DocumentData table corresponding to our dataId.
return ReadBlob(connection, dataId);
}
catch (Exception ex)
{
StorageDatabaseLogger.LogException(ex);
return ReadBlob(pooledConnection.Connection, dataId);
}
}
catch (Exception ex)
{
StorageDatabaseLogger.LogException(ex);
}
}
}
......@@ -94,33 +100,36 @@ public async Task<Stream> ReadStreamAsync(TKey key, CancellationToken cancellati
if (!Storage._shutdownTokenSource.IsCancellationRequested)
{
bool haveDataId;
TDatabaseId dataId;
using (var pooledConnection = Storage.GetPooledConnection())
{
// Determine the appropriate data-id to store this stream at.
if (TryGetDatabaseId(pooledConnection.Connection, key, out var dataId))
{
var (bytes, length, pooled) = GetBytes(stream);
haveDataId = TryGetDatabaseId(pooledConnection.Connection, key, out dataId);
}
await AddWriteTaskAsync(key, con =>
if (haveDataId)
{
var (bytes, length, pooled) = GetBytes(stream);
await AddWriteTaskAsync(key, con =>
{
InsertOrReplaceBlob(con, dataId, bytes, length);
if (pooled)
{
InsertOrReplaceBlob(con, dataId, bytes, length);
if (pooled)
{
ReturnPooledBytes(bytes);
}
}, cancellationToken).ConfigureAwait(false);
return true;
}
ReturnPooledBytes(bytes);
}
}, cancellationToken).ConfigureAwait(false);
return true;
}
}
return false;
}
private Task FlushPendingWritesAsync(SqlConnection connection, TKey key, CancellationToken cancellationToken)
=> Storage.FlushSpecificWritesAsync(
connection, _writeQueueKeyToWrites, _writeQueueKeyToWriteTask, GetWriteQueueKey(key), cancellationToken);
private Task FlushPendingWritesAsync(TKey key, CancellationToken cancellationToken)
=> Storage.FlushSpecificWritesAsync(_writeQueueKeyToWrites, _writeQueueKeyToWriteTask, GetWriteQueueKey(key), cancellationToken);
private Task AddWriteTaskAsync(TKey key, Action<SqlConnection> action, CancellationToken cancellationToken)
=> Storage.AddWriteTaskAsync(_writeQueueKeyToWrites, GetWriteQueueKey(key), action, cancellationToken);
......
......@@ -233,6 +233,14 @@ private void CloseWorker()
}
}
/// <summary>
/// Gets an <see cref="SqlConnection"/> from the connection pool, or creates one if none are available.
/// </summary>
/// <remarks>
/// Database connections have a large amount of overhead, and should be returned to the pool when they are no
/// longer in use. In particular, make sure to avoid letting a connection lease cross an <see langword="await"/>
/// boundary, as it will prevent code in the asynchronous operation from using the existing connection.
/// </remarks>
private PooledConnection GetPooledConnection()
=> new PooledConnection(this, GetConnection());
......
......@@ -50,16 +50,15 @@ internal partial class SQLitePersistentStorage
}
private async Task FlushSpecificWritesAsync<TKey>(
SqlConnection connection,
MultiDictionary<TKey, Action<SqlConnection>> keyToWriteActions,
Dictionary<TKey, Task> keyToWriteTask,
TKey key, CancellationToken cancellationToken)
TKey key,
CancellationToken cancellationToken)
{
var writesToProcess = ArrayBuilder<Action<SqlConnection>>.GetInstance();
try
{
await FlushSpecificWritesAsync(
connection, keyToWriteActions, keyToWriteTask, key, writesToProcess, cancellationToken).ConfigureAwait(false);
await FlushSpecificWritesAsync(keyToWriteActions, keyToWriteTask, key, writesToProcess, cancellationToken).ConfigureAwait(false);
}
finally
{
......@@ -68,8 +67,9 @@ internal partial class SQLitePersistentStorage
}
private async Task FlushSpecificWritesAsync<TKey>(
SqlConnection connection, MultiDictionary<TKey, Action<SqlConnection>> keyToWriteActions,
Dictionary<TKey, Task> keyToWriteTask, TKey key,
MultiDictionary<TKey, Action<SqlConnection>> keyToWriteActions,
Dictionary<TKey, Task> keyToWriteTask,
TKey key,
ArrayBuilder<Action<SqlConnection>> writesToProcess,
CancellationToken cancellationToken)
{
......@@ -96,7 +96,10 @@ internal partial class SQLitePersistentStorage
// would be losing data.
Debug.Assert(taskCompletionSource != null);
ProcessWriteQueue(connection, writesToProcess);
using (var pooledConnection = GetPooledConnection())
{
ProcessWriteQueue(pooledConnection.Connection, writesToProcess);
}
}
catch (OperationCanceledException ex)
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册