提交 1ac40f16 编写于 作者: R Ryuta Kamizono

Move Arel attribute normalization into `arel_table`

In Active Record internal, `arel_table` is not directly used but
`arel_attribute` is used, since `arel_table` doesn't normalize an
attribute name as a string, and doesn't resolve attribute aliases.

For the above reason, `arel_attribute` should be used rather than
`arel_table`, but most people directly use `arel_table`, both
`arel_table` and `arel_attribute` are private API though.

Although I'd not recommend using private API, `arel_table` is actually
widely used, and it is also problematic for unscopeable queries and
hash-like relation merging friendly, as I explained at #39863.

To resolve the issue, this change moves Arel attribute normalization
(attribute name as a string, and attribute alias resolution) into
`arel_table`.
上级 c418fb9f
......@@ -51,11 +51,11 @@ def initialize(connection, aliases)
@connection = connection
end
def aliased_table_for(table_name, aliased_name, type_caster)
def aliased_table_for(table_name, aliased_name, klass)
if aliases[table_name].zero?
# If it's zero, we can have our table_name
aliases[table_name] = 1
Arel::Table.new(table_name, type_caster: type_caster)
Arel::Table.new(table_name, klass: klass)
else
# Otherwise, we need to use an alias
aliased_name = @connection.table_alias_for(aliased_name)
......@@ -68,7 +68,7 @@ def aliased_table_for(table_name, aliased_name, type_caster)
else
aliased_name
end
Arel::Table.new(table_name, type_caster: type_caster).alias(table_alias)
Arel::Table.new(table_name, klass: klass).alias(table_alias)
end
end
......
......@@ -109,7 +109,7 @@ def get_chain(reflection, association, tracker)
aliased_table = tracker.aliased_table_for(
refl.table_name,
refl.alias_candidate(name),
refl.klass.type_caster
refl.klass
)
chain << ReflectionProxy.new(refl, aliased_table)
end
......
......@@ -176,7 +176,7 @@ def make_constraints(parent, child, join_type)
alias_tracker.aliased_table_for(
reflection.table_name,
table_alias_for(reflection, parent, reflection != child.reflection),
reflection.klass.type_caster
reflection.klass
)
end.concat child.children.flat_map { |c| make_constraints(child, c, join_type) }
end
......
......@@ -289,12 +289,10 @@ def ===(object) # :nodoc:
# scope :published_and_commented, -> { published.and(arel_table[:comments_count].gt(0)) }
# end
def arel_table # :nodoc:
@arel_table ||= Arel::Table.new(table_name, type_caster: type_caster)
@arel_table ||= Arel::Table.new(table_name, klass: self)
end
def arel_attribute(name, table = arel_table) # :nodoc:
name = name.to_s
name = attribute_aliases[name] || name
table[name]
end
......
......@@ -260,7 +260,7 @@ def find_sti_class(type_name)
end
def type_condition(table = arel_table)
sti_column = arel_attribute(inheritance_column, table)
sti_column = table[inheritance_column]
sti_names = ([self] + descendants).map(&:sti_name)
predicate_builder.build(sti_column, sti_names)
......
......@@ -414,7 +414,7 @@ def discriminate_class_for_record(record)
def _substitute_values(values)
values.map do |name, value|
attr = arel_attribute(name)
attr = arel_table[name]
bind = predicate_builder.build_bind_attribute(name, value)
[attr, bind]
end
......
......@@ -1052,7 +1052,7 @@ def klass
end
def aliased_table
@aliased_table ||= Arel::Table.new(table_name, type_caster: klass.type_caster)
@aliased_table ||= Arel::Table.new(table_name, klass: klass)
end
def join_primary_key(klass = self.klass)
......
......@@ -39,7 +39,7 @@ def initialize_copy(other)
end
def arel_attribute(name) # :nodoc:
klass.arel_attribute(name, table)
table[name]
end
def bind_attribute(name, value) # :nodoc:
......@@ -48,7 +48,7 @@ def bind_attribute(name, value) # :nodoc:
value = value.read_attribute(reflection.klass.primary_key) unless value.nil?
end
attr = arel_attribute(name)
attr = table[name]
bind = predicate_builder.build_bind_attribute(attr.name, value)
yield attr, bind
end
......@@ -352,7 +352,7 @@ def compute_cache_version(timestamp_column) # :nodoc:
else
collection = eager_loading? ? apply_join_dependency : self
column = connection.visitor.compile(arel_attribute(timestamp_column))
column = connection.visitor.compile(table[timestamp_column])
select_values = "COUNT(*) AS #{connection.quote_column_name("size")}, MAX(%s) AS timestamp"
if collection.has_limit_or_offset?
......@@ -447,7 +447,7 @@ def update_all(updates)
stmt = Arel::UpdateManager.new
stmt.table(arel.join_sources.empty? ? table : arel.source)
stmt.key = arel_attribute(primary_key)
stmt.key = table[primary_key]
stmt.take(arel.limit)
stmt.offset(arel.offset)
stmt.order(*arel.orders)
......@@ -457,7 +457,7 @@ def update_all(updates)
if klass.locking_enabled? &&
!updates.key?(klass.locking_column) &&
!updates.key?(klass.locking_column.to_sym)
attr = arel_attribute(klass.locking_column)
attr = table[klass.locking_column]
updates[attr.name] = _increment_attribute(attr)
end
stmt.set _substitute_values(updates)
......@@ -493,7 +493,7 @@ def update_counters(counters)
updates = {}
counters.each do |counter_name, value|
attr = arel_attribute(counter_name)
attr = table[counter_name]
updates[attr.name] = _increment_attribute(attr, value)
end
......@@ -589,7 +589,7 @@ def delete_all
stmt = Arel::DeleteManager.new
stmt.from(arel.join_sources.empty? ? table : arel.source)
stmt.key = arel_attribute(primary_key)
stmt.key = table[primary_key]
stmt.take(arel.limit)
stmt.offset(arel.offset)
stmt.order(*arel.orders)
......@@ -813,7 +813,7 @@ def _scoping(scope)
def _substitute_values(values)
values.map do |name, value|
attr = arel_attribute(name)
attr = table[name]
unless Arel.arel_node?(value)
type = klass.type_for_attribute(attr.name)
value = predicate_builder.build_bind_attribute(attr.name, type.cast(value))
......
......@@ -280,7 +280,7 @@ def apply_finish_limit(relation, finish, order)
end
def batch_order(order)
arel_attribute(primary_key).public_send(order)
table[primary_key].public_send(order)
end
def act_on_ignored_order(error_on_ignore)
......
......@@ -410,7 +410,7 @@ def apply_join_dependency(eager_loading: group_values.empty?)
def limited_ids_for(relation)
values = @klass.connection.columns_for_distinct(
connection.visitor.compile(arel_attribute(primary_key)),
connection.visitor.compile(table[primary_key]),
relation.order_values
)
......@@ -562,9 +562,9 @@ def find_last(limit)
def ordered_relation
if order_values.empty? && (implicit_order_column || primary_key)
if implicit_order_column && primary_key && implicit_order_column != primary_key
order(arel_attribute(implicit_order_column).asc, arel_attribute(primary_key).asc)
order(table[implicit_order_column].asc, table[primary_key].asc)
else
order(arel_attribute(implicit_order_column || primary_key).asc)
order(table[implicit_order_column || primary_key].asc)
end
else
self
......
......@@ -9,7 +9,7 @@ def call(attribute, value)
end
if value.select_values.empty?
value = value.select(value.arel_attribute(value.klass.primary_key))
value = value.select(value.table[value.klass.primary_key])
end
attribute.in(value.arel)
......
......@@ -1304,7 +1304,7 @@ def build_select(arel)
if select_values.any?
arel.project(*arel_columns(select_values.uniq))
elsif klass.ignored_columns.any?
arel.project(*klass.column_names.map { |field| arel_attribute(field) })
arel.project(*klass.column_names.map { |field| table[field] })
else
arel.project(table[Arel.star])
end
......@@ -1332,7 +1332,7 @@ def arel_column(field)
from = from_clause.name || from_clause.value
if klass.columns_hash.key?(field) && (!from || table_name_matches?(from))
arel_attribute(field)
table[field]
elsif field.match?(/\A\w+\.\w+\z/)
table, column = field.split(".")
predicate_builder.resolve_arel_attribute(table, column) do
......@@ -1351,7 +1351,7 @@ def table_name_matches?(from)
def reverse_sql_order(order_query)
if order_query.empty?
return [arel_attribute(primary_key).desc] if primary_key
return [table[primary_key].desc] if primary_key
raise IrreversibleOrderError,
"Relation has no current order and table has no primary key to be used as default order"
end
......@@ -1457,7 +1457,7 @@ def column_references(order_args)
def order_column(field)
arel_column(field) do |attr_name|
if attr_name == "count" && !group_values.empty?
arel_attribute(attr_name)
table[attr_name]
else
Arel.sql(connection.quote_table_name(attr_name))
end
......
......@@ -11,11 +11,7 @@ def initialize(klass, arel_table, reflection = nil)
end
def arel_attribute(column_name)
if klass
klass.arel_attribute(column_name, arel_table)
else
arel_table[column_name]
end
arel_table[column_name]
end
def type(column_name)
......
......@@ -8,7 +8,7 @@ class TableAlias < Arel::Nodes::Binary
alias :table_alias :name
def [](name)
Attribute.new(self, name)
relation.is_a?(Table) ? relation[name, self] : Attribute.new(self, name)
end
def table_name
......
......@@ -14,8 +14,9 @@ class << self; attr_accessor :engine; end
# TableAlias and Table both have a #table_name which is the name of the underlying table
alias :table_name :name
def initialize(name, as: nil, type_caster: nil)
def initialize(name, as: nil, klass: nil, type_caster: klass&.type_caster)
@name = name.to_s
@klass = klass
@type_caster = type_caster
# Sometime AR sends an :as parameter to table, to let the table know
......@@ -79,8 +80,10 @@ def having(expr)
from.having expr
end
def [](name)
::Arel::Attribute.new self, name
def [](name, table = self)
name = name.to_s if name.is_a?(Symbol)
name = @klass.attribute_aliases[name] || name if @klass
Attribute.new(table, name)
end
def hash
......
......@@ -8,19 +8,19 @@ class Default < ActiveRecord::Base; end
def test_case_insensitiveness
connection = ActiveRecord::Base.connection
attr = Default.arel_attribute(:char1)
attr = Default.arel_table[:char1]
comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql)
attr = Default.arel_attribute(:char2)
attr = Default.arel_table[:char2]
comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql)
attr = Default.arel_attribute(:char3)
attr = Default.arel_table[:char3]
comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql)
attr = Default.arel_attribute(:multiline_default)
attr = Default.arel_table[:multiline_default]
comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql)
end
......
......@@ -188,7 +188,7 @@ class TableTest < Arel::Spec
describe "when given a Symbol" do
it "manufactures an attribute if the symbol names an attribute within the relation" do
column = @relation[:id]
_(column.name).must_equal :id
_(column.name).must_equal "id"
end
end
end
......
......@@ -71,7 +71,7 @@ def test_user_supplied_joins_order_should_be_preserved
def test_deduplicate_joins
posts = Post.arel_table
constraint = posts[:author_id].eq(Author.arel_attribute(:id))
constraint = posts[:author_id].eq(Author.arel_table[:id])
authors = Author.joins(posts.create_join(posts, posts.create_on(constraint)))
authors = authors.joins(:author_address).merge(authors.where("posts.type": "SpecialPost"))
......
......@@ -78,6 +78,11 @@ def test_generated_relation_methods_module_name
assert_equal "Post::GeneratedRelationMethods", mod.inspect
end
def test_arel_attribute_normalization
assert_equal Post.arel_table["body"], Post.arel_table[:body]
assert_equal Post.arel_table["body"], Post.arel_table[:text]
end
def test_incomplete_schema_loading
topic = Topic.first
payload = { foo: 42 }
......
......@@ -1261,7 +1261,7 @@ def test_first_or_create
assert_predicate same_parrot, :persisted?
assert_equal parrot, same_parrot
canary = Bird.where(Bird.arel_attribute(:color).is_distinct_from("green")).first_or_create(name: "canary")
canary = Bird.where(Bird.arel_table[:color].is_distinct_from("green")).first_or_create(name: "canary")
assert_equal "canary", canary.name
assert_nil canary.color
end
......@@ -1385,7 +1385,7 @@ def test_first_or_initialize
assert_equal "parrot", parrot.name
assert_equal "green", parrot.color
canary = Bird.where(Bird.arel_attribute(:color).is_distinct_from("green")).first_or_initialize(name: "canary")
canary = Bird.where(Bird.arel_table[:color].is_distinct_from("green")).first_or_initialize(name: "canary")
assert_equal "canary", canary.name
assert_nil canary.color
end
......@@ -1963,7 +1963,7 @@ def test_destroy_by
assert_equal post, custom_post_relation.joins(:author).where!(title: post.title).take
end
test "arel_attribute respects a custom table" do
test "arel_table respects a custom table" do
assert_equal [posts(:sti_comments)], custom_post_relation.ranked_by_comments.limit_by(1).to_a
end
......@@ -2093,7 +2093,7 @@ def test_unscope_with_table_name_qualified_hash
end
def test_unscope_with_arel_sql
posts = Post.where(Arel.sql("'Welcome to the weblog'").eq(Post.arel_attribute(:title)))
posts = Post.where(Arel.sql("'Welcome to the weblog'").eq(Post.arel_table[:title]))
assert_equal 1, posts.count
assert_equal Post.count, posts.unscope(where: :title).count
......
......@@ -28,7 +28,7 @@ def greeting
scope :containing_the_letter_a, -> { where("body LIKE '%a%'") }
scope :titled_with_an_apostrophe, -> { where("title LIKE '%''%'") }
scope :ranked_by_comments, -> { order(arel_attribute(:comments_count).desc) }
scope :ranked_by_comments, -> { order(table[:comments_count].desc) }
scope :limit_by, lambda { |l| limit(l) }
scope :locked, -> { lock }
......@@ -339,10 +339,6 @@ def sanitize_sql_for_order(sql)
sql
end
def arel_attribute(name, table)
table[name]
end
def disallow_raw_sql!(*args)
# noop
end
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册