postgresql_adapter.rb 37.0 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2
require 'active_record/connection_adapters/abstract_adapter'

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
begin
  require_library_or_gem 'pg'
rescue LoadError => e
  begin
    require_library_or_gem 'postgres'
    class PGresult
      alias_method :nfields, :num_fields unless self.method_defined?(:nfields)
      alias_method :ntuples, :num_tuples unless self.method_defined?(:ntuples)
      alias_method :ftype, :type unless self.method_defined?(:ftype)
      alias_method :cmd_tuples, :cmdtuples unless self.method_defined?(:cmd_tuples)
    end
  rescue LoadError
    raise e
  end
end

D
Initial  
David Heinemeier Hansson 已提交
19 20 21 22
module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
23
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
24
      host     = config[:host]
25
      port     = config[:port] || 5432
D
Initial  
David Heinemeier Hansson 已提交
26 27 28 29 30 31 32 33 34
      username = config[:username].to_s
      password = config[:password].to_s

      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
36 37 38 39
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
40

41 42 43 44 45 46 47
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
48

49
      private
50 51 52 53 54 55
        def extract_limit(sql_type)
          return 8 if sql_type =~ /^bigint/i
          return 2 if sql_type =~ /^smallint/i
          super
        end

56 57 58 59 60
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
61

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
          # Actual code is defined dynamically in PostgreSQLAdapter.connect
          # depending on the server specifics
          super
        end
  
        # Escapes binary strings for bytea input to the database.
        def self.string_to_binary(value)
          if PGconn.respond_to?(:escape_bytea)
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                PGconn.escape_bytea(value) if value
              end
            end
          else
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                if value
                  result = ''
                  value.each_byte { |c| result << sprintf('\\\\%03o', c) }
                  result
                end
              end
            end
          end
          self.class.string_to_binary(value)
        end
  
        # Unescapes bytea output from a database to the binary string it represents.
        def self.binary_to_string(value)
93
          # In each case, check if the value actually is escaped PostgreSQL bytea output
94 95 96 97
          # or an unescaped Active Record attribute that was just written.
          if PGconn.respond_to?(:unescape_bytea)
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
98
                if value =~ /\\\d{3}/
99 100 101 102 103 104 105 106 107
                  PGconn.unescape_bytea(value)
                else
                  value
                end
              end
            end
          else
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
108
                if value =~ /\\\d{3}/
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
                  result = ''
                  i, max = 0, value.size
                  while i < max
                    char = value[i]
                    if char == ?\\
                      if value[i+1] == ?\\
                        char = ?\\
                        i += 1
                      else
                        char = value[i+1..i+3].oct
                        i += 3
                      end
                    end
                    result << char
                    i += 1
                  end
                  result
                else
                  value
                end
              end
            end
          end
          self.class.binary_to_string(value)
        end  
  
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
            when /^money$/
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
            when /^bytea$/
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
            when /^interval$/
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
            when /^xml$/
              :string
            # Arrays
            when /^\D+\[\]$/
              :string              
            # Object identifier types
            when /^oid$/
              :integer
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
  
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
            # Numeric types
183
            when /\A-?\d+(\.\d*)?\z/
184 185
              default
            # Character types
186
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
187
              $1
188 189 190
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
191
            # Binary data types
192
            when /\A'(.*)'::bytea\z/m
193 194
              $1
            # Date/time types
195
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
196
              $1
197
            when /\A'(.*)'::interval\z/
198 199
              $1
            # Boolean type
200
            when 'true'
201
              true
202
            when 'false'
203 204
              false
            # Geometric types
205
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
206 207
              $1
            # Network address types
208
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
209 210
              $1
            # Bit string types
211
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
212 213
              $1
            # XML type
214
            when /\A'(.*)'::xml\z/m
215 216
              $1
            # Arrays
217
            when /\A'(.*)'::"?\D+"?\[\]\z/
218 219
              $1
            # Object identifier types
220
            when /\A-?\d+\z/
221 222 223
              $1
            else
              # Anything else is blank, some user type, or some function
224
              # and we can't know the value of that, so return nil.
225 226 227
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
228 229 230 231
    end
  end

  module ConnectionAdapters
232 233
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
234 235 236
    #
    # Options:
    #
P
Pratik Naik 已提交
237 238 239 240 241 242 243 244 245
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the <tt>:schema_order</tt> option.
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO <encoding></tt> call on the connection.
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
    # * <tt>:allow_concurrency</tt> - If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
