postgresql_adapter.rb 37.1 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2
require 'active_record/connection_adapters/abstract_adapter'

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
begin
  require_library_or_gem 'pg'
rescue LoadError => e
  begin
    require_library_or_gem 'postgres'
    class PGresult
      alias_method :nfields, :num_fields unless self.method_defined?(:nfields)
      alias_method :ntuples, :num_tuples unless self.method_defined?(:ntuples)
      alias_method :ftype, :type unless self.method_defined?(:ftype)
      alias_method :cmd_tuples, :cmdtuples unless self.method_defined?(:cmd_tuples)
    end
  rescue LoadError
    raise e
  end
end

D
Initial  
David Heinemeier Hansson 已提交
19 20 21 22
module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
23
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
24
      host     = config[:host]
25
      port     = config[:port] || 5432
D
Initial  
David Heinemeier Hansson 已提交
26 27 28 29 30 31 32 33 34
      username = config[:username].to_s
      password = config[:password].to_s

      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
36 37 38 39
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
40

41 42 43 44 45 46 47
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
48

49
      private
50
        def extract_limit(sql_type)
51 52 53 54 55 56
          case sql_type
          when /^integer/i;   4
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
57 58
        end

59 60 61 62 63
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
64

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
          # Actual code is defined dynamically in PostgreSQLAdapter.connect
          # depending on the server specifics
          super
        end
  
        # Escapes binary strings for bytea input to the database.
        def self.string_to_binary(value)
          if PGconn.respond_to?(:escape_bytea)
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                PGconn.escape_bytea(value) if value
              end
            end
          else
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                if value
                  result = ''
                  value.each_byte { |c| result << sprintf('\\\\%03o', c) }
                  result
                end
              end
            end
          end
          self.class.string_to_binary(value)
        end
  
        # Unescapes bytea output from a database to the binary string it represents.
        def self.binary_to_string(value)
96
          # In each case, check if the value actually is escaped PostgreSQL bytea output
97 98 99 100
          # or an unescaped Active Record attribute that was just written.
          if PGconn.respond_to?(:unescape_bytea)
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
101
                if value =~ /\\\d{3}/
102 103 104 105 106 107 108 109 110
                  PGconn.unescape_bytea(value)
                else
                  value
                end
              end
            end
          else
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
111
                if value =~ /\\\d{3}/
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
                  result = ''
                  i, max = 0, value.size
                  while i < max
                    char = value[i]
                    if char == ?\\
                      if value[i+1] == ?\\
                        char = ?\\
                        i += 1
                      else
                        char = value[i+1..i+3].oct
                        i += 3
                      end
                    end
                    result << char
                    i += 1
                  end
                  result
                else
                  value
                end
              end
            end
          end
          self.class.binary_to_string(value)
        end  
  
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
            when /^money$/
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
            when /^bytea$/
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
            when /^interval$/
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
            when /^xml$/
              :string
            # Arrays
            when /^\D+\[\]$/
              :string              
            # Object identifier types
            when /^oid$/
              :integer
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
  
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
            # Numeric types
186
            when /\A-?\d+(\.\d*)?\z/
187 188
              default
            # Character types
189
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
190
              $1
191 192 193
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
194
            # Binary data types
195
            when /\A'(.*)'::bytea\z/m
196 197
              $1
            # Date/time types
198
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
199
              $1
200
            when /\A'(.*)'::interval\z/
201 202
              $1
            # Boolean type
203
            when 'true'
204
              true
205
            when 'false'
206 207
              false
            # Geometric types
208
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
209 210
              $1
            # Network address types
211
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
212 213
              $1
            # Bit string types
214
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
215 216
              $1
            # XML type
217
            when /\A'(.*)'::xml\z/m
218 219
              $1
            # Arrays
220
            when /\A'(.*)'::"?\D+"?\[\]\z/
221 222
              $1
            # Object identifier types
223
            when /\A-?\d+\z/
224 225 226
              $1
            else
              # Anything else is blank, some user type, or some function
227
              # and we can't know the value of that, so return nil.
228 229 230
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
231 232 233 234
    end
  end

  module ConnectionAdapters
235 236
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
237 238 239
    #
    # Options:
    #
P
Pratik Naik 已提交
240 241 242 243 244 245 246 247 248
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the <tt>:schema_order</tt> option.
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO <encoding></tt> call on the connection.
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
    # * <tt>:allow_concurrency</tt> - If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
