postgresql_adapter.rb 38.8 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2
require 'active_record/connection_adapters/abstract_adapter'

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
begin
  require_library_or_gem 'pg'
rescue LoadError => e
  begin
    require_library_or_gem 'postgres'
    class PGresult
      alias_method :nfields, :num_fields unless self.method_defined?(:nfields)
      alias_method :ntuples, :num_tuples unless self.method_defined?(:ntuples)
      alias_method :ftype, :type unless self.method_defined?(:ftype)
      alias_method :cmd_tuples, :cmdtuples unless self.method_defined?(:cmd_tuples)
    end
  rescue LoadError
    raise e
  end
end

D
Initial  
David Heinemeier Hansson 已提交
19 20 21 22
module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
23
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
24
      host     = config[:host]
25
      port     = config[:port] || 5432
26 27
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
28 29 30 31 32 33 34

      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
36 37 38 39
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
40

41 42 43 44 45 46 47
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
48

49
      private
50
        def extract_limit(sql_type)
51 52 53 54 55
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
56 57
        end

58 59 60 61 62
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
          # Actual code is defined dynamically in PostgreSQLAdapter.connect
          # depending on the server specifics
          super
        end
  
        # Escapes binary strings for bytea input to the database.
        def self.string_to_binary(value)
          if PGconn.respond_to?(:escape_bytea)
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                PGconn.escape_bytea(value) if value
              end
            end
          else
            self.class.module_eval do
              define_method(:string_to_binary) do |value|
                if value
                  result = ''
                  value.each_byte { |c| result << sprintf('\\\\%03o', c) }
                  result
                end
              end
            end
          end
          self.class.string_to_binary(value)
        end
  
        # Unescapes bytea output from a database to the binary string it represents.
        def self.binary_to_string(value)
95
          # In each case, check if the value actually is escaped PostgreSQL bytea output
96 97 98 99
          # or an unescaped Active Record attribute that was just written.
          if PGconn.respond_to?(:unescape_bytea)
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
100
                if value =~ /\\\d{3}/
101 102 103 104 105 106 107 108 109
                  PGconn.unescape_bytea(value)
                else
                  value
                end
              end
            end
          else
            self.class.module_eval do
              define_method(:binary_to_string) do |value|
110
                if value =~ /\\\d{3}/
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
                  result = ''
                  i, max = 0, value.size
                  while i < max
                    char = value[i]
                    if char == ?\\
                      if value[i+1] == ?\\
                        char = ?\\
                        i += 1
                      else
                        char = value[i+1..i+3].oct
                        i += 3
                      end
                    end
                    result << char
                    i += 1
                  end
                  result
                else
                  value
                end
              end
            end
          end
          self.class.binary_to_string(value)
        end  
  
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
            when /^money$/
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
            when /^bytea$/
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
            when /^interval$/
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
            when /^xml$/
              :string
            # Arrays
            when /^\D+\[\]$/
              :string              
            # Object identifier types
            when /^oid$/
              :integer
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
  
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
            # Numeric types
185 186
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
187
            # Character types
188
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
189
              $1
190 191 192
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
193
            # Binary data types
194
            when /\A'(.*)'::bytea\z/m
195 196
              $1
            # Date/time types
197
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
198
              $1
199
            when /\A'(.*)'::interval\z/
200 201
              $1
            # Boolean type
202
            when 'true'
203
              true
204
            when 'false'
205 206
              false
            # Geometric types
207
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
208 209
              $1
            # Network address types
210
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
211 212
              $1
            # Bit string types
213
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
214 215
              $1
            # XML type
216
            when /\A'(.*)'::xml\z/m
217 218
              $1
            # Arrays
219
            when /\A'(.*)'::"?\D+"?\[\]\z/
220 221
              $1
            # Object identifier types
222
            when /\A-?\d+\z/
223 224 225
              $1
            else
              # Anything else is blank, some user type, or some function
226
              # and we can't know the value of that, so return nil.
