query_methods.rb 13.6 KB
Newer Older
1
require 'active_support/core_ext/array/wrap'
2 3
require 'active_support/core_ext/object/blank'

4 5
module ActiveRecord
  module QueryMethods
6 7
    extend ActiveSupport::Concern

8
    attr_accessor :includes_values, :eager_load_values, :preload_values,
A
Aaron Patterson 已提交
9 10
                  :select_values, :group_values, :order_values, :joins_values,
                  :where_values, :having_values, :bind_values,
11
                  :limit_value, :offset_value, :lock_value, :readonly_value, :create_with_value,
12
                  :from_value, :reordering_value, :reverse_order_value,
13 14 15
                  :uniq_value, :references_values, :extending_values

    alias extensions extending_values
16

17
    def includes(*args)
18 19
      args.empty? ? self : clone.includes!(*args)
    end
20

21 22
    def includes!(*args)
      args.reject! {|a| a.blank? }
A
Aaron Patterson 已提交
23

24 25
      self.includes_values = (includes_values + args).flatten.uniq
      self
26
    end
27

28
    def eager_load(*args)
29 30
      args.blank? ? self : clone.eager_load!(*args)
    end
31

32 33 34
    def eager_load!(*args)
      self.eager_load_values += args
      self
35 36 37
    end

    def preload(*args)
38 39
      args.blank? ? self : clone.preload!(*args)
    end
40

41 42 43
    def preload!(*args)
      self.preload_values += args
      self
44
    end
45

46 47 48 49 50 51 52 53 54 55 56
    # Used to indicate that an association is referenced by an SQL string, and should
    # therefore be JOINed in any query rather than loaded separately.
    #
    # For example:
    #
    #   User.includes(:posts).where("posts.name = 'foo'")
    #   # => Doesn't JOIN the posts table, resulting in an error.
    #
    #   User.includes(:posts).where("posts.name = 'foo'").references(:posts)
    #   # => Query now knows the string references posts, so adds a JOIN
    def references(*args)
57 58
      args.blank? ? self : clone.references!(*args)
    end
59

60 61 62
    def references!(*args)
      self.references_values = (references_values + args.flatten.map(&:to_s)).uniq
      self
63 64
    end

65
    # Works in two unique ways.
66
    #
67 68 69 70 71 72 73 74
    # First: takes a block so it can be used just like Array#select.
    #
    #   Model.scoped.select { |m| m.field == value }
    #
    # This will build an array of objects from the database for the scope,
    # converting them into an array and iterating through them using Array#select.
    #
    # Second: Modifies the SELECT statement for the query so that only certain
V
Vijay Dev 已提交
75
    # fields are retrieved:
76 77 78 79 80
    #
    #   >> Model.select(:field)
    #   => [#<Model field:value>]
    #
    # Although in the above example it looks as though this method returns an
V
Vijay Dev 已提交
81
    # array, it actually returns a relation object and can have other query
82 83
    # methods appended to it, such as the other methods in ActiveRecord::QueryMethods.
    #
84
    # The argument to the method can also be an array of fields.
85
    #
86
    #   >> Model.select([:field, :other_field, :and_one_more])
V
Vijay Dev 已提交
87
    #   => [#<Model field: "value", other_field: "value", and_one_more: "value">]
88
    #
89 90
    # Accessing attributes of an object that do not have fields retrieved by a select
    # will throw <tt>ActiveModel::MissingAttributeError</tt>:
91 92
    #
    #   >> Model.select(:field).first.other_field
93
    #   => ActiveModel::MissingAttributeError: missing attribute: other_field
94
    def select(value = Proc.new)
95
      if block_given?
96 97 98 99 100 101 102 103 104 105
        to_a.select { |*block_args| value.call(*block_args) }
      else
        clone.select!(value)
      end
    end

    def select!(value = Proc.new)
      if block_given?
        # TODO: test
        to_a.select! { |*block_args| value.call(*block_args) }
106
      else
107 108
        self.select_values += Array.wrap(value)
        self
S
Santiago Pastorino 已提交
109
      end
110
    end
S
Santiago Pastorino 已提交
111

112
    def group(*args)
113 114
      args.blank? ? self : clone.group!(*args)
    end
115

