postgresql_adapter.rb 35.4 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2
require 'active_record/connection_adapters/abstract_adapter'

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
begin
  require_library_or_gem 'pg'
rescue LoadError => e
  begin
    require_library_or_gem 'postgres'
    class PGresult
      alias_method :nfields, :num_fields unless self.method_defined?(:nfields)
      alias_method :ntuples, :num_tuples unless self.method_defined?(:ntuples)
      alias_method :ftype, :type unless self.method_defined?(:ftype)
      alias_method :cmd_tuples, :cmdtuples unless self.method_defined?(:cmd_tuples)
    end
  rescue LoadError
    raise e
  end
end

D
Initial  
David Heinemeier Hansson 已提交
19 20 21 22
module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
23
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
24
      host     = config[:host]
25
      port     = config[:port] || 5432
D
Initial  
David Heinemeier Hansson 已提交
26 27 28 29 30 31 32 33 34
      username = config[:username].to_s
      password = config[:password].to_s

      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
36 37 38 39
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
40

41 42 43 44 45 46 47
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
48

49 50 51 52 53 54
      private
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
          # Actual code is defined dynamically in PostgreSQLAdapter.connect
          # depending on the server specifics
          super
        end
  
        # Escapes binary strings for bytea input to the database.
        def self.string_to_binary(value)
          if PGconn.respond_to?(:escape_bytea)
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                PGconn.escape_bytea(value) if value
              end
            end
          else
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                if value
                  result = ''
                  value.each_byte { |c| result << sprintf('\\\\%03o', c) }
                  result
                end
              end
            end
          end
          self.class.string_to_binary(value)
        end
  
        # Unescapes bytea output from a database to the binary string it represents.
        def self.binary_to_string(value)
87
          # In each case, check if the value actually is escaped PostgreSQL bytea output
88 89 90 91
          # or an unescaped Active Record attribute that was just written.
          if PGconn.respond_to?(:unescape_bytea)
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
92
                if value =~ /\\\d{3}/
93 94 95 96 97 98 99 100 101
                  PGconn.unescape_bytea(value)
                else
                  value
                end
              end
            end
          else
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
102
                if value =~ /\\\d{3}/
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
                  result = ''
                  i, max = 0, value.size
                  while i < max
                    char = value[i]
                    if char == ?\\
                      if value[i+1] == ?\\
                        char = ?\\
                        i += 1
                      else
                        char = value[i+1..i+3].oct
                        i += 3
                      end
                    end
                    result << char
                    i += 1
                  end
                  result
                else
                  value
                end
              end
            end
          end
          self.class.binary_to_string(value)
        end  
  
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
            when /^money$/
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
            when /^bytea$/
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
            when /^interval$/
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
            when /^xml$/
              :string
            # Arrays
            when /^\D+\[\]$/
              :string              
            # Object identifier types
            when /^oid$/
              :integer
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
  
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
            # Numeric types
177
            when /\A-?\d+(\.\d*)?\z/
178 179
              default
            # Character types
180
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
181
              $1
182 183 184
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
185
            # Binary data types
186
            when /\A'(.*)'::bytea\z/m
187 188
              $1
            # Date/time types
189
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
190
              $1
191
            when /\A'(.*)'::interval\z/
192 193
              $1
            # Boolean type
194
            when 'true'
195
              true
196
            when 'false'
197 198
              false
            # Geometric types
199
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
200 201
              $1
            # Network address types
202
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
203 204
              $1
            # Bit string types
205
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
206 207
              $1
            # XML type
208
            when /\A'(.*)'::xml\z/m
209 210
              $1
            # Arrays
211
            when /\A'(.*)'::"?\D+"?\[\]\z/
212 213
              $1
            # Object identifier types
214
            when /\A-?\d+\z/
215 216 217
              $1
            else
              # Anything else is blank, some user type, or some function
218
              # and we can't know the value of that, so return nil.
219 220 221
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
222 223 224 225
    end
  end

  module ConnectionAdapters
226 227
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
228 229 230 231 232 233 234 235
    #
    # Options:
    #
    # * <tt>:host</tt> -- Defaults to localhost
    # * <tt>:port</tt> -- Defaults to 5432
    # * <tt>:username</tt> -- Defaults to nothing
    # * <tt>:password</tt> -- Defaults to nothing
    # * <tt>:database</tt> -- The name of the database. No default, must be provided.
