join_dependency.rb 8.2 KB
Newer Older
1 2 3 4 5 6
module ActiveRecord
  module Associations
    class JoinDependency # :nodoc:
      autoload :JoinBase,        'active_record/associations/join_dependency/join_base'
      autoload :JoinAssociation, 'active_record/associations/join_dependency/join_association'

7
      attr_reader :join_parts, :reflections, :alias_tracker, :base_klass
8

9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
      # base is the base class on which operation is taking place.
      # associations is the list of associations which are joined using hash, symbol or array.
      # joins is the list of all string join commnads and arel nodes.
      #
      #  Example :
      #
      #  class Physician < ActiveRecord::Base
      #    has_many :appointments
      #    has_many :patients, through: :appointments
      #  end
      #
      #  If I execute `@physician.patients.to_a` then
      #    base #=> Physician
      #    associations #=> []
      #    joins #=>  [#<Arel::Nodes::InnerJoin: ...]
      #
      #  However if I execute `Physician.joins(:appointments).to_a` then
      #    base #=> Physician
      #    associations #=> [:appointments]
      #    joins #=>  []
      #
30
      def initialize(base, associations, joins)
31
        @base_klass    = base
32 33 34 35
        @table_joins   = joins
        @join_parts    = [JoinBase.new(base)]
        @associations  = {}
        @reflections   = []
36
        @alias_tracker = AliasTracker.new(base.connection, joins)
37
        @alias_tracker.aliased_name_for(base.table_name) # Updates the count for base.table_name to 1
38
        build(associations, join_parts.last, Arel::InnerJoin)
39 40 41 42 43 44 45 46 47 48 49
      end

      def graft(*associations)
        associations.each do |association|
          join_associations.detect {|a| association == a} ||
            build(association.reflection.name, association.find_parent_in(self) || join_base, association.join_type)
        end
        self
      end

      def join_associations
50
        join_parts.drop 1
51 52 53 54 55 56
      end

      def join_base
        join_parts.first
      end

57
      def join_relation(relation)
58 59
        join_associations.inject(relation) do |rel,association|
          association.join_relation(rel)
60 61 62
        end
      end

63
      def columns
64
        join_parts.collect { |join_part|
65 66 67 68
          table = join_part.aliased_table
          join_part.column_names_with_alias.collect{ |column_name, aliased_name|
            table[column_name].as Arel.sql(aliased_name)
          }
69
        }.flatten
70 71
      end

72
      def instantiate(result_set)
73 74 75
        primary_key = join_base.aliased_primary_key
        parents = {}

76 77 78 79 80 81
        type_caster = result_set.column_type primary_key

        records = result_set.map { |row_hash|
          primary_id = type_caster.type_cast row_hash[primary_key]
          parent = parents[primary_id] ||= join_base.instantiate(row_hash)
          construct(parent, @associations, join_associations, row_hash, result_set)
82 83 84
          parent
        }.uniq

85
        remove_duplicate_results!(base_klass, records, @associations)
86 87 88
        records
      end

89 90
      protected

91 92 93 94 95 96
      def remove_duplicate_results!(base, records, associations)
        case associations
        when Symbol, String
          reflection = base.reflections[associations]
          remove_uniq_by_reflection(reflection, records)
        when Hash
97
          associations.each_key do |name|
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
            reflection = base.reflections[name]
            remove_uniq_by_reflection(reflection, records)

            parent_records = []
            records.each do |record|
              if descendant = record.send(reflection.name)
                if reflection.collection?
                  parent_records.concat descendant.target.uniq
                else
                  parent_records << descendant
                end
              end
            end

            remove_duplicate_results!(reflection.klass, parent_records, associations[name]) unless parent_records.empty?
          end
        end
      end

      def cache_joined_association(association)
        associations = []
        parent = association.parent
        while parent != join_base
          associations.unshift(parent.reflection.name)
          parent = parent.parent
        end
A
Aaron Patterson 已提交
124 125
        ref = associations.inject(@associations) do |cache,key|
          cache[key]
126 127 128 129
        end
        ref[association.reflection.name] ||= {}
      end

130
      def build(associations, parent, join_type)
131 132
        case associations
        when Symbol, String
133
          reflection = parent.reflections[associations.intern] or
134
          raise ConfigurationError, "Association named '#{ associations }' was not found on #{ parent.base_klass.name }; perhaps you misspelled it?"
135 136
          unless join_association = find_join_association(reflection, parent)
            @reflections << reflection
137
            join_association = build_join_association(reflection, parent, join_type)
138 139 140 141 142 143 144 145 146
            @join_parts << join_association
            cache_joined_association(join_association)
          end
          join_association
        when Array
          associations.each do |association|
            build(association, parent, join_type)
          end
        when Hash
147 148 149
          associations.each do |left, right|
            join_association = build(left, parent, join_type)
            build(right, join_association, join_type)
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
          end
        else
          raise ConfigurationError, associations.inspect
        end
      end

      def find_join_association(name_or_reflection, parent)
        if String === name_or_reflection
          name_or_reflection = name_or_reflection.to_sym
        end

        join_associations.detect { |j|
          j.reflection == name_or_reflection && j.parent == parent
        }
      end

      def remove_uniq_by_reflection(reflection, records)
        if reflection && reflection.collection?
          records.each { |record| record.send(reflection.name).target.uniq! }
        end
      end

172
      def build_join_association(reflection, parent, join_type)
173 174 175 176 177 178
        reflection.check_validity!

        if reflection.options[:polymorphic]
          raise EagerLoadPolymorphicError.new(reflection)
        end

179
        JoinAssociation.new(reflection, join_parts.length, parent, join_type, alias_tracker)
180 181
      end

182
      def construct(parent, associations, join_parts, row, rs)
183
        associations.sort_by { |k,_| k.to_s }.each do |association_name, assoc|
184 185
          association = construct_scalar(parent, association_name, join_parts, row, rs)
          construct(association, assoc, join_parts, row, rs) if association
186 187
        end
      end
188

189
      def construct_scalar(parent, associations, join_parts, row, rs)
190
        name = associations.to_s
191

192 193 194 195
        join_part = join_parts.detect { |j|
          j.reflection.name.to_s == name &&
            j.parent_table_name == parent.class.table_name
        }
196

197 198 199
        raise(ConfigurationError, "No such association") unless join_part

        join_parts.delete(join_part)
200
        construct_association(parent, join_part, row, rs)
201 202
      end

203 204 205 206 207
      def construct_association(record, join_part, row, rs)
        caster = rs.column_type(join_part.parent.aliased_primary_key)
        row_id = caster.type_cast row[join_part.parent.aliased_primary_key]

        return if record.id != row_id
208 209 210

        macro = join_part.reflection.macro
        if macro == :has_one
211
          return record.association(join_part.reflection.name).target if record.association_cache.key?(join_part.reflection.name)
212 213 214
          association = join_part.instantiate(row) unless row[join_part.aliased_primary_key].nil?
          set_target_and_inverse(join_part, association, record)
        else
215
          association = join_part.instantiate(row) unless row[join_part.aliased_primary_key].nil?
216
          case macro
217
          when :has_many
218 219
            other = record.association(join_part.reflection.name)
            other.loaded!
220
            other.target.push(association) if association
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
            other.set_inverse_instance(association)
          when :belongs_to
            set_target_and_inverse(join_part, association, record)
          else
            raise ConfigurationError, "unknown macro: #{join_part.reflection.macro}"
          end
        end
        association
      end

      def set_target_and_inverse(join_part, association, record)
        other = record.association(join_part.reflection.name)
        other.target = association
        other.set_inverse_instance(association)
      end
    end
  end
end