246
    class PostgreSQLAdapter < AbstractAdapter
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
      ADAPTER_NAME = 'PostgreSQL'.freeze

      NATIVE_DATABASE_TYPES = {
        :primary_key => "serial primary key".freeze,
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
        :boolean     => { :name => "boolean" }
      }

264
      # Returns 'PostgreSQL' as adapter name for identification purposes.
265
      def adapter_name
266
        ADAPTER_NAME
267 268
      end

269 270
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
271
        super(connection, logger)
272
        @connection_parameters, @config = connection_parameters, config
273

274
        connect
275 276
      end

277 278 279
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
280
          @connection.status == PGconn::CONNECTION_OK
281
        else
282
          # We're asking the driver, not ActiveRecord, so use @connection.query instead of #query
283
          @connection.query 'SELECT 1'
284 285
          true
        end
286
      # postgres-pr raises a NoMethodError when querying if no connection is available.
287
      rescue PGError, NoMethodError
288
        false
289 290 291 292 293 294
      end

      # Close then reopen the connection.
      def reconnect!
        if @connection.respond_to?(:reset)
          @connection.reset
295
          configure_connection
296 297 298
        else
          disconnect!
          connect
299 300
        end
      end
301

302
      # Close the connection.
303 304 305
      def disconnect!
        @connection.close rescue nil
      end
306

307
      def native_database_types #:nodoc:
308
        NATIVE_DATABASE_TYPES
309
      end
310

311
      # Does PostgreSQL support migrations?
312 313
      def supports_migrations?
        true
314 315
      end

316 317 318 319 320 321 322 323 324 325 326
      # Does PostgreSQL support standard conforming strings?
      def supports_standard_conforming_strings?
        # Temporarily set the client message level above error to prevent unintentional
        # error messages in the logs when working on a PostgreSQL database server that
        # does not support standard conforming strings.
        client_min_messages_old = client_min_messages
        self.client_min_messages = 'panic'

        # postgres-pr does not raise an exception when client_min_messages is set higher
        # than error and "SHOW standard_conforming_strings" fails, but returns an empty
        # PGresult instead.
327
        has_support = query('SHOW standard_conforming_strings')[0][0] rescue false
328 329 330 331
        self.client_min_messages = client_min_messages_old
        has_support
      end

332
      def supports_insert_with_returning?
333
        postgresql_version >= 80200
334 335
      end

336 337
      # Returns the configured supported identifier length supported by PostgreSQL,
      # or report the default of 63 on PostgreSQL 7.x.
338
      def table_alias_length
339
        @table_alias_length ||= (postgresql_version >= 80000 ? query('SHOW max_identifier_length')[0][0].to_i : 63)
340
      end
341

342 343
      # QUOTING ==================================================

344 345
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
346
        if value.kind_of?(String) && column && column.type == :binary
347 348 349 350 351 352 353 354 355 356 357 358 359
          "#{quoted_string_prefix}'#{column.class.string_to_binary(value)}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^xml$/
          "xml '#{quote_string(value)}'"
        elsif value.kind_of?(Numeric) && column && column.sql_type =~ /^money$/
          # Not truly string input, so doesn't require (or allow) escape string syntax.
          "'#{value.to_s}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^bit/
          case value
            when /^[01]*$/
              "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i
              "X'#{value}'" # Hexadecimal notation
          end
360 361 362 363 364
        else
          super
        end
      end

365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
      # Quotes strings for use in SQL input in the postgres driver for better performance.
      def quote_string(s) #:nodoc:
        if PGconn.respond_to?(:escape)
          self.class.instance_eval do
            define_method(:quote_string) do |s|
              PGconn.escape(s)
            end
          end
        else
          # There are some incorrectly compiled postgres drivers out there
          # that don't define PGconn.escape.
          self.class.instance_eval do
            undef_method(:quote_string)
          end
        end
        quote_string(s)
      end

      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
385 386 387
        %("#{name}")
      end

388 389 390
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
391 392 393 394 395
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
396 397
      end

398 399
      # REFERENTIAL INTEGRITY ====================================

400 401 402 403 404 405 406
      def supports_disable_referential_integrity?() #:nodoc:
        version = query("SHOW server_version")[0][0].split('.')
        (version[0].to_i >= 8 && version[1].to_i >= 1) ? true : false
      rescue
        return false
      end

