NpgsqlModificationCommandBatch.cs 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal;

/// <summary>
/// The Npgsql-specific implementation for <see cref="ModificationCommandBatch" />.
/// </summary>
/// <remarks>
/// The usual ModificationCommandBatch implementation is <see cref="AffectedCountModificationCommandBatch"/>,
/// which selects the number of rows modified via a SQL query.
///
/// PostgreSQL actually has no way of selecting the modified row count.
/// SQL defines GET DIAGNOSTICS which should provide this, but in PostgreSQL it's only available
/// in PL/pgSQL. See http://www.postgresql.org/docs/9.4/static/unsupported-features-sql-standard.html,
/// identifier F121-01.
///
/// Instead, the affected row count can be accessed in the PostgreSQL protocol itself, which seems
/// cleaner and more efficient anyway (no additional query).
/// </remarks>
public class NpgsqlModificationCommandBatch : ReaderModificationCommandBatch
19
{
20
    /// <summary>
21
    /// Constructs an instance of the <see cref="NpgsqlModificationCommandBatch"/> class.
22
    /// </summary>
23 24
    public NpgsqlModificationCommandBatch(
        ModificationCommandBatchFactoryDependencies dependencies,
S
Shay Rojansky 已提交
25
        int maxBatchSize)
26
        : base(dependencies)
S
Shay Rojansky 已提交
27
        => MaxBatchSize = maxBatchSize;
28

S
Shay Rojansky 已提交
29 30 31 32
    /// <summary>
    ///     The maximum number of <see cref="ModificationCommand"/> instances that can be added to a single batch; defaults to 1000.
    /// </summary>
    protected override int MaxBatchSize { get; }
S
Shay Rojansky 已提交
33

34 35 36 37 38 39
    /// <summary>
    ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
    ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
    ///     any release. You should only use it directly in your code with extreme caution and knowing that
    ///     doing so can result in application failures when updating to a new Entity Framework Core release.
    /// </summary>
40 41 42
    protected override void Consume(RelationalDataReader reader)
    {
        var npgsqlReader = (NpgsqlDataReader)reader.DbDataReader;
S
Shay Rojansky 已提交
43 44

#pragma warning disable 618
45
        Debug.Assert(npgsqlReader.Statements.Count == ModificationCommands.Count, $"Reader has {npgsqlReader.Statements.Count} statements, expected {ModificationCommands.Count}");
S
Shay Rojansky 已提交
46 47
#pragma warning restore 618

48
        var commandIndex = 0;
49

50 51 52
        try
        {
            while (true)
53
            {
54 55 56 57
                // Find the next propagating command, if any
                int nextPropagating;
                for (nextPropagating = commandIndex;
                     nextPropagating < ModificationCommands.Count &&
58
                     !ResultSetMappings[nextPropagating].HasFlag(ResultSetMapping.HasResultRow);
59
                     nextPropagating++)
60
                {
61
                }
62

63 64 65 66
                // Go over all non-propagating commands before the next propagating one,
                // make sure they executed
                for (; commandIndex < nextPropagating; commandIndex++)
                {
S
Shay Rojansky 已提交
67
#pragma warning disable 618
68
                    if (npgsqlReader.Statements[commandIndex].Rows == 0)
69
                    {
70
                        ThrowAggregateUpdateConcurrencyException(reader, commandIndex, 1, 0);
71
                    }
72 73 74 75 76 77 78 79 80 81
#pragma warning restore 618
                }

                if (nextPropagating == ModificationCommands.Count)
                {
                    Debug.Assert(!npgsqlReader.NextResult(), "Expected less resultsets");
                    break;
                }

                // Propagate to results from the reader to the ModificationCommand
82

83
                var modificationCommand = ModificationCommands[commandIndex];
84

85 86
                if (!reader.Read())
                {
87
                    ThrowAggregateUpdateConcurrencyException(reader, commandIndex, 1, 0);
88
                }
89

90
                modificationCommand.PropagateResults(reader);
91 92

                npgsqlReader.NextResult();
93 94

                commandIndex++;
95 96
            }
        }
97
        catch (Exception ex) when (ex is not DbUpdateException and not OperationCanceledException)
98 99 100 101 102 103 104 105
        {
            throw new DbUpdateException(
                RelationalStrings.UpdateStoreException,
                ex,
                ModificationCommands[commandIndex].Entries);
        }
    }

106 107 108 109 110 111 112
    /// <summary>
    ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
    ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
    ///     any release. You should only use it directly in your code with extreme caution and knowing that
    ///     doing so can result in application failures when updating to a new Entity Framework Core release.
    /// </summary>
    protected override async Task ConsumeAsync(RelationalDataReader reader, CancellationToken cancellationToken = default)
113 114
    {
        var npgsqlReader = (NpgsqlDataReader)reader.DbDataReader;
S
Shay Rojansky 已提交
115 116

#pragma warning disable 618
117
        Debug.Assert(npgsqlReader.Statements.Count == ModificationCommands.Count, $"Reader has {npgsqlReader.Statements.Count} statements, expected {ModificationCommands.Count}");
S
Shay Rojansky 已提交
118 119
#pragma warning restore 618

120
        var commandIndex = 0;
121

122 123 124
        try
        {
            while (true)
125
            {
126 127 128 129
                // Find the next propagating command, if any
                int nextPropagating;
                for (nextPropagating = commandIndex;
                     nextPropagating < ModificationCommands.Count &&
130
                     !ResultSetMappings[nextPropagating].HasFlag(ResultSetMapping.HasResultRow);
131
                     nextPropagating++)
132
                {
133
                }
134

135 136 137 138
                // Go over all non-propagating commands before the next propagating one,
                // make sure they executed
                for (; commandIndex < nextPropagating; commandIndex++)
                {
S
Shay Rojansky 已提交
139
#pragma warning disable 618
140
                    if (npgsqlReader.Statements[commandIndex].Rows == 0)
141
                    {
142 143
                        await ThrowAggregateUpdateConcurrencyExceptionAsync(reader, commandIndex, 1, 0, cancellationToken)
                            .ConfigureAwait(false);
144
                    }
145 146
#pragma warning restore 618
                }
147

148 149 150 151 152 153 154
                if (nextPropagating == ModificationCommands.Count)
                {
                    Debug.Assert(!(await npgsqlReader.NextResultAsync(cancellationToken).ConfigureAwait(false)), "Expected less resultsets");
                    break;
                }

                // Extract result from the command and propagate it
155

156
                var modificationCommand = ModificationCommands[commandIndex];
157 158 159

                if (!(await reader.ReadAsync(cancellationToken).ConfigureAwait(false)))
                {
160 161
                    await ThrowAggregateUpdateConcurrencyExceptionAsync(reader, commandIndex, 1, 0, cancellationToken)
                        .ConfigureAwait(false);
162
                }
163

164
                modificationCommand.PropagateResults(reader);
165 166

                await npgsqlReader.NextResultAsync(cancellationToken).ConfigureAwait(false);
167 168

                commandIndex++;
169
            }
170
        }
171
        catch (Exception ex) when (ex is not DbUpdateException and not OperationCanceledException)
172 173 174 175 176
        {
            throw new DbUpdateException(
                RelationalStrings.UpdateStoreException,
                ex,
                ModificationCommands[commandIndex].Entries);
177 178
        }
    }
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270

    private IReadOnlyList<IUpdateEntry> AggregateEntries(int endIndex, int commandCount)
    {
        var entries = new List<IUpdateEntry>();
        for (var i = endIndex - commandCount; i < endIndex; i++)
        {
            entries.AddRange(ModificationCommands[i].Entries);
        }

        return entries;
    }

    /// <summary>
    ///     Throws an exception indicating the command affected an unexpected number of rows.
    /// </summary>
    /// <param name="reader">The data reader.</param>
    /// <param name="commandIndex">The ordinal of the command.</param>
    /// <param name="expectedRowsAffected">The expected number of rows affected.</param>
    /// <param name="rowsAffected">The actual number of rows affected.</param>
    protected virtual void ThrowAggregateUpdateConcurrencyException(
        RelationalDataReader reader,
        int commandIndex,
        int expectedRowsAffected,
        int rowsAffected)
    {
        var entries = AggregateEntries(commandIndex + 1, expectedRowsAffected);
        var exception = new DbUpdateConcurrencyException(
            RelationalStrings.UpdateConcurrencyException(expectedRowsAffected, rowsAffected),
            entries);

        if (!Dependencies.UpdateLogger.OptimisticConcurrencyException(
                Dependencies.CurrentContext.Context,
                entries,
                exception,
                (c, ex, e, d) => CreateConcurrencyExceptionEventData(c, reader, ex, e, d)).IsSuppressed)
        {
            throw exception;
        }
    }

    /// <summary>
    ///     Throws an exception indicating the command affected an unexpected number of rows.
    /// </summary>
    /// <param name="reader">The data reader.</param>
    /// <param name="commandIndex">The ordinal of the command.</param>
    /// <param name="expectedRowsAffected">The expected number of rows affected.</param>
    /// <param name="rowsAffected">The actual number of rows affected.</param>
    /// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
    /// <returns> A task that represents the asynchronous operation.</returns>
    /// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
    protected virtual async Task ThrowAggregateUpdateConcurrencyExceptionAsync(
        RelationalDataReader reader,
        int commandIndex,
        int expectedRowsAffected,
        int rowsAffected,
        CancellationToken cancellationToken)
    {
        var entries = AggregateEntries(commandIndex + 1, expectedRowsAffected);
        var exception = new DbUpdateConcurrencyException(
            RelationalStrings.UpdateConcurrencyException(expectedRowsAffected, rowsAffected),
            entries);

        if (!(await Dependencies.UpdateLogger.OptimisticConcurrencyExceptionAsync(
                    Dependencies.CurrentContext.Context,
                    entries,
                    exception,
                    (c, ex, e, d) => CreateConcurrencyExceptionEventData(c, reader, ex, e, d),
                    cancellationToken: cancellationToken)
                .ConfigureAwait(false)).IsSuppressed)
        {
            throw exception;
        }
    }

    private static RelationalConcurrencyExceptionEventData CreateConcurrencyExceptionEventData(
        DbContext context,
        RelationalDataReader reader,
        DbUpdateConcurrencyException exception,
        IReadOnlyList<IUpdateEntry> entries,
        EventDefinition<Exception> definition)
        => new(
            definition,
            (definition1, payload)
                => ((EventDefinition<Exception>)definition1).GenerateMessage(((ConcurrencyExceptionEventData)payload).Exception),
            context,
            reader.RelationalConnection.DbConnection,
            reader.DbCommand,
            reader.DbDataReader,
            reader.CommandId,
            reader.RelationalConnection.ConnectionId,
            entries,
            exception);
271
}