未验证 提交 f3382bfb 编写于 作者: R Ryuta Kamizono 提交者: GitHub

Merge pull request #39141 from kamipo/fix_rewhere

Fix `rewhere` to truly overwrite collided where clause by new where clause
* Fix `rewhere` to truly overwrite collided where clause by new where clause.
```ruby
steve = Person.find_by(name: "Steve")
david = Author.find_by(name: "David")
relation = Essay.where(writer: steve)
# Before
relation.rewhere(writer: david).to_a # => []
# After
relation.rewhere(writer: david).to_a # => [david]
```
*Ryuta Kamizono*
* Inspect time attributes with subsec. * Inspect time attributes with subsec.
```ruby ```ruby
......
...@@ -483,7 +483,7 @@ def unscope!(*args) # :nodoc: ...@@ -483,7 +483,7 @@ def unscope!(*args) # :nodoc:
raise ArgumentError, "Hash arguments in .unscope(*args) must have :where as the key." raise ArgumentError, "Hash arguments in .unscope(*args) must have :where as the key."
end end
target_values = Array(target_value).map(&:to_s) target_values = Array(target_value)
self.where_clause = where_clause.except(*target_values) self.where_clause = where_clause.except(*target_values)
end end
else else
...@@ -683,9 +683,7 @@ def where(opts = :chain, *rest) ...@@ -683,9 +683,7 @@ def where(opts = :chain, *rest)
end end
def where!(opts, *rest) # :nodoc: def where!(opts, *rest) # :nodoc:
opts = sanitize_forbidden_attributes(opts) self.where_clause += build_where_clause(opts, *rest)
references!(PredicateBuilder.references(opts)) if Hash === opts
self.where_clause += where_clause_factory.build(opts, rest)
self self
end end
...@@ -703,7 +701,17 @@ def where!(opts, *rest) # :nodoc: ...@@ -703,7 +701,17 @@ def where!(opts, *rest) # :nodoc:
# This is short-hand for <tt>unscope(where: conditions.keys).where(conditions)</tt>. # This is short-hand for <tt>unscope(where: conditions.keys).where(conditions)</tt>.
# Note that unlike reorder, we're only unscoping the named conditions -- not the entire where statement. # Note that unlike reorder, we're only unscoping the named conditions -- not the entire where statement.
def rewhere(conditions) def rewhere(conditions)
unscope(where: conditions.keys).where(conditions) attrs = []
scope = spawn
where_clause = scope.build_where_clause(conditions)
where_clause.each_attribute do |attr|
attrs << attr
end
scope.unscope!(where: attrs)
scope.where_clause += where_clause
scope
end end
# Returns a new relation, which is the logical union of this relation and the one passed as an # Returns a new relation, which is the logical union of this relation and the one passed as an
...@@ -1078,6 +1086,12 @@ def build_subquery(subquery_alias, select_value) # :nodoc: ...@@ -1078,6 +1086,12 @@ def build_subquery(subquery_alias, select_value) # :nodoc:
end end
end end
def build_where_clause(opts, *rest)
opts = sanitize_forbidden_attributes(opts)
references!(PredicateBuilder.references(opts)) if Hash === opts
where_clause_factory.build(opts, rest)
end
private private
def assert_mutability! def assert_mutability!
raise ImmutableRelation if @loaded raise ImmutableRelation if @loaded
......
...@@ -92,6 +92,12 @@ def contradiction? ...@@ -92,6 +92,12 @@ def contradiction?
end end
end end
def each_attribute(&block)
predicates.each do |node|
Arel.fetch_attribute(node, &block)
end
end
protected protected
attr_reader :predicates attr_reader :predicates
...@@ -141,7 +147,15 @@ def invert_predicate(node) ...@@ -141,7 +147,15 @@ def invert_predicate(node)
def except_predicates(columns) def except_predicates(columns)
predicates.reject do |node| predicates.reject do |node|
Arel.fetch_attribute(node) { |attr| columns.include?(attr.name.to_s) } Arel.fetch_attribute(node) do |attr|
columns.any? do |column|
if column.is_a?(Arel::Attributes::Attribute)
attr == column
else
attr.name.to_s == column.to_s
end
end
end
end end
end end
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
require "cases/helper" require "cases/helper"
require "models/post" require "models/post"
require "models/author" require "models/author"
require "models/man"
require "models/essay" require "models/essay"
require "models/comment" require "models/comment"
require "models/categorization" require "models/categorization"
...@@ -10,7 +11,7 @@ ...@@ -10,7 +11,7 @@
module ActiveRecord module ActiveRecord
class WhereChainTest < ActiveRecord::TestCase class WhereChainTest < ActiveRecord::TestCase
fixtures :posts, :authors, :essays fixtures :posts, :comments, :authors, :men, :essays
def test_missing_with_association def test_missing_with_association
assert posts(:authorless).author.blank? assert posts(:authorless).author.blank?
...@@ -89,9 +90,23 @@ def test_rewhere_with_one_overwriting_condition_and_one_unrelated ...@@ -89,9 +90,23 @@ def test_rewhere_with_one_overwriting_condition_and_one_unrelated
assert_equal expected.to_a, relation.to_a assert_equal expected.to_a, relation.to_a
end end
def test_rewhere_with_alias_condition
relation = Post.where(text: "hello").where(text: "world").rewhere(text: "hullo")
expected = Post.where(text: "hullo")
assert_equal expected.to_a, relation.to_a
end
def test_rewhere_with_nested_condition
relation = Post.where.missing(:comments).rewhere("comments.id": comments(:does_it_hurt))
expected = Post.left_joins(:comments).where("comments.id": comments(:does_it_hurt))
assert_equal expected.to_a, relation.to_a
end
def test_rewhere_with_polymorphic_association def test_rewhere_with_polymorphic_association
relation = Essay.where(writer: authors(:david)).rewhere(writer_id: "Mary") relation = Essay.where(writer: authors(:david)).rewhere(writer: men(:steve))
expected = Essay.where(writer: authors(:mary)) expected = Essay.where(writer: men(:steve))
assert_equal expected.to_a, relation.to_a assert_equal expected.to_a, relation.to_a
end end
......
...@@ -9,3 +9,8 @@ mary_stay_home: ...@@ -9,3 +9,8 @@ mary_stay_home:
name: Stay Home name: Stay Home
writer_type: Author writer_type: Author
writer_id: Mary writer_id: Mary
steve_connecting_the_dots:
name: Connecting The Dots
writer_type: Man
writer_id: Steve
...@@ -23,6 +23,8 @@ def greeting ...@@ -23,6 +23,8 @@ def greeting
end end
end end
alias_attribute :text, :body
scope :containing_the_letter_a, -> { where("body LIKE '%a%'") } scope :containing_the_letter_a, -> { where("body LIKE '%a%'") }
scope :titled_with_an_apostrophe, -> { where("title LIKE '%''%'") } scope :titled_with_an_apostrophe, -> { where("title LIKE '%''%'") }
scope :ranked_by_comments, -> { order(arel_attribute(:comments_count).desc) } scope :ranked_by_comments, -> { order(arel_attribute(:comments_count).desc) }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册