407
      def disable_referential_integrity(&block) #:nodoc:
408 409 410
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
411 412
        yield
      ensure
413 414 415
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
416
      end
417 418 419

      # DATABASE STATEMENTS ======================================

420 421 422 423 424 425
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

426
      # Executes an INSERT query and returns the new record's ID
427
      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
428 429 430 431 432 433 434 435 436 437 438 439 440 441
        # Extract the table from the insert sql. Yuck.
        table = sql.split(" ", 4)[2].gsub('"', '')

        # Try an insert with 'returning id' if available (PG >= 8.2)
        if supports_insert_with_returning?
          pk, sequence_name = *pk_and_sequence_for(table) unless pk
          if pk
            id = select_value("#{sql} RETURNING #{quote_column_name(pk)}")
            clear_query_cache
            return id
          end
        end

        # Otherwise, insert then grab last_insert_id.
442 443 444 445 446 447 448 449 450 451 452 453 454 455
        if insert_id = super
          insert_id
        else
          # If neither pk nor sequence name is given, look them up.
          unless pk || sequence_name
            pk, sequence_name = *pk_and_sequence_for(table)
          end

          # If a pk is given, fallback to default sequence name.
          # Don't fetch last insert id for a table without a pk.
          if pk && sequence_name ||= default_sequence_name(table, pk)
            last_insert_id(table, sequence_name)
          end
        end
456 457
      end

458 459 460 461 462 463 464 465 466 467 468 469 470 471
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
        ary = []
        for i in 0...res.ntuples do
          ary << []
          for j in 0...res.nfields do
            ary[i] << res.getvalue(i,j)
          end
        end
        return ary
      end


      # Queries the database and returns the results in an Array-like object
472
      def query(sql, name = nil) #:nodoc:
473 474
        log(sql, name) do
          if @async
475
            res = @connection.async_exec(sql)
476
          else
477
            res = @connection.exec(sql)
478
          end
479
          return result_as_array(res)
480
        end
481 482
      end

483
      # Executes an SQL statement, returning a PGresult object on success
484 485
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
486 487 488 489 490 491 492
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
493 494
      end

495
      # Executes an UPDATE query and returns the number of affected tuples.
496
      def update_sql(sql, name = nil)
497
        super.cmd_tuples
498 499
      end

500 501
      # Begins a transaction.
      def begin_db_transaction
502 503 504
        execute "BEGIN"
      end

505 506
      # Commits a transaction.
      def commit_db_transaction
507 508
        execute "COMMIT"
      end
509

510 511
      # Aborts a transaction.
      def rollback_db_transaction
512 513 514 515 516
        execute "ROLLBACK"
      end

      # SCHEMA STATEMENTS ========================================

517 518 519 520 521
      def recreate_database(name) #:nodoc:
        drop_database(name)
        create_database(name)
      end

522 523 524
      # Create a new PostgreSQL database.  Options include <tt>:owner</tt>, <tt>:template</tt>,
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
            " OWNER = '#{value}'"
          when :template
            " TEMPLATE = #{value}"
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
            " TABLESPACE = #{value}"
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

549
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
550 551 552 553 554 555 556
      end

      # Drops a PostgreSQL database
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
557 558 559 560 561 562 563 564 565
        if postgresql_version >= 80200
          execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
        else
          begin
            execute "DROP DATABASE #{quote_table_name(name)}"
          rescue ActiveRecord::StatementInvalid
            @logger.warn "#{name} database doesn't exist." if @logger
          end
        end
566 567 568
      end


569 570
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
571 572 573 574 575 576 577 578
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

579 580
      # Returns the list of all indexes for a table.
      def indexes(table_name, name = nil)
581 582 583 584
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
           SELECT distinct i.relname, d.indisunique, a.attname
             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
585 586 587 588 589
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
590
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
591
             AND a.attrelid = t.oid
592 593 594 595 596
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

615 616
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
617
        # Limit, precision, and scale are all handled by the superclass.