236
    # * <tt>:schema_search_path</tt> -- An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the :schema_order option.
237 238
    # * <tt>:encoding</tt> -- An optional client encoding that is used in a SET client_encoding TO <encoding> call on the connection.
    # * <tt>:min_messages</tt> -- An optional client min messages that is used in a SET client_min_messages TO <min_messages> call on the connection.
239
    # * <tt>:allow_concurrency</tt> -- If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
240
    class PostgreSQLAdapter < AbstractAdapter
241
      # Returns 'PostgreSQL' as adapter name for identification purposes.
242 243 244 245
      def adapter_name
        'PostgreSQL'
      end

246 247
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
248
        super(connection, logger)
249
        @connection_parameters, @config = connection_parameters, config
250

251
        connect
252 253
      end

254 255 256
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
257
          @connection.status == PGconn::CONNECTION_OK
258
        else
259
          # We're asking the driver, not ActiveRecord, so use @connection.query instead of #query
260
          @connection.query 'SELECT 1'
261 262
          true
        end
263
      # postgres-pr raises a NoMethodError when querying if no connection is available.
264
      rescue PGError, NoMethodError
265
        false
266 267 268 269 270 271
      end

      # Close then reopen the connection.
      def reconnect!
        if @connection.respond_to?(:reset)
          @connection.reset
272
          configure_connection
273 274 275
        else
          disconnect!
          connect
276 277
        end
      end
278

279
      # Close the connection.
280 281 282
      def disconnect!
        @connection.close rescue nil
      end
283

284
      def native_database_types #:nodoc:
285 286 287 288 289 290
        {
          :primary_key => "serial primary key",
          :string      => { :name => "character varying", :limit => 255 },
          :text        => { :name => "text" },
          :integer     => { :name => "integer" },
          :float       => { :name => "float" },
291
          :decimal     => { :name => "decimal" },
292 293
          :datetime    => { :name => "timestamp" },
          :timestamp   => { :name => "timestamp" },
294
          :time        => { :name => "time" },
295 296
          :date        => { :name => "date" },
          :binary      => { :name => "bytea" },
297
          :boolean     => { :name => "boolean" }
298 299
        }
      end
300

301
      # Does PostgreSQL support migrations?
302 303
      def supports_migrations?
        true
304 305
      end

306 307 308 309 310 311 312 313 314 315 316
      # Does PostgreSQL support standard conforming strings?
      def supports_standard_conforming_strings?
        # Temporarily set the client message level above error to prevent unintentional
        # error messages in the logs when working on a PostgreSQL database server that
        # does not support standard conforming strings.
        client_min_messages_old = client_min_messages
        self.client_min_messages = 'panic'

        # postgres-pr does not raise an exception when client_min_messages is set higher
        # than error and "SHOW standard_conforming_strings" fails, but returns an empty
        # PGresult instead.
317
        has_support = query('SHOW standard_conforming_strings')[0][0] rescue false
318 319 320 321 322 323
        self.client_min_messages = client_min_messages_old
        has_support
      end

      # Returns the configured supported identifier length supported by PostgreSQL,
      # or report the default of 63 on PostgreSQL 7.x.
324
      def table_alias_length
325
        @table_alias_length ||= (postgresql_version >= 80000 ? query('SHOW max_identifier_length')[0][0].to_i : 63)
326
      end
327

328 329
      # QUOTING ==================================================

330 331
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
332
        if value.kind_of?(String) && column && column.type == :binary
333 334 335 336 337 338 339 340 341 342 343 344 345
          "#{quoted_string_prefix}'#{column.class.string_to_binary(value)}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^xml$/
          "xml '#{quote_string(value)}'"
        elsif value.kind_of?(Numeric) && column && column.sql_type =~ /^money$/
          # Not truly string input, so doesn't require (or allow) escape string syntax.
          "'#{value.to_s}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^bit/
          case value
            when /^[01]*$/
              "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i
              "X'#{value}'" # Hexadecimal notation
          end
346 347 348 349 350
        else
          super
        end
      end

