diff --git a/activerecord/lib/active_record/association_preload.rb b/activerecord/lib/active_record/association_preload.rb index 0f96ee95d5795e357b1a24728fb0f700900cc590..c38b98258a5026c755f26508c13978507a18ba5c 100644 --- a/activerecord/lib/active_record/association_preload.rb +++ b/activerecord/lib/active_record/association_preload.rb @@ -4,28 +4,6 @@ def self.included(base) base.extend(ClassMethods) end - class HasManyAssociationStrategy - def initialize(through_reflection) - @through_reflection = through_reflection - end - - def primary_key - if @through_reflection && @through_reflection.macro == :belongs_to - @through_reflection.klass.primary_key - else - @through_reflection.primary_key_name - end - end - - def primary_key_name - if @through_reflection && @through_reflection.macro == :belongs_to - @through_reflection.primary_key_name - else - nil - end - end - end - module ClassMethods # Loads the named associations for the activerecord record (or records) given @@ -153,9 +131,9 @@ def preload_has_one_association(records, reflection, preload_options={}) def preload_has_many_association(records, reflection, preload_options={}) options = reflection.options - through_reflection = reflections[options[:through]] - strat = HasManyAssociationStrategy.new(through_reflection) - id_to_record_map, ids = construct_id_map(records, strat.primary_key_name) + + primary_key_name = reflection.through_reflection_primary_key_name + id_to_record_map, ids = construct_id_map(records, primary_key_name) records.each {|record| record.send(reflection.name).loaded} if options[:through] @@ -165,7 +143,7 @@ def preload_has_many_association(records, reflection, preload_options={}) source = reflection.source_reflection.name through_records.first.class.preload_associations(through_records, source, options) through_records.each do |through_record| - through_record_id = through_record[strat.primary_key].to_s + through_record_id = through_record[reflection.through_reflection_primary_key].to_s add_preloaded_records_to_collection(id_to_record_map[through_record_id], reflection.name, through_record.send(source)) end end diff --git a/activerecord/lib/active_record/reflection.rb b/activerecord/lib/active_record/reflection.rb index 77d03493dc83927eb8be86197517f23b3f21e78f..dbff4f24d6d3c025f8485bb3db59de9fad44921b 100644 --- a/activerecord/lib/active_record/reflection.rb +++ b/activerecord/lib/active_record/reflection.rb @@ -117,6 +117,11 @@ def sanitized_conditions #:nodoc: @sanitized_conditions ||= klass.send(:sanitize_sql, options[:conditions]) if options[:conditions] end + # Returns +true+ if +self+ is a +belongs_to+ reflection. + def belongs_to? + macro == :belongs_to + end + private def derive_class_name name.to_s.camelize @@ -200,6 +205,9 @@ def through_reflection false end + def through_reflection_primary_key_name + end + def source_reflection nil end @@ -212,7 +220,7 @@ def derive_class_name end def derive_primary_key_name - if macro == :belongs_to + if belongs_to? "#{name}_id" elsif options[:as] "#{options[:as]}_id" @@ -281,6 +289,14 @@ def check_validity! end end + def through_reflection_primary_key + through_reflection.belongs_to? ? through_reflection.klass.primary_key : through_reflection.primary_key_name + end + + def through_reflection_primary_key_name + through_reflection.primary_key_name if through_reflection.belongs_to? + end + private def derive_class_name # get the class_name of the belongs_to association of the through reflection