postgresql_adapter.rb 38.3 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2
require 'active_record/connection_adapters/abstract_adapter'

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
begin
  require_library_or_gem 'pg'
rescue LoadError => e
  begin
    require_library_or_gem 'postgres'
    class PGresult
      alias_method :nfields, :num_fields unless self.method_defined?(:nfields)
      alias_method :ntuples, :num_tuples unless self.method_defined?(:ntuples)
      alias_method :ftype, :type unless self.method_defined?(:ftype)
      alias_method :cmd_tuples, :cmdtuples unless self.method_defined?(:cmd_tuples)
    end
  rescue LoadError
    raise e
  end
end

D
Initial  
David Heinemeier Hansson 已提交
19 20 21 22
module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
23
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
24
      host     = config[:host]
25
      port     = config[:port] || 5432
26 27
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
28 29 30 31 32 33 34

      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
36 37 38 39
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
40

41 42 43 44 45 46 47
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
48

49
      private
50
        def extract_limit(sql_type)
51 52 53 54 55
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
56 57
        end

58 59 60 61 62
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
          # Actual code is defined dynamically in PostgreSQLAdapter.connect
          # depending on the server specifics
          super
        end
  
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
            when /^money$/
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
            when /^bytea$/
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
            when /^interval$/
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
            when /^xml$/
              :string
            # Arrays
            when /^\D+\[\]$/
              :string              
            # Object identifier types
            when /^oid$/
              :integer
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
  
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
            # Numeric types
119 120
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
121
            # Character types
122
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
123
              $1
124 125 126
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
127
            # Binary data types
128
            when /\A'(.*)'::bytea\z/m
129 130
              $1
            # Date/time types
131
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
132
              $1
133
            when /\A'(.*)'::interval\z/
134 135
              $1
            # Boolean type
136
            when 'true'
137
              true
138
            when 'false'
139 140
              false
            # Geometric types
141
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
142 143
              $1
            # Network address types
144
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
145 146
              $1
            # Bit string types
147
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
148 149
              $1
            # XML type
150
            when /\A'(.*)'::xml\z/m
151 152
              $1
            # Arrays
153
            when /\A'(.*)'::"?\D+"?\[\]\z/
154 155
              $1
            # Object identifier types
156
            when /\A-?\d+\z/
157 158 159
              $1
            else
              # Anything else is blank, some user type, or some function
160
              # and we can't know the value of that, so return nil.
161 162 163
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
164 165 166 167
    end
  end

  module ConnectionAdapters
168 169
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
170 171 172
    #
    # Options:
    #
P
Pratik Naik 已提交
173 174 175 176 177 178 179 180 181
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the <tt>:schema_order</tt> option.
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO <encoding></tt> call on the connection.
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
    # * <tt>:allow_concurrency</tt> - If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
182
    class PostgreSQLAdapter < AbstractAdapter
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
      ADAPTER_NAME = 'PostgreSQL'.freeze

      NATIVE_DATABASE_TYPES = {
        :primary_key => "serial primary key".freeze,
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
        :boolean     => { :name => "boolean" }
      }

200
      # Returns 'PostgreSQL' as adapter name for identification purposes.
201
      def adapter_name
202
        ADAPTER_NAME
203 204
      end

205 206
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
207
        super(connection, logger)
208
        @connection_parameters, @config = connection_parameters, config
209

210
        connect
211 212
      end

213 214 215
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
216
          @connection.status == PGconn::CONNECTION_OK
217
        else
218
          # We're asking the driver, not ActiveRecord, so use @connection.query instead of #query
219
          @connection.query 'SELECT 1'
220 221
          true
        end
222
      # postgres-pr raises a NoMethodError when querying if no connection is available.
223
      rescue PGError, NoMethodError
224
        false
225 226 227 228 229 230
      end

      # Close then reopen the connection.
      def reconnect!
        if @connection.respond_to?(:reset)
          @connection.reset
231
          configure_connection
232 233 234
        else
          disconnect!
          connect
235 236
        end
      end
237

238
      # Close the connection.
239 240 241
      def disconnect!
        @connection.close rescue nil
      end
242

243
      def native_database_types #:nodoc:
244
        NATIVE_DATABASE_TYPES
245
      end
246

247
      # Does PostgreSQL support migrations?
248 249
      def supports_migrations?
        true
250 251
      end