227 228 229
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
230 231 232 233
    end
  end

  module ConnectionAdapters
234 235
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
236 237 238
    #
    # Options:
    #
P
Pratik Naik 已提交
239 240 241 242 243 244 245 246 247
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the <tt>:schema_order</tt> option.
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO <encoding></tt> call on the connection.
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
    # * <tt>:allow_concurrency</tt> - If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
248
    class PostgreSQLAdapter < AbstractAdapter
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
      ADAPTER_NAME = 'PostgreSQL'.freeze

      NATIVE_DATABASE_TYPES = {
        :primary_key => "serial primary key".freeze,
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
        :boolean     => { :name => "boolean" }
      }

266
      # Returns 'PostgreSQL' as adapter name for identification purposes.
267
      def adapter_name
268
        ADAPTER_NAME
269 270
      end

271 272
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
273
        super(connection, logger)
274
        @connection_parameters, @config = connection_parameters, config
275

276
        connect
277 278
      end

279 280 281
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
282
          @connection.status == PGconn::CONNECTION_OK
283
        else
284
          # We're asking the driver, not ActiveRecord, so use @connection.query instead of #query
285
          @connection.query 'SELECT 1'
286 287
          true
        end
288
      # postgres-pr raises a NoMethodError when querying if no connection is available.
289
      rescue PGError, NoMethodError
290
        false
291 292 293 294 295 296
      end

      # Close then reopen the connection.
      def reconnect!
        if @connection.respond_to?(:reset)
          @connection.reset
297
          configure_connection
298 299 300
        else
          disconnect!
          connect
301 302
        end
      end
303

304
      # Close the connection.
305 306 307
      def disconnect!
        @connection.close rescue nil
      end
308

309
      def native_database_types #:nodoc:
310
        NATIVE_DATABASE_TYPES
311
      end
312

313
      # Does PostgreSQL support migrations?
314 315
      def supports_migrations?
        true
316 317
      end

318 319 320 321 322 323 324 325 326 327 328
      # Does PostgreSQL support standard conforming strings?
      def supports_standard_conforming_strings?
        # Temporarily set the client message level above error to prevent unintentional
        # error messages in the logs when working on a PostgreSQL database server that
        # does not support standard conforming strings.
        client_min_messages_old = client_min_messages
        self.client_min_messages = 'panic'

        # postgres-pr does not raise an exception when client_min_messages is set higher
        # than error and "SHOW standard_conforming_strings" fails, but returns an empty
        # PGresult instead.
329
        has_support = query('SHOW standard_conforming_strings')[0][0] rescue false
330 331 332 333
        self.client_min_messages = client_min_messages_old
        has_support
      end

334
      def supports_insert_with_returning?
335
        postgresql_version >= 80200
336 337
      end

338 339 340 341
      def supports_ddl_transactions?
        true
      end

342 343
      # Returns the configured supported identifier length supported by PostgreSQL,
      # or report the default of 63 on PostgreSQL 7.x.
344
      def table_alias_length
345
        @table_alias_length ||= (postgresql_version >= 80000 ? query('SHOW max_identifier_length')[0][0].to_i : 63)
346
      end
347

348 349
      # QUOTING ==================================================

350 351
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
352
        if value.kind_of?(String) && column && column.type == :binary
353 354 355 356 357 358 359 360 361 362 363 364 365
          "#{quoted_string_prefix}'#{column.class.string_to_binary(value)}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^xml$/
          "xml '#{quote_string(value)}'"
        elsif value.kind_of?(Numeric) && column && column.sql_type =~ /^money$/
          # Not truly string input, so doesn't require (or allow) escape string syntax.
          "'#{value.to_s}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^bit/
          case value
            when /^[01]*$/
              "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i
              "X'#{value}'" # Hexadecimal notation
          end
366 367 368 369 370
        else
          super
        end
      end