249
    class PostgreSQLAdapter < AbstractAdapter
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
      ADAPTER_NAME = 'PostgreSQL'.freeze

      NATIVE_DATABASE_TYPES = {
        :primary_key => "serial primary key".freeze,
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
        :boolean     => { :name => "boolean" }
      }

267
      # Returns 'PostgreSQL' as adapter name for identification purposes.
268
      def adapter_name
269
        ADAPTER_NAME
270 271
      end

272 273
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
274
        super(connection, logger)
275
        @connection_parameters, @config = connection_parameters, config
276

277
        connect
278 279
      end

280 281 282
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
283
          @connection.status == PGconn::CONNECTION_OK
284
        else
285
          # We're asking the driver, not ActiveRecord, so use @connection.query instead of #query
286
          @connection.query 'SELECT 1'
287 288
          true
        end
289
      # postgres-pr raises a NoMethodError when querying if no connection is available.
290
      rescue PGError, NoMethodError
291
        false
292 293 294 295 296 297
      end

      # Close then reopen the connection.
      def reconnect!
        if @connection.respond_to?(:reset)
          @connection.reset
298
          configure_connection
299 300 301
        else
          disconnect!
          connect
302 303
        end
      end
304

305
      # Close the connection.
306 307 308
      def disconnect!
        @connection.close rescue nil
      end
309

310
      def native_database_types #:nodoc:
311
        NATIVE_DATABASE_TYPES
312
      end
313

314
      # Does PostgreSQL support migrations?
315 316
      def supports_migrations?
        true
317 318
      end

319 320 321 322 323 324 325 326 327 328 329
      # Does PostgreSQL support standard conforming strings?
      def supports_standard_conforming_strings?
        # Temporarily set the client message level above error to prevent unintentional
        # error messages in the logs when working on a PostgreSQL database server that
        # does not support standard conforming strings.
        client_min_messages_old = client_min_messages
        self.client_min_messages = 'panic'

        # postgres-pr does not raise an exception when client_min_messages is set higher
        # than error and "SHOW standard_conforming_strings" fails, but returns an empty
        # PGresult instead.
330
        has_support = query('SHOW standard_conforming_strings')[0][0] rescue false
331 332 333 334
        self.client_min_messages = client_min_messages_old
        has_support
      end

335
      def supports_insert_with_returning?
336
        postgresql_version >= 80200
337 338
      end

339 340
      # Returns the configured supported identifier length supported by PostgreSQL,
      # or report the default of 63 on PostgreSQL 7.x.
341
      def table_alias_length
342
        @table_alias_length ||= (postgresql_version >= 80000 ? query('SHOW max_identifier_length')[0][0].to_i : 63)
343
      end
344

345 346
      # QUOTING ==================================================

347 348
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
349
        if value.kind_of?(String) && column && column.type == :binary
350 351 352 353 354 355 356 357 358 359 360 361 362
          "#{quoted_string_prefix}'#{column.class.string_to_binary(value)}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^xml$/
          "xml '#{quote_string(value)}'"
        elsif value.kind_of?(Numeric) && column && column.sql_type =~ /^money$/
          # Not truly string input, so doesn't require (or allow) escape string syntax.
          "'#{value.to_s}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^bit/
          case value
            when /^[01]*$/
              "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i
              "X'#{value}'" # Hexadecimal notation
          end
363 364 365 366 367
        else
          super
        end
      end

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
      # Quotes strings for use in SQL input in the postgres driver for better performance.
      def quote_string(s) #:nodoc:
        if PGconn.respond_to?(:escape)
          self.class.instance_eval do
            define_method(:quote_string) do |s|
              PGconn.escape(s)
            end
          end
        else
          # There are some incorrectly compiled postgres drivers out there
          # that don't define PGconn.escape.
          self.class.instance_eval do
            undef_method(:quote_string)
          end
        end
        quote_string(s)
      end

      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
388 389 390
        %("#{name}")
      end

391 392 393
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
394 395 396 397 398
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
399 400
      end

401 402
      # REFERENTIAL INTEGRITY ====================================

403 404 405 406 407 408 409
      def supports_disable_referential_integrity?() #:nodoc:
        version = query("SHOW server_version")[0][0].split('.')
        (version[0].to_i >= 8 && version[1].to_i >= 1) ? true : false
      rescue
        return false
      end

