attribute_assignment.rb 7.4 KB
Newer Older
1 2 3

module ActiveRecord
  module AttributeAssignment
4 5
    extend ActiveSupport::Concern
    include ActiveModel::DeprecatedMassAssignmentSecurity
6
    include ActiveModel::ForbiddenAttributesProtection
7

8
    # Allows you to set all the attributes by passing in a hash of attributes with
9
    # keys matching the attribute names (which again matches the column names).
10
    #
11 12 13
    # If the passed hash responds to <tt>permitted?</tt> method and the return value
    # of this method is +false+ an <tt>ActiveModel::ForbiddenAttributesError</tt>
    # exception is raised.
14
    def assign_attributes(new_attributes)
15
      return if new_attributes.blank?
16

J
Jon Leighton 已提交
17 18
      attributes                  = new_attributes.stringify_keys
      multi_parameter_attributes  = []
19 20
      nested_parameter_attributes = []

21
      attributes = sanitize_for_mass_assignment(attributes)
22 23 24 25

      attributes.each do |k, v|
        if k.include?("(")
          multi_parameter_attributes << [ k, v ]
J
Jon Leighton 已提交
26 27
        elsif v.is_a?(Hash)
          nested_parameter_attributes << [ k, v ]
28
        else
J
Jon Leighton 已提交
29
          _assign_attribute(k, v)
30 31 32
        end
      end

33
      assign_nested_parameter_attributes(nested_parameter_attributes) unless nested_parameter_attributes.empty?
34
      assign_multiparameter_attributes(multi_parameter_attributes) unless multi_parameter_attributes.empty?
35 36
    end

37 38
    alias attributes= assign_attributes

39 40
    private

J
Jon Leighton 已提交
41 42 43 44 45 46 47 48 49 50
    def _assign_attribute(k, v)
      public_send("#{k}=", v)
    rescue NoMethodError
      if respond_to?("#{k}=")
        raise
      else
        raise UnknownAttributeError, "unknown attribute: #{k}"
      end
    end

51 52 53 54 55
    # Assign any deferred nested attributes after the base attributes have been set.
    def assign_nested_parameter_attributes(pairs)
      pairs.each { |k, v| _assign_attribute(k, v) }
    end

56
    # Instantiates objects for all attribute classes that needs more than one constructor parameter. This is done
57
    # by calling new on the column type or aggregation type (through composed_of) object with these parameters.
58 59
    # So having the pairs written_on(1) = "2004", written_on(2) = "6", written_on(3) = "24", will instantiate
    # written_on (a date type) with Date.new("2004", "6", "24"). You can also specify a typecast character in the
60 61
    # parentheses to have the parameters typecasted before they're used in the constructor. Use i for Fixnum and
    # f for Float. If all the values for a given attribute are empty, the attribute will be set to +nil+.
62 63 64 65 66 67 68 69 70 71
    def assign_multiparameter_attributes(pairs)
      execute_callstack_for_multiparameter_attributes(
        extract_callstack_for_multiparameter_attributes(pairs)
      )
    end

    def execute_callstack_for_multiparameter_attributes(callstack)
      errors = []
      callstack.each do |name, values_with_empty_parameters|
        begin
72
          send("#{name}=", MultiparameterAttribute.new(self, name, values_with_empty_parameters).read_value)
73
        rescue => ex
74
          errors << AttributeAssignmentError.new("error on assignment #{values_with_empty_parameters.values.inspect} to #{name} (#{ex.message})", ex, name)
75 76 77
        end
      end
      unless errors.empty?
78 79
        error_descriptions = errors.map { |ex| ex.message }.join(",")
        raise MultiparameterAssignmentErrors.new(errors), "#{errors.size} error(s) on assignment of multiparameter attributes [#{error_descriptions}]"
80 81 82
      end
    end

83
    def extract_callstack_for_multiparameter_attributes(pairs)
84
      attributes = {}
