提交 6d40d2d3 编写于 作者: R Ryuta Kamizono

Handle UPDATE/DELETE with OFFSET in Arel

上级 322c5704
...@@ -356,11 +356,12 @@ def update_all(updates) ...@@ -356,11 +356,12 @@ def update_all(updates)
stmt.set Arel.sql(klass.sanitize_sql_for_assignment(updates, table.name)) stmt.set Arel.sql(klass.sanitize_sql_for_assignment(updates, table.name))
end end
if has_join_values? || offset_value if has_join_values?
@klass.connection.join_to_update(stmt, arel, arel_attribute(primary_key)) @klass.connection.join_to_update(stmt, arel, arel_attribute(primary_key))
else else
stmt.key = arel_attribute(primary_key) stmt.key = arel_attribute(primary_key)
stmt.take(arel.limit) stmt.take(arel.limit)
stmt.offset(arel.offset)
stmt.order(*arel.orders) stmt.order(*arel.orders)
stmt.wheres = arel.constraints stmt.wheres = arel.constraints
end end
...@@ -484,11 +485,12 @@ def delete_all ...@@ -484,11 +485,12 @@ def delete_all
stmt = Arel::DeleteManager.new stmt = Arel::DeleteManager.new
stmt.from(table) stmt.from(table)
if has_join_values? || offset_value if has_join_values?
@klass.connection.join_to_delete(stmt, arel, arel_attribute(primary_key)) @klass.connection.join_to_delete(stmt, arel, arel_attribute(primary_key))
else else
stmt.key = arel_attribute(primary_key) stmt.key = arel_attribute(primary_key)
stmt.take(arel.limit) stmt.take(arel.limit)
stmt.offset(arel.offset)
stmt.order(*arel.orders) stmt.order(*arel.orders)
stmt.wheres = arel.constraints stmt.wheres = arel.constraints
end end
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
module Arel # :nodoc: all module Arel # :nodoc: all
module Nodes module Nodes
class DeleteStatement < Arel::Nodes::Node class DeleteStatement < Arel::Nodes::Node
attr_accessor :left, :right, :orders, :limit, :key attr_accessor :left, :right, :orders, :limit, :offset, :key
alias :relation :left alias :relation :left
alias :relation= :left= alias :relation= :left=
...@@ -16,6 +16,7 @@ def initialize(relation = nil, wheres = []) ...@@ -16,6 +16,7 @@ def initialize(relation = nil, wheres = [])
@right = wheres @right = wheres
@orders = [] @orders = []
@limit = nil @limit = nil
@offset = nil
@key = nil @key = nil
end end
...@@ -26,7 +27,7 @@ def initialize_copy(other) ...@@ -26,7 +27,7 @@ def initialize_copy(other)
end end
def hash def hash
[self.class, @left, @right, @orders, @limit, @key].hash [self.class, @left, @right, @orders, @limit, @offset, @key].hash
end end
def eql?(other) def eql?(other)
...@@ -35,6 +36,7 @@ def eql?(other) ...@@ -35,6 +36,7 @@ def eql?(other)
self.right == other.right && self.right == other.right &&
self.orders == other.orders && self.orders == other.orders &&
self.limit == other.limit && self.limit == other.limit &&
self.offset == other.offset &&
self.key == other.key self.key == other.key
end end
alias :== :eql? alias :== :eql?
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
module Arel # :nodoc: all module Arel # :nodoc: all
module Nodes module Nodes
class UpdateStatement < Arel::Nodes::Node class UpdateStatement < Arel::Nodes::Node
attr_accessor :relation, :wheres, :values, :orders, :limit, :key attr_accessor :relation, :wheres, :values, :orders, :limit, :offset, :key
def initialize def initialize
@relation = nil @relation = nil
...@@ -11,6 +11,7 @@ def initialize ...@@ -11,6 +11,7 @@ def initialize
@values = [] @values = []
@orders = [] @orders = []
@limit = nil @limit = nil
@offset = nil
@key = nil @key = nil
end end
...@@ -21,7 +22,7 @@ def initialize_copy(other) ...@@ -21,7 +22,7 @@ def initialize_copy(other)
end end
def hash def hash
[@relation, @wheres, @values, @orders, @limit, @key].hash [@relation, @wheres, @values, @orders, @limit, @offset, @key].hash
end end
def eql?(other) def eql?(other)
...@@ -31,6 +32,7 @@ def eql?(other) ...@@ -31,6 +32,7 @@ def eql?(other)
self.values == other.values && self.values == other.values &&
self.orders == other.orders && self.orders == other.orders &&
self.limit == other.limit && self.limit == other.limit &&
self.offset == other.offset &&
self.key == other.key self.key == other.key
end end
alias :== :eql? alias :== :eql?
......
...@@ -10,6 +10,11 @@ def take(limit) ...@@ -10,6 +10,11 @@ def take(limit)
self self
end end
def offset(offset)
@ast.offset = Nodes::Offset.new(Nodes.build_quoted(offset)) if offset
self
end
def order(*expr) def order(*expr)
@ast.orders = expr @ast.orders = expr
self self
......
...@@ -56,18 +56,6 @@ def visit_Arel_Nodes_SelectCore(o, collector) ...@@ -56,18 +56,6 @@ def visit_Arel_Nodes_SelectCore(o, collector)
super super
end end
def visit_Arel_Nodes_UpdateStatement(o, collector)
collector << "UPDATE "
collector = visit o.relation, collector
unless o.values.empty?
collector << " SET "
collector = inject_join o.values, collector, ", "
end
collect_where_for(o, collector)
end
def visit_Arel_Nodes_Concat(o, collector) def visit_Arel_Nodes_Concat(o, collector)
collector << " CONCAT(" collector << " CONCAT("
visit o.left, collector visit o.left, collector
...@@ -77,7 +65,23 @@ def visit_Arel_Nodes_Concat(o, collector) ...@@ -77,7 +65,23 @@ def visit_Arel_Nodes_Concat(o, collector)
collector collector
end end
def build_subselect(key, o)
subselect = super
# Materialize subquery by adding distinct
# to work with MySQL 5.7.6 which sets optimizer_switch='derived_merge=on'
subselect.distinct unless subselect.limit || subselect.offset || subselect.orders.any?
Nodes::SelectStatement.new.tap do |stmt|
core = stmt.cores.last
core.froms = Nodes::Grouping.new(subselect).as("__active_record_temp")
core.projections = [Arel.sql(quote_column_name(key.name))]
end
end
def collect_where_for(o, collector) def collect_where_for(o, collector)
return super if o.offset
unless o.wheres.empty? unless o.wheres.empty?
collector << " WHERE " collector << " WHERE "
collector = inject_join o.wheres, collector, " AND " collector = inject_join o.wheres, collector, " AND "
......
...@@ -88,6 +88,7 @@ def build_subselect(key, o) ...@@ -88,6 +88,7 @@ def build_subselect(key, o)
core.wheres = o.wheres core.wheres = o.wheres
core.projections = [key] core.projections = [key]
stmt.limit = o.limit stmt.limit = o.limit
stmt.offset = o.offset
stmt.orders = o.orders stmt.orders = o.orders
stmt stmt
end end
...@@ -800,7 +801,7 @@ def inject_join(list, collector, join_str) ...@@ -800,7 +801,7 @@ def inject_join(list, collector, join_str)
end end
def collect_where_for(o, collector) def collect_where_for(o, collector)
if o.orders.empty? && o.limit.nil? if o.orders.empty? && o.limit.nil? && o.offset.nil?
wheres = o.wheres wheres = o.wheres
else else
wheres = [Nodes::In.new(o.key, [build_subselect(o.key, o)])] wheres = [Nodes::In.new(o.key, [build_subselect(o.key, o)])]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册