252 253 254 255 256 257 258 259 260 261 262
      # Does PostgreSQL support standard conforming strings?
      def supports_standard_conforming_strings?
        # Temporarily set the client message level above error to prevent unintentional
        # error messages in the logs when working on a PostgreSQL database server that
        # does not support standard conforming strings.
        client_min_messages_old = client_min_messages
        self.client_min_messages = 'panic'

        # postgres-pr does not raise an exception when client_min_messages is set higher
        # than error and "SHOW standard_conforming_strings" fails, but returns an empty
        # PGresult instead.
263
        has_support = query('SHOW standard_conforming_strings')[0][0] rescue false
264 265 266 267
        self.client_min_messages = client_min_messages_old
        has_support
      end

268
      def supports_insert_with_returning?
269
        postgresql_version >= 80200
270 271
      end

272 273 274 275
      def supports_ddl_transactions?
        true
      end

276 277
      # Returns the configured supported identifier length supported by PostgreSQL,
      # or report the default of 63 on PostgreSQL 7.x.
278
      def table_alias_length
279
        @table_alias_length ||= (postgresql_version >= 80000 ? query('SHOW max_identifier_length')[0][0].to_i : 63)
280
      end
281

282 283
      # QUOTING ==================================================

284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
      # Escapes binary strings for bytea input to the database.
      def escape_bytea(value)
        if PGconn.respond_to?(:escape_bytea)
          self.class.instance_eval do
            define_method(:escape_bytea) do |value|
              PGconn.escape_bytea(value) if value
            end
          end
        else
          self.class.instance_eval do
            define_method(:escape_bytea) do |value|
              if value
                result = ''
                value.each_byte { |c| result << sprintf('\\\\%03o', c) }
                result
              end
            end
          end
        end
        escape_bytea(value)
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
      def unescape_bytea(value)
        # In each case, check if the value actually is escaped PostgreSQL bytea output
        # or an unescaped Active Record attribute that was just written.
        if PGconn.respond_to?(:unescape_bytea)
          self.class.instance_eval do
            define_method(:unescape_bytea) do |value|
              if value =~ /\\\d{3}/
                PGconn.unescape_bytea(value)
              else
                value
              end
            end
          end
        else
          self.class.instance_eval do
            define_method(:unescape_bytea) do |value|
              if value =~ /\\\d{3}/
                result = ''
                i, max = 0, value.size
                while i < max
                  char = value[i]
                  if char == ?\\
                    if value[i+1] == ?\\
                      char = ?\\
                      i += 1
                    else
                      char = value[i+1..i+3].oct
                      i += 3
                    end
                  end
                  result << char
                  i += 1
                end
                result
              else
                value
              end
            end
          end
        end
        unescape_bytea(value)
      end

352 353
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
354
        if value.kind_of?(String) && column && column.type == :binary
355
          "#{quoted_string_prefix}'#{escape_bytea(value)}'"
356 357 358 359 360 361 362 363 364 365 366 367
        elsif value.kind_of?(String) && column && column.sql_type =~ /^xml$/
          "xml '#{quote_string(value)}'"
        elsif value.kind_of?(Numeric) && column && column.sql_type =~ /^money$/
          # Not truly string input, so doesn't require (or allow) escape string syntax.
          "'#{value.to_s}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^bit/
          case value
            when /^[01]*$/
              "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i
              "X'#{value}'" # Hexadecimal notation
          end
368 369 370 371 372
        else
          super
        end
      end

373 374 375 376 377 378 379 380 381 382 383 384
      # Quotes strings for use in SQL input in the postgres driver for better performance.
      def quote_string(s) #:nodoc:
        if PGconn.respond_to?(:escape)
          self.class.instance_eval do
            define_method(:quote_string) do |s|
              PGconn.escape(s)
            end
          end
        else
          # There are some incorrectly compiled postgres drivers out there
          # that don't define PGconn.escape.
          self.class.instance_eval do
385
            remove_method(:quote_string)
386 387 388 389 390 391 392
          end
        end
        quote_string(s)
      end

      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
393 394 395
        %("#{name}")
      end

396 397 398
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
399 400 401 402 403
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
404 405
      end

406 407
      # REFERENTIAL INTEGRITY ====================================

408 409 410 411 412 413 414
      def supports_disable_referential_integrity?() #:nodoc:
        version = query("SHOW server_version")[0][0].split('.')
        (version[0].to_i >= 8 && version[1].to_i >= 1) ? true : false
      rescue
        return false
      end

415
      def disable_referential_integrity(&block) #:nodoc:
416 417 418
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
419 420
        yield
      ensure
421 422 423
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
424
      end
425 426 427

      # DATABASE STATEMENTS ======================================

428 429 430 431 432 433
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

434
      # Executes an INSERT query and returns the new record's ID