410
      def disable_referential_integrity(&block) #:nodoc:
411 412 413
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
414 415
        yield
      ensure
416 417 418
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
419
      end
420 421 422

      # DATABASE STATEMENTS ======================================

423 424 425 426 427 428
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

429
      # Executes an INSERT query and returns the new record's ID
430
      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
431 432 433 434 435 436 437 438 439 440 441 442 443 444
        # Extract the table from the insert sql. Yuck.
        table = sql.split(" ", 4)[2].gsub('"', '')

        # Try an insert with 'returning id' if available (PG >= 8.2)
        if supports_insert_with_returning?
          pk, sequence_name = *pk_and_sequence_for(table) unless pk
          if pk
            id = select_value("#{sql} RETURNING #{quote_column_name(pk)}")
            clear_query_cache
            return id
          end
        end

        # Otherwise, insert then grab last_insert_id.
445 446 447 448 449 450 451 452 453 454 455 456 457 458
        if insert_id = super
          insert_id
        else
          # If neither pk nor sequence name is given, look them up.
          unless pk || sequence_name
            pk, sequence_name = *pk_and_sequence_for(table)
          end

          # If a pk is given, fallback to default sequence name.
          # Don't fetch last insert id for a table without a pk.
          if pk && sequence_name ||= default_sequence_name(table, pk)
            last_insert_id(table, sequence_name)
          end
        end
459 460
      end

461 462 463 464 465 466 467 468 469 470 471 472 473 474
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
        ary = []
        for i in 0...res.ntuples do
          ary << []
          for j in 0...res.nfields do
            ary[i] << res.getvalue(i,j)
          end
        end
        return ary
      end


      # Queries the database and returns the results in an Array-like object
475
      def query(sql, name = nil) #:nodoc:
476 477
        log(sql, name) do
          if @async
478
            res = @connection.async_exec(sql)
479
          else
480
            res = @connection.exec(sql)
481
          end
482
          return result_as_array(res)
483
        end
484 485
      end

486
      # Executes an SQL statement, returning a PGresult object on success
487 488
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
489 490 491 492 493 494 495
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
496 497
      end

498
      # Executes an UPDATE query and returns the number of affected tuples.
499
      def update_sql(sql, name = nil)
500
        super.cmd_tuples
501 502
      end

503 504
      # Begins a transaction.
      def begin_db_transaction
505 506 507
        execute "BEGIN"
      end

508 509
      # Commits a transaction.
      def commit_db_transaction
510 511
        execute "COMMIT"
      end
512

513 514
      # Aborts a transaction.
      def rollback_db_transaction
515 516 517 518 519
        execute "ROLLBACK"
      end

      # SCHEMA STATEMENTS ========================================

520 521 522 523 524
      def recreate_database(name) #:nodoc:
        drop_database(name)
        create_database(name)
      end

525 526 527
      # Create a new PostgreSQL database.  Options include <tt>:owner</tt>, <tt>:template</tt>,
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
            " OWNER = '#{value}'"
          when :template
            " TEMPLATE = #{value}"
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
            " TABLESPACE = #{value}"
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

552
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
553 554 555 556 557 558 559
      end

      # Drops a PostgreSQL database
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
560 561 562 563 564 565 566 567 568
        if postgresql_version >= 80200
          execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
        else
          begin
            execute "DROP DATABASE #{quote_table_name(name)}"
          rescue ActiveRecord::StatementInvalid
            @logger.warn "#{name} database doesn't exist." if @logger
          end
        end
569 570 571
      end


572 573
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
574 575 576 577 578 579 580 581
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

582 583
      # Returns the list of all indexes for a table.
      def indexes(table_name, name = nil)
584 585 586 587
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
           SELECT distinct i.relname, d.indisunique, a.attname
             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
588 589 590 591 592
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
593
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
594
             AND a.attrelid = t.oid
595 596 597 598 599
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

618 619
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
620
        # Limit, precision, and scale are all handled by the superclass.