116 117 118
    def group!(*args)
      self.group_values += args.flatten
      self
119
    end
120

121
    def order(*args)
122 123
      args.blank? ? self : clone.order!(*args)
    end
124

125
    def order!(*args)
126
      args       = args.flatten
127

128 129 130
      references = args.reject { |arg| Arel::Node === arg }
                       .map { |arg| arg =~ /^([a-zA-Z]\w*)\.(\w+)/ && $1 }
                       .compact
131
      references!(references) if references.any?
132

133 134
      self.order_values += args
      self
135
    end
136

137 138 139 140 141 142 143 144 145 146
    # Replaces any existing order defined on the relation with the specified order.
    #
    #   User.order('email DESC').reorder('id ASC') # generated SQL has 'ORDER BY id ASC'
    #
    # Subsequent calls to order on the same relation will be appended. For example:
    #
    #   User.order('email DESC').reorder('id ASC').order('name ASC')
    #
    # generates a query with 'ORDER BY id ASC, name ASC'.
    #
S
Sebastian Martinez 已提交
147
    def reorder(*args)
148 149
      args.blank? ? self : clone.reorder!(*args)
    end
150

151 152 153 154
    def reorder!(*args)
      self.reordering_value = true
      self.order_values = args.flatten
      self
S
Sebastian Martinez 已提交
155 156
    end

157
    def joins(*args)
158 159
      args.compact.blank? ? self : clone.joins!(*args)
    end
160

161
    def joins!(*args)
A
Aaron Patterson 已提交
162
      args.flatten!
163

164 165
      self.joins_values += args
      self
P
Pratik Naik 已提交
166 167
    end

A
Aaron Patterson 已提交
168
    def bind(value)
169 170 171 172 173 174
      clone.bind!(value)
    end

    def bind!(value)
      self.bind_values += [value]
      self
A
Aaron Patterson 已提交
175 176
    end

177
    def where(opts, *rest)
178 179 180 181 182
      opts.blank? ? self : clone.where!(opts, *rest)
    end

    def where!(opts, *rest)
      references!(PredicateBuilder.references(opts)) if Hash === opts
183

184 185
      self.where_values += build_where(opts, rest)
      self
186
    end
P
Pratik Naik 已提交
187

188
    def having(opts, *rest)
189 190 191 192 193
      opts.blank? ? self : clone.having!(opts, *rest)
    end

    def having!(opts, *rest)
      references!(PredicateBuilder.references(opts)) if Hash === opts
194

195 196
      self.having_values += build_where(opts, rest)
      self
197 198
    end

199
    def limit(value)
200 201 202 203 204 205
      clone.limit!(value)
    end

    def limit!(value)
      self.limit_value = value
      self
206 207
    end

208
    def offset(value)
209 210 211 212 213 214
      clone.offset!(value)
    end

    def offset!(value)
      self.offset_value = value
      self
215 216 217
    end

    def lock(locks = true)
218 219
      clone.lock!(locks)
    end
220

221
    def lock!(locks = true)
222
      case locks
223
      when String, TrueClass, NilClass
224
        self.lock_value = locks || true
225
      else
226
        self.lock_value = false
227
      end
228

229
      self
230 231
    end

232 233 234 235 236
    # Returns a chainable relation with zero records, specifically an
    # instance of the NullRelation class.
    #
    # The returned NullRelation inherits from Relation and implements the
    # Null Object pattern so it is an object with defined null behavior:
237
    # it always returns an empty array of records and does not query the database.
238 239 240 241
    #
    # Any subsequent condition chained to the returned relation will continue
    # generating an empty relation and will not fire any query to the database.
    #
242 243
    # Used in cases where a method or scope could return zero records but the
    # result needs to be chainable.
244 245 246 247
    #
    # For example:
    #
    #   @posts = current_user.visible_posts.where(:name => params[:name])
248
    #   # => the visible_posts method is expected to return a chainable Relation
249 250 251
    #
    #   def visible_posts
    #     case role
252
    #     when 'Country Manager'
253
    #       Post.where(:country => country)
254
    #     when 'Reviewer'
255
    #       Post.published
256
    #     when 'Bad User'
257 258 259 260 261 262 263 264
    #       Post.none # => returning [] instead breaks the previous code
    #     end
    #   end
    #
    def none
      NullRelation.new(@klass, @table)
    end

