NpgsqlModificationCommandBatch.cs 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal;

/// <summary>
/// The Npgsql-specific implementation for <see cref="ModificationCommandBatch" />.
/// </summary>
/// <remarks>
/// The usual ModificationCommandBatch implementation is <see cref="AffectedCountModificationCommandBatch"/>,
/// which selects the number of rows modified via a SQL query.
///
/// PostgreSQL actually has no way of selecting the modified row count.
/// SQL defines GET DIAGNOSTICS which should provide this, but in PostgreSQL it's only available
/// in PL/pgSQL. See http://www.postgresql.org/docs/9.4/static/unsupported-features-sql-standard.html,
/// identifier F121-01.
///
/// Instead, the affected row count can be accessed in the PostgreSQL protocol itself, which seems
/// cleaner and more efficient anyway (no additional query).
/// </remarks>
public class NpgsqlModificationCommandBatch : ReaderModificationCommandBatch
19
{
20
    /// <summary>
21
    /// Constructs an instance of the <see cref="NpgsqlModificationCommandBatch"/> class.
22
    /// </summary>
23 24
    public NpgsqlModificationCommandBatch(
        ModificationCommandBatchFactoryDependencies dependencies,
S
Shay Rojansky 已提交
25
        int maxBatchSize)
26
        : base(dependencies)
S
Shay Rojansky 已提交
27
        => MaxBatchSize = maxBatchSize;
28

S
Shay Rojansky 已提交
29 30 31 32
    /// <summary>
    ///     The maximum number of <see cref="ModificationCommand"/> instances that can be added to a single batch; defaults to 1000.
    /// </summary>
    protected override int MaxBatchSize { get; }
S
Shay Rojansky 已提交
33

34 35 36
    protected override void Consume(RelationalDataReader reader)
    {
        var npgsqlReader = (NpgsqlDataReader)reader.DbDataReader;
S
Shay Rojansky 已提交
37 38

#pragma warning disable 618
39
        Debug.Assert(npgsqlReader.Statements.Count == ModificationCommands.Count, $"Reader has {npgsqlReader.Statements.Count} statements, expected {ModificationCommands.Count}");
S
Shay Rojansky 已提交
40 41
#pragma warning restore 618

42
        var commandIndex = 0;
43

44 45 46
        try
        {
            while (true)
47
            {
48 49 50 51 52 53
                // Find the next propagating command, if any
                int nextPropagating;
                for (nextPropagating = commandIndex;
                     nextPropagating < ModificationCommands.Count &&
                     !ModificationCommands[nextPropagating].RequiresResultPropagation;
                     nextPropagating++)
54
                {
55
                }
56

57 58 59 60
                // Go over all non-propagating commands before the next propagating one,
                // make sure they executed
                for (; commandIndex < nextPropagating; commandIndex++)
                {
S
Shay Rojansky 已提交
61
#pragma warning disable 618
62
                    if (npgsqlReader.Statements[commandIndex].Rows == 0)
63
                    {
64
                        ThrowAggregateUpdateConcurrencyException(reader, commandIndex, 1, 0);
65
                    }
66 67 68 69 70 71 72 73 74 75
#pragma warning restore 618
                }

                if (nextPropagating == ModificationCommands.Count)
                {
                    Debug.Assert(!npgsqlReader.NextResult(), "Expected less resultsets");
                    break;
                }

                // Propagate to results from the reader to the ModificationCommand
76

77
                var modificationCommand = ModificationCommands[commandIndex];
78

79 80
                if (!reader.Read())
                {
81
                    ThrowAggregateUpdateConcurrencyException(reader, commandIndex, 1, 0);
82
                }
83

84 85 86
                Check.DebugAssert(modificationCommand.RequiresResultPropagation, "RequiresResultPropagation is false");

                modificationCommand.PropagateResults(reader);
87 88

                npgsqlReader.NextResult();
89 90

                commandIndex++;
91 92
            }
        }
93
        catch (Exception ex) when (ex is not DbUpdateException and not OperationCanceledException)
94 95 96 97 98 99 100 101 102 103 104 105 106
        {
            throw new DbUpdateException(
                RelationalStrings.UpdateStoreException,
                ex,
                ModificationCommands[commandIndex].Entries);
        }
    }

    protected override async Task ConsumeAsync(
        RelationalDataReader reader,
        CancellationToken cancellationToken = default)
    {
        var npgsqlReader = (NpgsqlDataReader)reader.DbDataReader;
S
Shay Rojansky 已提交
107 108

#pragma warning disable 618
109
        Debug.Assert(npgsqlReader.Statements.Count == ModificationCommands.Count, $"Reader has {npgsqlReader.Statements.Count} statements, expected {ModificationCommands.Count}");
S
Shay Rojansky 已提交
110 111
#pragma warning restore 618

112
        var commandIndex = 0;
113

114 115 116
        try
        {
            while (true)
117
            {
118 119 120 121 122 123
                // Find the next propagating command, if any
                int nextPropagating;
                for (nextPropagating = commandIndex;
                     nextPropagating < ModificationCommands.Count &&
                     !ModificationCommands[nextPropagating].RequiresResultPropagation;
                     nextPropagating++)
124
                {
125
                }
126

127 128 129 130
                // Go over all non-propagating commands before the next propagating one,
                // make sure they executed
                for (; commandIndex < nextPropagating; commandIndex++)
                {
S
Shay Rojansky 已提交
131
#pragma warning disable 618
132
                    if (npgsqlReader.Statements[commandIndex].Rows == 0)
133
                    {
134 135
                        await ThrowAggregateUpdateConcurrencyExceptionAsync(reader, commandIndex, 1, 0, cancellationToken)
                            .ConfigureAwait(false);
136
                    }
137 138
#pragma warning restore 618
                }
139

140 141 142 143 144 145 146
                if (nextPropagating == ModificationCommands.Count)
                {
                    Debug.Assert(!(await npgsqlReader.NextResultAsync(cancellationToken).ConfigureAwait(false)), "Expected less resultsets");
                    break;
                }

                // Extract result from the command and propagate it
147

148
                var modificationCommand = ModificationCommands[commandIndex];
149 150 151

                if (!(await reader.ReadAsync(cancellationToken).ConfigureAwait(false)))
                {
152 153
                    await ThrowAggregateUpdateConcurrencyExceptionAsync(reader, commandIndex, 1, 0, cancellationToken)
                        .ConfigureAwait(false);
154
                }
155

156 157 158
                Check.DebugAssert(modificationCommand.RequiresResultPropagation, "RequiresResultPropagation is false");

                modificationCommand.PropagateResults(reader);
159 160

                await npgsqlReader.NextResultAsync(cancellationToken).ConfigureAwait(false);
161 162

                commandIndex++;
163
            }
164
        }
165
        catch (Exception ex) when (ex is not DbUpdateException and not OperationCanceledException)
166 167 168 169 170
        {
            throw new DbUpdateException(
                RelationalStrings.UpdateStoreException,
                ex,
                ModificationCommands[commandIndex].Entries);
171 172
        }
    }
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

    private IReadOnlyList<IUpdateEntry> AggregateEntries(int endIndex, int commandCount)
    {
        var entries = new List<IUpdateEntry>();
        for (var i = endIndex - commandCount; i < endIndex; i++)
        {
            entries.AddRange(ModificationCommands[i].Entries);
        }

        return entries;
    }

    /// <summary>
    ///     Throws an exception indicating the command affected an unexpected number of rows.
    /// </summary>
    /// <param name="reader">The data reader.</param>
    /// <param name="commandIndex">The ordinal of the command.</param>
    /// <param name="expectedRowsAffected">The expected number of rows affected.</param>
    /// <param name="rowsAffected">The actual number of rows affected.</param>
    protected virtual void ThrowAggregateUpdateConcurrencyException(
        RelationalDataReader reader,
        int commandIndex,
        int expectedRowsAffected,
        int rowsAffected)
    {
        var entries = AggregateEntries(commandIndex + 1, expectedRowsAffected);
        var exception = new DbUpdateConcurrencyException(
            RelationalStrings.UpdateConcurrencyException(expectedRowsAffected, rowsAffected),
            entries);

        if (!Dependencies.UpdateLogger.OptimisticConcurrencyException(
                Dependencies.CurrentContext.Context,
                entries,
                exception,
                (c, ex, e, d) => CreateConcurrencyExceptionEventData(c, reader, ex, e, d)).IsSuppressed)
        {
            throw exception;
        }
    }

    /// <summary>
    ///     Throws an exception indicating the command affected an unexpected number of rows.
    /// </summary>
    /// <param name="reader">The data reader.</param>
    /// <param name="commandIndex">The ordinal of the command.</param>
    /// <param name="expectedRowsAffected">The expected number of rows affected.</param>
    /// <param name="rowsAffected">The actual number of rows affected.</param>
    /// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
    /// <returns> A task that represents the asynchronous operation.</returns>
    /// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
    protected virtual async Task ThrowAggregateUpdateConcurrencyExceptionAsync(
        RelationalDataReader reader,
        int commandIndex,
        int expectedRowsAffected,
        int rowsAffected,
        CancellationToken cancellationToken)
    {
        var entries = AggregateEntries(commandIndex + 1, expectedRowsAffected);
        var exception = new DbUpdateConcurrencyException(
            RelationalStrings.UpdateConcurrencyException(expectedRowsAffected, rowsAffected),
            entries);

        if (!(await Dependencies.UpdateLogger.OptimisticConcurrencyExceptionAsync(
                    Dependencies.CurrentContext.Context,
                    entries,
                    exception,
                    (c, ex, e, d) => CreateConcurrencyExceptionEventData(c, reader, ex, e, d),
                    cancellationToken: cancellationToken)
                .ConfigureAwait(false)).IsSuppressed)
        {
            throw exception;
        }
    }

    private static RelationalConcurrencyExceptionEventData CreateConcurrencyExceptionEventData(
        DbContext context,
        RelationalDataReader reader,
        DbUpdateConcurrencyException exception,
        IReadOnlyList<IUpdateEntry> entries,
        EventDefinition<Exception> definition)
        => new(
            definition,
            (definition1, payload)
                => ((EventDefinition<Exception>)definition1).GenerateMessage(((ConcurrencyExceptionEventData)payload).Exception),
            context,
            reader.RelationalConnection.DbConnection,
            reader.DbCommand,
            reader.DbDataReader,
            reader.CommandId,
            reader.RelationalConnection.ConnectionId,
            entries,
            exception);
265
}