371 372 373 374 375 376 377 378 379 380 381 382
      # Quotes strings for use in SQL input in the postgres driver for better performance.
      def quote_string(s) #:nodoc:
        if PGconn.respond_to?(:escape)
          self.class.instance_eval do
            define_method(:quote_string) do |s|
              PGconn.escape(s)
            end
          end
        else
          # There are some incorrectly compiled postgres drivers out there
          # that don't define PGconn.escape.
          self.class.instance_eval do
383
            remove_method(:quote_string)
384 385 386 387 388 389 390
          end
        end
        quote_string(s)
      end

      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
391 392 393
        %("#{name}")
      end

394 395 396
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
397 398 399 400 401
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
402 403
      end

404 405
      # REFERENTIAL INTEGRITY ====================================

406 407 408 409 410 411 412
      def supports_disable_referential_integrity?() #:nodoc:
        version = query("SHOW server_version")[0][0].split('.')
        (version[0].to_i >= 8 && version[1].to_i >= 1) ? true : false
      rescue
        return false
      end

413
      def disable_referential_integrity(&block) #:nodoc:
414 415 416
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
417 418
        yield
      ensure
419 420 421
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
422
      end
423 424 425

      # DATABASE STATEMENTS ======================================

426 427 428 429 430 431
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

432
      # Executes an INSERT query and returns the new record's ID
433
      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
434 435 436 437 438 439 440 441 442 443 444 445 446 447
        # Extract the table from the insert sql. Yuck.
        table = sql.split(" ", 4)[2].gsub('"', '')

        # Try an insert with 'returning id' if available (PG >= 8.2)
        if supports_insert_with_returning?
          pk, sequence_name = *pk_and_sequence_for(table) unless pk
          if pk
            id = select_value("#{sql} RETURNING #{quote_column_name(pk)}")
            clear_query_cache
            return id
          end
        end

        # Otherwise, insert then grab last_insert_id.
448 449 450 451 452 453 454 455 456 457 458 459 460 461
        if insert_id = super
          insert_id
        else
          # If neither pk nor sequence name is given, look them up.
          unless pk || sequence_name
            pk, sequence_name = *pk_and_sequence_for(table)
          end

          # If a pk is given, fallback to default sequence name.
          # Don't fetch last insert id for a table without a pk.
          if pk && sequence_name ||= default_sequence_name(table, pk)
            last_insert_id(table, sequence_name)
          end
        end
462 463
      end

464 465 466 467 468 469 470 471 472 473 474 475 476 477
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
        ary = []
        for i in 0...res.ntuples do
          ary << []
          for j in 0...res.nfields do
            ary[i] << res.getvalue(i,j)
          end
        end
        return ary
      end


      # Queries the database and returns the results in an Array-like object
478
      def query(sql, name = nil) #:nodoc:
479 480
        log(sql, name) do
          if @async
481
            res = @connection.async_exec(sql)
482
          else
483
            res = @connection.exec(sql)
484
          end
485
          return result_as_array(res)
486
        end
487 488
      end

489
      # Executes an SQL statement, returning a PGresult object on success
490 491
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
492 493 494 495 496 497 498
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
499 500
      end

501
      # Executes an UPDATE query and returns the number of affected tuples.
502
      def update_sql(sql, name = nil)
503
        super.cmd_tuples
504 505
      end

506 507
      # Begins a transaction.
      def begin_db_transaction
508 509 510
        execute "BEGIN"
      end

511 512
      # Commits a transaction.
      def commit_db_transaction
513 514
        execute "COMMIT"
      end
515

516 517
      # Aborts a transaction.
      def rollback_db_transaction
518 519 520
        execute "ROLLBACK"
      end

521 522 523 524 525 526 527 528 529 530 531
      # ruby-pg defines Ruby constants for transaction status,
      # ruby-postgres does not.
      PQTRANS_IDLE = defined?(PGconn::PQTRANS_IDLE) ? PGconn::PQTRANS_IDLE : 0

      # Check whether a transaction is active.
      def transaction_active?
        @connection.transaction_status != PQTRANS_IDLE
      end

      # Wrap a block in a transaction.  Returns result of block.
      def transaction(start_db_transaction = true)
