postgresql_adapter.rb 42.8 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4 5 6

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
7
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
8 9 10 11 12

module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
13
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
14
      host     = config[:host]
15
      port     = config[:port] || 5432
16 17
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
18

19
      if config.key?(:database)
D
Initial  
David Heinemeier Hansson 已提交
20 21 22 23 24
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

25
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
26 27 28 29
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
30

31 32 33 34 35 36 37
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
38

39 40 41
      # :stopdoc:
      class << self
        attr_accessor :money_precision
42 43 44 45 46 47 48 49 50 51
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
52 53 54
      end
      # :startdoc:

55
      private
56
        def extract_limit(sql_type)
57 58 59 60 61
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
62 63
        end

64 65 66 67 68
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
69

70 71
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
72 73 74 75 76
          if sql_type == 'money'
            self.class.money_precision
          else
            super
          end
77
        end
78

79 80 81 82 83 84 85
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
86
            when 'money'
87 88 89 90 91
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
92
            when 'bytea'
93 94 95 96
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
97
            when 'interval'
98 99 100 101 102 103 104 105 106 107 108
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
109
            when 'xml'
110
              :xml
111 112 113
            # tsvector type
            when 'tsvector'
              :tsvector
114 115
            # Arrays
            when /^\D+\[\]$/
116
              :string
117
            # Object identifier types
118
            when 'oid'
119
              :integer
120
            # UUID type
121
            when 'uuid'
122 123 124 125
              :string
            # Small and big integer types
            when /^(?:small|big)int$/
              :integer
126 127 128 129 130
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
131

132 133 134
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
135 136 137 138 139 140 141 142
            # This is a performance optimization for Ruby 1.9.2 in development.
            # If the value is nil, we return nil straight away without checking
            # the regular expressions. If we check each regular expression,
            # Regexp#=== will call NilClass#to_str, which will trigger
            # method_missing (defined by whiny nil in ActiveSupport) which
            # makes this method very very slow.
            when NilClass
              nil
143
            # Numeric types
144 145
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
146
            # Character types
147
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
148
              $1
149 150 151
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
152
            # Binary data types
153
            when /\A'(.*)'::bytea\z/m
154 155
              $1
            # Date/time types
156
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
157
              $1
158
            when /\A'(.*)'::interval\z/
159 160
              $1
            # Boolean type
161
            when 'true'
162
              true
163
            when 'false'
164 165
              false
            # Geometric types
166
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
167 168
              $1
            # Network address types
169
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
170 171
              $1
            # Bit string types
172
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
173 174
              $1
            # XML type
175
            when /\A'(.*)'::xml\z/m
176 177
              $1
            # Arrays
178
            when /\A'(.*)'::"?\D+"?\[\]\z/
179 180
              $1
            # Object identifier types
181
            when /\A-?\d+\z/
182 183 184
              $1
            else
              # Anything else is blank, some user type, or some function
185
              # and we can't know the value of that, so return nil.
186 187 188
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
189 190
    end

191 192
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
193 194 195
    #
    # Options:
    #
P
Pratik Naik 已提交
196 197 198 199 200
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
201
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
202
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
203
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
204
    #   <encoding></tt> call on the connection.
205
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
206
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
207
    class PostgreSQLAdapter < AbstractAdapter
208 209 210 211 212
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
213 214 215 216 217

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
218 219
      end

220
      ADAPTER_NAME = 'PostgreSQL'
221 222

      NATIVE_DATABASE_TYPES = {
223
        :primary_key => "serial primary key",
224 225 226 227 228 229 230 231 232 233
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
234
        :boolean     => { :name => "boolean" },
235 236
        :xml         => { :name => "xml" },
        :tsvector    => { :name => "tsvector" }
237 238
      }

239
      # Returns 'PostgreSQL' as adapter name for identification purposes.
240
      def adapter_name
241
        ADAPTER_NAME
242 243
      end

244 245
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
246 247 248 249
      def supports_statement_cache?
        true
      end

250 251 252 253
      def supports_index_sort_order?
        true
      end

254 255 256 257
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
258
          @cache   = Hash.new { |h,pid| h[pid] = {} }