621 622
        column_definitions(table_name).collect do |name, type, default, notnull|
          PostgreSQLColumn.new(name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
623 624 625
        end
      end

626 627 628 629 630 631
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
632 633
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
634
          @schema_search_path = schema_csv
635
        end
D
Initial  
David Heinemeier Hansson 已提交
636 637
      end

638 639
      # Returns the active schema search path.
      def schema_search_path
640
        @schema_search_path ||= query('SHOW search_path')[0][0]
641
      end
642

643 644 645 646 647 648 649 650 651 652 653 654
      # Returns the current client message level.
      def client_min_messages
        query('SHOW client_min_messages')[0][0]
      end

      # Set the client message level.
      def client_min_messages=(level)
        execute("SET client_min_messages TO '#{level}'")
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
655
        default_pk, default_seq = pk_and_sequence_for(table_name)
656
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
657 658
      end

659 660
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
661 662 663 664 665 666 667
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
668 669
            quoted_sequence = quote_column_name(sequence)

670
            select_value <<-end_sql, 'Reset sequence'
671
              SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
672 673 674 675
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
676 677 678
        end
      end

679 680
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
681 682
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
683
        result = query(<<-end_sql, 'PK and serial sequence')[0]
684
          SELECT attr.attname, seq.relname
685 686 687 688 689 690 691 692 693 694 695 696 697
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{table}'::regclass
698
        end_sql
699 700 701 702 703

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
704
          result = query(<<-end_sql, 'PK and custom sequence')[0]
705 706 707 708 709 710 711
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
712 713 714 715 716 717
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{table}'::regclass
              AND cons.contype = 'p'
718
              AND def.adsrc ~* 'nextval'
719 720
          end_sql
        end
721

722
        # [primary_key, sequence]
723
        [result.first, result.last]
724 725
      rescue
        nil
726 727
      end

728
      # Renames a table.
729
      def rename_table(name, new_name)
730
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
731
      end
732

733 734
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
735
      def add_column(table_name, column_name, type, options = {})
736 737 738 739
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
740
        execute("ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}")
741

742 743
        change_column_default(table_name, column_name, default) if options_include_default?(options)
        change_column_null(table_name, column_name, false, default) if notnull
S
Scott Barron 已提交
744
      end
D
Initial  
David Heinemeier Hansson 已提交
745

746 747
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
748 749
        quoted_table_name = quote_table_name(table_name)

750
        begin
751
          execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
752
        rescue ActiveRecord::StatementInvalid
753
          # This is PostgreSQL 7.x, so we have to use a more arcane way of doing it.
754 755 756 757 758 759 760 761 762 763 764
          begin
            begin_db_transaction
            tmp_column_name = "#{column_name}_ar_tmp"
            add_column(table_name, tmp_column_name, type, options)
            execute "UPDATE #{quoted_table_name} SET #{quote_column_name(tmp_column_name)} = CAST(#{quote_column_name(column_name)} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
            remove_column(table_name, column_name)
            rename_column(table_name, tmp_column_name, column_name)
            commit_db_transaction
          rescue
            rollback_db_transaction
          end
765
        end
766

767 768
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
769
      end
770

771 772
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
773
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
774
      end
775

776 777
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
778
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
779
        end
780
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
781 782
      end

783 784
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
785
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
786
      end
787

788 789
      # Drops an index from a table.
      def remove_index(table_name, options = {})
790
        execute "DROP INDEX #{index_name(table_name, options)}"
791
      end
792

793 794
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
795 796
        return super unless type.to_s == 'integer'

797 798 799 800
        case limit
          when 1..2;      'smallint'
          when 3..4, nil; 'integer'
          when 5..8;      'bigint'
801
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
802 803
        end
      end
804

805
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
806 807 808
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
809
      #
810
      #   distinct("posts.id", "posts.created_at desc")
811
      def distinct(columns, order_by) #:nodoc:
812 813
        return "DISTINCT #{columns}" if order_by.blank?

814 815
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
816
        order_columns = order_by.split(',').collect { |s| s.split.first }
817
        order_columns.delete_if &:blank?
818
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
819

820 821
        # Return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly.
822 823
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
824
      end
825
      
826
      # Returns an ORDER BY clause for the passed order option.
827 828
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
P
Pratik Naik 已提交
829
      # by wrapping the +sql+ string as a sub-select and ordering in that query.
830
      def add_order_by_for_association_limiting!(sql, options) #:nodoc:
831 832 833 834 835 836 837 838
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
839

840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856
      protected
        # Returns the version of the connected PostgreSQL version.
        def postgresql_version
          @postgresql_version ||=
            if @connection.respond_to?(:server_version)
              @connection.server_version
            else
              # Mimic PGconn.server_version behavior
              begin
                query('SELECT version()')[0][0] =~ /PostgreSQL (\d+)\.(\d+)\.(\d+)/
                ($1.to_i * 10000) + ($2.to_i * 100) + $3.to_i
              rescue
                0
              end
            end
        end

D
Initial  
David Heinemeier Hansson 已提交
857
      private
858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879
        # The internal PostgreSQL identifer of the money data type.
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)
          PGconn.translate_results = false if PGconn.respond_to?(:translate_results=)

          # Ignore async_exec and async_query when using postgres-pr.
          @async = @config[:allow_concurrency] && @connection.respond_to?(:async_exec)

          # Use escape string syntax if available. We cannot do this lazily when encountering
          # the first string, because that could then break any transactions in progress.
          # See: http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
          # If PostgreSQL doesn't know the standard_conforming_strings parameter then it doesn't
          # support escape string syntax. Don't override the inherited quoted_string_prefix.
          if supports_standard_conforming_strings?
            self.class.instance_eval do
              define_method(:quoted_string_prefix) { 'E' }
            end
          end