435
      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
436 437 438 439 440 441 442 443 444 445 446 447 448 449
        # Extract the table from the insert sql. Yuck.
        table = sql.split(" ", 4)[2].gsub('"', '')

        # Try an insert with 'returning id' if available (PG >= 8.2)
        if supports_insert_with_returning?
          pk, sequence_name = *pk_and_sequence_for(table) unless pk
          if pk
            id = select_value("#{sql} RETURNING #{quote_column_name(pk)}")
            clear_query_cache
            return id
          end
        end

        # Otherwise, insert then grab last_insert_id.
450 451 452 453 454 455 456 457 458 459 460 461 462 463
        if insert_id = super
          insert_id
        else
          # If neither pk nor sequence name is given, look them up.
          unless pk || sequence_name
            pk, sequence_name = *pk_and_sequence_for(table)
          end

          # If a pk is given, fallback to default sequence name.
          # Don't fetch last insert id for a table without a pk.
          if pk && sequence_name ||= default_sequence_name(table, pk)
            last_insert_id(table, sequence_name)
          end
        end
464 465
      end

466 467
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
468 469 470 471
        # check if we have any binary column and if they need escaping
        unescape_col = []
        for j in 0...res.nfields do
          # unescape string passed BYTEA field (OID == 17)
472
          unescape_col << ( res.ftype(j)==17 )
473 474
        end

475 476 477 478
        ary = []
        for i in 0...res.ntuples do
          ary << []
          for j in 0...res.nfields do
479 480 481
            data = res.getvalue(i,j)
            data = unescape_bytea(data) if unescape_col[j] and data.is_a?(String)
            ary[i] << data
482 483 484 485 486 487 488
          end
        end
        return ary
      end


      # Queries the database and returns the results in an Array-like object
489
      def query(sql, name = nil) #:nodoc:
490 491
        log(sql, name) do
          if @async
492
            res = @connection.async_exec(sql)
493
          else
494
            res = @connection.exec(sql)
495
          end
496
          return result_as_array(res)
497
        end
498 499
      end

500
      # Executes an SQL statement, returning a PGresult object on success
501 502
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
503 504 505 506 507 508 509
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
510 511
      end

512
      # Executes an UPDATE query and returns the number of affected tuples.
513
      def update_sql(sql, name = nil)
514
        super.cmd_tuples
515 516
      end

517 518
      # Begins a transaction.
      def begin_db_transaction
519 520 521
        execute "BEGIN"
      end

522 523
      # Commits a transaction.
      def commit_db_transaction
524 525
        execute "COMMIT"
      end
526

527 528
      # Aborts a transaction.
      def rollback_db_transaction
529 530 531
        execute "ROLLBACK"
      end

J
Jonathan Viney 已提交
532 533 534 535 536 537 538 539 540 541 542
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

      def release_savepoint(savepoint_number)
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
543

544 545
      # SCHEMA STATEMENTS ========================================

546 547 548 549 550
      def recreate_database(name) #:nodoc:
        drop_database(name)
        create_database(name)
      end

551 552 553
      # Create a new PostgreSQL database.  Options include <tt>:owner</tt>, <tt>:template</tt>,
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
554 555 556 557 558 559 560 561 562 563
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
564
            " OWNER = \"#{value}\""
565
          when :template
566
            " TEMPLATE = \"#{value}\""
567 568 569
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
570
            " TABLESPACE = \"#{value}\""
571 572 573 574 575 576 577
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

578
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
579 580 581 582 583 584 585
      end

      # Drops a PostgreSQL database
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
586 587 588 589 590 591 592 593 594
        if postgresql_version >= 80200
          execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
        else
          begin
            execute "DROP DATABASE #{quote_table_name(name)}"
          rescue ActiveRecord::StatementInvalid
            @logger.warn "#{name} database doesn't exist." if @logger
          end
        end
595 596 597
      end


598 599
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
600 601 602 603 604 605 606 607
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

608 609
      # Returns the list of all indexes for a table.
      def indexes(table_name, name = nil)
610 611 612 613
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
           SELECT distinct i.relname, d.indisunique, a.attname
             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
614 615 616 617 618
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
619
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
620
             AND a.attrelid = t.oid
621 622 623 624 625
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

644 645
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
646
        # Limit, precision, and scale are all handled by the superclass.