351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
      # Quotes strings for use in SQL input in the postgres driver for better performance.
      def quote_string(s) #:nodoc:
        if PGconn.respond_to?(:escape)
          self.class.instance_eval do
            define_method(:quote_string) do |s|
              PGconn.escape(s)
            end
          end
        else
          # There are some incorrectly compiled postgres drivers out there
          # that don't define PGconn.escape.
          self.class.instance_eval do
            undef_method(:quote_string)
          end
        end
        quote_string(s)
      end

      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
371 372 373
        %("#{name}")
      end

374 375 376
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
377 378 379 380 381
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
382 383
      end

384 385
      # REFERENTIAL INTEGRITY ====================================

386 387 388 389 390 391 392
      def supports_disable_referential_integrity?() #:nodoc:
        version = query("SHOW server_version")[0][0].split('.')
        (version[0].to_i >= 8 && version[1].to_i >= 1) ? true : false
      rescue
        return false
      end

393
      def disable_referential_integrity(&block) #:nodoc:
394 395 396
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
397 398
        yield
      ensure
399 400 401
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
402
      end
403 404 405

      # DATABASE STATEMENTS ======================================

406 407 408 409 410 411
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

412
      # Executes an INSERT query and returns the new record's ID
413
      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
414 415
        table = sql.split(" ", 4)[2].gsub('"', '')
        super || pk && last_insert_id(table, sequence_name || default_sequence_name(table, pk))
416 417
      end

418 419 420 421 422 423 424 425 426 427 428 429 430 431
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
        ary = []
        for i in 0...res.ntuples do
          ary << []
          for j in 0...res.nfields do
            ary[i] << res.getvalue(i,j)
          end
        end
        return ary
      end


      # Queries the database and returns the results in an Array-like object
432
      def query(sql, name = nil) #:nodoc:
433 434
        log(sql, name) do
          if @async
435
            res = @connection.async_exec(sql)
436
          else
437
            res = @connection.exec(sql)
438
          end
439
          return result_as_array(res)
440
        end
441 442
      end

443
      # Executes an SQL statement, returning a PGresult object on success
444 445
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
446 447 448 449 450 451 452
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
453 454
      end

455
      # Executes an UPDATE query and returns the number of affected tuples.
456
      def update_sql(sql, name = nil)
457
        super.cmd_tuples
458 459
      end

460 461
      # Begins a transaction.
      def begin_db_transaction
462 463 464
        execute "BEGIN"
      end

465 466
      # Commits a transaction.
      def commit_db_transaction
467 468
        execute "COMMIT"
      end
469

470 471
      # Aborts a transaction.
      def rollback_db_transaction
472 473 474 475 476
        execute "ROLLBACK"
      end

      # SCHEMA STATEMENTS ========================================

477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
      def recreate_database(name) #:nodoc:
        drop_database(name)
        create_database(name)
      end

      # Create a new PostgreSQL database.  Options include :owner, :template,
      # :encoding, :tablespace, and :connection_limit (note that MySQL uses
      # :charset while PostgreSQL uses :encoding).
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
            " OWNER = '#{value}'"
          when :template
            " TEMPLATE = #{value}"
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
            " TABLESPACE = #{value}"
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

        execute "CREATE DATABASE #{name}#{option_string}"
      end

      # Drops a PostgreSQL database
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
        execute "DROP DATABASE IF EXISTS #{name}"
      end


521 522
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
523 524 525 526 527 528 529 530
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

531 532
      # Returns the list of all indexes for a table.
      def indexes(table_name, name = nil)
533 534 535 536
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
           SELECT distinct i.relname, d.indisunique, a.attname
             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
537 538 539 540 541
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
542
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
543
             AND a.attrelid = t.oid
544 545 546 547 548
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

567 568
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
569
        # Limit, precision, and scale are all handled by the superclass.
