NpgsqlModificationCommandBatch.cs 7.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal;

/// <summary>
/// The Npgsql-specific implementation for <see cref="ModificationCommandBatch" />.
/// </summary>
/// <remarks>
/// The usual ModificationCommandBatch implementation is <see cref="AffectedCountModificationCommandBatch"/>,
/// which selects the number of rows modified via a SQL query.
///
/// PostgreSQL actually has no way of selecting the modified row count.
/// SQL defines GET DIAGNOSTICS which should provide this, but in PostgreSQL it's only available
/// in PL/pgSQL. See http://www.postgresql.org/docs/9.4/static/unsupported-features-sql-standard.html,
/// identifier F121-01.
///
/// Instead, the affected row count can be accessed in the PostgreSQL protocol itself, which seems
/// cleaner and more efficient anyway (no additional query).
/// </remarks>
public class NpgsqlModificationCommandBatch : ReaderModificationCommandBatch
19
{
20 21 22 23
    private const int DefaultBatchSize = 1000;
    private readonly int _maxBatchSize;
    private int _parameterCount;

24
    /// <summary>
25
    /// Constructs an instance of the <see cref="NpgsqlModificationCommandBatch"/> class.
26
    /// </summary>
27 28 29 30
    public NpgsqlModificationCommandBatch(
        ModificationCommandBatchFactoryDependencies dependencies,
        int? maxBatchSize)
        : base(dependencies)
31
    {
32
        if (maxBatchSize.HasValue && maxBatchSize.Value <= 0)
33
        {
34
            throw new ArgumentOutOfRangeException(nameof(maxBatchSize), RelationalStrings.InvalidMaxBatchSize(maxBatchSize));
35 36
        }

37 38
        _maxBatchSize = maxBatchSize ?? DefaultBatchSize;
    }
S
Shay Rojansky 已提交
39

40
    protected override int GetParameterCount() => _parameterCount;
S
Shay Rojansky 已提交
41

42 43 44 45 46 47
    protected override bool CanAddCommand(IReadOnlyModificationCommand modificationCommand)
    {
        if (ModificationCommands.Count >= _maxBatchSize)
        {
            return false;
        }
S
Shay Rojansky 已提交
48

49 50 51 52
        var newParamCount = (long)_parameterCount + modificationCommand.ColumnModifications.Count;
        if (newParamCount > int.MaxValue)
        {
            return false;
S
Shay Rojansky 已提交
53
        }
54

55 56 57
        _parameterCount = (int)newParamCount;
        return true;
    }
58

59 60 61 62 63 64
    protected override bool IsCommandTextValid()
        => true;

    protected override void Consume(RelationalDataReader reader)
    {
        var npgsqlReader = (NpgsqlDataReader)reader.DbDataReader;
S
Shay Rojansky 已提交
65 66

#pragma warning disable 618
67
        Debug.Assert(npgsqlReader.Statements.Count == ModificationCommands.Count, $"Reader has {npgsqlReader.Statements.Count} statements, expected {ModificationCommands.Count}");
S
Shay Rojansky 已提交
68 69
#pragma warning restore 618

70
        var commandIndex = 0;
71

72 73 74
        try
        {
            while (true)
75
            {
76 77 78 79 80 81
                // Find the next propagating command, if any
                int nextPropagating;
                for (nextPropagating = commandIndex;
                     nextPropagating < ModificationCommands.Count &&
                     !ModificationCommands[nextPropagating].RequiresResultPropagation;
                     nextPropagating++)
82
                {
83
                }
84

85 86 87 88
                // Go over all non-propagating commands before the next propagating one,
                // make sure they executed
                for (; commandIndex < nextPropagating; commandIndex++)
                {
S
Shay Rojansky 已提交
89
#pragma warning disable 618
90
                    if (npgsqlReader.Statements[commandIndex].Rows == 0)
91 92 93
                    {
                        throw new DbUpdateConcurrencyException(
                            RelationalStrings.UpdateConcurrencyException(1, 0),
94 95
                            ModificationCommands[commandIndex].Entries
                        );
96
                    }
97 98 99 100 101 102 103 104 105 106
#pragma warning restore 618
                }

                if (nextPropagating == ModificationCommands.Count)
                {
                    Debug.Assert(!npgsqlReader.NextResult(), "Expected less resultsets");
                    break;
                }

                // Propagate to results from the reader to the ModificationCommand
107

108
                var modificationCommand = ModificationCommands[commandIndex++];
109

110 111 112 113 114
                if (!reader.Read())
                {
                    throw new DbUpdateConcurrencyException(
                        RelationalStrings.UpdateConcurrencyException(1, 0),
                        modificationCommand.Entries);
115
                }
116 117 118 119 120

                var valueBufferFactory = CreateValueBufferFactory(modificationCommand.ColumnModifications);
                modificationCommand.PropagateResults(valueBufferFactory.Create(npgsqlReader));

                npgsqlReader.NextResult();
121 122
            }
        }
123
        catch (DbUpdateException)
124
        {
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            throw;
        }
        catch (Exception ex)
        {
            throw new DbUpdateException(
                RelationalStrings.UpdateStoreException,
                ex,
                ModificationCommands[commandIndex].Entries);
        }
    }

    protected override async Task ConsumeAsync(
        RelationalDataReader reader,
        CancellationToken cancellationToken = default)
    {
        var npgsqlReader = (NpgsqlDataReader)reader.DbDataReader;
S
Shay Rojansky 已提交
141 142

#pragma warning disable 618
143
        Debug.Assert(npgsqlReader.Statements.Count == ModificationCommands.Count, $"Reader has {npgsqlReader.Statements.Count} statements, expected {ModificationCommands.Count}");
S
Shay Rojansky 已提交
144 145
#pragma warning restore 618

146
        var commandIndex = 0;
147

148 149 150
        try
        {
            while (true)
151
            {
152 153 154 155 156 157
                // Find the next propagating command, if any
                int nextPropagating;
                for (nextPropagating = commandIndex;
                     nextPropagating < ModificationCommands.Count &&
                     !ModificationCommands[nextPropagating].RequiresResultPropagation;
                     nextPropagating++)
158
                {
159
                }
160

161 162 163 164
                // Go over all non-propagating commands before the next propagating one,
                // make sure they executed
                for (; commandIndex < nextPropagating; commandIndex++)
                {
S
Shay Rojansky 已提交
165
#pragma warning disable 618
166
                    if (npgsqlReader.Statements[commandIndex].Rows == 0)
167 168 169
                    {
                        throw new DbUpdateConcurrencyException(
                            RelationalStrings.UpdateConcurrencyException(1, 0),
170
                            ModificationCommands[commandIndex].Entries
171 172
                        );
                    }
173 174
#pragma warning restore 618
                }
175

176 177 178 179 180 181 182
                if (nextPropagating == ModificationCommands.Count)
                {
                    Debug.Assert(!(await npgsqlReader.NextResultAsync(cancellationToken).ConfigureAwait(false)), "Expected less resultsets");
                    break;
                }

                // Extract result from the command and propagate it
183

184 185 186 187 188 189 190 191
                var modificationCommand = ModificationCommands[commandIndex++];

                if (!(await reader.ReadAsync(cancellationToken).ConfigureAwait(false)))
                {
                    throw new DbUpdateConcurrencyException(
                        RelationalStrings.UpdateConcurrencyException(1, 0),
                        modificationCommand.Entries
                    );
192
                }
193 194 195 196 197

                var valueBufferFactory = CreateValueBufferFactory(modificationCommand.ColumnModifications);
                modificationCommand.PropagateResults(valueBufferFactory.Create(npgsqlReader));

                await npgsqlReader.NextResultAsync(cancellationToken).ConfigureAwait(false);
198
            }
199 200 201 202 203 204 205 206 207 208 209
        }
        catch (DbUpdateException)
        {
            throw;
        }
        catch (Exception ex)
        {
            throw new DbUpdateException(
                RelationalStrings.UpdateStoreException,
                ex,
                ModificationCommands[commandIndex].Entries);
210 211
        }
    }
212
}