265
    def readonly(value = true)
266 267 268 269 270 271
      clone.readonly!(value)
    end

    def readonly!(value = true)
      self.readonly_value = value
      self
272 273
    end

274
    def create_with(value)
275 276 277 278 279 280
      clone.create_with!(value)
    end

    def create_with!(value)
      self.create_with_value = value ? create_with_value.merge(value) : {}
      self
281 282
    end

283
    def from(value)
284 285 286 287 288 289
      clone.from!(value)
    end

    def from!(value)
      self.from_value = value
      self
290 291
    end

292 293 294 295 296 297 298 299 300 301 302
    # Specifies whether the records should be unique or not. For example:
    #
    #   User.select(:name)
    #   # => Might return two records with the same name
    #
    #   User.select(:name).uniq
    #   # => Returns 1 record per unique name
    #
    #   User.select(:name).uniq.uniq(false)
    #   # => You can also remove the uniqueness
    def uniq(value = true)
303 304 305 306 307 308
      clone.uniq!(value)
    end

    def uniq!(value = true)
      self.uniq_value = value
      self
309 310
    end

311
    # Used to extend a scope with additional methods, either through
312 313
    # a module or through a block provided.
    #
314 315 316 317 318 319 320 321 322 323 324 325 326
    # The object returned is a relation, which can be further extended.
    #
    # === Using a module
    #
    #   module Pagination
    #     def page(number)
    #       # pagination code goes here
    #     end
    #   end
    #
    #   scope = Model.scoped.extending(Pagination)
    #   scope.page(params[:page])
    #
V
Vijay Dev 已提交
327
    # You can also pass a list of modules:
328 329 330 331 332 333 334
    #
    #   scope = Model.scoped.extending(Pagination, SomethingElse)
    #
    # === Using a block
    #
    #   scope = Model.scoped.extending do
    #     def page(number)
335
    #       # pagination code goes here
336 337 338 339 340 341 342 343
    #     end
    #   end
    #   scope.page(params[:page])
    #
    # You can also use a block and a module list:
    #
    #   scope = Model.scoped.extending(Pagination) do
    #     def per_page(number)
344
    #       # pagination code goes here
345 346
    #     end
    #   end
347 348 349 350 351 352 353
    def extending(*modules, &block)
      if modules.any? || block
        clone.extending!(*modules, &block)
      else
        self
      end
    end
354

355 356
    def extending!(*modules, &block)
      modules << Module.new(&block) if block_given?
357

358
      self.extending_values = modules.flatten
359
      extend(*extending_values) if extending_values.any?
360

361
      self
362 363
    end

364
    def reverse_order
365 366 367 368 369 370
      clone.reverse_order!
    end

    def reverse_order!
      self.reverse_order_value = !reverse_order_value
      self
371 372
    end

373
    def arel
374
      @arel ||= with_default_scope.build_arel
375 376
    end

377
    def build_arel
378
      arel = table.from table
379

A
Aaron Patterson 已提交
380
      build_joins(arel, @joins_values) unless @joins_values.empty?
381

A
Aaron Patterson 已提交
382
      collapse_wheres(arel, (@where_values - ['']).uniq)
383

A
Aaron Patterson 已提交
384
      arel.having(*@having_values.uniq.reject{|h| h.blank?}) unless @having_values.empty?
385

386
      arel.take(connection.sanitize_limit(@limit_value)) if @limit_value
387
      arel.skip(@offset_value.to_i) if @offset_value
A
Aaron Patterson 已提交
388

A
Aaron Patterson 已提交
389
      arel.group(*@group_values.uniq.reject{|g| g.blank?}) unless @group_values.empty?
390

391
      order = @order_values
B
Brian Mathiyakom 已提交
392
      order = reverse_sql_order(order) if @reverse_order_value
393
      arel.order(*order.uniq.reject{|o| o.blank?}) unless order.empty?
394

A
Aaron Patterson 已提交
395
      build_select(arel, @select_values.uniq)
396

397
      arel.distinct(@uniq_value)
A
Aaron Patterson 已提交
398 399
      arel.from(@from_value) if @from_value
      arel.lock(@lock_value) if @lock_value