532
        transaction_open = false
533 534
        begin
          if block_given?
535 536 537 538
            if start_db_transaction
              begin_db_transaction
              transaction_open = true
            end
539 540 541
            yield
          end
        rescue Exception => database_transaction_rollback
542 543 544 545
          if transaction_open && transaction_active?
            transaction_open = false
            rollback_db_transaction
          end
546 547 548
          raise unless database_transaction_rollback.is_a? ActiveRecord::Rollback
        end
      ensure
549
        if transaction_open && transaction_active?
550 551 552 553 554 555 556 557 558 559
          begin
            commit_db_transaction
          rescue Exception => database_transaction_rollback
            rollback_db_transaction
            raise
          end
        end
      end


560 561
      # SCHEMA STATEMENTS ========================================

562 563 564 565 566
      def recreate_database(name) #:nodoc:
        drop_database(name)
        create_database(name)
      end

567 568 569
      # Create a new PostgreSQL database.  Options include <tt>:owner</tt>, <tt>:template</tt>,
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
570 571 572 573 574 575 576 577 578 579
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
580
            " OWNER = \"#{value}\""
581
          when :template
582
            " TEMPLATE = \"#{value}\""
583 584 585
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
586
            " TABLESPACE = \"#{value}\""
587 588 589 590 591 592 593
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

594
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
595 596 597 598 599 600 601
      end

      # Drops a PostgreSQL database
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
602 603 604 605 606 607 608 609 610
        if postgresql_version >= 80200
          execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
        else
          begin
            execute "DROP DATABASE #{quote_table_name(name)}"
          rescue ActiveRecord::StatementInvalid
            @logger.warn "#{name} database doesn't exist." if @logger
          end
        end
611 612 613
      end


614 615
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
616 617 618 619 620 621 622 623
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

624 625
      # Returns the list of all indexes for a table.
      def indexes(table_name, name = nil)
626 627 628 629
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
           SELECT distinct i.relname, d.indisunique, a.attname
             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
630 631 632 633 634
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
635
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
636
             AND a.attrelid = t.oid
637 638 639 640 641
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

660 661
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
662
        # Limit, precision, and scale are all handled by the superclass.