259 260
        end

261 262 263 264
        def each(&block); cache.each(&block); end
        def key?(key);    cache.key?(key); end
        def [](key);      cache[key]; end
        def length;       cache.length; end
265 266 267 268 269 270

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
271 272
          while @max <= cache.size
            dealloc(cache.shift.last)
273 274
          end
          @counter += 1
275
          cache[sql] = key
276 277 278
        end

        def clear
279
          cache.each_value do |stmt_key|
280 281
            dealloc stmt_key
          end
282
          cache.clear
283 284
        end

285 286 287 288 289
        def delete(sql_key)
          dealloc cache[sql_key]
          cache.delete sql_key
        end

290
        private
291 292 293 294
        def cache
          @cache[$$]
        end

295
        def dealloc(key)
296 297 298 299 300 301 302
          @connection.query "DEALLOCATE #{key}" if connection_active?
        end

        def connection_active?
          @connection.status == PGconn::CONNECTION_OK
        rescue PGError
          false
303 304 305
        end
      end

306 307
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
308
        super(connection, logger)
309
        @connection_parameters, @config = connection_parameters, config
310
        @visitor = Arel::Visitors::PostgreSQL.new self
311

312 313
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
314 315
        @table_alias_length = nil

316
        connect
317 318
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
319 320 321 322 323

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

324
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
325 326
      end

X
Xavier Noria 已提交
327
      # Clears the prepared statements cache.
328 329 330 331
      def clear_cache!
        @statements.clear
      end

332 333
      # Is this connection alive and ready for queries?
      def active?
334 335
        @connection.status == PGconn::CONNECTION_OK
      rescue PGError
336
        false
337 338 339 340
      end

      # Close then reopen the connection.
      def reconnect!
341 342 343
        clear_cache!
        @connection.reset
        configure_connection
344
      end
345

346 347 348 349 350
      def reset!
        clear_cache!
        super
      end

351 352
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
353
      def disconnect!
354
        clear_cache!
355 356
        @connection.close rescue nil
      end
357

358
      def native_database_types #:nodoc:
359
        NATIVE_DATABASE_TYPES
360
      end
361

362
      # Returns true, since this connection adapter supports migrations.
363 364
      def supports_migrations?
        true
365 366
      end

367
      # Does PostgreSQL support finding primary key on non-Active Record tables?
368 369 370 371
      def supports_primary_key? #:nodoc:
        true
      end

372 373 374
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
375
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
376 377
      ensure
        self.client_min_messages = old
378 379
      end

380
      def supports_insert_with_returning?
381
        true
382 383
      end

384 385 386
      def supports_ddl_transactions?
        true
      end
387

388
      # Returns true, since this connection adapter supports savepoints.
389 390 391
      def supports_savepoints?
        true
      end
392

393
      # Returns the configured supported identifier length supported by PostgreSQL
394
      def table_alias_length
395
        @table_alias_length ||= query('SHOW max_identifier_length')[0][0].to_i
396
      end
397

398 399
      # QUOTING ==================================================

400
      # Escapes binary strings for bytea input to the database.
401 402
      def escape_bytea(value)
        @connection.escape_bytea(value) if value
403 404 405 406 407
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
408 409
      def unescape_bytea(value)
        @connection.unescape_bytea(value) if value
410 411
      end

412 413
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
414 415
        return super unless column

A
Aaron Patterson 已提交
416
        case value
417 418 419
        when Float
          return super unless value.infinite? && column.type == :datetime
          "'#{value.to_s.downcase}'"
A
Aaron Patterson 已提交
420 421
        when Numeric
          return super unless column.sql_type == 'money'
422
          # Not truly string input, so doesn't require (or allow) escape string syntax.
423
          "'#{value}'"
A
Aaron Patterson 已提交
424 425 426 427 428 429 430 431 432 433 434
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
435
          end
436 437 438 439 440
        else
          super
        end
      end

441 442 443 444 445 446
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
447
          { :value => value, :format => 1 }
448 449 450 451 452
        else
          super
        end
      end

453 454 455
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
456 457
      end

458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

477 478
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
479
        PGconn.quote_ident(name.to_s)