85 86 87 88 89 90 91 92 93 94 95 96

      pairs.each do |(multiparameter_name, value)|
        attribute_name = multiparameter_name.split("(").first
        attributes[attribute_name] ||= {}

        parameter_value = value.empty? ? nil : type_cast_attribute_value(multiparameter_name, value)
        attributes[attribute_name][find_parameter_position(multiparameter_name)] ||= parameter_value
      end

      attributes
    end

97 98 99 100 101 102
    def type_cast_attribute_value(multiparameter_name, value)
      multiparameter_name =~ /\([0-9]*([if])\)/ ? value.send("to_" + $1) : value
    end

    def find_parameter_position(multiparameter_name)
      multiparameter_name.scan(/\(([0-9]*).*\)/).first.first.to_i
103 104
    end

105 106
    class MultiparameterAttribute #:nodoc:
      attr_reader :object, :name, :values, :column
107

108 109 110 111
      def initialize(object, name, values)
        @object = object
        @name   = name
        @values = values
112 113
      end

114 115
      def read_value
        return if values.values.compact.empty?
116

117 118 119
        @column = object.class.reflect_on_aggregation(name.to_sym) || object.column_for_attribute(name)
        klass   = column.klass

120
        if klass == Time
121
          read_time
122
        elsif klass == Date
123
          read_date
124
        else
125
          read_other(klass)
126
        end
127
      end
128

129
      private
130

131
      def instantiate_time_object(set_values)
132
        if object.class.send(:create_time_zone_conversion_attribute?, name, column)
133
          Time.zone.local(*set_values)
134
        else
135
          Time.send(object.class.default_timezone, *set_values)
136
        end
137 138
      end

139
      def read_time
140 141 142 143 144 145 146 147 148
        # If column is a :time (and not :date or :timestamp) there is no need to validate if
        # there are year/month/day fields
        if column.type == :time
          # if the column is a time set the values to their defaults as January 1, 1970, but only if they're nil
          { 1 => 1970, 2 => 1, 3 => 1 }.each do |key,value|
            values[key] ||= value
          end
        else
          # else column is a timestamp, so if Date bits were not provided, error
149
          validate_required_parameters!([1,2,3])
150 151

          # If Date bits were provided but blank, then return nil
152
          return if blank_date_parameter?
153
        end
154

155
        max_position = extract_max_param(6)
156 157 158
        set_values   = values.values_at(*(1..max_position))
        # If Time bits are not there, then default to 0
        (3..5).each { |i| set_values[i] = set_values[i].presence || 0 }
159
        instantiate_time_object(set_values)
160
      end
161

162 163
      def read_date
        return if blank_date_parameter?
164 165 166 167
        set_values = values.values_at(1,2,3)
        begin
          Date.new(*set_values)
        rescue ArgumentError # if Date.new raises an exception on an invalid date
168
          instantiate_time_object(set_values).to_date # we instantiate Time object and convert it back to a date thus using Time's logic in handling invalid dates
169 170
        end
      end
171

172 173
      def read_other(klass)
        max_position = extract_max_param
174
        positions    = (1..max_position)
175
        validate_required_parameters!(positions)
176

177 178
        set_values = values.values_at(*positions)
        klass.new(*set_values)
179 180
      end

181
      # Checks whether some blank date parameter exists. Note that this is different
182
      # than the validate_required_parameters! method, since it just checks for blank
183 184
      # positions instead of missing ones, and does not raise in case one blank position
      # exists. The caller is responsible to handle the case of this returning true.
185 186
      def blank_date_parameter?
        (1..3).any? { |position| values[position].blank? }
187
      end
188

189
      # If some position is not provided, it errors out a missing parameter exception.
190
      def validate_required_parameters!(positions)
191
        if missing_parameter = positions.detect { |position| !values.key?(position) }
192 193 194
          raise ArgumentError.new("Missing Parameter - #{name}(#{missing_parameter})")
        end
      end
195

196 197
      def extract_max_param(upper_cap = 100)
        [values.keys.max, upper_cap].min
198
      end
199 200 201
    end
  end
end