663 664
        column_definitions(table_name).collect do |name, type, default, notnull|
          PostgreSQLColumn.new(name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
665 666 667
        end
      end

668 669 670 671 672 673 674 675 676 677 678 679 680
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

681 682 683 684 685 686
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
687 688
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
689
          @schema_search_path = schema_csv
690
        end
D
Initial  
David Heinemeier Hansson 已提交
691 692
      end

693 694
      # Returns the active schema search path.
      def schema_search_path
695
        @schema_search_path ||= query('SHOW search_path')[0][0]
696
      end
697

698 699 700 701 702 703 704 705 706 707 708 709
      # Returns the current client message level.
      def client_min_messages
        query('SHOW client_min_messages')[0][0]
      end

      # Set the client message level.
      def client_min_messages=(level)
        execute("SET client_min_messages TO '#{level}'")
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
710
        default_pk, default_seq = pk_and_sequence_for(table_name)
711
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
712 713
      end

714 715
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
716 717 718 719 720 721 722
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
723 724
            quoted_sequence = quote_column_name(sequence)

725
            select_value <<-end_sql, 'Reset sequence'
726
              SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
727 728 729 730
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
731 732 733
        end
      end

734 735
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
736 737
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
738
        result = query(<<-end_sql, 'PK and serial sequence')[0]
739
          SELECT attr.attname, seq.relname
740 741 742 743 744 745 746 747 748 749 750 751 752
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{table}'::regclass
753
        end_sql
754 755 756 757 758

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
759
          result = query(<<-end_sql, 'PK and custom sequence')[0]
760 761 762 763 764 765 766
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
767 768 769 770 771 772
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{table}'::regclass
              AND cons.contype = 'p'
773
              AND def.adsrc ~* 'nextval'
774 775
          end_sql
        end
776

777
        # [primary_key, sequence]
778
        [result.first, result.last]
779 780
      rescue
        nil
781 782
      end

783
      # Renames a table.
784
      def rename_table(name, new_name)
785
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
786
      end
787

788 789
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
790
      def add_column(table_name, column_name, type, options = {})
791 792 793 794
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
795
        execute("ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}")
796

797 798
        change_column_default(table_name, column_name, default) if options_include_default?(options)
        change_column_null(table_name, column_name, false, default) if notnull
S
Scott Barron 已提交
799
      end
D
Initial  
David Heinemeier Hansson 已提交
800

801 802
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
803 804
        quoted_table_name = quote_table_name(table_name)

805
        begin
806
          execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
807 808
        rescue ActiveRecord::StatementInvalid => e
          raise e if postgresql_version > 80000
809
          # This is PostgreSQL 7.x, so we have to use a more arcane way of doing it.
810 811 812 813 814 815 816 817 818 819 820
          begin
            begin_db_transaction
            tmp_column_name = "#{column_name}_ar_tmp"
            add_column(table_name, tmp_column_name, type, options)
            execute "UPDATE #{quoted_table_name} SET #{quote_column_name(tmp_column_name)} = CAST(#{quote_column_name(column_name)} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
            remove_column(table_name, column_name)
            rename_column(table_name, tmp_column_name, column_name)
            commit_db_transaction
          rescue
            rollback_db_transaction
          end
821
        end
822

823 824
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
825
      end
826

827 828
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
829
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
830
      end
831

832 833
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
834
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
835
        end
836
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
837 838
      end

839 840
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
841
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
842
      end
843

844 845
      # Drops an index from a table.
      def remove_index(table_name, options = {})
846
        execute "DROP INDEX #{index_name(table_name, options)}"
847
      end
848

849 850
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
851 852
        return super unless type.to_s == 'integer'

853 854 855 856
        case limit
          when 1..2;      'smallint'
          when 3..4, nil; 'integer'
          when 5..8;      'bigint'
857
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
858 859
        end
      end
860

861
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
862 863 864
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
865
      #
866
      #   distinct("posts.id", "posts.created_at desc")
867
      def distinct(columns, order_by) #:nodoc:
868 869
        return "DISTINCT #{columns}" if order_by.blank?

870 871
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
872
        order_columns = order_by.split(',').collect { |s| s.split.first }
873
        order_columns.delete_if &:blank?
874
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
875

876 877
        # Return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly.
878 879
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
880
      end
881
      
882
      # Returns an ORDER BY clause for the passed order option.
883 884
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
P
Pratik Naik 已提交
885
      # by wrapping the +sql+ string as a sub-select and ordering in that query.
886
      def add_order_by_for_association_limiting!(sql, options) #:nodoc:
887 888 889 890 891 892 893 894
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
895

896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912
      protected
        # Returns the version of the connected PostgreSQL version.
        def postgresql_version
          @postgresql_version ||=
            if @connection.respond_to?(:server_version)
              @connection.server_version
            else
              # Mimic PGconn.server_version behavior
              begin
                query('SELECT version()')[0][0] =~ /PostgreSQL (\d+)\.(\d+)\.(\d+)/
                ($1.to_i * 10000) + ($2.to_i * 100) + $3.to_i
              rescue
                0
              end
            end
        end

D
Initial  
David Heinemeier Hansson 已提交
913
      private
P
Pratik Naik 已提交
914
        # The internal PostgreSQL identifier of the money data type.
915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)
          PGconn.translate_results = false if PGconn.respond_to?(:translate_results=)

          # Ignore async_exec and async_query when using postgres-pr.
          @async = @config[:allow_concurrency] && @connection.respond_to?(:async_exec)

          # Use escape string syntax if available. We cannot do this lazily when encountering
          # the first string, because that could then break any transactions in progress.
          # See: http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
          # If PostgreSQL doesn't know the standard_conforming_strings parameter then it doesn't
          # support escape string syntax. Don't override the inherited quoted_string_prefix.
          if supports_standard_conforming_strings?
            self.class.instance_eval do
              define_method(:quoted_string_prefix) { 'E' }
            end
          end