480 481
      end

482 483 484
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
485 486 487 488 489
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
490 491
      end

492 493
      # Set the authorized user for this session
      def session_auth=(user)
494
        clear_cache!
A
Aaron Patterson 已提交
495
        exec_query "SET SESSION AUTHORIZATION #{user}"
496 497
      end

498 499
      # REFERENTIAL INTEGRITY ====================================

500
      def supports_disable_referential_integrity? #:nodoc:
501
        true
502 503
      end

504
      def disable_referential_integrity #:nodoc:
505
        if supports_disable_referential_integrity? then
506 507
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
508 509
        yield
      ensure
510
        if supports_disable_referential_integrity? then
511 512
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
513
      end
514 515 516

      # DATABASE STATEMENTS ======================================

X
Xavier Noria 已提交
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
      def explain(arel)
        sql = "EXPLAIN #{to_sql(arel)}"
        ExplainPrettyPrinter.new.pp(exec_query(sql))
      end

      class ExplainPrettyPrinter # :nodoc:
        # Pretty prints the result of a EXPLAIN in a way that resembles the output of the
        # PostgreSQL shell:
        #
        #                                     QUERY PLAN
        #   ------------------------------------------------------------------------------
        #    Nested Loop Left Join  (cost=0.00..37.24 rows=8 width=0)
        #      Join Filter: (posts.user_id = users.id)
        #      ->  Index Scan using users_pkey on users  (cost=0.00..8.27 rows=1 width=4)
        #            Index Cond: (id = 1)
        #      ->  Seq Scan on posts  (cost=0.00..28.88 rows=8 width=4)
        #            Filter: (posts.user_id = 1)
        #   (6 rows)
        #
        def pp(result)
          header = result.columns.first
          lines  = result.rows.map(&:first)

          # We add 2 because there's one char of padding at both sides, note
          # the extra hyphens in the example above.
          width = [header, *lines].map(&:length).max + 2

          pp = []

          pp << header.center(width).rstrip
          pp << '-' * width

          pp += lines.map {|line| " #{line}"}

          nrows = result.rows.length
          rows_label = nrows == 1 ? 'row' : 'rows'
          pp << "(#{nrows} #{rows_label})"

          pp.join("\n") + "\n"
        end
      end

559 560 561 562 563 564
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

565
      # Executes an INSERT query and returns the new record's ID
566
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
567 568 569 570 571
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
572

573
        if pk
574 575 576
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
        else
          super
577
        end
578
      end
579
      alias :create :insert
580

581 582
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
583
        # check if we have any binary column and if they need escaping
584 585
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
586 587
        end

588 589 590 591 592 593
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
594 595
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
596 597 598

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
599 600
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
601 602 603 604 605 606
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
607
          monies.each do |index, _|
608 609 610 611 612 613 614 615 616 617
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
618
            end
619 620 621 622 623 624
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
625
      def query(sql, name = nil) #:nodoc:
626
        log(sql, name) do
627
          result_as_array @connection.async_exec(sql)
628
        end
629 630
      end

631
      # Executes an SQL statement, returning a PGresult object on success
632 633
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
634
        log(sql, name) do
635
          @connection.async_exec(sql)
636
        end
637 638
      end

639 640
      def substitute_at(column, index)
        Arel.sql("$#{index + 1}")
641 642
      end

A
Aaron Patterson 已提交
643
      def exec_query(sql, name = 'SQL', binds = [])
644
        log(sql, name, binds) do
645 646
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
647

648 649 650
          ret = ActiveRecord::Result.new(result.fields, result_as_array(result))
          result.clear
          return ret
651 652 653
        end
      end

654 655 656 657 658 659 660 661 662
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
663
      alias :exec_update :exec_delete
664

665 666
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
667 668 669
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
670 671 672 673 674 675 676
        end

        sql = "#{sql} RETURNING #{quote_column_name(pk)}" if pk

        [sql, binds]
      end

677
      # Executes an UPDATE query and returns the number of affected tuples.
678
      def update_sql(sql, name = nil)
679
        super.cmd_tuples
680 681
      end

