From 4b534152951fe9b1af8a1d3311da766e1d42392a Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Sat, 1 Nov 2014 17:29:14 -0600 Subject: [PATCH] Use a bound parameter for the "id = " portion of update statements We need to re-order the bind parameters since the AST returned by the relation will have the where statement as the first bp, which breaks on PG. --- activerecord/lib/active_record/relation.rb | 10 ++++++++-- .../lib/active_record/relation/query_methods.rb | 9 ++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index a25e6e321f..03bce4f5b7 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -79,12 +79,18 @@ def _update_record(values, id, id_was) # :nodoc: scope.unscope!(where: @klass.inheritance_column) end - um = scope.where(@klass.arel_table[@klass.primary_key].eq(id_was || id)).arel.compile_update(substitutes, @klass.primary_key) + relation = scope.where(@klass.primary_key => (id_was || id)) + bvs = binds + relation.bind_values + um = relation + .arel + .compile_update(substitutes, @klass.primary_key) + reorder_bind_params(um.ast, bvs) @klass.connection.update( um, 'SQL', - binds) + bvs, + ) end def substitute_values(values) # :nodoc: diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index f65ee7790e..a686e3263b 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -883,12 +883,15 @@ def build_arel # Reorder bind indexes if joins produced bind values bvs = arel.bind_values + bind_values - arel.ast.grep(Arel::Nodes::BindParam).each_with_index do |bp, i| + reorder_bind_params(arel.ast, bvs) + arel + end + + def reorder_bind_params(ast, bvs) + ast.grep(Arel::Nodes::BindParam).each_with_index do |bp, i| column = bvs[i].first bp.replace connection.substitute_at(column, i) end - - arel end def symbol_unscoping(scope) -- GitLab