936

937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955
          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
          money_precision = (postgresql_version >= 80300) ? 19 : 10
          PostgreSQLColumn.module_eval(<<-end_eval)
            def extract_precision(sql_type)
              if sql_type =~ /^money$/
                #{money_precision}
              else
                super
              end
            end
          end_eval

          configure_connection
        end

        # Configures the encoding, verbosity, and schema search path of the connection.
        # This is called by #connect and should not be called manually.
956 957
        def configure_connection
          if @config[:encoding]
958 959 960 961 962
            if @connection.respond_to?(:set_client_encoding)
              @connection.set_client_encoding(@config[:encoding])
            else
              execute("SET client_encoding TO '#{@config[:encoding]}'")
            end
963
          end
964 965
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
966 967
        end

968 969
        # Returns the current ID of a table's sequence.
        def last_insert_id(table, sequence_name) #:nodoc:
970
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
971 972
        end

973
        # Executes a SELECT query and returns the results, performing any data type
974
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
D
Initial  
David Heinemeier Hansson 已提交
975
        def select(sql, name = nil)
976 977 978 979 980 981 982 983 984 985 986 987 988
          fields, rows = select_raw(sql, name)
          result = []
          for row in rows
            row_hash = {}
            fields.each_with_index do |f, i|
              row_hash[f] = row[i]
            end
            result << row_hash
          end
          result
        end

        def select_raw(sql, name = nil)
989
          res = execute(sql, name)
990
          results = result_as_array(res)
991
          fields = []
M
Marcel Molina 已提交
992
          rows = []
993
          if res.ntuples > 0
M
Marcel Molina 已提交
994 995 996
            fields = res.fields
            results.each do |row|
              hashed_row = {}
997 998 999
              row.each_index do |cell_index|
                # If this is a money type column and there are any currency symbols,
                # then strip them off. Indeed it would be prettier to do this in
1000
                # PostgreSQLColumn.string_to_decimal but would break form input
1001
                # fields that call value_before_type_cast.
1002
                if res.ftype(cell_index) == MONEY_COLUMN_TYPE_OID
1003
                  # Because money output is formatted according to the locale, there are two
1004
                  # cases to consider (note the decimal separators):
1005 1006
                  #  (1) $12,345,678.12        
                  #  (2) $12.345.678,12
1007
                  case column = row[cell_index]
1008
                    when /^-?\D+[\d,]+\.\d{2}$/  # (1)
1009
                      row[cell_index] = column.gsub(/[^-\d\.]/, '')
1010
                    when /^-?\D+[\d\.]+,\d{2}$/  # (2)
1011
                      row[cell_index] = column.gsub(/[^-\d,]/, '').sub(/,/, '.')
1012
                  end
M
Marcel Molina 已提交
1013
                end
1014

1015
                hashed_row[fields[cell_index]] = column
M
Marcel Molina 已提交
1016
              end
1017
              rows << row
M
Marcel Molina 已提交
1018 1019
            end
          end
1020
          res.clear
1021
          return fields, rows
M
Marcel Molina 已提交
1022 1023
        end

1024
        # Returns the list of a table's column names, data types, and default values.
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1042
        def column_definitions(table_name) #:nodoc:
1043
          query <<-end_sql
1044
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1045 1046 1047 1048 1049 1050
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
             WHERE a.attrelid = '#{table_name}'::regclass
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1051 1052 1053 1054
        end
    end
  end
end