682 683
      # Begins a transaction.
      def begin_db_transaction
684 685 686
        execute "BEGIN"
      end

687 688
      # Commits a transaction.
      def commit_db_transaction
689 690
        execute "COMMIT"
      end
691

692 693
      # Aborts a transaction.
      def rollback_db_transaction
694 695
        execute "ROLLBACK"
      end
696

697 698
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
699
      end
700

J
Jonathan Viney 已提交
701 702 703 704 705 706 707 708
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

709
      def release_savepoint
J
Jonathan Viney 已提交
710 711
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
712

713 714
      # SCHEMA STATEMENTS ========================================

715 716 717
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
718
        drop_database(name)
719
        create_database(name, options)
720 721
      end

722
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
723 724
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
725 726 727 728 729 730 731 732 733 734
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
735
            " OWNER = \"#{value}\""
736
          when :template
737
            " TEMPLATE = \"#{value}\""
738 739 740
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
741
            " TABLESPACE = \"#{value}\""
742 743 744 745 746 747 748
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

749
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
750 751
      end

752
      # Drops a PostgreSQL database.
753 754 755 756
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
757
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
758 759
      end

760 761
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
762
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
763
          SELECT tablename
764 765 766 767 768
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

769
      # Returns true if table exists.
770 771
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
772
      def table_exists?(name)
773
        schema, table = Utils.extract_schema_and_table(name.to_s)
774
        return false unless table
775

776 777
        binds = [[nil, table]]
        binds << [nil, schema] if schema
778 779

        exec_query(<<-SQL, 'SCHEMA', binds).rows.first[0].to_i > 0
780
            SELECT COUNT(*)
A
Aaron Patterson 已提交
781 782 783 784 785
            FROM pg_class c
            LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE c.relkind in ('v','r')
            AND c.relname = $1
            AND n.nspname = #{schema ? '$2' : 'ANY (current_schemas(false))'}
786 787 788
        SQL
      end

789 790 791 792 793 794 795 796
      # Returns true if schema exists.
      def schema_exists?(name)
        exec_query(<<-SQL, 'SCHEMA', [[nil, name]]).rows.first[0].to_i > 0
          SELECT COUNT(*)
          FROM pg_namespace
          WHERE nspname = $1
        SQL
      end
797

798
      # Returns an array of indexes for the given table.
799
      def indexes(table_name, name = nil)
800
         result = query(<<-SQL, name)
801
           SELECT distinct i.relname, d.indisunique, d.indkey, pg_get_indexdef(d.indexrelid), t.oid
802 803 804
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
805 806 807
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
808
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname = ANY (current_schemas(false)) )
809 810 811
          ORDER BY i.relname
        SQL

812

813
        result.map do |row|
814 815 816
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
817 818
          inddef = row[3]
          oid = row[4]
819

820 821
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
822 823 824 825 826
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

827
          column_names = columns.values_at(*indkey).compact
828 829 830 831 832 833

          # add info on sort order for columns (only desc order is explicitly specified, asc is the default)
          desc_order_columns = inddef.scan(/(\w+) DESC/).flatten
          orders = desc_order_columns.any? ? Hash[desc_order_columns.map {|order_column| [order_column, :desc]}] : {}
      
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names, [], orders)
834
        end.compact
835 836
      end

837 838
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
839
        # Limit, precision, and scale are all handled by the superclass.