570 571
        column_definitions(table_name).collect do |name, type, default, notnull|
          PostgreSQLColumn.new(name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
572 573 574
        end
      end

575 576 577 578 579 580
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
581 582
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
583
          @schema_search_path = schema_csv
584
        end
D
Initial  
David Heinemeier Hansson 已提交
585 586
      end

587 588
      # Returns the active schema search path.
      def schema_search_path
589
        @schema_search_path ||= query('SHOW search_path')[0][0]
590
      end
591

592 593 594 595 596 597 598 599 600 601 602 603
      # Returns the current client message level.
      def client_min_messages
        query('SHOW client_min_messages')[0][0]
      end

      # Set the client message level.
      def client_min_messages=(level)
        execute("SET client_min_messages TO '#{level}'")
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
604
        default_pk, default_seq = pk_and_sequence_for(table_name)
605
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
606 607
      end

608 609
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
610 611 612 613 614 615 616
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
617 618
            quoted_sequence = quote_column_name(sequence)

619
            select_value <<-end_sql, 'Reset sequence'
620
              SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
621 622 623 624
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
625 626 627
        end
      end

628 629
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
630 631
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
632
        result = query(<<-end_sql, 'PK and serial sequence')[0]
633
          SELECT attr.attname, seq.relname
634 635 636 637 638 639 640 641 642 643 644 645 646
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{table}'::regclass
647
        end_sql
648 649 650 651 652

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
653
          result = query(<<-end_sql, 'PK and custom sequence')[0]
654 655 656 657 658 659 660
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
661 662 663 664 665 666
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{table}'::regclass
              AND cons.contype = 'p'
667
              AND def.adsrc ~* 'nextval'
668 669
          end_sql
        end
670

671
        # [primary_key, sequence]
672
        [result.first, result.last]
673 674
      rescue
        nil
675 676
      end

677
      # Renames a table.
678 679 680
      def rename_table(name, new_name)
        execute "ALTER TABLE #{name} RENAME TO #{new_name}"
      end
681

682 683
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
684
      def add_column(table_name, column_name, type, options = {})
685 686 687 688
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
689
        execute("ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}")
690

691 692
        change_column_default(table_name, column_name, default) if options_include_default?(options)
        change_column_null(table_name, column_name, false, default) if notnull
S
Scott Barron 已提交
693
      end
D
Initial  
David Heinemeier Hansson 已提交
694

695 696
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
697 698
        quoted_table_name = quote_table_name(table_name)

699
        begin
700
          execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
701
        rescue ActiveRecord::StatementInvalid
702
          # This is PostgreSQL 7.x, so we have to use a more arcane way of doing it.
703 704 705 706 707 708 709 710 711 712 713
          begin
            begin_db_transaction
            tmp_column_name = "#{column_name}_ar_tmp"
            add_column(table_name, tmp_column_name, type, options)
            execute "UPDATE #{quoted_table_name} SET #{quote_column_name(tmp_column_name)} = CAST(#{quote_column_name(column_name)} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
            remove_column(table_name, column_name)
            rename_column(table_name, tmp_column_name, column_name)
            commit_db_transaction
          rescue
            rollback_db_transaction
          end
714
        end
715

716 717
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
718
      end
719

720 721
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
722
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
723
      end
724

725 726
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
727
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
728
        end
729
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
730 731
      end

732 733
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
734
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
735
      end
736

737 738
      # Drops an index from a table.
      def remove_index(table_name, options = {})
739
        execute "DROP INDEX #{index_name(table_name, options)}"
740
      end
741

742 743
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
744 745 746 747 748 749 750 751 752 753
        return super unless type.to_s == 'integer'

        if limit.nil? || limit == 4
          'integer'
        elsif limit < 4
          'smallint'
        else
          'bigint'
        end
      end
754
      
755
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
756 757 758
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
759
      #
760
      #   distinct("posts.id", "posts.created_at desc")
761
      def distinct(columns, order_by) #:nodoc:
762 763
        return "DISTINCT #{columns}" if order_by.blank?

764 765
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
766
        order_columns = order_by.split(',').collect { |s| s.split.first }
767
        order_columns.delete_if &:blank?
768
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
769

770 771
        # Return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly.
772 773
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
774
      end
775
      
776
      # Returns an ORDER BY clause for the passed order option.
777 778 779
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
      # by wrapping the sql as a sub-select and ordering in that query.
780
      def add_order_by_for_association_limiting!(sql, options) #:nodoc:
781 782 783 784 785 786 787 788
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
789

790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
      protected
        # Returns the version of the connected PostgreSQL version.
        def postgresql_version
          @postgresql_version ||=
            if @connection.respond_to?(:server_version)
              @connection.server_version
            else
              # Mimic PGconn.server_version behavior
              begin
                query('SELECT version()')[0][0] =~ /PostgreSQL (\d+)\.(\d+)\.(\d+)/
                ($1.to_i * 10000) + ($2.to_i * 100) + $3.to_i
              rescue
                0
              end
            end
        end

D
Initial  
David Heinemeier Hansson 已提交
807
      private
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
        # The internal PostgreSQL identifer of the money data type.
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)
          PGconn.translate_results = false if PGconn.respond_to?(:translate_results=)

          # Ignore async_exec and async_query when using postgres-pr.
          @async = @config[:allow_concurrency] && @connection.respond_to?(:async_exec)

          # Use escape string syntax if available. We cannot do this lazily when encountering
          # the first string, because that could then break any transactions in progress.
          # See: http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
          # If PostgreSQL doesn't know the standard_conforming_strings parameter then it doesn't
          # support escape string syntax. Don't override the inherited quoted_string_prefix.
          if supports_standard_conforming_strings?
            self.class.instance_eval do
              define_method(:quoted_string_prefix) { 'E' }
            end
          end