647 648
        column_definitions(table_name).collect do |name, type, default, notnull|
          PostgreSQLColumn.new(name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
649 650 651
        end
      end

652 653 654 655 656 657 658 659 660 661 662 663 664
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

665 666 667 668 669 670
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
671 672
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
673
          @schema_search_path = schema_csv
674
        end
D
Initial  
David Heinemeier Hansson 已提交
675 676
      end

677 678
      # Returns the active schema search path.
      def schema_search_path
679
        @schema_search_path ||= query('SHOW search_path')[0][0]
680
      end
681

682 683 684 685 686 687 688 689 690 691 692 693
      # Returns the current client message level.
      def client_min_messages
        query('SHOW client_min_messages')[0][0]
      end

      # Set the client message level.
      def client_min_messages=(level)
        execute("SET client_min_messages TO '#{level}'")
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
694
        default_pk, default_seq = pk_and_sequence_for(table_name)
695
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
696 697
      end

698 699
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
700 701 702 703 704 705 706
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
707 708
            quoted_sequence = quote_column_name(sequence)

709
            select_value <<-end_sql, 'Reset sequence'
710
              SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
711 712 713 714
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
715 716 717
        end
      end

718 719
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
720 721
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
722
        result = query(<<-end_sql, 'PK and serial sequence')[0]
723
          SELECT attr.attname, seq.relname
724 725 726 727 728 729 730 731 732 733 734 735 736
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{table}'::regclass
737
        end_sql
738 739 740 741 742

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
743
          result = query(<<-end_sql, 'PK and custom sequence')[0]
744 745 746 747 748 749 750
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
751 752 753 754 755 756
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{table}'::regclass
              AND cons.contype = 'p'
757
              AND def.adsrc ~* 'nextval'
758 759
          end_sql
        end
760

761
        # [primary_key, sequence]
762
        [result.first, result.last]
763 764
      rescue
        nil
765 766
      end

767
      # Renames a table.
768
      def rename_table(name, new_name)
769
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
770
      end
771

772 773
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
774
      def add_column(table_name, column_name, type, options = {})
775 776 777 778
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
779
        execute("ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}")
780

781 782
        change_column_default(table_name, column_name, default) if options_include_default?(options)
        change_column_null(table_name, column_name, false, default) if notnull
S
Scott Barron 已提交
783
      end
D
Initial  
David Heinemeier Hansson 已提交
784

785 786
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
787 788
        quoted_table_name = quote_table_name(table_name)

789
        begin
790
          execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
791 792
        rescue ActiveRecord::StatementInvalid => e
          raise e if postgresql_version > 80000
793
          # This is PostgreSQL 7.x, so we have to use a more arcane way of doing it.
794 795 796 797 798 799 800 801 802 803 804
          begin
            begin_db_transaction
            tmp_column_name = "#{column_name}_ar_tmp"
            add_column(table_name, tmp_column_name, type, options)
            execute "UPDATE #{quoted_table_name} SET #{quote_column_name(tmp_column_name)} = CAST(#{quote_column_name(column_name)} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
            remove_column(table_name, column_name)
            rename_column(table_name, tmp_column_name, column_name)
            commit_db_transaction
          rescue
            rollback_db_transaction
          end
805
        end
806

807 808
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
809
      end
810

811 812
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
813
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
814
      end
815

816 817
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
818
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
819
        end
820
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
821 822
      end

823 824
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
825
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
826
      end
827

828 829
      # Drops an index from a table.
      def remove_index(table_name, options = {})
830
        execute "DROP INDEX #{index_name(table_name, options)}"
831
      end
832

833 834
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
835 836
        return super unless type.to_s == 'integer'

837 838 839 840
        case limit
          when 1..2;      'smallint'
          when 3..4, nil; 'integer'
          when 5..8;      'bigint'
841
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
842 843
        end
      end
844

845
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
846 847 848
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
849
      #
850
      #   distinct("posts.id", "posts.created_at desc")
851
      def distinct(columns, order_by) #:nodoc:
852 853
        return "DISTINCT #{columns}" if order_by.blank?

854 855
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
856
        order_columns = order_by.split(',').collect { |s| s.split.first }
857
        order_columns.delete_if &:blank?
858
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
859

860 861
        # Return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly.
862 863
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
864
      end
865
      
866
      # Returns an ORDER BY clause for the passed order option.
867 868
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
P
Pratik Naik 已提交
869
      # by wrapping the +sql+ string as a sub-select and ordering in that query.
870
      def add_order_by_for_association_limiting!(sql, options) #:nodoc:
871 872 873 874 875 876 877 878
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
879

880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
      protected
        # Returns the version of the connected PostgreSQL version.
        def postgresql_version
          @postgresql_version ||=
            if @connection.respond_to?(:server_version)
              @connection.server_version
            else
              # Mimic PGconn.server_version behavior
              begin
                query('SELECT version()')[0][0] =~ /PostgreSQL (\d+)\.(\d+)\.(\d+)/
                ($1.to_i * 10000) + ($2.to_i * 100) + $3.to_i
              rescue
                0
              end
            end
        end

D
Initial  
David Heinemeier Hansson 已提交
897
      private
P
Pratik Naik 已提交
898
        # The internal PostgreSQL identifier of the money data type.
899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)
          PGconn.translate_results = false if PGconn.respond_to?(:translate_results=)

          # Ignore async_exec and async_query when using postgres-pr.
          @async = @config[:allow_concurrency] && @connection.respond_to?(:async_exec)

          # Use escape string syntax if available. We cannot do this lazily when encountering
          # the first string, because that could then break any transactions in progress.
          # See: http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
          # If PostgreSQL doesn't know the standard_conforming_strings parameter then it doesn't
          # support escape string syntax. Don't override the inherited quoted_string_prefix.
          if supports_standard_conforming_strings?
            self.class.instance_eval do
              define_method(:quoted_string_prefix) { 'E' }
            end
          end