618 619
        column_definitions(table_name).collect do |name, type, default, notnull|
          PostgreSQLColumn.new(name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
620 621 622
        end
      end

623 624 625 626 627 628
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
629 630
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
631
          @schema_search_path = schema_csv
632
        end
D
Initial  
David Heinemeier Hansson 已提交
633 634
      end

635 636
      # Returns the active schema search path.
      def schema_search_path
637
        @schema_search_path ||= query('SHOW search_path')[0][0]
638
      end
639

640 641 642 643 644 645 646 647 648 649 650 651
      # Returns the current client message level.
      def client_min_messages
        query('SHOW client_min_messages')[0][0]
      end

      # Set the client message level.
      def client_min_messages=(level)
        execute("SET client_min_messages TO '#{level}'")
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
652
        default_pk, default_seq = pk_and_sequence_for(table_name)
653
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
654 655
      end

656 657
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
658 659 660 661 662 663 664
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
665 666
            quoted_sequence = quote_column_name(sequence)

667
            select_value <<-end_sql, 'Reset sequence'
668
              SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
669 670 671 672
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
673 674 675
        end
      end

676 677
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
678 679
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
680
        result = query(<<-end_sql, 'PK and serial sequence')[0]
681
          SELECT attr.attname, seq.relname
682 683 684 685 686 687 688 689 690 691 692 693 694
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{table}'::regclass
695
        end_sql
696 697 698 699 700

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
701
          result = query(<<-end_sql, 'PK and custom sequence')[0]
702 703 704 705 706 707 708
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
709 710 711 712 713 714
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{table}'::regclass
              AND cons.contype = 'p'
715
              AND def.adsrc ~* 'nextval'
716 717
          end_sql
        end
718

719
        # [primary_key, sequence]
720
        [result.first, result.last]
721 722
      rescue
        nil
723 724
      end

725
      # Renames a table.
726
      def rename_table(name, new_name)
727
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
728
      end
729

730 731
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
732
      def add_column(table_name, column_name, type, options = {})
733 734 735 736
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
737
        execute("ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}")
738

739 740
        change_column_default(table_name, column_name, default) if options_include_default?(options)
        change_column_null(table_name, column_name, false, default) if notnull
S
Scott Barron 已提交
741
      end
D
Initial  
David Heinemeier Hansson 已提交
742

743 744
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
745 746
        quoted_table_name = quote_table_name(table_name)

747
        begin
748
          execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
749
        rescue ActiveRecord::StatementInvalid
750
          # This is PostgreSQL 7.x, so we have to use a more arcane way of doing it.
751 752 753 754 755 756 757 758 759 760 761
          begin
            begin_db_transaction
            tmp_column_name = "#{column_name}_ar_tmp"
            add_column(table_name, tmp_column_name, type, options)
            execute "UPDATE #{quoted_table_name} SET #{quote_column_name(tmp_column_name)} = CAST(#{quote_column_name(column_name)} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
            remove_column(table_name, column_name)
            rename_column(table_name, tmp_column_name, column_name)
            commit_db_transaction
          rescue
            rollback_db_transaction
          end
762
        end
763

764 765
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
766
      end
767

768 769
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
770
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
771
      end
772

773 774
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
775
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
776
        end
777
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
778 779
      end

780 781
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
782
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
783
      end
784

785 786
      # Drops an index from a table.
      def remove_index(table_name, options = {})
787
        execute "DROP INDEX #{index_name(table_name, options)}"
788
      end
789

790 791
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
792 793
        return super unless type.to_s == 'integer'

794 795 796 797
        case limit
          when 1..2;      'smallint'
          when 3..4, nil; 'integer'
          when 5..8;      'bigint'
798 799
        end
      end
800
      
801
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
802 803 804
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
805
      #
806
      #   distinct("posts.id", "posts.created_at desc")
807
      def distinct(columns, order_by) #:nodoc:
808 809
        return "DISTINCT #{columns}" if order_by.blank?

810 811
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
812
        order_columns = order_by.split(',').collect { |s| s.split.first }
813
        order_columns.delete_if &:blank?
814
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
815

816 817
        # Return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly.
818 819
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
820
      end
821
      
822
      # Returns an ORDER BY clause for the passed order option.
823 824
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
P
Pratik Naik 已提交
825
      # by wrapping the +sql+ string as a sub-select and ordering in that query.
826
      def add_order_by_for_association_limiting!(sql, options) #:nodoc:
827 828 829 830 831 832 833 834
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
835

836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852
      protected
        # Returns the version of the connected PostgreSQL version.
        def postgresql_version
          @postgresql_version ||=
            if @connection.respond_to?(:server_version)
              @connection.server_version
            else
              # Mimic PGconn.server_version behavior
              begin
                query('SELECT version()')[0][0] =~ /PostgreSQL (\d+)\.(\d+)\.(\d+)/
                ($1.to_i * 10000) + ($2.to_i * 100) + $3.to_i
              rescue
                0
              end
            end
        end

D
Initial  
David Heinemeier Hansson 已提交
853
      private
854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875
        # The internal PostgreSQL identifer of the money data type.
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)
          PGconn.translate_results = false if PGconn.respond_to?(:translate_results=)

          # Ignore async_exec and async_query when using postgres-pr.
          @async = @config[:allow_concurrency] && @connection.respond_to?(:async_exec)

          # Use escape string syntax if available. We cannot do this lazily when encountering
          # the first string, because that could then break any transactions in progress.
          # See: http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
          # If PostgreSQL doesn't know the standard_conforming_strings parameter then it doesn't
          # support escape string syntax. Don't override the inherited quoted_string_prefix.
          if supports_standard_conforming_strings?
            self.class.instance_eval do
              define_method(:quoted_string_prefix) { 'E' }
            end
          end