880

881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
          money_precision = (postgresql_version >= 80300) ? 19 : 10
          PostgreSQLColumn.module_eval(<<-end_eval)
            def extract_precision(sql_type)
              if sql_type =~ /^money$/
                #{money_precision}
              else
                super
              end
            end
          end_eval

          configure_connection
        end

        # Configures the encoding, verbosity, and schema search path of the connection.
        # This is called by #connect and should not be called manually.
900 901
        def configure_connection
          if @config[:encoding]
902 903 904 905 906
            if @connection.respond_to?(:set_client_encoding)
              @connection.set_client_encoding(@config[:encoding])
            else
              execute("SET client_encoding TO '#{@config[:encoding]}'")
            end
907
          end
908 909
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
910 911
        end

912 913
        # Returns the current ID of a table's sequence.
        def last_insert_id(table, sequence_name) #:nodoc:
914
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
915 916
        end

917
        # Executes a SELECT query and returns the results, performing any data type
918
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
D
Initial  
David Heinemeier Hansson 已提交
919
        def select(sql, name = nil)
920 921 922 923 924 925 926 927 928 929 930 931 932
          fields, rows = select_raw(sql, name)
          result = []
          for row in rows
            row_hash = {}
            fields.each_with_index do |f, i|
              row_hash[f] = row[i]
            end
            result << row_hash
          end
          result
        end

        def select_raw(sql, name = nil)
933
          res = execute(sql, name)
934
          results = result_as_array(res)
935
          fields = []
M
Marcel Molina 已提交
936
          rows = []
937
          if res.ntuples > 0
M
Marcel Molina 已提交
938 939 940
            fields = res.fields
            results.each do |row|
              hashed_row = {}
941 942 943
              row.each_index do |cell_index|
                # If this is a money type column and there are any currency symbols,
                # then strip them off. Indeed it would be prettier to do this in
944
                # PostgreSQLColumn.string_to_decimal but would break form input
945
                # fields that call value_before_type_cast.
946
                if res.ftype(cell_index) == MONEY_COLUMN_TYPE_OID
947
                  # Because money output is formatted according to the locale, there are two
948
                  # cases to consider (note the decimal separators):
949 950
                  #  (1) $12,345,678.12        
                  #  (2) $12.345.678,12
951
                  case column = row[cell_index]
952
                    when /^-?\D+[\d,]+\.\d{2}$/  # (1)
953
                      row[cell_index] = column.gsub(/[^-\d\.]/, '')
954
                    when /^-?\D+[\d\.]+,\d{2}$/  # (2)
955
                      row[cell_index] = column.gsub(/[^-\d,]/, '').sub(/,/, '.')
956
                  end
M
Marcel Molina 已提交
957
                end
958

959
                hashed_row[fields[cell_index]] = column
M
Marcel Molina 已提交
960
              end
961
              rows << row
M
Marcel Molina 已提交
962 963
            end
          end
964
          res.clear
965
          return fields, rows
M
Marcel Molina 已提交
966 967
        end

968
        # Returns the list of a table's column names, data types, and default values.
969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
986
        def column_definitions(table_name) #:nodoc:
987
          query <<-end_sql
988
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
989 990 991 992 993 994
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
             WHERE a.attrelid = '#{table_name}'::regclass
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
995 996 997 998
        end
    end
  end
end