920

921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939
          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
          money_precision = (postgresql_version >= 80300) ? 19 : 10
          PostgreSQLColumn.module_eval(<<-end_eval)
            def extract_precision(sql_type)
              if sql_type =~ /^money$/
                #{money_precision}
              else
                super
              end
            end
          end_eval

          configure_connection
        end

        # Configures the encoding, verbosity, and schema search path of the connection.
        # This is called by #connect and should not be called manually.
940 941
        def configure_connection
          if @config[:encoding]
942 943 944 945 946
            if @connection.respond_to?(:set_client_encoding)
              @connection.set_client_encoding(@config[:encoding])
            else
              execute("SET client_encoding TO '#{@config[:encoding]}'")
            end
947
          end
948 949
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
950 951
        end

952 953
        # Returns the current ID of a table's sequence.
        def last_insert_id(table, sequence_name) #:nodoc:
954
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
955 956
        end

957
        # Executes a SELECT query and returns the results, performing any data type
958
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
D
Initial  
David Heinemeier Hansson 已提交
959
        def select(sql, name = nil)
960 961 962 963 964 965 966 967 968 969 970 971 972
          fields, rows = select_raw(sql, name)
          result = []
          for row in rows
            row_hash = {}
            fields.each_with_index do |f, i|
              row_hash[f] = row[i]
            end
            result << row_hash
          end
          result
        end

        def select_raw(sql, name = nil)
973
          res = execute(sql, name)
974
          results = result_as_array(res)
975
          fields = []
M
Marcel Molina 已提交
976
          rows = []
977
          if res.ntuples > 0
M
Marcel Molina 已提交
978 979 980
            fields = res.fields
            results.each do |row|
              hashed_row = {}
981 982 983
              row.each_index do |cell_index|
                # If this is a money type column and there are any currency symbols,
                # then strip them off. Indeed it would be prettier to do this in
984
                # PostgreSQLColumn.string_to_decimal but would break form input
985
                # fields that call value_before_type_cast.
986
                if res.ftype(cell_index) == MONEY_COLUMN_TYPE_OID
987
                  # Because money output is formatted according to the locale, there are two
988
                  # cases to consider (note the decimal separators):
989 990
                  #  (1) $12,345,678.12        
                  #  (2) $12.345.678,12
991
                  case column = row[cell_index]
992
                    when /^-?\D+[\d,]+\.\d{2}$/  # (1)
993
                      row[cell_index] = column.gsub(/[^-\d\.]/, '')
994
                    when /^-?\D+[\d\.]+,\d{2}$/  # (2)
995
                      row[cell_index] = column.gsub(/[^-\d,]/, '').sub(/,/, '.')
996
                  end
M
Marcel Molina 已提交
997
                end
998

999
                hashed_row[fields[cell_index]] = column
M
Marcel Molina 已提交
1000
              end
1001
              rows << row
M
Marcel Molina 已提交
1002 1003
            end
          end
1004
          res.clear
1005
          return fields, rows
M
Marcel Molina 已提交
1006 1007
        end

1008
        # Returns the list of a table's column names, data types, and default values.
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1026
        def column_definitions(table_name) #:nodoc:
1027
          query <<-end_sql
1028
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1029 1030 1031 1032 1033 1034
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
             WHERE a.attrelid = '#{table_name}'::regclass
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1035 1036 1037 1038
        end
    end
  end
end