diff --git a/activerecord/lib/active_record/associations/collection_association.rb b/activerecord/lib/active_record/associations/collection_association.rb index 2dca6b612ec2192bc39b3c39d7bb221ca06dbb1b..00355f3e892d3e400d779e2ecb4ea70841806a85 100644 --- a/activerecord/lib/active_record/associations/collection_association.rb +++ b/activerecord/lib/active_record/associations/collection_association.rb @@ -246,9 +246,12 @@ def destroy_all end end - # Count all records using SQL. Construct options and pass them with - # scope to the target class's +count+. + # Returns the number of records. If no arguments are given, it counts all + # columns using SQL. If one argument is given, it counts only the passed + # column using SQL. If a block is given, it counts the number of records + # yielding a true value. def count(column_name = nil) + return super if block_given? relation = scope if association_scope.distinct_value # This is needed because 'SELECT count(DISTINCT *)..' is not valid SQL. diff --git a/activerecord/lib/active_record/associations/collection_proxy.rb b/activerecord/lib/active_record/associations/collection_proxy.rb index b9aed05135fa55e671f1c17a8914bd3cf367a4ab..9350064028dc7558437222a50a3ca7d4d0d353c0 100644 --- a/activerecord/lib/active_record/associations/collection_proxy.rb +++ b/activerecord/lib/active_record/associations/collection_proxy.rb @@ -715,12 +715,13 @@ def distinct end alias uniq distinct - # Count all records using SQL. + # Count all records. # # class Person < ActiveRecord::Base # has_many :pets # end # + # # This will perform the count using SQL. # person.pets.count # => 3 # person.pets # # => [ @@ -728,8 +729,13 @@ def distinct # # #, # # # # # ] - def count(column_name = nil) - @association.count(column_name) + # + # Passing a block will select all of a person's pets in SQL and then + # perform the count using Ruby. + # + # person.pets.count { |pet| pet.name.include?('-') } # => 2 + def count(column_name = nil, &block) + @association.count(column_name, &block) end # Returns the size of the collection. If the collection hasn't been loaded, diff --git a/activerecord/lib/active_record/relation/calculations.rb b/activerecord/lib/active_record/relation/calculations.rb index 54c9af48986fa07a0b1a0e865832ad4647c67c23..120f34109e69ce434dc34af5f1f77f82213419bc 100644 --- a/activerecord/lib/active_record/relation/calculations.rb +++ b/activerecord/lib/active_record/relation/calculations.rb @@ -37,7 +37,11 @@ module Calculations # Note: not all valid {Relation#select}[rdoc-ref:QueryMethods#select] expressions are valid #count expressions. The specifics differ # between databases. In invalid cases, an error from the database is thrown. def count(column_name = nil) - calculate(:count, column_name) + if block_given? + to_a.count { |*block_args| yield(*block_args) } + else + calculate(:count, column_name) + end end # Calculates the average value on a given column. Returns +nil+ if there's diff --git a/activerecord/test/cases/calculations_test.rb b/activerecord/test/cases/calculations_test.rb index 8f2682c781318fbe156a4e1d0be558231d629f4d..cfae7001594630fe84c786e608ba597f6006bc5e 100644 --- a/activerecord/test/cases/calculations_test.rb +++ b/activerecord/test/cases/calculations_test.rb @@ -482,6 +482,10 @@ def test_count_with_where_and_order assert_equal 1, Account.where(firm_name: '37signals').order(:firm_name).reverse_order.count end + def test_count_with_block + assert_equal 4, Account.count { |account| account.credit_limit.modulo(10).zero? } + end + def test_should_sum_expression # Oracle adapter returns floating point value 636.0 after SUM if current_adapter?(:OracleAdapter) diff --git a/activerecord/test/cases/relations_test.rb b/activerecord/test/cases/relations_test.rb index 95e4230a5884a9c033f2ef969607728376eef944..aa5766534f51296eecd8c4453bea63ad272b5def 100644 --- a/activerecord/test/cases/relations_test.rb +++ b/activerecord/test/cases/relations_test.rb @@ -1086,6 +1086,11 @@ def test_count assert_equal 9, posts.where(:comments_count => 0).count end + def test_count_with_block + posts = Post.all + assert_equal 10, posts.count { |p| p.comments_count.even? } + end + def test_count_on_association_relation author = Author.last another_author = Author.first