830

831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
          money_precision = (postgresql_version >= 80300) ? 19 : 10
          PostgreSQLColumn.module_eval(<<-end_eval)
            def extract_precision(sql_type)
              if sql_type =~ /^money$/
                #{money_precision}
              else
                super
              end
            end
          end_eval

          configure_connection
        end

        # Configures the encoding, verbosity, and schema search path of the connection.
        # This is called by #connect and should not be called manually.
850 851
        def configure_connection
          if @config[:encoding]
852 853 854 855 856
            if @connection.respond_to?(:set_client_encoding)
              @connection.set_client_encoding(@config[:encoding])
            else
              execute("SET client_encoding TO '#{@config[:encoding]}'")
            end
857
          end
858 859
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
860 861
        end

862 863
        # Returns the current ID of a table's sequence.
        def last_insert_id(table, sequence_name) #:nodoc:
864
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
865 866
        end

867
        # Executes a SELECT query and returns the results, performing any data type
868
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
D
Initial  
David Heinemeier Hansson 已提交
869
        def select(sql, name = nil)
870 871 872 873 874 875 876 877 878 879 880 881 882
          fields, rows = select_raw(sql, name)
          result = []
          for row in rows
            row_hash = {}
            fields.each_with_index do |f, i|
              row_hash[f] = row[i]
            end
            result << row_hash
          end
          result
        end

        def select_raw(sql, name = nil)
883
          res = execute(sql, name)
884
          results = result_as_array(res)
885
          fields = []
M
Marcel Molina 已提交
886
          rows = []
887
          if res.ntuples > 0
M
Marcel Molina 已提交
888 889 890
            fields = res.fields
            results.each do |row|
              hashed_row = {}
891 892 893
              row.each_index do |cell_index|
                # If this is a money type column and there are any currency symbols,
                # then strip them off. Indeed it would be prettier to do this in
894
                # PostgreSQLColumn.string_to_decimal but would break form input
895
                # fields that call value_before_type_cast.
896
                if res.ftype(cell_index) == MONEY_COLUMN_TYPE_OID
897
                  # Because money output is formatted according to the locale, there are two
898
                  # cases to consider (note the decimal separators):
899 900
                  #  (1) $12,345,678.12        
                  #  (2) $12.345.678,12
901
                  case column = row[cell_index]
902
                    when /^-?\D+[\d,]+\.\d{2}$/  # (1)
903
                      row[cell_index] = column.gsub(/[^-\d\.]/, '')
904
                    when /^-?\D+[\d\.]+,\d{2}$/  # (2)
905
                      row[cell_index] = column.gsub(/[^-\d,]/, '').sub(/,/, '.')
906
                  end
M
Marcel Molina 已提交
907
                end
908

909
                hashed_row[fields[cell_index]] = column
M
Marcel Molina 已提交
910
              end
911
              rows << row
M
Marcel Molina 已提交
912 913
            end
          end
914
          res.clear
915
          return fields, rows
M
Marcel Molina 已提交
916 917
        end

918
        # Returns the list of a table's column names, data types, and default values.
919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
936
        def column_definitions(table_name) #:nodoc:
937
          query <<-end_sql
938
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
939 940 941 942 943 944
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
             WHERE a.attrelid = '#{table_name}'::regclass
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
945 946 947 948
        end
    end
  end
end