840 841
        column_definitions(table_name).collect do |column_name, type, default, notnull|
          PostgreSQLColumn.new(column_name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
842 843 844
        end
      end

845 846 847 848 849
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

850 851 852 853 854
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

855 856 857 858 859 860 861 862
      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

863 864 865 866 867 868
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
869 870
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
871
          @schema_search_path = schema_csv
872
        end
D
Initial  
David Heinemeier Hansson 已提交
873 874
      end

875 876
      # Returns the active schema search path.
      def schema_search_path
X
Xavier Noria 已提交
877
        @schema_search_path ||= query('SHOW search_path', 'SCHEMA')[0][0]
878
      end
879

880 881
      # Returns the current client message level.
      def client_min_messages
882
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
883 884 885 886
      end

      # Set the client message level.
      def client_min_messages=(level)
887
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
888 889 890 891
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
892 893 894 895 896 897 898 899 900 901
        serial_sequence(table_name, pk || 'id').split('.').last
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
        result = exec_query(<<-eosql, 'SCHEMA', [[nil, table], [nil, column]])
          SELECT pg_get_serial_sequence($1, $2)
        eosql
        result.rows.first.first
902 903
      end

904 905
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
906 907
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
908

909 910 911
          pk ||= default_pk
          sequence ||= default_sequence
        end
912

913 914 915 916 917
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
918
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
919

920 921 922
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
923 924 925
        end
      end

926 927
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
928 929
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
930
        result = exec_query(<<-end_sql, 'SCHEMA').rows.first
931
          SELECT attr.attname, ns.nspname, seq.relname
932
          FROM pg_class seq
A
Akira Matsuda 已提交
933
          INNER JOIN pg_depend dep ON seq.oid = dep.objid
934 935
          INNER JOIN pg_attribute attr ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
936
          INNER JOIN pg_namespace ns ON seq.relnamespace = ns.oid
937 938 939
          WHERE seq.relkind  = 'S'
            AND cons.contype = 'p'
            AND dep.refobjid = '#{quote_table_name(table)}'::regclass
940
        end_sql
941

942
        # [primary_key, sequence]
943 944 945 946 947
        if result.second ==  'public' then
          sequence = result.last
        else
          sequence = result.second+'.'+result.last
        end
G
Guillermo Iguaran 已提交
948

949
        [result.first, sequence]
950 951
      rescue
        nil
952 953
      end

954 955
      # Returns just a table's primary key
      def primary_key(table)
956
        row = exec_query(<<-end_sql, 'SCHEMA', [[nil, table]]).rows.first
957
          SELECT DISTINCT(attr.attname)
958 959 960 961 962
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
            AND dep.refobjid = $1::regclass
963 964 965
        end_sql

        row && row.first
966 967
      end

968
      # Renames a table.
969 970 971
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
972
      def rename_table(name, new_name)
973
        clear_cache!
974
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
975
      end
976

977 978
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
979
      def add_column(table_name, column_name, type, options = {})
980
        clear_cache!
981 982
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
983

984
        execute add_column_sql
S
Scott Barron 已提交
985
      end
D
Initial  
David Heinemeier Hansson 已提交
986

987 988
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
989
        clear_cache!
990 991
        quoted_table_name = quote_table_name(table_name)

992
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
993

994 995
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
996
      end
997

998 999
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
1000
        clear_cache!
1001
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
1002
      end
1003

1004
      def change_column_null(table_name, column_name, null, default = nil)
1005
        clear_cache!
1006
        unless null || default.nil?
1007
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
1008
        end
1009
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
1010 1011
      end

1012 1013
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
1014
        clear_cache!
1015
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
1016
      end
1017

1018 1019 1020 1021
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

1022 1023 1024 1025
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

1026 1027
      def index_name_length
        63
1028
      end
1029

1030 1031
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
1032
        return super unless type.to_s == 'integer'
1033
        return 'integer' unless limit
1034

1035
        case limit
1036 1037 1038
          when 1, 2; 'smallint'
          when 3, 4; 'integer'
          when 5..8; 'bigint'
1039
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
1040 1041
        end
      end
1042

1043
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
1044 1045 1046
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
1047
      #
1048
      #   distinct("posts.id", "posts.created_at desc")
1049 1050
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
1051

1052 1053
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
1054
        order_columns = orders.collect { |s| s.gsub(/\s+(ASC|DESC)\s*/i, '') }
1055
        order_columns.delete_if { |c| c.blank? }
1056
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
1057

1058
        "DISTINCT #{columns}, #{order_columns * ', '}"
1059
      end
1060

1061
      module Utils
1062 1063
        extend self

1064 1065 1066 1067 1068 1069 1070 1071 1072 1073
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1074
        def extract_schema_and_table(name)
1075 1076 1077 1078 1079
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1080
      protected
1081
        # Returns the version of the connected PostgreSQL server.
1082
        def postgresql_version
1083
          @connection.server_version
1084 1085
        end

1086 1087 1088
        def translate_exception(exception, message)
          case exception.message
          when /duplicate key value violates unique constraint/
1089
            RecordNotUnique.new(message, exception)
1090
          when /violates foreign key constraint/
1091
            InvalidForeignKey.new(message, exception)
1092 1093 1094 1095 1096
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1097
      private
1098 1099
        FEATURE_NOT_SUPPORTED = "0A000" # :nodoc:

1100 1101
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1102
        end
1103

1104
        def exec_cache(sql, binds)
1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
          begin
            stmt_key = prepare_statement sql

            # Clear the queue
            @connection.get_last_result
            @connection.send_query_prepared(stmt_key, binds.map { |col, val|
              type_cast(val, col)
            })
            @connection.block
            @connection.get_last_result
          rescue PGError => e
            # Get the PG code for the failure.  Annoyingly, the code for
            # prepared statements whose return value may have changed is
            # FEATURE_NOT_SUPPORTED.  Check here for more details:
            # http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/cache/plancache.c#l573
            code = e.result.result_error_field(PGresult::PG_DIAG_SQLSTATE)
            if FEATURE_NOT_SUPPORTED == code
              @statements.delete sql_key(sql)
              retry
            else
              raise e
            end
          end
        end

        # Returns the statement identifier for the client side cache
        # of statements
        def sql_key(sql)
          "#{schema_search_path}-#{sql}"
        end

        # Prepare the statement if it hasn't been prepared, return
        # the statement key.
        def prepare_statement(sql)
          sql_key = sql_key(sql)
1140
          unless @statements.key? sql_key
1141
            nextkey = @statements.next_key
1142
            @connection.prepare nextkey, sql
1143
            @statements[sql_key] = nextkey
1144
          end
1145
          @statements[sql_key]
1146
        end
1147

P
Pratik Naik 已提交
1148
        # The internal PostgreSQL identifier of the money data type.
1149
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1150 1151
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1152 1153 1154 1155 1156 1157 1158 1159 1160

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1161 1162
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1163 1164 1165
          configure_connection
        end

1166
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1167
        # This is called by #connect and should not be called manually.
1168 1169
        def configure_connection
          if @config[:encoding]
1170
            @connection.set_client_encoding(@config[:encoding])
1171
          end
1172 1173
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1174 1175 1176 1177

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1178
          # If using Active Record's time zone support configure the connection to return
1179
          # TIMESTAMP WITH ZONE types in UTC.
1180
          if ActiveRecord::Base.default_timezone == :utc
1181
            execute("SET time zone 'UTC'", 'SCHEMA')
1182
          elsif @local_tz
1183
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1184
          end
1185 1186
        end

1187
        # Returns the current ID of a table's sequence.
1188 1189 1190
        def last_insert_id(sequence_name) #:nodoc:
          r = exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]])
          Integer(r.rows.first.first)
D
Initial  
David Heinemeier Hansson 已提交
1191 1192
        end

1193
        # Executes a SELECT query and returns the results, performing any data type
1194
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1195
        def select(sql, name = nil, binds = [])
A
Aaron Patterson 已提交
1196
          exec_query(sql, name, binds).to_a
1197 1198 1199
        end

        def select_raw(sql, name = nil)
1200
          res = execute(sql, name)
1201
          results = result_as_array(res)
1202
          fields = res.fields
1203
          res.clear
1204
          return fields, results
M
Marcel Molina 已提交
1205 1206
        end

1207
        # Returns the list of a table's column names, data types, and default values.
1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1225
        def column_definitions(table_name) #:nodoc:
1226
          exec_query(<<-end_sql, 'SCHEMA').rows
1227
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1228 1229
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1230
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1231 1232 1233
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1234
        end
1235 1236

        def extract_pg_identifier_from_name(name)
1237
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1238 1239

          if match_data
1240 1241
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
J
José Valim 已提交
1242
            [match_data[1], (rest.length > 0 ? rest : nil)]
1243 1244
          end
        end
1245

1246 1247 1248 1249 1250
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1251 1252 1253
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1254 1255 1256
    end
  end
end