400 401

      arel
402 403
    end

404 405
    private

406
    def custom_join_ast(table, joins)
407 408
      joins = joins.reject { |join| join.blank? }

409
      return [] if joins.empty?
410 411 412

      @implicit_readonly = true

413
      joins.map do |join|
414 415 416 417 418 419
        case join
        when Array
          join = Arel.sql(join.join(' ')) if array_of_strings?(join)
        when String
          join = Arel.sql(join)
        end
420
        table.create_string_join(join)
421 422 423
      end
    end

424 425 426
    def collapse_wheres(arel, wheres)
      equalities = wheres.grep(Arel::Nodes::Equality)

A
Aaron Patterson 已提交
427
      arel.where(Arel::Nodes::And.new(equalities)) unless equalities.empty?
428 429 430

      (wheres - equalities).each do |where|
        where = Arel.sql(where) if String === where
431
        arel.where(Arel::Nodes::Grouping.new(where))
432 433 434
      end
    end

435
    def build_where(opts, other = [])
A
Aaron Patterson 已提交
436 437
      case opts
      when String, Array
438
        [@klass.send(:sanitize_sql, other.empty? ? opts : ([opts] + other))]
A
Aaron Patterson 已提交
439 440
      when Hash
        attributes = @klass.send(:expand_hash_conditions_for_aggregates, opts)
441
        PredicateBuilder.build_from_hash(table.engine, attributes, table)
442
      else
443
        [opts]
444 445 446
      end
    end

447
    def build_joins(manager, joins)
A
Aaron Patterson 已提交
448 449 450 451 452 453
      buckets = joins.group_by do |join|
        case join
        when String
          'string_join'
        when Hash, Symbol, Array
          'association_join'
454
        when ActiveRecord::Associations::JoinDependency::JoinAssociation
A
Aaron Patterson 已提交
455
          'stashed_join'
456 457
        when Arel::Nodes::Join
          'join_node'
A
Aaron Patterson 已提交
458 459 460
        else
          raise 'unknown class: %s' % join.class.name
        end
461 462
      end

A
Aaron Patterson 已提交
463 464
      association_joins         = buckets['association_join'] || []
      stashed_association_joins = buckets['stashed_join'] || []
465
      join_nodes                = (buckets['join_node'] || []).uniq
A
Aaron Patterson 已提交
466 467 468
      string_joins              = (buckets['string_join'] || []).map { |x|
        x.strip
      }.uniq
469

470
      join_list = join_nodes + custom_join_ast(manager, string_joins)
471

472
      join_dependency = ActiveRecord::Associations::JoinDependency.new(
473 474 475 476
        @klass,
        association_joins,
        join_list
      )
477 478 479 480 481

      join_dependency.graft(*stashed_association_joins)

      @implicit_readonly = true unless association_joins.empty? && stashed_association_joins.empty?

A
Aaron Patterson 已提交
482
      # FIXME: refactor this to build an AST
483
      join_dependency.join_associations.each do |association|
484
        association.join_to(manager)
485 486
      end

487
      manager.join_sources.concat join_list
488 489

      manager
490 491
    end

492
    def build_select(arel, selects)
493
      unless selects.empty?
494
        @implicit_readonly = false
495
        arel.project(*selects)
496
      else
497
        arel.project(@klass.arel_table[Arel.star])
498 499 500
      end
    end

501
    def reverse_sql_order(order_query)
B
Brian Mathiyakom 已提交
502 503
      order_query = ["#{quoted_table_name}.#{quoted_primary_key} ASC"] if order_query.empty?

504 505
      order_query.map do |o|
        case o
506
        when Arel::Nodes::Ordering
507 508
          o.reverse
        when String, Symbol
509 510 511 512
          o.to_s.split(',').collect do |s|
            s.strip!
            s.gsub!(/\sasc\Z/i, ' DESC') || s.gsub!(/\sdesc\Z/i, ' ASC') || s.concat(' DESC')
          end
513 514 515 516
        else
          o
        end
      end.flatten
517 518
    end

P
Pratik Naik 已提交
519 520 521 522
    def array_of_strings?(o)
      o.is_a?(Array) && o.all?{|obj| obj.is_a?(String)}
    end

523 524
  end
end