join_dependency.rb 8.2 KB
Newer Older
1 2 3 4 5 6 7
module ActiveRecord
  module Associations
    class JoinDependency # :nodoc:
      autoload :JoinPart,        'active_record/associations/join_dependency/join_part'
      autoload :JoinBase,        'active_record/associations/join_dependency/join_base'
      autoload :JoinAssociation, 'active_record/associations/join_dependency/join_association'

8
      attr_reader :join_parts, :reflections, :alias_tracker, :base_klass
9

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
      # base is the base class on which operation is taking place.
      # associations is the list of associations which are joined using hash, symbol or array.
      # joins is the list of all string join commnads and arel nodes.
      #
      #  Example :
      #
      #  class Physician < ActiveRecord::Base
      #    has_many :appointments
      #    has_many :patients, through: :appointments
      #  end
      #
      #  If I execute `@physician.patients.to_a` then
      #    base #=> Physician
      #    associations #=> []
      #    joins #=>  [#<Arel::Nodes::InnerJoin: ...]
      #
      #  However if I execute `Physician.joins(:appointments).to_a` then
      #    base #=> Physician
      #    associations #=> [:appointments]
      #    joins #=>  []
      #
31
      def initialize(base, associations, joins)
32
        @base_klass    = base
33 34 35 36
        @table_joins   = joins
        @join_parts    = [JoinBase.new(base)]
        @associations  = {}
        @reflections   = []
37
        @alias_tracker = AliasTracker.new(base.connection, joins)
38
        @alias_tracker.aliased_name_for(base.table_name) # Updates the count for base.table_name to 1
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
        build(associations)
      end

      def graft(*associations)
        associations.each do |association|
          join_associations.detect {|a| association == a} ||
            build(association.reflection.name, association.find_parent_in(self) || join_base, association.join_type)
        end
        self
      end

      def join_associations
        join_parts.last(join_parts.length - 1)
      end

      def join_base
        join_parts.first
      end

      def columns
59
        join_parts.collect { |join_part|
60 61 62 63
          table = join_part.aliased_table
          join_part.column_names_with_alias.collect{ |column_name, aliased_name|
            table[column_name].as Arel.sql(aliased_name)
          }
64
        }.flatten
65 66 67 68 69 70 71 72 73 74 75 76 77
      end

      def instantiate(rows)
        primary_key = join_base.aliased_primary_key
        parents = {}

        records = rows.map { |model|
          primary_id = model[primary_key]
          parent = parents[primary_id] ||= join_base.instantiate(model)
          construct(parent, @associations, join_associations, model)
          parent
        }.uniq

78
        remove_duplicate_results!(base_klass, records, @associations)
79 80 81
        records
      end

82 83
      protected

84 85 86 87 88 89 90 91 92 93
      def remove_duplicate_results!(base, records, associations)
        case associations
        when Symbol, String
          reflection = base.reflections[associations]
          remove_uniq_by_reflection(reflection, records)
        when Array
          associations.each do |association|
            remove_duplicate_results!(base, records, association)
          end
        when Hash
94
          associations.each_key do |name|
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            reflection = base.reflections[name]
            remove_uniq_by_reflection(reflection, records)

            parent_records = []
            records.each do |record|
              if descendant = record.send(reflection.name)
                if reflection.collection?
                  parent_records.concat descendant.target.uniq
                else
                  parent_records << descendant
                end
              end
            end

            remove_duplicate_results!(reflection.klass, parent_records, associations[name]) unless parent_records.empty?
          end
        end
      end

      def cache_joined_association(association)
        associations = []
        parent = association.parent
        while parent != join_base
          associations.unshift(parent.reflection.name)
          parent = parent.parent
        end
        ref = @associations
        associations.each do |key|
          ref = ref[key]
        end
        ref[association.reflection.name] ||= {}
      end

      def build(associations, parent = nil, join_type = Arel::InnerJoin)
        parent ||= join_parts.last
        case associations
        when Symbol, String
132
          reflection = parent.reflections[associations.intern] or
133
          raise ConfigurationError, "Association named '#{ associations }' was not found on #{ parent.base_klass.name }; perhaps you misspelled it?"
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
          unless join_association = find_join_association(reflection, parent)
            @reflections << reflection
            join_association = build_join_association(reflection, parent)
            join_association.join_type = join_type
            @join_parts << join_association
            cache_joined_association(join_association)
          end
          join_association
        when Array
          associations.each do |association|
            build(association, parent, join_type)
          end
        when Hash
          associations.keys.sort_by { |a| a.to_s }.each do |name|
            join_association = build(name, parent, join_type)
            build(associations[name], join_association, join_type)
          end
        else
          raise ConfigurationError, associations.inspect
        end
      end

      def find_join_association(name_or_reflection, parent)
        if String === name_or_reflection
          name_or_reflection = name_or_reflection.to_sym
        end

        join_associations.detect { |j|
          j.reflection == name_or_reflection && j.parent == parent
        }
      end

      def remove_uniq_by_reflection(reflection, records)
        if reflection && reflection.collection?
          records.each { |record| record.send(reflection.name).target.uniq! }
        end
      end

      def build_join_association(reflection, parent)
        JoinAssociation.new(reflection, self, parent)
      end

      def construct(parent, associations, join_parts, row)
        case associations
        when Symbol, String
          name = associations.to_s

          join_part = join_parts.detect { |j|
            j.reflection.name.to_s == name &&
              j.parent_table_name == parent.class.table_name }

            raise(ConfigurationError, "No such association") unless join_part

            join_parts.delete(join_part)
            construct_association(parent, join_part, row)
        when Array
          associations.each do |association|
            construct(parent, association, join_parts, row)
          end
        when Hash
          associations.sort_by { |k,_| k.to_s }.each do |association_name, assoc|
            association = construct(parent, association_name, join_parts, row)
            construct(association, assoc, join_parts, row) if association
          end
        else
          raise ConfigurationError, associations.inspect
        end
      end

      def construct_association(record, join_part, row)
        return if record.id.to_s != join_part.parent.record_id(row).to_s

        macro = join_part.reflection.macro
        if macro == :has_one
208
          return record.association(join_part.reflection.name).target if record.association_cache.key?(join_part.reflection.name)
209 210 211
          association = join_part.instantiate(row) unless row[join_part.aliased_primary_key].nil?
          set_target_and_inverse(join_part, association, record)
        else
212
          association = join_part.instantiate(row) unless row[join_part.aliased_primary_key].nil?
213 214 215 216
          case macro
          when :has_many, :has_and_belongs_to_many
            other = record.association(join_part.reflection.name)
            other.loaded!
217
            other.target.push(association) if association
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
            other.set_inverse_instance(association)
          when :belongs_to
            set_target_and_inverse(join_part, association, record)
          else
            raise ConfigurationError, "unknown macro: #{join_part.reflection.macro}"
          end
        end
        association
      end

      def set_target_and_inverse(join_part, association, record)
        other = record.association(join_part.reflection.name)
        other.target = association
        other.set_inverse_instance(association)
      end
    end
  end
end