提交 9641d0b2 编写于 作者: R Ryuta Kamizono

Support type casting for grandchild's attributes

Related to #39292.

Fixes #39460.
上级 63680ed3
...@@ -309,7 +309,7 @@ def execute_simple_calculation(operation, column_name, distinct) #:nodoc: ...@@ -309,7 +309,7 @@ def execute_simple_calculation(operation, column_name, distinct) #:nodoc:
type_cast_calculated_value(result.cast_values.first, operation) do |value| type_cast_calculated_value(result.cast_values.first, operation) do |value|
type = column.try(:type_caster) || type = column.try(:type_caster) ||
lookup_cast_type_from_join_dependencies(column_name.to_s, build_join_dependencies) || Type.default_value lookup_cast_type_from_join_dependencies(column_name.to_s) || Type.default_value
type.deserialize(value) type.deserialize(value)
end end
end end
...@@ -388,7 +388,7 @@ def execute_grouped_calculation(operation, column_name, distinct) #:nodoc: ...@@ -388,7 +388,7 @@ def execute_grouped_calculation(operation, column_name, distinct) #:nodoc:
result[key] = type_cast_calculated_value(row[column_alias], operation) do |value| result[key] = type_cast_calculated_value(row[column_alias], operation) do |value|
type ||= column.try(:type_caster) || type ||= column.try(:type_caster) ||
lookup_cast_type_from_join_dependencies(column_name.to_s, build_join_dependencies) || Type.default_value lookup_cast_type_from_join_dependencies(column_name.to_s) || Type.default_value
type.deserialize(value) type.deserialize(value)
end end
end end
...@@ -416,19 +416,10 @@ def type_for(field, &block) ...@@ -416,19 +416,10 @@ def type_for(field, &block)
@klass.type_for_attribute(field_name, &block) @klass.type_for_attribute(field_name, &block)
end end
def build_join_dependencies def lookup_cast_type_from_join_dependencies(name, join_dependencies = build_join_dependencies)
join_dependencies = [] each_join_dependencies(join_dependencies) do |join|
join_dependencies.unshift construct_join_dependency( type = join.base_klass.attribute_types.fetch(name, nil)
select_association_list(joins_values + left_outer_joins_values, join_dependencies), nil return type if type
)
end
def lookup_cast_type_from_join_dependencies(name, join_dependencies)
join_dependencies.each do |join_dependency|
join_dependency.each do |join|
type = join.base_klass.attribute_types.fetch(name, nil)
return type if type
end
end end
nil nil
end end
......
...@@ -14,11 +14,11 @@ def initialize(table) ...@@ -14,11 +14,11 @@ def initialize(table)
register_handler(Set, ArrayHandler.new(self)) register_handler(Set, ArrayHandler.new(self))
end end
def build_from_hash(attributes) def build_from_hash(attributes, &block)
attributes = attributes.stringify_keys attributes = attributes.stringify_keys
attributes = convert_dot_notation_to_hash(attributes) attributes = convert_dot_notation_to_hash(attributes)
expand_from_hash(attributes) expand_from_hash(attributes, &block)
end end
def self.references(attributes) def self.references(attributes)
...@@ -61,17 +61,18 @@ def build_bind_attribute(column_name, value) ...@@ -61,17 +61,18 @@ def build_bind_attribute(column_name, value)
Arel::Nodes::BindParam.new(attr) Arel::Nodes::BindParam.new(attr)
end end
def resolve_arel_attribute(table_name, column_name) def resolve_arel_attribute(table_name, column_name, &block)
table.associated_table(table_name).arel_attribute(column_name) table.associated_table(table_name, &block).arel_attribute(column_name)
end end
protected protected
def expand_from_hash(attributes) def expand_from_hash(attributes, &block)
return ["1=0"] if attributes.empty? return ["1=0"] if attributes.empty?
attributes.flat_map do |key, value| attributes.flat_map do |key, value|
if value.is_a?(Hash) && !table.has_column?(key) if value.is_a?(Hash) && !table.has_column?(key)
table.associated_predicate_builder(key).expand_from_hash(value) table.associated_table(key, &block)
.predicate_builder.expand_from_hash(value.stringify_keys)
elsif table.associated_with?(key) elsif table.associated_with?(key)
# Find the foreign key when using queries such as: # Find the foreign key when using queries such as:
# Post.where(author: author) # Post.where(author: author)
......
...@@ -1077,11 +1077,39 @@ def build_subquery(subquery_alias, select_value) # :nodoc: ...@@ -1077,11 +1077,39 @@ def build_subquery(subquery_alias, select_value) # :nodoc:
def build_where_clause(opts, rest = []) # :nodoc: def build_where_clause(opts, rest = []) # :nodoc:
opts = sanitize_forbidden_attributes(opts) opts = sanitize_forbidden_attributes(opts)
self.references_values |= PredicateBuilder.references(opts) if Hash === opts self.references_values |= PredicateBuilder.references(opts) if Hash === opts
where_clause_factory.build(opts, rest) where_clause_factory.build(opts, rest) do |table_name|
lookup_reflection_from_join_dependencies(table_name)
end
end end
alias :build_having_clause :build_where_clause alias :build_having_clause :build_where_clause
private private
def lookup_reflection_from_join_dependencies(table_name)
each_join_dependencies do |join|
return join.reflection if table_name == join.table_name
end
nil
end
def each_join_dependencies(join_dependencies = build_join_dependencies)
join_dependencies.each do |join_dependency|
join_dependency.each do |join|
yield join
end
end
end
def build_join_dependencies
associations = joins_values | left_outer_joins_values
associations |= eager_load_values unless eager_load_values.empty?
associations |= includes_values unless includes_values.empty?
join_dependencies = []
join_dependencies.unshift construct_join_dependency(
select_association_list(associations, join_dependencies), nil
)
end
def assert_mutability! def assert_mutability!
raise ImmutableRelation if @loaded raise ImmutableRelation if @loaded
raise ImmutableRelation if defined?(@arel) && @arel raise ImmutableRelation if defined?(@arel) && @arel
...@@ -1268,7 +1296,9 @@ def arel_column(field) ...@@ -1268,7 +1296,9 @@ def arel_column(field)
arel_attribute(field) arel_attribute(field)
elsif field.match?(/\A\w+\.\w+\z/) elsif field.match?(/\A\w+\.\w+\z/)
table, column = field.split(".") table, column = field.split(".")
predicate_builder.resolve_arel_attribute(table, column) predicate_builder.resolve_arel_attribute(table, column) do
lookup_reflection_from_join_dependencies(table)
end
else else
yield field yield field
end end
......
...@@ -8,12 +8,12 @@ def initialize(klass, predicate_builder) ...@@ -8,12 +8,12 @@ def initialize(klass, predicate_builder)
@predicate_builder = predicate_builder @predicate_builder = predicate_builder
end end
def build(opts, other) def build(opts, other, &block)
case opts case opts
when String, Array when String, Array
parts = [klass.sanitize_sql(other.empty? ? opts : ([opts] + other))] parts = [klass.sanitize_sql(other.empty? ? opts : ([opts] + other))]
when Hash when Hash
parts = predicate_builder.build_from_hash(opts) parts = predicate_builder.build_from_hash(opts, &block)
when Arel::Nodes::Node when Arel::Nodes::Node
parts = [opts] parts = [opts]
else else
......
...@@ -24,19 +24,23 @@ def type(column_name) ...@@ -24,19 +24,23 @@ def type(column_name)
end end
def has_column?(column_name) def has_column?(column_name)
klass && klass.columns_hash.key?(column_name.to_s) klass&.columns_hash.key?(column_name)
end end
def associated_with?(association_name) def associated_with?(table_name)
klass && klass._reflect_on_association(association_name) klass&._reflect_on_association(table_name) || klass&._reflect_on_association(table_name.singularize)
end end
def associated_table(table_name) def associated_table(table_name)
reflection = klass._reflect_on_association(table_name) || klass._reflect_on_association(table_name.to_s.singularize) reflection = klass._reflect_on_association(table_name) || klass._reflect_on_association(table_name.singularize)
if !reflection && table_name == arel_table.name if !reflection && table_name == arel_table.name
self return self
elsif reflection && !reflection.polymorphic? end
reflection ||= yield table_name if block_given?
if reflection && !reflection.polymorphic?
association_klass = reflection.klass association_klass = reflection.klass
arel_table = association_klass.arel_table.alias(table_name) arel_table = association_klass.arel_table.alias(table_name)
TableMetadata.new(association_klass, arel_table, reflection) TableMetadata.new(association_klass, arel_table, reflection)
...@@ -47,32 +51,24 @@ def associated_table(table_name) ...@@ -47,32 +51,24 @@ def associated_table(table_name)
end end
end end
def associated_predicate_builder(table_name)
associated_table(table_name).predicate_builder
end
def polymorphic_association? def polymorphic_association?
reflection&.polymorphic? reflection&.polymorphic?
end end
def aggregated_with?(aggregation_name)
klass && reflect_on_aggregation(aggregation_name)
end
def reflect_on_aggregation(aggregation_name) def reflect_on_aggregation(aggregation_name)
klass.reflect_on_aggregation(aggregation_name) klass&.reflect_on_aggregation(aggregation_name)
end end
alias :aggregated_with? :reflect_on_aggregation
protected def predicate_builder
def predicate_builder if klass
if klass predicate_builder = klass.predicate_builder.dup
predicate_builder = klass.predicate_builder.dup predicate_builder.instance_variable_set(:@table, self)
predicate_builder.instance_variable_set(:@table, self) predicate_builder
predicate_builder else
else PredicateBuilder.new(self)
PredicateBuilder.new(self)
end
end end
end
private private
attr_reader :klass, :types, :arel_table, :reflection attr_reader :klass, :types, :arel_table, :reflection
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
require "support/stubs/strong_parameters" require "support/stubs/strong_parameters"
class CalculationsTest < ActiveRecord::TestCase class CalculationsTest < ActiveRecord::TestCase
fixtures :companies, :accounts, :authors, :topics, :speedometers, :minivans, :books, :posts, :comments fixtures :companies, :accounts, :authors, :author_addresses, :topics, :speedometers, :minivans, :books, :posts, :comments
def test_should_sum_field def test_should_sum_field
assert_equal 318, Account.sum(:credit_limit) assert_equal 318, Account.sum(:credit_limit)
...@@ -751,33 +751,35 @@ def test_pluck_type_cast_with_conflict_column_names ...@@ -751,33 +751,35 @@ def test_pluck_type_cast_with_conflict_column_names
[Date.new(2004, 4, 15), "reading"], [Date.new(2004, 4, 15), "reading"],
[Date.new(2004, 4, 15), "read"], [Date.new(2004, 4, 15), "read"],
] ]
actual = actual = AuthorAddress.joins(author: [:topics, :books]).order(:"books.last_read")
Author.joins(:topics, :books).order(:"books.last_read") .where("books.last_read": [:unread, :reading, :read])
.where.not("books.last_read": nil)
.pluck(:"topics.last_read", :"books.last_read") .pluck(:"topics.last_read", :"books.last_read")
assert_equal expected, actual assert_equal expected, actual
end end
def test_pluck_type_cast_with_joins_without_table_name_qualified_column def test_pluck_type_cast_with_joins_without_table_name_qualified_column
assert_pluck_type_cast_without_table_name_qualified_column(Author.joins(:books)) assert_pluck_type_cast_without_table_name_qualified_column(AuthorAddress.joins(author: :books))
end end
def test_pluck_type_cast_with_left_joins_without_table_name_qualified_column def test_pluck_type_cast_with_left_joins_without_table_name_qualified_column
assert_pluck_type_cast_without_table_name_qualified_column(Author.left_joins(:books)) assert_pluck_type_cast_without_table_name_qualified_column(AuthorAddress.left_joins(author: :books))
end end
def test_pluck_type_cast_with_eager_load_without_table_name_qualified_column def test_pluck_type_cast_with_eager_load_without_table_name_qualified_column
assert_pluck_type_cast_without_table_name_qualified_column(Author.eager_load(:books)) assert_pluck_type_cast_without_table_name_qualified_column(AuthorAddress.eager_load(author: :books))
end end
def assert_pluck_type_cast_without_table_name_qualified_column(authors) def assert_pluck_type_cast_without_table_name_qualified_column(author_addresses)
expected = [ expected = [
[nil, "unread"], [nil, "unread"],
["ebook", "reading"], ["ebook", "reading"],
["paperback", "read"], ["paperback", "read"],
] ]
actual = authors.order(:last_read).where.not("books.last_read": nil).pluck(:format, :last_read) actual = author_addresses.order(:last_read)
.where("books.last_read": [:unread, :reading, :read])
.pluck(:format, :last_read)
assert_equal expected, actual assert_equal expected, actual
end end
private :assert_pluck_type_cast_without_table_name_qualified_column private :assert_pluck_type_cast_without_table_name_qualified_column
......
...@@ -488,8 +488,8 @@ def test_inheritance_without_mapping ...@@ -488,8 +488,8 @@ def test_inheritance_without_mapping
end end
def test_scope_inherited_properly def test_scope_inherited_properly
assert_nothing_raised { Company.of_first_firm } assert_nothing_raised { Company.of_first_firm.to_a }
assert_nothing_raised { Client.of_first_firm } assert_nothing_raised { Client.of_first_firm.to_a }
end end
def test_inheritance_with_default_scope def test_inheritance_with_default_scope
......
...@@ -9,6 +9,7 @@ class Company < AbstractCompany ...@@ -9,6 +9,7 @@ class Company < AbstractCompany
validates_presence_of :name validates_presence_of :name
has_one :account, foreign_key: "firm_id"
has_one :dummy_account, foreign_key: "firm_id", class_name: "Account" has_one :dummy_account, foreign_key: "firm_id", class_name: "Account"
has_many :contracts has_many :contracts
has_many :developers, through: :contracts has_many :developers, through: :contracts
...@@ -16,8 +17,7 @@ class Company < AbstractCompany ...@@ -16,8 +17,7 @@ class Company < AbstractCompany
attribute :metadata, :json attribute :metadata, :json
scope :of_first_firm, lambda { scope :of_first_firm, lambda {
joins(account: :firm). joins(account: :firm).where("companies.id": 1)
where("firms.id" => 1)
} }
def arbitrary_method def arbitrary_method
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册