876

877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895
          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
          money_precision = (postgresql_version >= 80300) ? 19 : 10
          PostgreSQLColumn.module_eval(<<-end_eval)
            def extract_precision(sql_type)
              if sql_type =~ /^money$/
                #{money_precision}
              else
                super
              end
            end
          end_eval

          configure_connection
        end

        # Configures the encoding, verbosity, and schema search path of the connection.
        # This is called by #connect and should not be called manually.
896 897
        def configure_connection
          if @config[:encoding]
898 899 900 901 902
            if @connection.respond_to?(:set_client_encoding)
              @connection.set_client_encoding(@config[:encoding])
            else
              execute("SET client_encoding TO '#{@config[:encoding]}'")
            end
903
          end
904 905
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
906 907
        end

908 909
        # Returns the current ID of a table's sequence.
        def last_insert_id(table, sequence_name) #:nodoc:
910
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
911 912
        end

913
        # Executes a SELECT query and returns the results, performing any data type
914
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
D
Initial  
David Heinemeier Hansson 已提交
915
        def select(sql, name = nil)
916 917 918 919 920 921 922 923 924 925 926 927 928
          fields, rows = select_raw(sql, name)
          result = []
          for row in rows
            row_hash = {}
            fields.each_with_index do |f, i|
              row_hash[f] = row[i]
            end
            result << row_hash
          end
          result
        end

        def select_raw(sql, name = nil)
929
          res = execute(sql, name)
930
          results = result_as_array(res)
931
          fields = []
M
Marcel Molina 已提交
932
          rows = []
933
          if res.ntuples > 0
M
Marcel Molina 已提交
934 935 936
            fields = res.fields
            results.each do |row|
              hashed_row = {}
937 938 939
              row.each_index do |cell_index|
                # If this is a money type column and there are any currency symbols,
                # then strip them off. Indeed it would be prettier to do this in
940
                # PostgreSQLColumn.string_to_decimal but would break form input
941
                # fields that call value_before_type_cast.
942
                if res.ftype(cell_index) == MONEY_COLUMN_TYPE_OID
943
                  # Because money output is formatted according to the locale, there are two
944
                  # cases to consider (note the decimal separators):
945 946
                  #  (1) $12,345,678.12        
                  #  (2) $12.345.678,12
947
                  case column = row[cell_index]
948
                    when /^-?\D+[\d,]+\.\d{2}$/  # (1)
949
                      row[cell_index] = column.gsub(/[^-\d\.]/, '')
950
                    when /^-?\D+[\d\.]+,\d{2}$/  # (2)
951
                      row[cell_index] = column.gsub(/[^-\d,]/, '').sub(/,/, '.')
952
                  end
M
Marcel Molina 已提交
953
                end
954

955
                hashed_row[fields[cell_index]] = column
M
Marcel Molina 已提交
956
              end
957
              rows << row
M
Marcel Molina 已提交
958 959
            end
          end
960
          res.clear
961
          return fields, rows
M
Marcel Molina 已提交
962 963
        end

964
        # Returns the list of a table's column names, data types, and default values.
965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
982
        def column_definitions(table_name) #:nodoc:
983
          query <<-end_sql
984
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
985 986 987 988 989 990
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
             WHERE a.attrelid = '#{table_name}'::regclass
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
991 992 993 994
        end
    end
  end
end